oxicuda_runtime/stream.rs
1//! CUDA stream management.
2//!
3//! Implements the CUDA Runtime stream API:
4//! - `cudaStreamCreate` / `cudaStreamCreateWithFlags` / `cudaStreamCreateWithPriority`
5//! - `cudaStreamDestroy`
6//! - `cudaStreamSynchronize`
7//! - `cudaStreamQuery`
8//! - `cudaStreamWaitEvent`
9//! - `cudaStreamGetPriority`
10//! - `cudaStreamGetFlags`
11//! - `cudaStreamGetDevice`
12//! - The default stream (`cudaStreamDefault` / `cudaStreamLegacy` / `cudaStreamPerThread`)
13
14use oxicuda_driver::ffi::CUstream;
15use oxicuda_driver::loader::try_driver;
16
17use crate::error::{CudaRtError, CudaRtResult};
18
19// ─── StreamFlags ─────────────────────────────────────────────────────────────
20
21/// Flags for stream creation.
22///
23/// Mirrors `cudaStreamFlags`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
25pub struct StreamFlags(pub u32);
26
27impl StreamFlags {
28 /// Default stream flag: stream synchronises with the legacy default stream.
29 pub const DEFAULT: Self = Self(0x0);
30 /// Non-blocking stream: the stream does not implicitly synchronise with the
31 /// legacy default stream (mirrors `cudaStreamNonBlocking`).
32 pub const NON_BLOCKING: Self = Self(0x1);
33}
34
35// ─── CudaStream ──────────────────────────────────────────────────────────────
36
37/// A CUDA stream handle.
38///
39/// Wraps the raw `CUstream` handle from the driver API. The stream is
40/// **not** automatically destroyed when dropped — call [`stream_destroy`]
41/// explicitly or use the stream within its creating context lifetime.
42///
43/// Use [`CudaStream::DEFAULT`] to obtain the special legacy-default
44/// stream sentinel.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct CudaStream(CUstream);
47
48impl CudaStream {
49 /// The legacy default CUDA stream (`cudaStreamDefault` = 0).
50 ///
51 /// Operations on the default stream block all other streams in the context.
52 pub const DEFAULT: Self = Self(CUstream(std::ptr::null_mut()));
53
54 /// Per-thread default stream (`cudaStreamPerThread`).
55 ///
56 /// Equivalent to passing `cudaStreamPerThread` in the Runtime API.
57 /// The value `0x2` is the canonical sentinel used by the CUDA Runtime.
58 pub const PER_THREAD: Self = Self(CUstream(2 as *mut std::ffi::c_void));
59
60 /// Construct a `CudaStream` from a raw driver handle.
61 ///
62 /// # Safety
63 ///
64 /// The caller must ensure the handle is valid and not used after the
65 /// associated context is destroyed.
66 #[must_use]
67 pub const unsafe fn from_raw(raw: CUstream) -> Self {
68 Self(raw)
69 }
70
71 /// Returns the underlying raw `CUstream`.
72 #[must_use]
73 pub fn raw(self) -> CUstream {
74 self.0
75 }
76
77 /// Returns `true` if this is the legacy default stream.
78 #[must_use]
79 pub fn is_default(self) -> bool {
80 self.0.is_null()
81 }
82}
83
84// ─── Stream creation / destruction ────────────────────────────────────────────
85
86/// Create a new CUDA stream with default flags.
87///
88/// Mirrors `cudaStreamCreate`.
89///
90/// # Errors
91///
92/// Propagates driver errors.
93pub fn stream_create() -> CudaRtResult<CudaStream> {
94 stream_create_with_flags(StreamFlags::DEFAULT)
95}
96
97/// Create a new CUDA stream with the given flags.
98///
99/// Mirrors `cudaStreamCreateWithFlags`.
100///
101/// # Errors
102///
103/// Propagates driver errors.
104pub fn stream_create_with_flags(flags: StreamFlags) -> CudaRtResult<CudaStream> {
105 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
106 let mut stream = CUstream::default();
107 // SAFETY: FFI; stream is a valid stack-allocated opaque pointer.
108 let rc = unsafe { (api.cu_stream_create)(&raw mut stream, flags.0) };
109 if rc != 0 {
110 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
111 }
112 Ok(CudaStream(stream))
113}
114
115/// Create a new CUDA stream with the given flags and scheduling priority.
116///
117/// Mirrors `cudaStreamCreateWithPriority`.
118///
119/// `priority` is a signed integer where lower values indicate higher priority.
120/// The valid range can be queried with `cudaDeviceGetStreamPriorityRange`.
121///
122/// # Errors
123///
124/// Propagates driver errors.
125pub fn stream_create_with_priority(flags: StreamFlags, priority: i32) -> CudaRtResult<CudaStream> {
126 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
127 let mut stream = CUstream::default();
128 // SAFETY: FFI.
129 let rc = unsafe { (api.cu_stream_create_with_priority)(&raw mut stream, flags.0, priority) };
130 if rc != 0 {
131 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
132 }
133 Ok(CudaStream(stream))
134}
135
136/// Destroy a CUDA stream.
137///
138/// Mirrors `cudaStreamDestroy`.
139///
140/// # Errors
141///
142/// Propagates driver errors.
143pub fn stream_destroy(stream: CudaStream) -> CudaRtResult<()> {
144 if stream.is_default() {
145 return Ok(()); // default stream is never explicitly destroyed
146 }
147 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
148 // SAFETY: FFI; stream handle is valid.
149 let rc = unsafe { (api.cu_stream_destroy_v2)(stream.raw()) };
150 if rc != 0 {
151 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
152 }
153 Ok(())
154}
155
156// ─── Stream synchronisation / query ──────────────────────────────────────────
157
158/// Wait until all preceding operations in `stream` complete.
159///
160/// Mirrors `cudaStreamSynchronize`.
161///
162/// # Errors
163///
164/// Propagates driver errors.
165pub fn stream_synchronize(stream: CudaStream) -> CudaRtResult<()> {
166 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
167 // SAFETY: FFI.
168 let rc = unsafe { (api.cu_stream_synchronize)(stream.raw()) };
169 if rc != 0 {
170 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::Unknown));
171 }
172 Ok(())
173}
174
175/// Check whether all preceding operations in `stream` have completed.
176///
177/// Mirrors `cudaStreamQuery`.
178///
179/// Returns `Ok(true)` if complete, `Ok(false)` if still running.
180///
181/// # Errors
182///
183/// Propagates driver errors (other than `NotReady`).
184pub fn stream_query(stream: CudaStream) -> CudaRtResult<bool> {
185 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
186 // SAFETY: FFI.
187 let rc = unsafe { (api.cu_stream_query)(stream.raw()) };
188 match rc {
189 0 => Ok(true), // CUDA_SUCCESS — complete
190 600 => Ok(false), // CUDA_ERROR_NOT_READY — still running
191 other => Err(CudaRtError::from_code(other).unwrap_or(CudaRtError::Unknown)),
192 }
193}
194
195/// Make all future work submitted to `stream` wait until `event` is recorded.
196///
197/// Mirrors `cudaStreamWaitEvent`.
198///
199/// # Errors
200///
201/// Propagates driver errors.
202pub fn stream_wait_event(
203 stream: CudaStream,
204 event: crate::event::CudaEvent,
205 flags: u32,
206) -> CudaRtResult<()> {
207 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
208 // SAFETY: FFI.
209 let rc = unsafe { (api.cu_stream_wait_event)(stream.raw(), event.raw(), flags) };
210 if rc != 0 {
211 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
212 }
213 Ok(())
214}
215
216/// Returns the priority of `stream`.
217///
218/// Mirrors `cudaStreamGetPriority`.
219///
220/// # Errors
221///
222/// Propagates driver errors.
223pub fn stream_get_priority(stream: CudaStream) -> CudaRtResult<i32> {
224 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
225 let mut priority: std::ffi::c_int = 0;
226 // SAFETY: FFI.
227 let rc = unsafe { (api.cu_stream_get_priority)(stream.raw(), &raw mut priority) };
228 if rc != 0 {
229 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
230 }
231 Ok(priority)
232}
233
234/// Returns the flags of `stream`.
235///
236/// Mirrors `cudaStreamGetFlags`.
237///
238/// # Errors
239///
240/// Propagates driver errors.
241pub fn stream_get_flags(stream: CudaStream) -> CudaRtResult<StreamFlags> {
242 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
243 let mut flags: u32 = 0;
244 // SAFETY: FFI.
245 let rc = unsafe { (api.cu_stream_get_flags)(stream.raw(), &raw mut flags) };
246 if rc != 0 {
247 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
248 }
249 Ok(StreamFlags(flags))
250}
251
252// ─── Tests ───────────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn default_stream_is_null() {
260 assert!(CudaStream::DEFAULT.is_default());
261 assert!(!CudaStream::PER_THREAD.is_default());
262 }
263
264 #[test]
265 fn stream_flags_values() {
266 assert_eq!(StreamFlags::DEFAULT.0, 0);
267 assert_eq!(StreamFlags::NON_BLOCKING.0, 1);
268 }
269
270 #[test]
271 fn stream_destroy_default_is_noop() {
272 // Should never hit the driver for the default stream.
273 let result = stream_destroy(CudaStream::DEFAULT);
274 // Without a driver it fails with DriverNotAvailable; with a driver it's Ok.
275 let _ = result;
276 }
277
278 #[test]
279 fn stream_create_without_gpu_returns_error() {
280 let result = stream_create();
281 // Must either succeed (GPU present) or fail with DriverNotAvailable /
282 // some other non-panic error.
283 assert!(result.is_ok() || result.is_err());
284 }
285}