ghostflow_cuda/
stream.rs

1//! CUDA streams for async execution - Real Implementation
2
3use crate::error::CudaResult;
4use crate::ffi;
5
6/// CUDA stream for asynchronous operations
7#[derive(Debug)]
8pub struct CudaStream {
9    handle: ffi::cudaStream_t,
10    is_default: bool,
11}
12
13impl CudaStream {
14    /// Create new CUDA stream
15    pub fn new() -> CudaResult<Self> {
16        #[cfg(feature = "cuda")]
17        {
18            let mut handle: ffi::cudaStream_t = std::ptr::null_mut();
19            
20            unsafe {
21                let err = ffi::cudaStreamCreate(&mut handle);
22                if err != 0 {
23                    return Err(CudaError::DriverError(err));
24                }
25            }
26            
27            Ok(CudaStream {
28                handle,
29                is_default: false,
30            })
31        }
32        
33        #[cfg(not(feature = "cuda"))]
34        {
35            Ok(CudaStream {
36                handle: std::ptr::null_mut(),
37                is_default: false,
38            })
39        }
40    }
41
42    /// Get default stream (stream 0)
43    pub fn default_stream() -> Self {
44        CudaStream {
45            handle: std::ptr::null_mut(),
46            is_default: true,
47        }
48    }
49
50    /// Get raw handle
51    pub fn handle(&self) -> ffi::cudaStream_t {
52        self.handle
53    }
54
55    /// Synchronize stream - wait for all operations to complete
56    pub fn synchronize(&self) -> CudaResult<()> {
57        #[cfg(feature = "cuda")]
58        unsafe {
59            let err = ffi::cudaStreamSynchronize(self.handle);
60            if err != 0 {
61                return Err(CudaError::SyncError);
62            }
63        }
64        Ok(())
65    }
66
67    /// Check if stream is complete (non-blocking)
68    pub fn is_complete(&self) -> CudaResult<bool> {
69        #[cfg(feature = "cuda")]
70        unsafe {
71            let err = ffi::cudaStreamQuery(self.handle);
72            if err == 0 {
73                return Ok(true);
74            } else if err == 600 { // cudaErrorNotReady
75                return Ok(false);
76            } else {
77                return Err(CudaError::DriverError(err));
78            }
79        }
80        
81        #[cfg(not(feature = "cuda"))]
82        Ok(true)
83    }
84
85    /// Wait for event on this stream
86    pub fn wait_event(&self, _event: &CudaEvent) -> CudaResult<()> {
87        #[cfg(feature = "cuda")]
88        unsafe {
89            // cudaStreamWaitEvent would be called here
90            // For now, just synchronize
91            self.synchronize()?;
92        }
93        Ok(())
94    }
95}
96
97impl Default for CudaStream {
98    fn default() -> Self {
99        Self::default_stream()
100    }
101}
102
103impl Drop for CudaStream {
104    fn drop(&mut self) {
105        if !self.is_default && !self.handle.is_null() {
106            #[cfg(feature = "cuda")]
107            unsafe {
108                let _ = ffi::cudaStreamDestroy(self.handle);
109            }
110        }
111    }
112}
113
114unsafe impl Send for CudaStream {}
115unsafe impl Sync for CudaStream {}
116
117/// CUDA event for synchronization and timing
118#[derive(Debug)]
119pub struct CudaEvent {
120    handle: ffi::cudaEvent_t,
121}
122
123impl CudaEvent {
124    /// Create new event
125    pub fn new() -> CudaResult<Self> {
126        #[cfg(feature = "cuda")]
127        {
128            let mut handle: ffi::cudaEvent_t = std::ptr::null_mut();
129            
130            unsafe {
131                let err = ffi::cudaEventCreate(&mut handle);
132                if err != 0 {
133                    return Err(CudaError::DriverError(err));
134                }
135            }
136            
137            Ok(CudaEvent { handle })
138        }
139        
140        #[cfg(not(feature = "cuda"))]
141        {
142            Ok(CudaEvent {
143                handle: std::ptr::null_mut(),
144            })
145        }
146    }
147
148    /// Record event on stream
149    pub fn record(&self, _stream: &CudaStream) -> CudaResult<()> {
150        #[cfg(feature = "cuda")]
151        unsafe {
152            let err = ffi::cudaEventRecord(self.handle, _stream.handle());
153            if err != 0 {
154                return Err(CudaError::DriverError(err));
155            }
156        }
157        Ok(())
158    }
159
160    /// Synchronize on event - wait until event is recorded
161    pub fn synchronize(&self) -> CudaResult<()> {
162        #[cfg(feature = "cuda")]
163        unsafe {
164            let err = ffi::cudaEventSynchronize(self.handle);
165            if err != 0 {
166                return Err(CudaError::SyncError);
167            }
168        }
169        Ok(())
170    }
171
172    /// Get elapsed time between two events (in milliseconds)
173    pub fn elapsed_time(_start: &CudaEvent, _end: &CudaEvent) -> CudaResult<f32> {
174        #[cfg(feature = "cuda")]
175        {
176            let mut ms: f32 = 0.0;
177            
178            unsafe {
179                let err = ffi::cudaEventElapsedTime(&mut ms, _start.handle, _end.handle);
180                if err != 0 {
181                    return Err(CudaError::DriverError(err));
182                }
183            }
184            
185            Ok(ms)
186        }
187        
188        #[cfg(not(feature = "cuda"))]
189        Ok(0.0)
190    }
191}
192
193impl Default for CudaEvent {
194    fn default() -> Self {
195        Self::new().unwrap_or(CudaEvent {
196            handle: std::ptr::null_mut(),
197        })
198    }
199}
200
201impl Drop for CudaEvent {
202    fn drop(&mut self) {
203        if !self.handle.is_null() {
204            #[cfg(feature = "cuda")]
205            unsafe {
206                let _ = ffi::cudaEventDestroy(self.handle);
207            }
208        }
209    }
210}
211
212unsafe impl Send for CudaEvent {}
213unsafe impl Sync for CudaEvent {}
214
215/// Timer utility using CUDA events
216pub struct CudaTimer {
217    start: CudaEvent,
218    stop: CudaEvent,
219    stream: CudaStream,
220}
221
222impl CudaTimer {
223    pub fn new(stream: CudaStream) -> CudaResult<Self> {
224        Ok(CudaTimer {
225            start: CudaEvent::new()?,
226            stop: CudaEvent::new()?,
227            stream,
228        })
229    }
230
231    pub fn start(&self) -> CudaResult<()> {
232        self.start.record(&self.stream)
233    }
234
235    pub fn stop(&self) -> CudaResult<()> {
236        self.stop.record(&self.stream)
237    }
238
239    pub fn elapsed_ms(&self) -> CudaResult<f32> {
240        self.stop.synchronize()?;
241        CudaEvent::elapsed_time(&self.start, &self.stop)
242    }
243}