1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::ptr::DevicePtr;
6use crate::ffi::result;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10pub struct Stream {
14 internal: DevicePtr,
15 device: DeviceId,
16}
17
18unsafe impl Send for Stream {}
24
25unsafe impl Sync for Stream {}
31
32impl Stream {
33 pub fn null() -> Self {
34 Self {
35 internal: unsafe { DevicePtr::null() },
38 device: Device::get_or_panic(),
39 }
40 }
41
42 pub fn new() -> Result<Self> {
43 let device = Device::get()?;
44 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
45 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
46 let ret = cpp!(unsafe [
47 ptr_ptr as "void**"
48 ] -> i32 as "std::int32_t" {
49 return cudaStreamCreate((cudaStream_t*) ptr_ptr);
50 });
51 result!(
52 ret,
53 Stream {
54 internal: DevicePtr::from_addr(ptr),
55 device,
56 }
57 )
58 }
59
60 pub fn synchronize(&self) -> Result<()> {
61 Device::set(self.device)?;
62 let ptr = self.internal.as_ptr();
63 let ret = cpp!(unsafe [
64 ptr as "void*"
65 ] -> i32 as "std::int32_t" {
66 return cudaStreamSynchronize((cudaStream_t) ptr);
67 });
68 result!(ret)
69 }
70
71 pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
72 Device::set(self.device)?;
73 let ptr = self.internal.as_ptr();
74 let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
75 let f_boxed2 = Box::new(f_boxed);
76 let f_boxed2_ptr = Box::into_raw(f_boxed2);
77 let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
78 let ret = cpp!(unsafe [
79 ptr as "void*",
80 user_data as "void*"
81 ] -> i32 as "std::int32_t" {
82 return cudaStreamAddCallback(
83 (cudaStream_t) ptr,
84 cuda_ffi_Callback,
85 user_data,
86 0
87 );
88 });
89 result!(ret)
90 }
91
92 #[inline(always)]
94 pub fn as_internal(&self) -> &DevicePtr {
95 &self.internal
96 }
97
98 #[inline(always)]
100 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
101 &mut self.internal
102 }
103
104 #[inline(always)]
106 pub fn device(&self) -> DeviceId {
107 self.device
108 }
109
110 pub unsafe fn destroy(&mut self) {
120 if self.internal.is_null() {
121 return;
122 }
123
124 Device::set_or_panic(self.device);
125
126 let mut internal = unsafe { self.internal.take() };
129 let ptr = internal.as_mut_ptr();
130
131 let _ret = cpp!(unsafe [
134 ptr as "void*"
135 ] -> i32 as "std::int32_t" {
136 return cudaStreamSynchronize((cudaStream_t) ptr);
137 });
138
139 let _ret = cpp!(unsafe [
140 ptr as "void*"
141 ] -> i32 as "std::int32_t" {
142 return cudaStreamDestroy((cudaStream_t) ptr);
143 });
144 }
145}
146
147impl Drop for Stream {
148 #[inline]
149 fn drop(&mut self) {
150 unsafe {
152 self.destroy();
153 }
154 }
155}
156
157cpp! {{
158 static void cuda_ffi_Callback(
181 __attribute__((unused)) cudaStream_t stream,
182 cudaError_t status,
183 void* user_data
184 ) {
185 rust!(cuda_ffi_Callback_internal [
186 status : i32 as "std::int32_t",
187 user_data : *mut std::ffi::c_void as "void*"
188 ] {
189 unsafe {
192 let user_data = std::mem::transmute(user_data);
193 let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
194 function()
195 }
196 });
197 }
198}}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_new() {
206 assert!(Stream::new().is_ok());
207 }
208
209 #[test]
210 fn test_synchronize() {
211 let stream = Stream::new().unwrap();
212 assert!(stream.synchronize().is_ok());
213 }
214
215 #[test]
216 fn test_synchronize_null_stream() {
217 let stream = Stream::null();
218 assert!(stream.synchronize().is_ok());
219 }
220}