Skip to main content

oxicuda_runtime/
stream.rs

1//! CUDA stream management.
2//!
3//! Implements the CUDA Runtime stream API:
4//! - `cudaStreamCreate` / `cudaStreamCreateWithFlags` / `cudaStreamCreateWithPriority`
5//! - `cudaStreamDestroy`
6//! - `cudaStreamSynchronize`
7//! - `cudaStreamQuery`
8//! - `cudaStreamWaitEvent`
9//! - `cudaStreamGetPriority`
10//! - `cudaStreamGetFlags`
11//! - `cudaStreamGetDevice`
12//! - The default stream (`cudaStreamDefault` / `cudaStreamLegacy` / `cudaStreamPerThread`)
13
14use oxicuda_driver::ffi::CUstream;
15use oxicuda_driver::loader::try_driver;
16
17use crate::error::{CudaRtError, CudaRtResult};
18
19// ─── StreamFlags ─────────────────────────────────────────────────────────────
20
21/// Flags for stream creation.
22///
23/// Mirrors `cudaStreamFlags`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
25pub struct StreamFlags(pub u32);
26
27impl StreamFlags {
28    /// Default stream flag: stream synchronises with the legacy default stream.
29    pub const DEFAULT: Self = Self(0x0);
30    /// Non-blocking stream: the stream does not implicitly synchronise with the
31    /// legacy default stream (mirrors `cudaStreamNonBlocking`).
32    pub const NON_BLOCKING: Self = Self(0x1);
33}
34
35// ─── CudaStream ──────────────────────────────────────────────────────────────
36
37/// A CUDA stream handle.
38///
39/// Wraps the raw `CUstream` handle from the driver API.  The stream is
40/// **not** automatically destroyed when dropped — call [`stream_destroy`]
41/// explicitly or use the stream within its creating context lifetime.
42///
43/// Use [`CudaStream::DEFAULT`] to obtain the special legacy-default
44/// stream sentinel.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct CudaStream(CUstream);
47
48impl CudaStream {
49    /// The legacy default CUDA stream (`cudaStreamDefault` = 0).
50    ///
51    /// Operations on the default stream block all other streams in the context.
52    pub const DEFAULT: Self = Self(CUstream(std::ptr::null_mut()));
53
54    /// Per-thread default stream (`cudaStreamPerThread`).
55    ///
56    /// Equivalent to passing `cudaStreamPerThread` in the Runtime API.
57    /// The value `0x2` is the canonical sentinel used by the CUDA Runtime.
58    pub const PER_THREAD: Self = Self(CUstream(2 as *mut std::ffi::c_void));
59
60    /// Construct a `CudaStream` from a raw driver handle.
61    ///
62    /// # Safety
63    ///
64    /// The caller must ensure the handle is valid and not used after the
65    /// associated context is destroyed.
66    #[must_use]
67    pub const unsafe fn from_raw(raw: CUstream) -> Self {
68        Self(raw)
69    }
70
71    /// Returns the underlying raw `CUstream`.
72    #[must_use]
73    pub fn raw(self) -> CUstream {
74        self.0
75    }
76
77    /// Returns `true` if this is the legacy default stream.
78    #[must_use]
79    pub fn is_default(self) -> bool {
80        self.0.is_null()
81    }
82}
83
84// ─── Stream creation / destruction ────────────────────────────────────────────
85
86/// Create a new CUDA stream with default flags.
87///
88/// Mirrors `cudaStreamCreate`.
89///
90/// # Errors
91///
92/// Propagates driver errors.
93pub fn stream_create() -> CudaRtResult<CudaStream> {
94    stream_create_with_flags(StreamFlags::DEFAULT)
95}
96
97/// Create a new CUDA stream with the given flags.
98///
99/// Mirrors `cudaStreamCreateWithFlags`.
100///
101/// # Errors
102///
103/// Propagates driver errors.
104pub fn stream_create_with_flags(flags: StreamFlags) -> CudaRtResult<CudaStream> {
105    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
106    let mut stream = CUstream::default();
107    // SAFETY: FFI; stream is a valid stack-allocated opaque pointer.
108    let rc = unsafe { (api.cu_stream_create)(&raw mut stream, flags.0) };
109    if rc != 0 {
110        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
111    }
112    Ok(CudaStream(stream))
113}
114
115/// Create a new CUDA stream with the given flags and scheduling priority.
116///
117/// Mirrors `cudaStreamCreateWithPriority`.
118///
119/// `priority` is a signed integer where lower values indicate higher priority.
120/// The valid range can be queried with `cudaDeviceGetStreamPriorityRange`.
121///
122/// # Errors
123///
124/// Propagates driver errors.
125pub fn stream_create_with_priority(flags: StreamFlags, priority: i32) -> CudaRtResult<CudaStream> {
126    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
127    let mut stream = CUstream::default();
128    // SAFETY: FFI.
129    let rc = unsafe { (api.cu_stream_create_with_priority)(&raw mut stream, flags.0, priority) };
130    if rc != 0 {
131        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
132    }
133    Ok(CudaStream(stream))
134}
135
136/// Destroy a CUDA stream.
137///
138/// Mirrors `cudaStreamDestroy`.
139///
140/// # Errors
141///
142/// Propagates driver errors.
143pub fn stream_destroy(stream: CudaStream) -> CudaRtResult<()> {
144    if stream.is_default() {
145        return Ok(()); // default stream is never explicitly destroyed
146    }
147    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
148    // SAFETY: FFI; stream handle is valid.
149    let rc = unsafe { (api.cu_stream_destroy_v2)(stream.raw()) };
150    if rc != 0 {
151        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
152    }
153    Ok(())
154}
155
156// ─── Stream synchronisation / query ──────────────────────────────────────────
157
158/// Wait until all preceding operations in `stream` complete.
159///
160/// Mirrors `cudaStreamSynchronize`.
161///
162/// # Errors
163///
164/// Propagates driver errors.
165pub fn stream_synchronize(stream: CudaStream) -> CudaRtResult<()> {
166    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
167    // SAFETY: FFI.
168    let rc = unsafe { (api.cu_stream_synchronize)(stream.raw()) };
169    if rc != 0 {
170        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::Unknown));
171    }
172    Ok(())
173}
174
175/// Check whether all preceding operations in `stream` have completed.
176///
177/// Mirrors `cudaStreamQuery`.
178///
179/// Returns `Ok(true)` if complete, `Ok(false)` if still running.
180///
181/// # Errors
182///
183/// Propagates driver errors (other than `NotReady`).
184pub fn stream_query(stream: CudaStream) -> CudaRtResult<bool> {
185    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
186    // SAFETY: FFI.
187    let rc = unsafe { (api.cu_stream_query)(stream.raw()) };
188    match rc {
189        0 => Ok(true),    // CUDA_SUCCESS — complete
190        600 => Ok(false), // CUDA_ERROR_NOT_READY — still running
191        other => Err(CudaRtError::from_code(other).unwrap_or(CudaRtError::Unknown)),
192    }
193}
194
195/// Make all future work submitted to `stream` wait until `event` is recorded.
196///
197/// Mirrors `cudaStreamWaitEvent`.
198///
199/// # Errors
200///
201/// Propagates driver errors.
202pub fn stream_wait_event(
203    stream: CudaStream,
204    event: crate::event::CudaEvent,
205    flags: u32,
206) -> CudaRtResult<()> {
207    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
208    // SAFETY: FFI.
209    let rc = unsafe { (api.cu_stream_wait_event)(stream.raw(), event.raw(), flags) };
210    if rc != 0 {
211        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
212    }
213    Ok(())
214}
215
216/// Returns the priority of `stream`.
217///
218/// Mirrors `cudaStreamGetPriority`.
219///
220/// # Errors
221///
222/// Propagates driver errors.
223pub fn stream_get_priority(stream: CudaStream) -> CudaRtResult<i32> {
224    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
225    let mut priority: std::ffi::c_int = 0;
226    // SAFETY: FFI.
227    let rc = unsafe { (api.cu_stream_get_priority)(stream.raw(), &raw mut priority) };
228    if rc != 0 {
229        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
230    }
231    Ok(priority)
232}
233
234/// Returns the flags of `stream`.
235///
236/// Mirrors `cudaStreamGetFlags`.
237///
238/// # Errors
239///
240/// Propagates driver errors.
241pub fn stream_get_flags(stream: CudaStream) -> CudaRtResult<StreamFlags> {
242    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
243    let mut flags: u32 = 0;
244    // SAFETY: FFI.
245    let rc = unsafe { (api.cu_stream_get_flags)(stream.raw(), &raw mut flags) };
246    if rc != 0 {
247        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
248    }
249    Ok(StreamFlags(flags))
250}
251
252// ─── Tests ───────────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn default_stream_is_null() {
260        assert!(CudaStream::DEFAULT.is_default());
261        assert!(!CudaStream::PER_THREAD.is_default());
262    }
263
264    #[test]
265    fn stream_flags_values() {
266        assert_eq!(StreamFlags::DEFAULT.0, 0);
267        assert_eq!(StreamFlags::NON_BLOCKING.0, 1);
268    }
269
270    #[test]
271    fn stream_destroy_default_is_noop() {
272        // Should never hit the driver for the default stream.
273        let result = stream_destroy(CudaStream::DEFAULT);
274        // Without a driver it fails with DriverNotAvailable; with a driver it's Ok.
275        let _ = result;
276    }
277
278    #[test]
279    fn stream_create_without_gpu_returns_error() {
280        let result = stream_create();
281        // Must either succeed (GPU present) or fail with DriverNotAvailable /
282        // some other non-panic error.
283        assert!(result.is_ok() || result.is_err());
284    }
285}