#[cfg(feature = "cuda")]
use std::cell::RefCell;
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "cuda")]
use std::sync::{Arc, OnceLock};
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
const STREAMS_PER_DEVICE: usize = 8;
#[cfg(feature = "cuda")]
const MAX_DEVICES: usize = 64;
#[cfg(feature = "cuda")]
pub struct CudaEventWrapper {
inner: CudaEvent,
}
#[cfg(feature = "cuda")]
impl CudaEventWrapper {
pub fn new(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
let inner = ctx.new_event(None)?;
Ok(Self { inner })
}
pub fn new_with_timing(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
let flags = cudarc::driver::sys::CUevent_flags::CU_EVENT_DEFAULT;
let inner = ctx.new_event(Some(flags))?;
Ok(Self { inner })
}
pub fn record(&self, stream: &CudaStream) -> GpuResult<()> {
self.inner.record(stream)?;
Ok(())
}
pub fn synchronize(&self) -> GpuResult<()> {
self.inner.synchronize()?;
Ok(())
}
pub fn query(&self) -> GpuResult<bool> {
Ok(self.inner.is_complete())
}
pub fn wait_on(&self, stream: &CudaStream) -> GpuResult<()> {
stream.wait(&self.inner)?;
Ok(())
}
#[inline]
pub fn inner(&self) -> &CudaEvent {
&self.inner
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for CudaEventWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaEventWrapper").finish_non_exhaustive()
}
}
#[cfg(feature = "cuda")]
struct DeviceStreams {
streams: Vec<Arc<CudaStream>>,
counter: AtomicUsize,
}
#[cfg(feature = "cuda")]
static STREAM_POOL: OnceLock<Vec<OnceLock<DeviceStreams>>> = OnceLock::new();
#[cfg(feature = "cuda")]
fn pool_slots() -> &'static Vec<OnceLock<DeviceStreams>> {
STREAM_POOL.get_or_init(|| {
(0..MAX_DEVICES).map(|_| OnceLock::new()).collect()
})
}
pub struct StreamPool;
#[cfg(feature = "cuda")]
impl StreamPool {
pub fn get_stream(
ctx: &Arc<CudaContext>,
device_ordinal: usize,
) -> GpuResult<Arc<CudaStream>> {
if device_ordinal >= MAX_DEVICES {
return Err(GpuError::InvalidDevice {
ordinal: device_ordinal,
count: MAX_DEVICES,
});
}
let slots = pool_slots();
let device_streams = slots[device_ordinal].get_or_init(|| {
let mut streams = Vec::with_capacity(STREAMS_PER_DEVICE);
for _ in 0..STREAMS_PER_DEVICE {
match ctx.new_stream() {
Ok(s) => streams.push(s),
Err(_) => break,
}
}
if streams.is_empty() {
if let Ok(s) = ctx.default_stream().fork() {
streams.push(s);
}
}
DeviceStreams {
streams,
counter: AtomicUsize::new(0),
}
});
if device_streams.streams.is_empty() {
return Err(GpuError::Driver(cudarc::driver::DriverError(
cudarc::driver::sys::cudaError_enum::CUDA_ERROR_OUT_OF_MEMORY,
)));
}
let idx = device_streams.counter.fetch_add(1, Ordering::Relaxed)
% device_streams.streams.len();
Ok(Arc::clone(&device_streams.streams[idx]))
}
pub fn pool_size(device_ordinal: usize) -> usize {
if device_ordinal >= MAX_DEVICES {
return 0;
}
let slots = pool_slots();
slots[device_ordinal]
.get()
.map(|ds| ds.streams.len())
.unwrap_or(0)
}
}
#[cfg(feature = "cuda")]
thread_local! {
static CURRENT_STREAMS: RefCell<HashMap<usize, Arc<CudaStream>>> =
RefCell::new(HashMap::new());
}
#[cfg(feature = "cuda")]
pub fn get_current_stream(device: usize) -> Option<Arc<CudaStream>> {
CURRENT_STREAMS.with(|map| map.borrow().get(&device).cloned())
}
#[cfg(feature = "cuda")]
pub fn set_current_stream(device: usize, stream: Arc<CudaStream>) {
CURRENT_STREAMS.with(|map| {
map.borrow_mut().insert(device, stream);
});
}
#[cfg(feature = "cuda")]
pub fn clear_current_stream(device: usize) {
CURRENT_STREAMS.with(|map| {
map.borrow_mut().remove(&device);
});
}
#[cfg(feature = "cuda")]
pub fn current_stream_or_default(device: &crate::device::GpuDevice) -> Arc<CudaStream> {
get_current_stream(device.ordinal())
.unwrap_or_else(|| Arc::clone(device.default_stream()))
}
#[cfg(feature = "cuda")]
pub struct StreamGuard {
device: usize,
previous: Option<Arc<CudaStream>>,
}
#[cfg(feature = "cuda")]
impl StreamGuard {
pub fn new(device: usize, stream: Arc<CudaStream>) -> Self {
let previous = get_current_stream(device);
set_current_stream(device, stream);
Self { device, previous }
}
}
#[cfg(feature = "cuda")]
impl Drop for StreamGuard {
fn drop(&mut self) {
match self.previous.take() {
Some(prev) => set_current_stream(self.device, prev),
None => clear_current_stream(self.device),
}
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for StreamGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamGuard")
.field("device", &self.device)
.field("has_previous", &self.previous.is_some())
.finish()
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct CudaEventWrapper;
#[cfg(not(feature = "cuda"))]
impl StreamPool {
pub fn get_stream(_device_ordinal: usize) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
pub fn pool_size(_device_ordinal: usize) -> usize {
0
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct StreamGuard;
#[cfg(not(feature = "cuda"))]
pub fn get_current_stream(_device: usize) -> Option<()> {
None
}
#[cfg(not(feature = "cuda"))]
pub fn set_current_stream(_device: usize, _stream: ()) {}
#[cfg(not(feature = "cuda"))]
pub fn clear_current_stream(_device: usize) {}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
use cudarc::driver::CudaContext;
fn test_ctx() -> Option<Arc<CudaContext>> {
CudaContext::new(0).ok()
}
#[test]
fn event_record_sync() {
let Some(ctx) = test_ctx() else { return };
let stream = ctx.default_stream();
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
event.record(&stream).expect("record should succeed");
event.synchronize().expect("synchronize should succeed");
assert!(
event.query().expect("query should succeed"),
"event should be complete after synchronize"
);
}
#[test]
fn event_query_before_record() {
let Some(ctx) = test_ctx() else { return };
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
let complete = event.query().expect("query should not error");
assert!(complete, "unrecorded event should report complete");
}
#[test]
fn stream_pool_round_robin() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
let s1 = StreamPool::get_stream(&ctx, dev)
.expect("first get_stream should succeed");
let s2 = StreamPool::get_stream(&ctx, dev)
.expect("second get_stream should succeed");
let pool_size = StreamPool::pool_size(dev);
assert!(pool_size > 0, "pool should have streams");
assert!(pool_size <= STREAMS_PER_DEVICE, "pool should not exceed configured size");
let mut streams = vec![s1, s2];
for _ in 2..pool_size {
streams.push(
StreamPool::get_stream(&ctx, dev).expect("get_stream should succeed"),
);
}
let wrap = StreamPool::get_stream(&ctx, dev)
.expect("wrapped get_stream should succeed");
assert_eq!(
Arc::as_ptr(&wrap),
Arc::as_ptr(&streams[0]),
"round-robin should wrap back to the first stream"
);
}
#[test]
fn stream_pool_invalid_device() {
let Some(ctx) = test_ctx() else { return };
let result = StreamPool::get_stream(&ctx, MAX_DEVICES + 1);
assert!(result.is_err(), "should reject ordinal >= MAX_DEVICES");
}
#[test]
fn stream_guard_restores_previous() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
assert!(
get_current_stream(dev).is_none(),
"should start with no current stream"
);
let s1 = ctx.new_stream().expect("new_stream should succeed");
let s2 = ctx.new_stream().expect("new_stream should succeed");
let s1_ptr = Arc::as_ptr(&s1);
let s2_ptr = Arc::as_ptr(&s2);
set_current_stream(dev, Arc::clone(&s1));
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s1_ptr,
"current stream should be s1"
);
{
let _guard = StreamGuard::new(dev, Arc::clone(&s2));
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s2_ptr,
"current stream should be s2 inside guard"
);
}
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s1_ptr,
"current stream should be restored to s1 after guard drop"
);
clear_current_stream(dev);
assert!(
get_current_stream(dev).is_none(),
"should be cleared after explicit clear"
);
}
#[test]
fn stream_guard_clears_when_no_previous() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
clear_current_stream(dev);
assert!(get_current_stream(dev).is_none());
let s1 = ctx.new_stream().expect("new_stream should succeed");
{
let _guard = StreamGuard::new(dev, Arc::clone(&s1));
assert!(
get_current_stream(dev).is_some(),
"guard should set current stream"
);
}
assert!(
get_current_stream(dev).is_none(),
"guard with no previous should clear current stream on drop"
);
}
#[test]
fn current_stream_or_default_fallback() {
let Some(ctx) = test_ctx() else { return };
let dev_ordinal = 0;
clear_current_stream(dev_ordinal);
let device = crate::device::GpuDevice::new(dev_ordinal)
.expect("GpuDevice::new should succeed");
let default_ptr = Arc::as_ptr(device.default_stream());
let stream = current_stream_or_default(&device);
assert_eq!(
Arc::as_ptr(&stream),
default_ptr,
"should fall back to device default stream"
);
let custom = ctx.new_stream().expect("new_stream should succeed");
let custom_ptr = Arc::as_ptr(&custom);
set_current_stream(dev_ordinal, custom);
let stream = current_stream_or_default(&device);
assert_eq!(
Arc::as_ptr(&stream),
custom_ptr,
"should use thread-local current stream"
);
clear_current_stream(dev_ordinal);
}
#[test]
fn event_wait_on_stream() {
let Some(ctx) = test_ctx() else { return };
let stream1 = ctx.default_stream();
let stream2 = ctx.new_stream().expect("new_stream should succeed");
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
event.record(&stream1).expect("record should succeed");
event.wait_on(&stream2).expect("wait_on should succeed");
stream2.synchronize().expect("synchronize should succeed");
}
}