use std::sync::{Arc, Weak};
use cudarc::driver::{CudaContext, CudaStream};
use super::{ActorHints, StreamAllocator};
#[derive(Clone)]
pub struct PerActorAllocator {
inner: Arc<PerActorInner>,
}
enum PerActorInner {
Shared { stream: Arc<CudaStream> },
Fresh {
ctx: Arc<CudaContext>,
minted: parking_lot::Mutex<Vec<Weak<CudaStream>>>,
},
}
impl PerActorAllocator {
pub fn new(stream: Arc<CudaStream>) -> Self {
Self {
inner: Arc::new(PerActorInner::Shared { stream }),
}
}
pub fn with_context(ctx: Arc<CudaContext>) -> Self {
Self {
inner: Arc::new(PerActorInner::Fresh {
ctx,
minted: parking_lot::Mutex::new(Vec::new()),
}),
}
}
pub fn live_streams(&self) -> usize {
match self.inner.as_ref() {
PerActorInner::Shared { .. } => 1,
PerActorInner::Fresh { minted, .. } => {
let mut g = minted.lock();
g.retain(|w| w.strong_count() > 0);
g.len()
}
}
}
}
impl StreamAllocator for PerActorAllocator {
fn acquire(&self, _hints: ActorHints) -> Arc<CudaStream> {
match self.inner.as_ref() {
PerActorInner::Shared { stream } => stream.clone(),
PerActorInner::Fresh { ctx, minted } => {
let s = ctx
.new_stream()
.unwrap_or_else(|e| panic!("ContextPoisoned: new_stream: {e}"));
minted.lock().push(Arc::downgrade(&s));
s
}
}
}
}