#[cfg(feature = "gpu")]
use cudarc::driver::{CudaContext, CudaStream};
use crate::error::{DbxError, DbxResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamPriority {
High,
Normal,
}
#[cfg(feature = "gpu")]
pub struct GpuStreamContext {
pub stream_id: usize,
pub priority: StreamPriority,
stream: Arc<CudaStream>,
device: Arc<CudaContext>,
}
#[cfg(feature = "gpu")]
impl GpuStreamContext {
pub fn new(
stream_id: usize,
priority: StreamPriority,
device: Arc<CudaContext>,
) -> DbxResult<Self> {
let stream = device
.fork_default_stream()
.map_err(|e| DbxError::Gpu(format!("Failed to create stream: {:?}", e)))?;
Ok(Self {
stream_id,
priority,
stream,
device,
})
}
pub fn stream(&self) -> &CudaStream {
&self.stream
}
pub fn synchronize(&self) -> DbxResult<()> {
self.stream
.synchronize()
.map_err(|e| DbxError::Gpu(format!("Stream sync failed: {:?}", e)))
}
}
#[cfg(feature = "gpu")]
pub struct StreamManager {
device: Arc<CudaContext>,
streams: Vec<GpuStreamContext>,
next_id: usize,
}
#[cfg(feature = "gpu")]
impl StreamManager {
pub fn new(device: Arc<CudaContext>) -> DbxResult<Self> {
Ok(Self {
device,
streams: Vec::new(),
next_id: 0,
})
}
pub fn create_stream(&mut self, priority: StreamPriority) -> DbxResult<usize> {
let stream_id = self.next_id;
self.next_id += 1;
let context = GpuStreamContext::new(stream_id, priority, self.device.clone())?;
self.streams.push(context);
Ok(stream_id)
}
pub fn get_stream(&self, stream_id: usize) -> Option<&GpuStreamContext> {
self.streams.iter().find(|s| s.stream_id == stream_id)
}
pub fn synchronize_all(&self) -> DbxResult<()> {
for stream in &self.streams {
stream.synchronize()?;
}
Ok(())
}
pub fn stream_count(&self) -> usize {
self.streams.len()
}
}
#[cfg(not(feature = "gpu"))]
pub struct GpuStreamContext;
#[cfg(not(feature = "gpu"))]
pub struct StreamManager;
#[cfg(not(feature = "gpu"))]
impl StreamManager {
pub fn new(_device: ()) -> DbxResult<Self> {
Err(DbxError::NotImplemented(
"GPU acceleration is not enabled".to_string(),
))
}
}