#[cfg(feature = "cuda")]
use std::cell::RefCell;
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "cuda")]
use std::sync::{Arc, OnceLock};
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
const STREAMS_PER_DEVICE: usize = 8;
#[cfg(feature = "cuda")]
const STREAMS_PER_PRIORITY: usize = 4;
#[cfg(feature = "cuda")]
const MAX_DEVICES: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StreamPriority {
High,
Normal,
Low,
}
impl StreamPriority {
pub fn to_cuda_priority(self, range: (i32, i32)) -> i32 {
let (least, greatest) = range;
match self {
StreamPriority::High => greatest,
StreamPriority::Normal => {
if least == greatest {
least
} else {
(least + greatest) / 2
}
}
StreamPriority::Low => least,
}
}
}
#[cfg(feature = "cuda")]
pub fn get_stream_priority_range(ctx: &Arc<CudaContext>) -> GpuResult<(i32, i32)> {
use cudarc::driver::sys;
ctx.bind_to_thread()?;
let mut least: std::ffi::c_int = 0;
let mut greatest: std::ffi::c_int = 0;
unsafe {
sys::cuCtxGetStreamPriorityRange(&mut least as *mut _, &mut greatest as *mut _)
.result()?;
}
Ok((least, greatest))
}
#[cfg(feature = "cuda")]
struct CudaStreamMirror {
_cu_stream: cudarc::driver::sys::CUstream,
_ctx: Arc<CudaContext>,
}
#[cfg(feature = "cuda")]
const _CUDA_STREAM_LAYOUT_GUARD: () = {
assert!(
std::mem::size_of::<CudaStreamMirror>() == std::mem::size_of::<CudaStream>(),
"cudarc::driver::CudaStream layout has changed; update CudaStreamMirror"
);
assert!(
std::mem::align_of::<CudaStreamMirror>() == std::mem::align_of::<CudaStream>(),
"cudarc::driver::CudaStream alignment has changed; update CudaStreamMirror"
);
};
#[cfg(feature = "cuda")]
pub fn new_stream_with_priority(
ctx: &Arc<CudaContext>,
priority: StreamPriority,
) -> GpuResult<Arc<CudaStream>> {
use cudarc::driver::sys;
use std::mem::MaybeUninit;
ctx.bind_to_thread()?;
let range = get_stream_priority_range(ctx)?;
let cuda_prio = priority.to_cuda_priority(range);
let mut raw_stream: MaybeUninit<sys::CUstream> = MaybeUninit::uninit();
let res = unsafe {
sys::cuStreamCreateWithPriority(
raw_stream.as_mut_ptr(),
sys::CUstream_flags::CU_STREAM_NON_BLOCKING as u32,
cuda_prio,
)
};
res.result()?;
let raw_stream = unsafe { raw_stream.assume_init() };
let mirror = CudaStreamMirror {
_cu_stream: raw_stream,
_ctx: ctx.clone(),
};
let cuda_stream: CudaStream = unsafe { std::mem::transmute(mirror) };
Ok(Arc::new(cuda_stream))
}
#[cfg(feature = "cuda")]
pub struct CudaEventWrapper {
inner: CudaEvent,
}
#[cfg(feature = "cuda")]
impl CudaEventWrapper {
pub fn new(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
let inner = ctx.new_event(None)?;
Ok(Self { inner })
}
pub fn new_with_timing(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
let flags = cudarc::driver::sys::CUevent_flags::CU_EVENT_DEFAULT;
let inner = ctx.new_event(Some(flags))?;
Ok(Self { inner })
}
pub fn record(&self, stream: &CudaStream) -> GpuResult<()> {
self.inner.record(stream)?;
Ok(())
}
pub fn synchronize(&self) -> GpuResult<()> {
self.inner.synchronize()?;
Ok(())
}
pub fn query(&self) -> GpuResult<bool> {
Ok(self.inner.is_complete())
}
pub fn wait_on(&self, stream: &CudaStream) -> GpuResult<()> {
stream.wait(&self.inner)?;
Ok(())
}
pub fn elapsed_ms(&self, end: &Self) -> GpuResult<f32> {
Ok(self.inner.elapsed_ms(&end.inner)?)
}
pub fn elapsed_us(&self, end: &Self) -> GpuResult<u64> {
let ms = self.elapsed_ms(end)?;
if ms <= 0.0 {
return Ok(0);
}
Ok((ms * 1000.0).round() as u64)
}
#[inline]
pub fn inner(&self) -> &CudaEvent {
&self.inner
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for CudaEventWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaEventWrapper").finish_non_exhaustive()
}
}
#[cfg(feature = "cuda")]
struct DeviceStreams {
streams: Vec<Arc<CudaStream>>,
counter: AtomicUsize,
}
#[cfg(feature = "cuda")]
static STREAM_POOL: OnceLock<Vec<OnceLock<DeviceStreams>>> = OnceLock::new();
#[cfg(feature = "cuda")]
fn pool_slots() -> &'static Vec<OnceLock<DeviceStreams>> {
STREAM_POOL.get_or_init(|| {
(0..MAX_DEVICES).map(|_| OnceLock::new()).collect()
})
}
pub struct StreamPool;
#[cfg(feature = "cuda")]
impl StreamPool {
pub fn get_stream(
ctx: &Arc<CudaContext>,
device_ordinal: usize,
) -> GpuResult<Arc<CudaStream>> {
if device_ordinal >= MAX_DEVICES {
return Err(GpuError::InvalidDevice {
ordinal: device_ordinal,
count: MAX_DEVICES,
});
}
let slots = pool_slots();
let device_streams = slots[device_ordinal].get_or_init(|| {
let mut streams = Vec::with_capacity(STREAMS_PER_DEVICE);
for _ in 0..STREAMS_PER_DEVICE {
match ctx.new_stream() {
Ok(s) => streams.push(s),
Err(_) => break,
}
}
if streams.is_empty() {
if let Ok(s) = ctx.default_stream().fork() {
streams.push(s);
}
}
DeviceStreams {
streams,
counter: AtomicUsize::new(0),
}
});
if device_streams.streams.is_empty() {
return Err(GpuError::Driver(cudarc::driver::DriverError(
cudarc::driver::sys::cudaError_enum::CUDA_ERROR_OUT_OF_MEMORY,
)));
}
let idx = device_streams.counter.fetch_add(1, Ordering::Relaxed)
% device_streams.streams.len();
Ok(Arc::clone(&device_streams.streams[idx]))
}
pub fn pool_size(device_ordinal: usize) -> usize {
if device_ordinal >= MAX_DEVICES {
return 0;
}
let slots = pool_slots();
slots[device_ordinal]
.get()
.map(|ds| ds.streams.len())
.unwrap_or(0)
}
pub fn get_priority_stream(
ctx: &Arc<CudaContext>,
device_ordinal: usize,
priority: StreamPriority,
) -> GpuResult<Arc<CudaStream>> {
if device_ordinal >= MAX_DEVICES {
return Err(GpuError::InvalidDevice {
ordinal: device_ordinal,
count: MAX_DEVICES,
});
}
let slots = priority_pool_slots();
let key = (device_ordinal, priority);
let priority_streams = slots
.lock()
.unwrap_or_else(|p| p.into_inner())
.entry(key)
.or_default()
.clone();
if priority_streams.is_empty() {
let mut new_streams = Vec::with_capacity(STREAMS_PER_PRIORITY);
for _ in 0..STREAMS_PER_PRIORITY {
match new_stream_with_priority(ctx, priority) {
Ok(s) => new_streams.push(s),
Err(_) => break,
}
}
if new_streams.is_empty() {
return Err(GpuError::Driver(cudarc::driver::DriverError(
cudarc::driver::sys::cudaError_enum::CUDA_ERROR_OUT_OF_MEMORY,
)));
}
let mut guard = slots.lock().unwrap_or_else(|p| p.into_inner());
let entry = guard.entry(key).or_default();
if entry.is_empty() {
*entry = new_streams.clone();
}
let snapshot = entry.clone();
drop(guard);
let idx = priority_pool_counter(key)
.fetch_add(1, Ordering::Relaxed)
% snapshot.len();
return Ok(Arc::clone(&snapshot[idx]));
}
let idx = priority_pool_counter(key).fetch_add(1, Ordering::Relaxed)
% priority_streams.len();
Ok(Arc::clone(&priority_streams[idx]))
}
pub fn priority_pool_size(device_ordinal: usize, priority: StreamPriority) -> usize {
if device_ordinal >= MAX_DEVICES {
return 0;
}
let slots = priority_pool_slots();
slots
.lock()
.unwrap_or_else(|p| p.into_inner())
.get(&(device_ordinal, priority))
.map(|v| v.len())
.unwrap_or(0)
}
}
#[cfg(feature = "cuda")]
type PriorityPoolMap = std::sync::Mutex<HashMap<(usize, StreamPriority), Vec<Arc<CudaStream>>>>;
#[cfg(feature = "cuda")]
type PriorityCounterMap = std::sync::Mutex<HashMap<(usize, StreamPriority), Arc<AtomicUsize>>>;
#[cfg(feature = "cuda")]
static PRIORITY_POOL: OnceLock<PriorityPoolMap> = OnceLock::new();
#[cfg(feature = "cuda")]
fn priority_pool_slots() -> &'static PriorityPoolMap {
PRIORITY_POOL.get_or_init(|| std::sync::Mutex::new(HashMap::new()))
}
#[cfg(feature = "cuda")]
static PRIORITY_POOL_COUNTERS: OnceLock<PriorityCounterMap> = OnceLock::new();
#[cfg(feature = "cuda")]
fn priority_pool_counter(key: (usize, StreamPriority)) -> Arc<AtomicUsize> {
let map = PRIORITY_POOL_COUNTERS
.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
let mut guard = map.lock().unwrap_or_else(|p| p.into_inner());
Arc::clone(
guard
.entry(key)
.or_insert_with(|| Arc::new(AtomicUsize::new(0))),
)
}
#[cfg(feature = "cuda")]
thread_local! {
static CURRENT_STREAMS: RefCell<HashMap<usize, Arc<CudaStream>>> =
RefCell::new(HashMap::new());
}
#[cfg(feature = "cuda")]
pub fn get_current_stream(device: usize) -> Option<Arc<CudaStream>> {
CURRENT_STREAMS.with(|map| map.borrow().get(&device).cloned())
}
#[cfg(feature = "cuda")]
pub fn set_current_stream(device: usize, stream: Arc<CudaStream>) {
CURRENT_STREAMS.with(|map| {
map.borrow_mut().insert(device, stream);
});
}
#[cfg(feature = "cuda")]
pub fn clear_current_stream(device: usize) {
CURRENT_STREAMS.with(|map| {
map.borrow_mut().remove(&device);
});
}
#[cfg(feature = "cuda")]
pub fn current_stream_or_default(device: &crate::device::GpuDevice) -> Arc<CudaStream> {
get_current_stream(device.ordinal())
.unwrap_or_else(|| Arc::clone(device.default_stream()))
}
#[cfg(feature = "cuda")]
pub struct StreamGuard {
device: usize,
previous: Option<Arc<CudaStream>>,
}
#[cfg(feature = "cuda")]
impl StreamGuard {
pub fn new(device: usize, stream: Arc<CudaStream>) -> Self {
let previous = get_current_stream(device);
set_current_stream(device, stream);
Self { device, previous }
}
}
#[cfg(feature = "cuda")]
impl Drop for StreamGuard {
fn drop(&mut self) {
match self.previous.take() {
Some(prev) => set_current_stream(self.device, prev),
None => clear_current_stream(self.device),
}
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for StreamGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamGuard")
.field("device", &self.device)
.field("has_previous", &self.previous.is_some())
.finish()
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct CudaEventWrapper;
#[cfg(not(feature = "cuda"))]
impl StreamPool {
pub fn get_stream(_device_ordinal: usize) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
pub fn pool_size(_device_ordinal: usize) -> usize {
0
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct StreamGuard;
#[cfg(not(feature = "cuda"))]
pub fn get_current_stream(_device: usize) -> Option<()> {
None
}
#[cfg(not(feature = "cuda"))]
pub fn set_current_stream(_device: usize, _stream: ()) {
}
#[cfg(not(feature = "cuda"))]
pub fn clear_current_stream(_device: usize) {
}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
use cudarc::driver::CudaContext;
fn test_ctx() -> Option<Arc<CudaContext>> {
CudaContext::new(0).ok()
}
#[test]
fn event_record_sync() {
let Some(ctx) = test_ctx() else { return };
let stream = ctx.default_stream();
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
event.record(&stream).expect("record should succeed");
event.synchronize().expect("synchronize should succeed");
assert!(
event.query().expect("query should succeed"),
"event should be complete after synchronize"
);
}
#[test]
fn event_query_before_record() {
let Some(ctx) = test_ctx() else { return };
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
let complete = event.query().expect("query should not error");
assert!(complete, "unrecorded event should report complete");
}
#[test]
fn stream_pool_round_robin() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
let s1 = StreamPool::get_stream(&ctx, dev)
.expect("first get_stream should succeed");
let s2 = StreamPool::get_stream(&ctx, dev)
.expect("second get_stream should succeed");
let pool_size = StreamPool::pool_size(dev);
assert!(pool_size > 0, "pool should have streams");
assert!(pool_size <= STREAMS_PER_DEVICE, "pool should not exceed configured size");
let mut streams = vec![s1, s2];
for _ in 2..pool_size {
streams.push(
StreamPool::get_stream(&ctx, dev).expect("get_stream should succeed"),
);
}
let wrap = StreamPool::get_stream(&ctx, dev)
.expect("wrapped get_stream should succeed");
assert_eq!(
Arc::as_ptr(&wrap),
Arc::as_ptr(&streams[0]),
"round-robin should wrap back to the first stream"
);
}
#[test]
fn stream_pool_invalid_device() {
let Some(ctx) = test_ctx() else { return };
let result = StreamPool::get_stream(&ctx, MAX_DEVICES + 1);
assert!(result.is_err(), "should reject ordinal >= MAX_DEVICES");
}
#[test]
fn stream_guard_restores_previous() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
assert!(
get_current_stream(dev).is_none(),
"should start with no current stream"
);
let s1 = ctx.new_stream().expect("new_stream should succeed");
let s2 = ctx.new_stream().expect("new_stream should succeed");
let s1_ptr = Arc::as_ptr(&s1);
let s2_ptr = Arc::as_ptr(&s2);
set_current_stream(dev, Arc::clone(&s1));
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s1_ptr,
"current stream should be s1"
);
{
let _guard = StreamGuard::new(dev, Arc::clone(&s2));
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s2_ptr,
"current stream should be s2 inside guard"
);
}
assert_eq!(
Arc::as_ptr(&get_current_stream(dev).unwrap()),
s1_ptr,
"current stream should be restored to s1 after guard drop"
);
clear_current_stream(dev);
assert!(
get_current_stream(dev).is_none(),
"should be cleared after explicit clear"
);
}
#[test]
fn stream_guard_clears_when_no_previous() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
clear_current_stream(dev);
assert!(get_current_stream(dev).is_none());
let s1 = ctx.new_stream().expect("new_stream should succeed");
{
let _guard = StreamGuard::new(dev, Arc::clone(&s1));
assert!(
get_current_stream(dev).is_some(),
"guard should set current stream"
);
}
assert!(
get_current_stream(dev).is_none(),
"guard with no previous should clear current stream on drop"
);
}
#[test]
fn current_stream_or_default_fallback() {
let Some(ctx) = test_ctx() else { return };
let dev_ordinal = 0;
clear_current_stream(dev_ordinal);
let device = crate::device::GpuDevice::new(dev_ordinal)
.expect("GpuDevice::new should succeed");
let default_ptr = Arc::as_ptr(device.default_stream());
let stream = current_stream_or_default(&device);
assert_eq!(
Arc::as_ptr(&stream),
default_ptr,
"should fall back to device default stream"
);
let custom = ctx.new_stream().expect("new_stream should succeed");
let custom_ptr = Arc::as_ptr(&custom);
set_current_stream(dev_ordinal, custom);
let stream = current_stream_or_default(&device);
assert_eq!(
Arc::as_ptr(&stream),
custom_ptr,
"should use thread-local current stream"
);
clear_current_stream(dev_ordinal);
}
#[test]
fn event_wait_on_stream() {
let Some(ctx) = test_ctx() else { return };
let stream1 = ctx.default_stream();
let stream2 = ctx.new_stream().expect("new_stream should succeed");
let event = CudaEventWrapper::new(&ctx)
.expect("event creation should succeed");
event.record(&stream1).expect("record should succeed");
event.wait_on(&stream2).expect("wait_on should succeed");
stream2.synchronize().expect("synchronize should succeed");
}
#[test]
fn priority_range_returns_sane_values() {
let Some(ctx) = test_ctx() else { return };
let (least, greatest) = get_stream_priority_range(&ctx)
.expect("priority range should query successfully");
assert!(
greatest <= least,
"priority range invariant violated: greatest={greatest} > least={least}"
);
}
#[test]
fn stream_priority_resolves_within_range() {
let range = (5, -5);
assert_eq!(StreamPriority::High.to_cuda_priority(range), -5);
assert_eq!(StreamPriority::Low.to_cuda_priority(range), 5);
let normal = StreamPriority::Normal.to_cuda_priority(range);
assert!((-5..=5).contains(&normal));
}
#[test]
fn stream_priority_collapsed_range_resolves_to_zero() {
let range = (0, 0);
assert_eq!(StreamPriority::High.to_cuda_priority(range), 0);
assert_eq!(StreamPriority::Normal.to_cuda_priority(range), 0);
assert_eq!(StreamPriority::Low.to_cuda_priority(range), 0);
}
#[test]
fn new_stream_with_priority_succeeds_for_all_three_levels() {
let Some(ctx) = test_ctx() else { return };
let high = new_stream_with_priority(&ctx, StreamPriority::High)
.expect("high-priority stream creation should succeed");
let normal = new_stream_with_priority(&ctx, StreamPriority::Normal)
.expect("normal-priority stream creation should succeed");
let low = new_stream_with_priority(&ctx, StreamPriority::Low)
.expect("low-priority stream creation should succeed");
assert_ne!(Arc::as_ptr(&high), Arc::as_ptr(&normal));
assert_ne!(Arc::as_ptr(&normal), Arc::as_ptr(&low));
assert_ne!(Arc::as_ptr(&high), Arc::as_ptr(&low));
}
#[test]
fn new_stream_with_priority_actually_runs_kernels() {
let Some(ctx) = test_ctx() else { return };
let stream = new_stream_with_priority(&ctx, StreamPriority::High)
.expect("high-priority stream creation should succeed");
stream.synchronize().expect("synchronize should succeed");
}
#[test]
fn priority_pool_caches_streams_per_device_and_priority() {
let Some(ctx) = test_ctx() else { return };
let dev = 0;
let _h1 = StreamPool::get_priority_stream(&ctx, dev, StreamPriority::High)
.expect("get_priority_stream High should succeed");
let _l1 = StreamPool::get_priority_stream(&ctx, dev, StreamPriority::Low)
.expect("get_priority_stream Low should succeed");
let high_size = StreamPool::priority_pool_size(dev, StreamPriority::High);
let low_size = StreamPool::priority_pool_size(dev, StreamPriority::Low);
assert!(high_size > 0, "high-priority pool should have streams");
assert!(low_size > 0, "low-priority pool should have streams");
assert!(high_size <= STREAMS_PER_PRIORITY);
assert!(low_size <= STREAMS_PER_PRIORITY);
}
#[test]
fn priority_pool_invalid_device() {
let Some(ctx) = test_ctx() else { return };
let result = StreamPool::get_priority_stream(&ctx, 9999, StreamPriority::High);
assert!(matches!(result, Err(GpuError::InvalidDevice { .. })));
}
}