use crate::error::DeviceError;
use cuda_core::{CudaContext, CudaStream};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
pub trait SchedulingPolicy: Send + Sync {
fn next_stream(&self) -> Result<Arc<CudaStream>, DeviceError>;
}
pub struct StreamPoolRoundRobin {
next_stream_idx: AtomicUsize,
stream_pool: Vec<Arc<CudaStream>>,
}
impl StreamPoolRoundRobin {
pub fn new(ctx: &Arc<CudaContext>, num_streams: usize) -> Result<Self, DeviceError> {
let mut stream_pool = Vec::with_capacity(num_streams);
for _ in 0..num_streams {
stream_pool.push(ctx.new_stream()?);
}
Ok(Self {
stream_pool,
next_stream_idx: AtomicUsize::new(0),
})
}
}
impl SchedulingPolicy for StreamPoolRoundRobin {
fn next_stream(&self) -> Result<Arc<CudaStream>, DeviceError> {
let idx = self
.next_stream_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.stream_pool.len();
Ok(Arc::clone(&self.stream_pool[idx]))
}
}
pub struct SingleStream {
stream: Arc<CudaStream>,
}
impl SingleStream {
pub fn new(ctx: &Arc<CudaContext>) -> Result<Self, DeviceError> {
Ok(Self {
stream: ctx.new_stream()?,
})
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl SchedulingPolicy for SingleStream {
fn next_stream(&self) -> Result<Arc<CudaStream>, DeviceError> {
Ok(Arc::clone(&self.stream))
}
}