Skip to main content

dbx_core/engine/
stream.rs

1//! GPU Stream Management for asynchronous operations
2//!
3//! Provides CUDA Streams for overlapping data transfer and kernel execution.
4
5#[cfg(feature = "gpu")]
6use cudarc::driver::{CudaContext, CudaStream};
7
8use crate::error::{DbxError, DbxResult};
9
10/// Priority level for GPU streams
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum StreamPriority {
13    High,
14    Normal,
15}
16
17/// GPU Stream Context - manages a single CUDA stream
18#[cfg(feature = "gpu")]
19pub struct GpuStreamContext {
20    /// Unique stream identifier
21    pub stream_id: usize,
22    /// Stream priority
23    pub priority: StreamPriority,
24    /// CUDA stream handle
25    stream: Arc<CudaStream>,
26    /// Parent device
27    device: Arc<CudaContext>,
28}
29
30#[cfg(feature = "gpu")]
31impl GpuStreamContext {
32    pub fn new(
33        stream_id: usize,
34        priority: StreamPriority,
35        device: Arc<CudaContext>,
36    ) -> DbxResult<Self> {
37        // cudarc 0.19.2: use fork_default_stream for separate stream creation
38        // Note: cudarc doesn't expose priority-based stream creation directly
39        let stream = device
40            .fork_default_stream()
41            .map_err(|e| DbxError::Gpu(format!("Failed to create stream: {:?}", e)))?;
42
43        Ok(Self {
44            stream_id,
45            priority,
46            stream,
47            device,
48        })
49    }
50
51    /// Get the underlying CUDA stream
52    pub fn stream(&self) -> &CudaStream {
53        &self.stream
54    }
55
56    /// Synchronize this stream (wait for all operations to complete)
57    pub fn synchronize(&self) -> DbxResult<()> {
58        self.stream
59            .synchronize()
60            .map_err(|e| DbxError::Gpu(format!("Stream sync failed: {:?}", e)))
61    }
62}
63
64/// Stream Manager - manages multiple CUDA streams for async operations
65#[cfg(feature = "gpu")]
66pub struct StreamManager {
67    /// Device context
68    device: Arc<CudaContext>,
69    /// Active streams
70    streams: Vec<GpuStreamContext>,
71    /// Next stream ID
72    next_id: usize,
73}
74
75#[cfg(feature = "gpu")]
76impl StreamManager {
77    /// Create a new stream manager
78    pub fn new(device: Arc<CudaContext>) -> DbxResult<Self> {
79        Ok(Self {
80            device,
81            streams: Vec::new(),
82            next_id: 0,
83        })
84    }
85
86    /// Create a new stream with the given priority
87    pub fn create_stream(&mut self, priority: StreamPriority) -> DbxResult<usize> {
88        let stream_id = self.next_id;
89        self.next_id += 1;
90
91        let context = GpuStreamContext::new(stream_id, priority, self.device.clone())?;
92        self.streams.push(context);
93
94        Ok(stream_id)
95    }
96
97    /// Get a stream by ID
98    pub fn get_stream(&self, stream_id: usize) -> Option<&GpuStreamContext> {
99        self.streams.iter().find(|s| s.stream_id == stream_id)
100    }
101
102    /// Synchronize all streams
103    pub fn synchronize_all(&self) -> DbxResult<()> {
104        for stream in &self.streams {
105            stream.synchronize()?;
106        }
107        Ok(())
108    }
109
110    /// Get the number of active streams
111    pub fn stream_count(&self) -> usize {
112        self.streams.len()
113    }
114}
115
116// Stub implementations for non-GPU builds
117#[cfg(not(feature = "gpu"))]
118pub struct GpuStreamContext;
119
120#[cfg(not(feature = "gpu"))]
121pub struct StreamManager;
122
123#[cfg(not(feature = "gpu"))]
124impl StreamManager {
125    pub fn new(_device: ()) -> DbxResult<Self> {
126        Err(DbxError::NotImplemented(
127            "GPU acceleration is not enabled".to_string(),
128        ))
129    }
130}