async_cuda/ffi/
stream.rs

1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::ptr::DevicePtr;
6use crate::ffi::result;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10/// Synchronous implementation of [`crate::Stream`].
11///
12/// Refer to [`crate::Stream`] for documentation.
13pub struct Stream {
14    internal: DevicePtr,
15    device: DeviceId,
16}
17
18/// Implements [`Send`] for [`Stream`].
19///
20/// # Safety
21///
22/// This property is inherited from the CUDA API, which is thread-safe.
23unsafe impl Send for Stream {}
24
25/// Implements [`Sync`] for [`Stream`].
26///
27/// # Safety
28///
29/// This property is inherited from the CUDA API, which is thread-safe.
30unsafe impl Sync for Stream {}
31
32impl Stream {
33    pub fn null() -> Self {
34        Self {
35            // SAFETY: This is safe because a null pointer for stream indicates the default
36            // stream in CUDA and all functions accept this.
37            internal: unsafe { DevicePtr::null() },
38            device: Device::get_or_panic(),
39        }
40    }
41
42    pub fn new() -> Result<Self> {
43        let device = Device::get()?;
44        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
45        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
46        let ret = cpp!(unsafe [
47            ptr_ptr as "void**"
48        ] -> i32 as "std::int32_t" {
49            return cudaStreamCreate((cudaStream_t*) ptr_ptr);
50        });
51        result!(
52            ret,
53            Stream {
54                internal: DevicePtr::from_addr(ptr),
55                device,
56            }
57        )
58    }
59
60    pub fn synchronize(&self) -> Result<()> {
61        Device::set(self.device)?;
62        let ptr = self.internal.as_ptr();
63        let ret = cpp!(unsafe [
64            ptr as "void*"
65        ] -> i32 as "std::int32_t" {
66            return cudaStreamSynchronize((cudaStream_t) ptr);
67        });
68        result!(ret)
69    }
70
71    pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
72        Device::set(self.device)?;
73        let ptr = self.internal.as_ptr();
74        let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
75        let f_boxed2 = Box::new(f_boxed);
76        let f_boxed2_ptr = Box::into_raw(f_boxed2);
77        let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
78        let ret = cpp!(unsafe [
79            ptr as "void*",
80            user_data as "void*"
81        ] -> i32 as "std::int32_t" {
82            return cudaStreamAddCallback(
83                (cudaStream_t) ptr,
84                cuda_ffi_Callback,
85                user_data,
86                0
87            );
88        });
89        result!(ret)
90    }
91
92    /// Get readonly reference to internal [`DevicePtr`].
93    #[inline(always)]
94    pub fn as_internal(&self) -> &DevicePtr {
95        &self.internal
96    }
97
98    /// Get mutable reference to internal [`DevicePtr`].
99    #[inline(always)]
100    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
101        &mut self.internal
102    }
103
104    /// Get corresponding device as [`DeviceId`].
105    #[inline(always)]
106    pub fn device(&self) -> DeviceId {
107        self.device
108    }
109
110    /// Destroy stream.
111    ///
112    /// # Panics
113    ///
114    /// This function panics if binding to the corresponding device fails.
115    ///
116    /// # Safety
117    ///
118    /// The object may not be used after this function is called, except for being dropped.
119    pub unsafe fn destroy(&mut self) {
120        if self.internal.is_null() {
121            return;
122        }
123
124        Device::set_or_panic(self.device);
125
126        // SAFETY: This will cause `self` to hold a null pointer. It is safe here because we don't
127        // use the object after this.
128        let mut internal = unsafe { self.internal.take() };
129        let ptr = internal.as_mut_ptr();
130
131        // SAFETY: We must synchronize the stream before destroying it to make sure we are not
132        // dropping a stream that still has operations pending.
133        let _ret = cpp!(unsafe [
134            ptr as "void*"
135        ] -> i32 as "std::int32_t" {
136            return cudaStreamSynchronize((cudaStream_t) ptr);
137        });
138
139        let _ret = cpp!(unsafe [
140            ptr as "void*"
141        ] -> i32 as "std::int32_t" {
142            return cudaStreamDestroy((cudaStream_t) ptr);
143        });
144    }
145}
146
147impl Drop for Stream {
148    #[inline]
149    fn drop(&mut self) {
150        // SAFETY: This is safe since the object cannot be used after this.
151        unsafe {
152            self.destroy();
153        }
154    }
155}
156
157cpp! {{
158    /// Holds the C++ code that makes up the native part required to get our CUDA callback to work
159    /// over the FFI.
160    ///
161    /// # Arguments
162    ///
163    /// * `stream` - The CUDA stream on which the callback was scheduled.
164    /// * `status` - The CUDA status value (this could represent an error from an earlier async CUDA
165    ///   call).
166    /// * `user_data` - The user data pointer provided when adding the callback.
167    ///
168    /// # Example
169    ///
170    /// It can be used like so:
171    ///
172    /// ```cpp
173    /// return cudaStreamAddCallback(
174    ///     stream,
175    ///     cuda_ffi_Callback,
176    ///     user_data,
177    ///     0
178    /// );
179    /// ```
180    static void cuda_ffi_Callback(
181      __attribute__((unused)) cudaStream_t stream,
182      cudaError_t status,
183      void* user_data
184    ) {
185        rust!(cuda_ffi_Callback_internal [
186            status : i32 as "std::int32_t",
187            user_data : *mut std::ffi::c_void as "void*"
188        ] {
189            // SAFETY: We boxed the closure ourselves and did `Box::into_raw`, which allows us to
190            // reinstate the box here and use it accordingly. It will be dropped here after use.
191            unsafe {
192                let user_data = std::mem::transmute(user_data);
193                let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
194                function()
195            }
196        });
197    }
198}}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_new() {
206        assert!(Stream::new().is_ok());
207    }
208
209    #[test]
210    fn test_synchronize() {
211        let stream = Stream::new().unwrap();
212        assert!(stream.synchronize().is_ok());
213    }
214
215    #[test]
216    fn test_synchronize_null_stream() {
217        let stream = Stream::null();
218        assert!(stream.synchronize().is_ok());
219    }
220}