use crate::error::DeviceError;
use cuda_core::{Device, Stream};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
pub trait SchedulingPolicy: Send + Sync {
fn next_stream(&self) -> Result<Arc<Stream>, DeviceError>;
}
pub struct StreamPoolRoundRobin {
next_stream_idx: AtomicUsize,
stream_pool: Vec<Arc<Stream>>,
}
impl StreamPoolRoundRobin {
pub fn new(device: &Arc<Device>, num_streams: usize) -> Result<Self, DeviceError> {
let mut stream_pool = Vec::with_capacity(num_streams);
for _ in 0..num_streams {
stream_pool.push(device.new_stream()?);
}
Ok(Self {
stream_pool,
next_stream_idx: AtomicUsize::new(0),
})
}
}
impl SchedulingPolicy for StreamPoolRoundRobin {
fn next_stream(&self) -> Result<Arc<Stream>, DeviceError> {
let idx = self
.next_stream_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.stream_pool.len();
Ok(Arc::clone(&self.stream_pool[idx]))
}
}
pub struct SingleStream {
stream: Arc<Stream>,
}
impl SingleStream {
pub fn new(device: &Arc<Device>) -> Result<Self, DeviceError> {
Ok(Self {
stream: device.new_stream()?,
})
}
pub fn stream(&self) -> &Arc<Stream> {
&self.stream
}
}
impl SchedulingPolicy for SingleStream {
fn next_stream(&self) -> Result<Arc<Stream>, DeviceError> {
Ok(Arc::clone(&self.stream))
}
}