Skip to main content

baracuda_driver/
stream.rs

1//! CUDA streams — ordered queues of work on a device.
2
3use std::sync::Arc;
4
5use baracuda_cuda_sys::types::CUstream_flags;
6use baracuda_cuda_sys::{driver, CUstream};
7
8use crate::context::Context;
9use crate::error::{check, Result};
10
11/// An asynchronous work queue on a CUDA device.
12///
13/// Work submitted to the same stream executes in order; work on different
14/// streams may run concurrently, subject to device scheduling. Streams are
15/// `Send + Sync` — CUDA explicitly permits concurrent submission from
16/// multiple host threads.
17#[derive(Clone)]
18pub struct Stream {
19    inner: Arc<StreamInner>,
20}
21
22struct StreamInner {
23    handle: CUstream,
24    // Hold the owning context so it outlives the stream.
25    context: Context,
26}
27
28// SAFETY: NVIDIA documents that a CUstream may be used from any thread.
29unsafe impl Send for StreamInner {}
30unsafe impl Sync for StreamInner {}
31
32impl core::fmt::Debug for StreamInner {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        f.debug_struct("Stream")
35            .field("handle", &self.handle)
36            .finish_non_exhaustive()
37    }
38}
39
40impl core::fmt::Debug for Stream {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        self.inner.fmt(f)
43    }
44}
45
46impl Stream {
47    /// Create a new stream on `context` with default flags (blocking wrt the
48    /// legacy default stream).
49    pub fn new(context: &Context) -> Result<Self> {
50        Self::with_flags(context, CUstream_flags::DEFAULT)
51    }
52
53    /// Create a non-blocking stream — work on this stream does not
54    /// synchronize with the legacy null stream.
55    pub fn non_blocking(context: &Context) -> Result<Self> {
56        Self::with_flags(context, CUstream_flags::NON_BLOCKING)
57    }
58
59    /// Create a stream with a raw flag bitmask (see [`CUstream_flags`]).
60    pub fn with_flags(context: &Context, flags: u32) -> Result<Self> {
61        context.set_current()?;
62        let d = driver()?;
63        let cu = d.cu_stream_create()?;
64        let mut stream: CUstream = core::ptr::null_mut();
65        // SAFETY: writable pointer; flags are from a known module.
66        check(unsafe { cu(&mut stream, flags) })?;
67        Ok(Self {
68            inner: Arc::new(StreamInner {
69                handle: stream,
70                context: context.clone(),
71            }),
72        })
73    }
74
75    /// Create a stream with a specific priority. Use
76    /// [`Context::stream_priority_range`] to discover the legal range on
77    /// this device (lower = higher priority).
78    pub fn with_priority(context: &Context, flags: u32, priority: i32) -> Result<Self> {
79        context.set_current()?;
80        let d = driver()?;
81        let cu = d.cu_stream_create_with_priority()?;
82        let mut stream: CUstream = core::ptr::null_mut();
83        check(unsafe { cu(&mut stream, flags, priority) })?;
84        Ok(Self {
85            inner: Arc::new(StreamInner {
86                handle: stream,
87                context: context.clone(),
88            }),
89        })
90    }
91
92    /// This stream's scheduling priority.
93    pub fn priority(&self) -> Result<i32> {
94        let d = driver()?;
95        let cu = d.cu_stream_get_priority()?;
96        let mut p: core::ffi::c_int = 0;
97        check(unsafe { cu(self.inner.handle, &mut p) })?;
98        Ok(p)
99    }
100
101    /// This stream's flags bitmask.
102    pub fn flags(&self) -> Result<u32> {
103        let d = driver()?;
104        let cu = d.cu_stream_get_flags()?;
105        let mut f: core::ffi::c_uint = 0;
106        check(unsafe { cu(self.inner.handle, &mut f) })?;
107        Ok(f)
108    }
109
110    /// Enqueue a host-side callback on this stream. The callback runs on
111    /// a driver-owned thread after all prior stream work completes.
112    ///
113    /// The closure is boxed and freed after it runs; a panic inside will
114    /// abort the process (there's no way to propagate it through the C
115    /// callback). Keep the closure simple.
116    pub fn launch_host_func<F>(&self, f: F) -> Result<()>
117    where
118        F: FnOnce() + Send + 'static,
119    {
120        use core::ffi::c_void;
121
122        // Box up the closure and hand the raw pointer to CUDA.
123        let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(f));
124        let raw = Box::into_raw(boxed) as *mut c_void;
125
126        unsafe extern "C" fn trampoline(user_data: *mut c_void) {
127            // SAFETY: user_data was `Box::into_raw`'d just above.
128            let f: Box<Box<dyn FnOnce() + Send>> =
129                unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
130            (*f)();
131        }
132
133        let d = driver()?;
134        let cu = d.cu_launch_host_func()?;
135        // SAFETY: trampoline owns and frees the boxed closure; stream handle is live.
136        let rc = unsafe { cu(self.inner.handle, Some(trampoline), raw) };
137        if rc != baracuda_cuda_sys::CUresult::SUCCESS {
138            // Reclaim the box so we don't leak on submission failure.
139            // SAFETY: cuLaunchHostFunc didn't take ownership on error.
140            drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
141            return Err(crate::error::Error::Status { status: rc });
142        }
143        Ok(())
144    }
145
146    /// Block the calling thread until all work previously enqueued on this
147    /// stream has completed.
148    pub fn synchronize(&self) -> Result<()> {
149        let d = driver()?;
150        let cu = d.cu_stream_synchronize()?;
151        check(unsafe { cu(self.inner.handle) })
152    }
153
154    /// `Ok(true)` if the stream has completed all queued work, `Ok(false)`
155    /// if work is still outstanding.
156    pub fn is_complete(&self) -> Result<bool> {
157        use baracuda_cuda_sys::CUresult;
158        let d = driver()?;
159        let cu = d.cu_stream_query()?;
160        let res = unsafe { cu(self.inner.handle) };
161        match res {
162            CUresult::SUCCESS => Ok(true),
163            CUresult::ERROR_NOT_READY => Ok(false),
164            other => Err(crate::error::Error::Status { status: other }),
165        }
166    }
167
168    /// The [`Context`] this stream lives in.
169    #[inline]
170    pub fn context(&self) -> &Context {
171        &self.inner.context
172    }
173
174    /// Raw `CUstream` handle. Use with care.
175    #[inline]
176    pub fn as_raw(&self) -> CUstream {
177        self.inner.handle
178    }
179
180    /// Device-to-device async copy scheduled on this stream.
181    ///
182    /// Sugar over [`DeviceBuffer::copy_to_device_async`] with the borrows
183    /// flipped the way the call-site usually wants them: the destination
184    /// buffer is taken by `&mut`, so the borrow checker will catch
185    /// aliasing bugs at compile time. `src.len()` must equal `dst.len()`.
186    ///
187    /// ```no_run
188    /// # use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
189    /// # use baracuda_types::DeviceRepr;
190    /// # fn demo() -> baracuda_driver::Result<()> {
191    /// let ctx = Context::new(&Device::get(0)?)?;
192    /// let stream = Stream::new(&ctx)?;
193    /// let src: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, 1024)?;
194    /// let mut dst: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, 1024)?;
195    /// stream.memcpy_dtod(&src, &mut dst)?;
196    /// # Ok(()) }
197    /// ```
198    pub fn memcpy_dtod<T: baracuda_types::DeviceRepr>(
199        &self,
200        src: &crate::memory::DeviceBuffer<T>,
201        dst: &mut crate::memory::DeviceBuffer<T>,
202    ) -> Result<()> {
203        src.copy_to_device_async(dst, self)
204    }
205
206    /// Return the driver-assigned 64-bit ID for this stream. Useful for
207    /// correlating CUPTI traces against baracuda streams.
208    pub fn id(&self) -> Result<u64> {
209        let d = driver()?;
210        let cu = d.cu_stream_get_id()?;
211        let mut out: u64 = 0;
212        check(unsafe { cu(self.inner.handle, &mut out) })?;
213        Ok(out)
214    }
215
216    /// Copy all CUDA-managed attributes (access policy window, sync
217    /// policy) from `src` onto `self`. Does not copy priority or flags
218    /// (those are set at stream creation time).
219    pub fn copy_attributes_from(&self, src: &Stream) -> Result<()> {
220        let d = driver()?;
221        let cu = d.cu_stream_copy_attributes()?;
222        check(unsafe { cu(self.inner.handle, src.inner.handle) })
223    }
224
225    /// Make this stream wait for `event` to complete before processing
226    /// any subsequent work. `flags` is typically `0`
227    /// (`CU_EVENT_WAIT_DEFAULT`). Use this for cross-stream
228    /// dependencies — record an event on stream A, then have stream B
229    /// wait on it.
230    pub fn wait_event(&self, event: &crate::Event, flags: u32) -> Result<()> {
231        let d = driver()?;
232        let cu = d.cu_stream_wait_event()?;
233        check(unsafe { cu(self.inner.handle, event.as_raw(), flags) })
234    }
235
236    /// Read a `CUstreamAttrValue` for `attr` from this stream. The
237    /// caller passes a writable buffer big enough for the largest
238    /// attribute value (`CUstreamAttrValue` is up to 48 bytes).
239    /// Use the `CU_STREAM_ATTRIBUTE_*` constants for `attr`.
240    ///
241    /// # Safety
242    ///
243    /// `value_out` must be a writable region matching the layout of the
244    /// `CUstreamAttrValue` variant for `attr`.
245    pub unsafe fn get_attribute(
246        &self,
247        attr: i32,
248        value_out: *mut core::ffi::c_void,
249    ) -> Result<()> { unsafe {
250        let d = driver()?;
251        let cu = d.cu_stream_get_attribute()?;
252        check(cu(self.inner.handle, attr, value_out))
253    }}
254
255    /// Set a `CUstreamAttrValue` on this stream. See [`Self::get_attribute`]
256    /// for the value layout.
257    ///
258    /// # Safety
259    ///
260    /// `value` must point at a properly-initialized `CUstreamAttrValue`
261    /// variant for `attr`.
262    pub unsafe fn set_attribute(
263        &self,
264        attr: i32,
265        value: *const core::ffi::c_void,
266    ) -> Result<()> { unsafe {
267        let d = driver()?;
268        let cu = d.cu_stream_set_attribute()?;
269        check(cu(self.inner.handle, attr, value))
270    }}
271
272    /// Associate a managed-memory region with this stream. Pass
273    /// `flags = 0` for the default ("one thread").
274    pub fn attach_mem_async(
275        &self,
276        dptr: baracuda_cuda_sys::CUdeviceptr,
277        length: usize,
278        flags: u32,
279    ) -> Result<()> {
280        let d = driver()?;
281        let cu = d.cu_stream_attach_mem_async()?;
282        check(unsafe { cu(self.inner.handle, dptr, length, flags) })
283    }
284
285    /// Enqueue a 32-bit write of `value` to device memory `addr` on this
286    /// stream, ordered like any other stream op.
287    ///
288    /// `flags` is a bitmask of
289    /// [`baracuda_cuda_sys::types::CUstreamWriteValue_flags`].
290    pub fn write_value_32(
291        &self,
292        addr: baracuda_cuda_sys::CUdeviceptr,
293        value: u32,
294        flags: u32,
295    ) -> Result<()> {
296        let d = driver()?;
297        let cu = d.cu_stream_write_value_32()?;
298        check(unsafe { cu(self.inner.handle, addr, value, flags) })
299    }
300
301    pub fn write_value_64(
302        &self,
303        addr: baracuda_cuda_sys::CUdeviceptr,
304        value: u64,
305        flags: u32,
306    ) -> Result<()> {
307        let d = driver()?;
308        let cu = d.cu_stream_write_value_64()?;
309        check(unsafe { cu(self.inner.handle, addr, value, flags) })
310    }
311
312    /// Block the stream until the device memory at `addr` satisfies the
313    /// condition specified by `flags` (see
314    /// [`baracuda_cuda_sys::types::CUstreamWaitValue_flags`] —
315    /// GEQ / EQ / AND / NOR, optionally OR'd with FLUSH).
316    pub fn wait_value_32(
317        &self,
318        addr: baracuda_cuda_sys::CUdeviceptr,
319        value: u32,
320        flags: u32,
321    ) -> Result<()> {
322        let d = driver()?;
323        let cu = d.cu_stream_wait_value_32()?;
324        check(unsafe { cu(self.inner.handle, addr, value, flags) })
325    }
326
327    pub fn wait_value_64(
328        &self,
329        addr: baracuda_cuda_sys::CUdeviceptr,
330        value: u64,
331        flags: u32,
332    ) -> Result<()> {
333        let d = driver()?;
334        let cu = d.cu_stream_wait_value_64()?;
335        check(unsafe { cu(self.inner.handle, addr, value, flags) })
336    }
337
338    /// Submit a batch of wait/write value ops atomically on this stream.
339    /// `ops` is typically a small array built via
340    /// [`baracuda_cuda_sys::types::CUstreamBatchMemOpParams::wait_value_32`]
341    /// etc.
342    pub fn batch_mem_op(
343        &self,
344        ops: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
345        flags: u32,
346    ) -> Result<()> {
347        let d = driver()?;
348        let cu = d.cu_stream_batch_mem_op()?;
349        check(unsafe {
350            cu(
351                self.inner.handle,
352                ops.len() as core::ffi::c_uint,
353                ops.as_mut_ptr(),
354                flags,
355            )
356        })
357    }
358
359    /// Query stream-capture state. Returns `(active, capture_id, graph_handle)`
360    /// where `active` is `true` if the stream is currently capturing. The
361    /// graph handle is only meaningful while capturing.
362    pub fn capture_info(&self) -> Result<(bool, u64, baracuda_cuda_sys::CUgraph)> {
363        let d = driver()?;
364        let cu = d.cu_stream_get_capture_info()?;
365        let mut status: core::ffi::c_int = 0;
366        let mut id: u64 = 0;
367        let mut graph: baracuda_cuda_sys::CUgraph = core::ptr::null_mut();
368        let mut deps_ptr: *const baracuda_cuda_sys::CUgraphNode = core::ptr::null();
369        let mut num_deps: usize = 0;
370        check(unsafe {
371            cu(
372                self.inner.handle,
373                &mut status,
374                &mut id,
375                &mut graph,
376                &mut deps_ptr,
377                &mut num_deps,
378            )
379        })?;
380        // CUstreamCaptureStatus: NONE=0, ACTIVE=1, INVALIDATED=2.
381        Ok((status == 1, id, graph))
382    }
383}
384
385impl Drop for StreamInner {
386    fn drop(&mut self) {
387        if let Ok(d) = driver() {
388            if let Ok(cu) = d.cu_stream_destroy() {
389                // SAFETY: last Arc drop; handle is unique.
390                let _ = unsafe { cu(self.handle) };
391            }
392        }
393    }
394}