async_cuda_core/ffi/
stream.rs1use cpp::cpp;
2
3use crate::ffi::ptr::DevicePtr;
4use crate::ffi::result;
5
6type Result<T> = std::result::Result<T, crate::error::Error>;
7
8pub struct Stream {
12 internal: DevicePtr,
13}
14
15unsafe impl Send for Stream {}
21
22unsafe impl Sync for Stream {}
28
29impl Stream {
30 pub fn null() -> Self {
31 Self {
32 internal: unsafe { DevicePtr::null() },
35 }
36 }
37
38 pub fn new() -> Result<Self> {
39 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41
42 let ret = cpp!(unsafe [
43 ptr_ptr as "void**"
44 ] -> i32 as "std::int32_t" {
45 return cudaStreamCreate((cudaStream_t*) ptr_ptr);
46 });
47 result!(
48 ret,
49 Stream {
50 internal: ptr.into()
51 }
52 )
53 }
54
55 pub fn synchronize(&self) -> Result<()> {
56 let ptr = self.internal.as_ptr();
57 let ret = cpp!(unsafe [
58 ptr as "void*"
59 ] -> i32 as "std::int32_t" {
60 return cudaStreamSynchronize((cudaStream_t) ptr);
61 });
62 result!(ret)
63 }
64
65 pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
66 let ptr = self.internal.as_ptr();
67
68 let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
69 let f_boxed2 = Box::new(f_boxed);
70 let f_boxed2_ptr = Box::into_raw(f_boxed2);
71 let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
72
73 let ret = cpp!(unsafe [
74 ptr as "void*",
75 user_data as "void*"
76 ] -> i32 as "std::int32_t" {
77 return cudaStreamAddCallback(
78 (cudaStream_t) ptr,
79 cuda_ffi_Callback,
80 user_data,
81 0
82 );
83 });
84
85 result!(ret)
86 }
87
88 #[inline(always)]
90 pub fn as_internal(&self) -> &DevicePtr {
91 &self.internal
92 }
93
94 #[inline(always)]
96 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
97 &mut self.internal
98 }
99}
100
101impl Drop for Stream {
102 fn drop(&mut self) {
103 if self.internal.is_null() {
104 return;
105 }
106
107 let mut internal = unsafe { self.internal.take() };
110 let ptr = internal.as_mut_ptr();
111
112 let _ret = cpp!(unsafe [
115 ptr as "void*"
116 ] -> i32 as "std::int32_t" {
117 return cudaStreamSynchronize((cudaStream_t) ptr);
118 });
119
120 let _ret = cpp!(unsafe [
121 ptr as "void*"
122 ] -> i32 as "std::int32_t" {
123 return cudaStreamDestroy((cudaStream_t) ptr);
124 });
125 }
126}
127
128cpp! {{
129 static void cuda_ffi_Callback(
152 __attribute__((unused)) cudaStream_t stream,
153 cudaError_t status,
154 void* user_data
155 ) {
156 rust!(cuda_ffi_Callback_internal [
157 status : i32 as "std::int32_t",
158 user_data : *mut std::ffi::c_void as "void*"
159 ] {
160 unsafe {
163 let user_data = std::mem::transmute(user_data);
164 let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
165 function()
166 }
167 });
168 }
169}}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_new() {
177 assert!(Stream::new().is_ok());
178 }
179
180 #[test]
181 fn test_synchronize() {
182 let stream = Stream::new().unwrap();
183 assert!(stream.synchronize().is_ok());
184 }
185
186 #[test]
187 fn test_synchronize_null_stream() {
188 let stream = Stream::null();
189 assert!(stream.synchronize().is_ok());
190 }
191}