async_cuda_core/ffi/
stream.rs

1use cpp::cpp;
2
3use crate::ffi::ptr::DevicePtr;
4use crate::ffi::result;
5
6type Result<T> = std::result::Result<T, crate::error::Error>;
7
8/// Synchronous implementation of [`crate::Stream`].
9///
10/// Refer to [`crate::Stream`] for documentation.
11pub struct Stream {
12    internal: DevicePtr,
13}
14
15/// Implements [`Send`] for [`Stream`].
16///
17/// # Safety
18///
19/// This property is inherited from the CUDA API, which is thread-safe.
20unsafe impl Send for Stream {}
21
22/// Implements [`Sync`] for [`Stream`].
23///
24/// # Safety
25///
26/// This property is inherited from the CUDA API, which is thread-safe.
27unsafe impl Sync for Stream {}
28
29impl Stream {
30    pub fn null() -> Self {
31        Self {
32            // SAFETY: This is safe because a null pointer for stream indicates the default
33            // stream in CUDA and all functions accept this.
34            internal: unsafe { DevicePtr::null() },
35        }
36    }
37
38    pub fn new() -> Result<Self> {
39        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41
42        let ret = cpp!(unsafe [
43            ptr_ptr as "void**"
44        ] -> i32 as "std::int32_t" {
45            return cudaStreamCreate((cudaStream_t*) ptr_ptr);
46        });
47        result!(
48            ret,
49            Stream {
50                internal: ptr.into()
51            }
52        )
53    }
54
55    pub fn synchronize(&self) -> Result<()> {
56        let ptr = self.internal.as_ptr();
57        let ret = cpp!(unsafe [
58            ptr as "void*"
59        ] -> i32 as "std::int32_t" {
60            return cudaStreamSynchronize((cudaStream_t) ptr);
61        });
62        result!(ret)
63    }
64
65    pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
66        let ptr = self.internal.as_ptr();
67
68        let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
69        let f_boxed2 = Box::new(f_boxed);
70        let f_boxed2_ptr = Box::into_raw(f_boxed2);
71        let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
72
73        let ret = cpp!(unsafe [
74            ptr as "void*",
75            user_data as "void*"
76        ] -> i32 as "std::int32_t" {
77            return cudaStreamAddCallback(
78                (cudaStream_t) ptr,
79                cuda_ffi_Callback,
80                user_data,
81                0
82            );
83        });
84
85        result!(ret)
86    }
87
88    /// Get readonly reference to internal [`DevicePtr`].
89    #[inline(always)]
90    pub fn as_internal(&self) -> &DevicePtr {
91        &self.internal
92    }
93
94    /// Get readonly reference to internal [`DevicePtr`].
95    #[inline(always)]
96    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
97        &mut self.internal
98    }
99}
100
101impl Drop for Stream {
102    fn drop(&mut self) {
103        if self.internal.is_null() {
104            return;
105        }
106
107        // SAFETY: This will cause `self` to hold a null pointer. It is safe here because we don't
108        // use the object after this.
109        let mut internal = unsafe { self.internal.take() };
110        let ptr = internal.as_mut_ptr();
111
112        // SAFETY: We must synchronize the stream before destroying it to make sure we are not
113        // dropping a stream that still has operations pending.
114        let _ret = cpp!(unsafe [
115            ptr as "void*"
116        ] -> i32 as "std::int32_t" {
117            return cudaStreamSynchronize((cudaStream_t) ptr);
118        });
119
120        let _ret = cpp!(unsafe [
121            ptr as "void*"
122        ] -> i32 as "std::int32_t" {
123            return cudaStreamDestroy((cudaStream_t) ptr);
124        });
125    }
126}
127
128cpp! {{
129    /// Holds the C++ code that makes up the native part required to get our CUDA callback to work
130    /// over the FFI.
131    ///
132    /// # Arguments
133    ///
134    /// * `stream` - The CUDA stream on which the callback was scheduled.
135    /// * `status` - The CUDA status value (this could represent an error from an earlier async CUDA
136    ///   call).
137    /// * `user_data` - The user data pointer provided when adding the callback.
138    ///
139    /// # Example
140    ///
141    /// It can be used like so:
142    ///
143    /// ```cpp
144    /// return cudaStreamAddCallback(
145    ///     stream,
146    ///     cuda_ffi_Callback,
147    ///     user_data,
148    ///     0
149    /// );
150    /// ```
151    static void cuda_ffi_Callback(
152      __attribute__((unused)) cudaStream_t stream,
153      cudaError_t status,
154      void* user_data
155    ) {
156        rust!(cuda_ffi_Callback_internal [
157            status : i32 as "std::int32_t",
158            user_data : *mut std::ffi::c_void as "void*"
159        ] {
160            // SAFETY: We boxed the closure ourselves and did `Box::into_raw`, which allows us to
161            // reinstate the box here and use it accordingly. It will be dropped here after use.
162            unsafe {
163                let user_data = std::mem::transmute(user_data);
164                let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
165                function()
166            }
167        });
168    }
169}}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_new() {
177        assert!(Stream::new().is_ok());
178    }
179
180    #[test]
181    fn test_synchronize() {
182        let stream = Stream::new().unwrap();
183        assert!(stream.synchronize().is_ok());
184    }
185
186    #[test]
187    fn test_synchronize_null_stream() {
188        let stream = Stream::null();
189        assert!(stream.synchronize().is_ok());
190    }
191}