use std::sync::Arc;
use std::sync::Mutex;
use cudarc::driver::CudaStream;
use super::resource::StreamId;
use crate::CudaDevice;
pub const DEFAULT_MAX_STREAMS: usize = 16;
pub const ENV_WCOJ_POOL_MB_PER_STREAM: &str = "XLOG_WCOJ_POOL_MB_PER_STREAM";
pub const DEFAULT_POOL_MB_PER_STREAM: u64 = 256;
pub fn configured_pool_mb_per_stream() -> u64 {
std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM)
.ok()
.and_then(|raw| raw.trim().parse::<u64>().ok())
.filter(|mb| *mb > 0)
.unwrap_or(DEFAULT_POOL_MB_PER_STREAM)
}
pub fn configured_pool_bytes_per_stream() -> u64 {
configured_pool_mb_per_stream().saturating_mul(1024 * 1024)
}
pub fn planned_pool_budget_bytes(arms: u64, streams: u64) -> u64 {
arms.saturating_mul(streams)
.saturating_mul(configured_pool_bytes_per_stream())
}
#[derive(Debug)]
pub enum StreamPoolError {
Capacity { max: usize },
ForkFailed(String),
}
impl std::fmt::Display for StreamPoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Capacity { max } => {
write!(f, "stream pool at capacity (max={})", max)
}
Self::ForkFailed(msg) => {
write!(f, "stream fork failed: {}", msg)
}
}
}
}
impl std::error::Error for StreamPoolError {}
pub struct StreamPool {
device: Arc<CudaDevice>,
max_streams: usize,
pool_bytes_per_stream: u64,
streams: Mutex<Vec<Arc<CudaStream>>>,
}
impl StreamPool {
pub fn new(device: Arc<CudaDevice>, max_streams: usize) -> Self {
Self {
device,
max_streams: max_streams.max(1),
pool_bytes_per_stream: configured_pool_bytes_per_stream(),
streams: Mutex::new(Vec::new()),
}
}
pub fn with_defaults(device: Arc<CudaDevice>) -> Self {
Self::new(device, DEFAULT_MAX_STREAMS)
}
pub fn acquire(&self) -> Result<StreamId, StreamPoolError> {
let mut streams = self.streams.lock().expect("stream pool poisoned");
if streams.len() >= self.max_streams {
return Err(StreamPoolError::Capacity {
max: self.max_streams,
});
}
match self.device.inner().stream().fork() {
Ok(handle) => {
streams.push(handle);
Ok(StreamId(streams.len() as u32))
}
Err(e) => Err(StreamPoolError::ForkFailed(e.to_string())),
}
}
pub fn resolve(&self, id: StreamId) -> Option<Arc<CudaStream>> {
if id == StreamId::DEFAULT {
return Some(Arc::clone(self.device.inner().stream()));
}
let streams = self.streams.lock().expect("stream pool poisoned");
let idx = id.0 as usize;
if idx == 0 || idx > streams.len() {
return None;
}
Some(Arc::clone(&streams[idx - 1]))
}
pub fn non_default_len(&self) -> usize {
self.streams.lock().expect("stream pool poisoned").len()
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn max_streams(&self) -> usize {
self.max_streams
}
pub fn pool_bytes_per_stream(&self) -> u64 {
self.pool_bytes_per_stream
}
}
#[cfg(test)]
mod tests {
use super::*;
static ENV_LOCK: Mutex<()> = Mutex::new(());
fn try_device() -> Option<Arc<CudaDevice>> {
CudaDevice::new(0).ok().map(Arc::new)
}
#[test]
fn acquire_returns_distinct_non_default_ids() {
let Some(device) = try_device() else {
return;
};
let pool = StreamPool::new(device, 4);
let a = pool.acquire().expect("first acquire");
let b = pool.acquire().expect("second acquire");
assert_ne!(a, StreamId::DEFAULT);
assert_ne!(b, StreamId::DEFAULT);
assert_ne!(a, b, "consecutive acquire calls must yield distinct ids");
assert_eq!(pool.non_default_len(), 2);
}
#[test]
fn acquire_returns_capacity_error_at_max() {
let Some(device) = try_device() else {
return;
};
let pool = StreamPool::new(device, 1);
let _first = pool.acquire().expect("first acquire under cap");
let err = pool.acquire();
assert!(
matches!(err, Err(StreamPoolError::Capacity { max: 1 })),
"expected Capacity error once max_streams hit, got {:?}",
err
);
}
#[test]
fn resolve_default_returns_device_default_stream() {
let Some(device) = try_device() else {
return;
};
let pool = StreamPool::with_defaults(device);
assert!(pool.resolve(StreamId::DEFAULT).is_some());
}
#[test]
fn resolve_acquired_returns_owned_stream() {
let Some(device) = try_device() else {
return;
};
let pool = StreamPool::new(device, 4);
let id = pool.acquire().expect("acquire");
assert_ne!(id, StreamId::DEFAULT);
assert!(pool.resolve(id).is_some());
}
#[test]
fn resolve_unknown_returns_none() {
let Some(device) = try_device() else {
return;
};
let pool = StreamPool::with_defaults(device);
assert!(pool.resolve(StreamId(99)).is_none());
}
#[test]
fn pool_mb_per_stream_env_overrides_default() {
let _guard = ENV_LOCK.lock().expect("env lock poisoned");
let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, "128");
assert_eq!(configured_pool_mb_per_stream(), 128);
match old {
Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
}
}
#[test]
fn planned_pool_budget_uses_default_4_by_4_contract() {
let _guard = ENV_LOCK.lock().expect("env lock poisoned");
let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM);
assert_eq!(configured_pool_mb_per_stream(), DEFAULT_POOL_MB_PER_STREAM);
assert_eq!(
planned_pool_budget_bytes(4, 4),
4_u64 * 4 * 256 * 1024 * 1024
);
match old {
Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
}
}
}