Skip to main content

baracuda_runtime/
stream.rs

1//! Runtime-API streams.
2
3use std::sync::Arc;
4
5use baracuda_cuda_sys::runtime::{cudaStream_t, runtime, types::cudaStreamFlags};
6
7use crate::device::Device;
8use crate::error::{check, Result};
9
10/// An asynchronous work queue on the current CUDA device.
11#[derive(Clone)]
12pub struct Stream {
13    inner: Arc<StreamInner>,
14}
15
16struct StreamInner {
17    handle: cudaStream_t,
18    device: Device,
19}
20
21unsafe impl Send for StreamInner {}
22unsafe impl Sync for StreamInner {}
23
24impl core::fmt::Debug for StreamInner {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        f.debug_struct("Stream")
27            .field("handle", &self.handle)
28            .field("device", &self.device)
29            .finish()
30    }
31}
32
33impl core::fmt::Debug for Stream {
34    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
35        self.inner.fmt(f)
36    }
37}
38
39impl Stream {
40    /// Create a stream with default (legacy-default-stream-synchronizing) flags
41    /// on the current device.
42    pub fn new() -> Result<Self> {
43        Self::with_flags(cudaStreamFlags::DEFAULT)
44    }
45
46    /// Create a non-blocking stream — does not synchronize with the legacy
47    /// default stream.
48    pub fn non_blocking() -> Result<Self> {
49        Self::with_flags(cudaStreamFlags::NON_BLOCKING)
50    }
51
52    /// Adopt a raw `cudaStream_t` handle. The wrapper will call
53    /// `cudaStreamDestroy` on drop.
54    ///
55    /// # Safety
56    ///
57    /// `handle` must be a live stream on the current device. Do not
58    /// destroy it externally.
59    pub unsafe fn from_raw(handle: cudaStream_t) -> Self {
60        let device = Device::current().unwrap_or(Device::from_ordinal(0));
61        Self {
62            inner: Arc::new(StreamInner { handle, device }),
63        }
64    }
65
66    /// Create a stream with raw flags (see [`cudaStreamFlags`]).
67    pub fn with_flags(flags: u32) -> Result<Self> {
68        let r = runtime()?;
69        let cu = r.cuda_stream_create_with_flags()?;
70        let mut stream: cudaStream_t = core::ptr::null_mut();
71        check(unsafe { cu(&mut stream, flags) })?;
72        let device = Device::current()?;
73        Ok(Self {
74            inner: Arc::new(StreamInner {
75                handle: stream,
76                device,
77            }),
78        })
79    }
80
81    /// Block the calling thread until all prior work on this stream is complete.
82    pub fn synchronize(&self) -> Result<()> {
83        let r = runtime()?;
84        let cu = r.cuda_stream_synchronize()?;
85        check(unsafe { cu(self.inner.handle) })
86    }
87
88    /// `Ok(true)` if all queued work has finished, `Ok(false)` if work remains.
89    pub fn is_complete(&self) -> Result<bool> {
90        use baracuda_cuda_sys::runtime::cudaError_t;
91        let r = runtime()?;
92        let cu = r.cuda_stream_query()?;
93        match unsafe { cu(self.inner.handle) } {
94            cudaError_t::Success => Ok(true),
95            cudaError_t::NotReady => Ok(false),
96            other => Err(crate::error::Error::Status { status: other }),
97        }
98    }
99
100    /// The device this stream belongs to.
101    #[inline]
102    pub fn device(&self) -> Device {
103        self.inner.device
104    }
105
106    /// Raw `cudaStream_t` handle. Use with care.
107    #[inline]
108    pub fn as_raw(&self) -> cudaStream_t {
109        self.inner.handle
110    }
111
112    /// Create a stream with a specific scheduling priority (lower = higher
113    /// priority). Use [`stream_priority_range`] to discover the legal
114    /// range on the current device.
115    pub fn with_priority(flags: u32, priority: i32) -> Result<Self> {
116        let r = runtime()?;
117        let cu = r.cuda_stream_create_with_priority()?;
118        let mut stream: cudaStream_t = core::ptr::null_mut();
119        check(unsafe { cu(&mut stream, flags, priority) })?;
120        let device = Device::current()?;
121        Ok(Self {
122            inner: Arc::new(StreamInner {
123                handle: stream,
124                device,
125            }),
126        })
127    }
128
129    /// This stream's scheduling priority.
130    pub fn priority(&self) -> Result<i32> {
131        let r = runtime()?;
132        let cu = r.cuda_stream_get_priority()?;
133        let mut p: core::ffi::c_int = 0;
134        check(unsafe { cu(self.inner.handle, &mut p) })?;
135        Ok(p)
136    }
137
138    /// This stream's flags bitmask.
139    pub fn flags(&self) -> Result<u32> {
140        let r = runtime()?;
141        let cu = r.cuda_stream_get_flags()?;
142        let mut f: core::ffi::c_uint = 0;
143        check(unsafe { cu(self.inner.handle, &mut f) })?;
144        Ok(f)
145    }
146
147    /// Wait for `event` on this stream — blocks future work on `self`
148    /// until the event has completed.
149    pub fn wait_event(&self, event: &crate::Event, flags: u32) -> Result<()> {
150        let r = runtime()?;
151        let cu = r.cuda_stream_wait_event()?;
152        check(unsafe { cu(self.inner.handle, event.as_raw(), flags) })
153    }
154}
155
156/// Return `(least_priority, greatest_priority)` supported on the current
157/// device. Lower numbers are higher priority.
158pub fn stream_priority_range() -> Result<(i32, i32)> {
159    let r = runtime()?;
160    let cu = r.cuda_device_get_stream_priority_range()?;
161    let mut low: core::ffi::c_int = 0;
162    let mut high: core::ffi::c_int = 0;
163    check(unsafe { cu(&mut low, &mut high) })?;
164    Ok((low, high))
165}
166
167impl Stream {
168    /// Enqueue a host-side callback on this stream. Runs on a
169    /// driver-owned thread after prior stream work completes.
170    ///
171    /// The closure is boxed and freed after it runs; a panic inside
172    /// aborts the process.
173    pub fn launch_host_func<F>(&self, f: F) -> Result<()>
174    where
175        F: FnOnce() + Send + 'static,
176    {
177        use core::ffi::c_void;
178
179        let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(f));
180        let raw = Box::into_raw(boxed) as *mut c_void;
181
182        unsafe extern "C" fn trampoline(user_data: *mut c_void) {
183            let f: Box<Box<dyn FnOnce() + Send>> =
184                unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
185            (*f)();
186        }
187
188        let r = runtime()?;
189        let cu = r.cuda_launch_host_func()?;
190        let rc = unsafe { cu(self.inner.handle, Some(trampoline), raw) };
191        if rc != baracuda_cuda_sys::runtime::cudaError_t::Success {
192            // Reclaim the box — cudaLaunchHostFunc didn't take ownership on error.
193            drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
194            return Err(crate::error::Error::Status { status: rc });
195        }
196        Ok(())
197    }
198
199    /// Enqueue a 32-bit write of `value` to device memory `addr`.
200    ///
201    /// # Safety
202    ///
203    /// `addr` must be a live device-addressable pointer.
204    pub unsafe fn write_value_32(
205        &self,
206        addr: *mut core::ffi::c_void,
207        value: u32,
208        flags: u32,
209    ) -> Result<()> { unsafe {
210        let r = runtime()?;
211        let cu = r.cuda_stream_write_value_32()?;
212        check(cu(self.inner.handle, addr, value, flags))
213    }}
214
215    /// # Safety
216    ///
217    /// Same as [`write_value_32`].
218    pub unsafe fn write_value_64(
219        &self,
220        addr: *mut core::ffi::c_void,
221        value: u64,
222        flags: u32,
223    ) -> Result<()> { unsafe {
224        let r = runtime()?;
225        let cu = r.cuda_stream_write_value_64()?;
226        check(cu(self.inner.handle, addr, value, flags))
227    }}
228
229    /// Block the stream until the 32-bit device memory at `addr` satisfies
230    /// the condition selected by `flags` (GEQ / EQ / AND / NOR, optionally
231    /// OR'd with FLUSH).
232    ///
233    /// # Safety
234    ///
235    /// `addr` must be a live device-addressable pointer.
236    pub unsafe fn wait_value_32(
237        &self,
238        addr: *mut core::ffi::c_void,
239        value: u32,
240        flags: u32,
241    ) -> Result<()> { unsafe {
242        let r = runtime()?;
243        let cu = r.cuda_stream_wait_value_32()?;
244        check(cu(self.inner.handle, addr, value, flags))
245    }}
246
247    /// # Safety
248    ///
249    /// Same as [`wait_value_32`].
250    pub unsafe fn wait_value_64(
251        &self,
252        addr: *mut core::ffi::c_void,
253        value: u64,
254        flags: u32,
255    ) -> Result<()> { unsafe {
256        let r = runtime()?;
257        let cu = r.cuda_stream_wait_value_64()?;
258        check(cu(self.inner.handle, addr, value, flags))
259    }}
260
261    /// Associate a managed-memory region with this stream
262    /// (`cudaStreamAttachMemAsync`). Pass `flags = 0` for the default.
263    ///
264    /// # Safety
265    ///
266    /// `dev_ptr` must be a managed-memory allocation.
267    pub unsafe fn attach_mem_async(
268        &self,
269        dev_ptr: *mut core::ffi::c_void,
270        length: usize,
271        flags: u32,
272    ) -> Result<()> { unsafe {
273        let r = runtime()?;
274        let cu = r.cuda_stream_attach_mem_async()?;
275        check(cu(self.inner.handle, dev_ptr, length, flags))
276    }}
277
278    /// Copy CUDA-managed attributes (access-policy window, sync policy)
279    /// from `src` onto `self`.
280    pub fn copy_attributes_from(&self, src: &Stream) -> Result<()> {
281        let r = runtime()?;
282        let cu = r.cuda_stream_copy_attributes()?;
283        check(unsafe { cu(self.inner.handle, src.inner.handle) })
284    }
285
286    /// Enqueue a batch of stream mem-ops (`WAIT_VALUE_32/64`,
287    /// `WRITE_VALUE_32/64`) atomically. Much cheaper than issuing the
288    /// ops one at a time.
289    ///
290    /// Build entries with [`baracuda_cuda_sys::types::CUstreamBatchMemOpParams::write_value_32`]
291    /// etc. Pass `flags = 0` for the default.
292    ///
293    /// # Safety
294    ///
295    /// Every entry's `address` must be a live device-addressable pointer.
296    pub unsafe fn batch_mem_op(
297        &self,
298        params: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
299        flags: u32,
300    ) -> Result<()> { unsafe {
301        let r = runtime()?;
302        let cu = r.cuda_stream_batch_mem_op()?;
303        check(cu(
304            self.inner.handle,
305            params.len() as core::ffi::c_uint,
306            params.as_mut_ptr(),
307            flags,
308        ))
309    }}
310}
311
312impl Drop for StreamInner {
313    fn drop(&mut self) {
314        if let Ok(r) = runtime() {
315            if let Ok(cu) = r.cuda_stream_destroy() {
316                let _ = unsafe { cu(self.handle) };
317            }
318        }
319    }
320}