use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaStream};
use super::{ActorHints, StreamAllocator};
pub struct PooledAllocator {
pool: Vec<Arc<CudaStream>>,
cursor: parking_lot::Mutex<usize>,
}
impl PooledAllocator {
pub fn new(streams: Vec<Arc<CudaStream>>) -> Self {
assert!(
!streams.is_empty(),
"PooledAllocator requires at least one stream"
);
Self {
pool: streams,
cursor: parking_lot::Mutex::new(0),
}
}
pub fn with_size(ctx: &Arc<CudaContext>, count: usize) -> Self {
assert!(count > 0, "PooledAllocator requires count >= 1");
let mut streams = Vec::with_capacity(count);
for _ in 0..count {
let s = ctx
.new_stream()
.unwrap_or_else(|e| panic!("ContextPoisoned: new_stream: {e}"));
streams.push(s);
}
Self::new(streams)
}
pub fn size(&self) -> usize {
self.pool.len()
}
}
impl StreamAllocator for PooledAllocator {
fn acquire(&self, _hints: ActorHints) -> Arc<CudaStream> {
let mut cur = self.cursor.lock();
let idx = *cur % self.pool.len();
*cur = cur.wrapping_add(1);
self.pool[idx].clone()
}
}