cuda_oxide/
stream.rs

1use num_enum::TryFromPrimitive;
2use std::{ffi::c_void, marker::PhantomData, pin::Pin, ptr::null_mut, rc::Rc};
3
4use crate::*;
5
6/// A stream of asynchronous operations operating in a [`Context`]
7pub struct Stream<'a> {
8    pub(crate) inner: *mut sys::CUstream_st,
9    pub(crate) pending_stores: Vec<Pin<Box<[u8]>>>,
10    _p: PhantomData<&'a ()>,
11}
12
13/// Wait comparison type for waiting on some condition in [`Stream::wait_32`]/etc
14#[derive(Debug, Copy, Clone, TryFromPrimitive)]
15#[repr(u32)]
16pub enum WaitValueMode {
17    /// Wait until (int32_t)(*addr - value) >= 0 (or int64_t for 64 bit values). Note this is a cyclic comparison which ignores wraparound. (Default behavior.)
18    Geq = 0x0,
19    /// Wait until *addr == value.
20    Eq = 0x1,
21    /// Wait until (*addr & value) != 0.
22    And = 0x2,
23    /// Wait until ~(*addr | value) != 0. Support for this operation can be queried with cuDeviceGetAttribute() and CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR.
24    Nor = 0x3,
25}
26
27unsafe extern "C" fn host_callback(arg: *mut std::ffi::c_void) {
28    let closure: Box<Box<dyn FnOnce() + Send + Sync>> = Box::from_raw(arg as *mut _);
29    closure();
30}
31
32impl<'a> Stream<'a> {
33    /// Creates a new stream for a handle
34    pub fn new(_handle: &Rc<Handle<'a>>) -> CudaResult<Self> {
35        let mut out = null_mut();
36        cuda_error(unsafe {
37            sys::cuStreamCreate(
38                &mut out as *mut _,
39                sys::CUstream_flags_enum_CU_STREAM_NON_BLOCKING,
40            )
41        })?;
42        Ok(Self {
43            inner: out,
44            pending_stores: vec![],
45            _p: PhantomData,
46        })
47    }
48
49    /// Drives all pending tasks on the stream to completion
50    pub fn sync(&mut self) -> CudaResult<()> {
51        cuda_error(unsafe { sys::cuStreamSynchronize(self.inner) })?;
52        self.pending_stores.clear();
53        Ok(())
54    }
55
56    /// Returns `Ok(true)` if the stream has finished processing all queued tasks.
57    pub fn is_synced(&self) -> CudaResult<bool> {
58        match cuda_error(unsafe { sys::cuStreamQuery(self.inner) }) {
59            Ok(()) => Ok(true),
60            Err(ErrorCode::NotReady) => Ok(false),
61            Err(e) => Err(e),
62        }
63    }
64
65    /// Wait for a 4-byte value in a specific location to compare to `value` by `mode`.
66    pub fn wait_32<'b>(
67        &'b mut self,
68        addr: &'b DevicePtr<'a>,
69        value: u32,
70        mode: WaitValueMode,
71        flush: bool,
72    ) -> CudaResult<()> {
73        if addr.len < 4 {
74            panic!("overflow in Stream::wait_32");
75        }
76        let flush = if flush { 1u32 << 30 } else { 0 };
77        cuda_error(unsafe {
78            sys::cuStreamWaitValue32(self.inner, addr.inner, value, mode as u32 | flush)
79        })
80    }
81
82    /// Wait for a 8-byte value in a specific location to compare to `value` by `mode`.
83    pub fn wait_64<'b>(
84        &mut self,
85        addr: &'b DevicePtr<'a>,
86        value: u64,
87        mode: WaitValueMode,
88        flush: bool,
89    ) -> CudaResult<()> {
90        if addr.len < 8 {
91            panic!("overflow in Stream::wait_64");
92        }
93        let flush = if flush { 1u32 << 30 } else { 0 };
94        cuda_error(unsafe {
95            sys::cuStreamWaitValue64(self.inner, addr.inner, value, mode as u32 | flush)
96        })
97    }
98
99    /// Writes a 4-byte value to device memory asynchronously
100    pub fn write_32<'b>(
101        &'b mut self,
102        addr: &'b DevicePtr<'a>,
103        value: u32,
104        no_memory_barrier: bool,
105    ) -> CudaResult<()> {
106        if addr.len < 4 {
107            panic!("overflow in Stream::write_32");
108        }
109        let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 };
110        cuda_error(unsafe {
111            sys::cuStreamWriteValue32(self.inner, addr.inner, value, no_memory_barrier)
112        })
113    }
114
115    /// Writes a 8-byte value to device memory asynchronously
116    pub fn write_64<'b>(
117        &'b mut self,
118        addr: &'b DevicePtr<'a>,
119        value: u64,
120        no_memory_barrier: bool,
121    ) -> CudaResult<()> {
122        if addr.len < 8 {
123            panic!("overflow in Stream::write_64");
124        }
125        let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 };
126        cuda_error(unsafe {
127            sys::cuStreamWriteValue64(self.inner, addr.inner, value, no_memory_barrier)
128        })
129    }
130
131    /// Calls a callback closure function `callback` once all prior tasks in the Stream have been driven to completion.
132    /// Note that it is a memory leak to drop the stream before this callback is called.
133    /// The callback is not guaranteed to be called if the stream errors out.
134    /// Also note that it is erroneous in `libcuda` to make any calls to `libcuda` from this callback.
135    /// The callback is called from a CUDA internal thread, however this is an implementation detail of `libcuda` and not guaranteed.
136    pub fn callback<F: FnOnce() + Send + Sync>(&mut self, callback: F) -> CudaResult<()> {
137        let callback: Box<Box<dyn FnOnce()>> = Box::new(Box::new(callback));
138        cuda_error(unsafe {
139            sys::cuLaunchHostFunc(
140                self.inner,
141                Some(host_callback),
142                Box::leak(callback) as *mut _ as *mut _,
143            )
144        })
145    }
146
147    /// Launch a CUDA kernel on this [`Stream`] with the given `grid_dim` grid dimensions, `block_dim` block dimensions, `shared_mem_size` allocated shared memory pool, and `parameters` kernel parameters.
148    /// It is undefined behavior to pass in `parameters` that do not conform to the passes CUDA kernel. If the argument count is wrong, CUDA will generally throw an error.
149    /// If your `parameters` is accurate to the kernel definition, then this function is otherwise safe.
150    pub unsafe fn launch<'b, D1: Into<Dim3>, D2: Into<Dim3>, K: KernelParameters>(
151        &mut self,
152        f: &Function<'a, 'b>,
153        grid_dim: D1,
154        block_dim: D2,
155        shared_mem_size: u32,
156        parameters: K,
157    ) -> CudaResult<()> {
158        let grid_dim = grid_dim.into().0;
159        let block_dim = block_dim.into().0;
160        let mut kernel_params = vec![];
161        parameters.params(&mut kernel_params);
162        let mut new_kernel_params = Vec::with_capacity(kernel_params.len());
163        for param in &kernel_params {
164            new_kernel_params.push(param.as_ptr() as *mut c_void);
165        }
166        cuda_error(sys::cuLaunchKernel(
167            f.inner,
168            grid_dim.0,
169            grid_dim.1,
170            grid_dim.2,
171            block_dim.0,
172            block_dim.1,
173            block_dim.2,
174            shared_mem_size,
175            self.inner,
176            new_kernel_params.as_mut_ptr(),
177            null_mut(),
178        ))
179    }
180}
181
182impl<'a> Drop for Stream<'a> {
183    fn drop(&mut self) {
184        if let Err(e) = cuda_error(unsafe { sys::cuStreamDestroy_v2(self.inner) }) {
185            eprintln!("CUDA: failed to drop stream: {:?}", e);
186        }
187    }
188}