1use std::sync::Arc;
4
5use baracuda_cuda_sys::types::CUstream_flags;
6use baracuda_cuda_sys::{driver, CUstream};
7
8use crate::context::Context;
9use crate::error::{check, Result};
10
11#[derive(Clone)]
18pub struct Stream {
19 inner: Arc<StreamInner>,
20}
21
22struct StreamInner {
23 handle: CUstream,
24 context: Context,
26}
27
28unsafe impl Send for StreamInner {}
30unsafe impl Sync for StreamInner {}
31
32impl core::fmt::Debug for StreamInner {
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 f.debug_struct("Stream")
35 .field("handle", &self.handle)
36 .finish_non_exhaustive()
37 }
38}
39
40impl core::fmt::Debug for Stream {
41 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42 self.inner.fmt(f)
43 }
44}
45
46impl Stream {
47 pub fn new(context: &Context) -> Result<Self> {
50 Self::with_flags(context, CUstream_flags::DEFAULT)
51 }
52
53 pub fn non_blocking(context: &Context) -> Result<Self> {
56 Self::with_flags(context, CUstream_flags::NON_BLOCKING)
57 }
58
59 pub fn with_flags(context: &Context, flags: u32) -> Result<Self> {
61 context.set_current()?;
62 let d = driver()?;
63 let cu = d.cu_stream_create()?;
64 let mut stream: CUstream = core::ptr::null_mut();
65 check(unsafe { cu(&mut stream, flags) })?;
67 Ok(Self {
68 inner: Arc::new(StreamInner {
69 handle: stream,
70 context: context.clone(),
71 }),
72 })
73 }
74
75 pub fn with_priority(context: &Context, flags: u32, priority: i32) -> Result<Self> {
79 context.set_current()?;
80 let d = driver()?;
81 let cu = d.cu_stream_create_with_priority()?;
82 let mut stream: CUstream = core::ptr::null_mut();
83 check(unsafe { cu(&mut stream, flags, priority) })?;
84 Ok(Self {
85 inner: Arc::new(StreamInner {
86 handle: stream,
87 context: context.clone(),
88 }),
89 })
90 }
91
92 pub fn priority(&self) -> Result<i32> {
94 let d = driver()?;
95 let cu = d.cu_stream_get_priority()?;
96 let mut p: core::ffi::c_int = 0;
97 check(unsafe { cu(self.inner.handle, &mut p) })?;
98 Ok(p)
99 }
100
101 pub fn flags(&self) -> Result<u32> {
103 let d = driver()?;
104 let cu = d.cu_stream_get_flags()?;
105 let mut f: core::ffi::c_uint = 0;
106 check(unsafe { cu(self.inner.handle, &mut f) })?;
107 Ok(f)
108 }
109
110 pub fn launch_host_func<F>(&self, f: F) -> Result<()>
117 where
118 F: FnOnce() + Send + 'static,
119 {
120 use core::ffi::c_void;
121
122 let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(f));
124 let raw = Box::into_raw(boxed) as *mut c_void;
125
126 unsafe extern "C" fn trampoline(user_data: *mut c_void) {
127 let f: Box<Box<dyn FnOnce() + Send>> =
129 unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
130 (*f)();
131 }
132
133 let d = driver()?;
134 let cu = d.cu_launch_host_func()?;
135 let rc = unsafe { cu(self.inner.handle, Some(trampoline), raw) };
137 if rc != baracuda_cuda_sys::CUresult::SUCCESS {
138 drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
141 return Err(crate::error::Error::Status { status: rc });
142 }
143 Ok(())
144 }
145
146 pub fn synchronize(&self) -> Result<()> {
149 let d = driver()?;
150 let cu = d.cu_stream_synchronize()?;
151 check(unsafe { cu(self.inner.handle) })
152 }
153
154 pub fn is_complete(&self) -> Result<bool> {
157 use baracuda_cuda_sys::CUresult;
158 let d = driver()?;
159 let cu = d.cu_stream_query()?;
160 let res = unsafe { cu(self.inner.handle) };
161 match res {
162 CUresult::SUCCESS => Ok(true),
163 CUresult::ERROR_NOT_READY => Ok(false),
164 other => Err(crate::error::Error::Status { status: other }),
165 }
166 }
167
168 #[inline]
170 pub fn context(&self) -> &Context {
171 &self.inner.context
172 }
173
174 #[inline]
176 pub fn as_raw(&self) -> CUstream {
177 self.inner.handle
178 }
179
180 pub fn memcpy_dtod<T: baracuda_types::DeviceRepr>(
199 &self,
200 src: &crate::memory::DeviceBuffer<T>,
201 dst: &mut crate::memory::DeviceBuffer<T>,
202 ) -> Result<()> {
203 src.copy_to_device_async(dst, self)
204 }
205
206 pub fn id(&self) -> Result<u64> {
209 let d = driver()?;
210 let cu = d.cu_stream_get_id()?;
211 let mut out: u64 = 0;
212 check(unsafe { cu(self.inner.handle, &mut out) })?;
213 Ok(out)
214 }
215
216 pub fn copy_attributes_from(&self, src: &Stream) -> Result<()> {
220 let d = driver()?;
221 let cu = d.cu_stream_copy_attributes()?;
222 check(unsafe { cu(self.inner.handle, src.inner.handle) })
223 }
224
225 pub fn wait_event(&self, event: &crate::Event, flags: u32) -> Result<()> {
231 let d = driver()?;
232 let cu = d.cu_stream_wait_event()?;
233 check(unsafe { cu(self.inner.handle, event.as_raw(), flags) })
234 }
235
236 pub unsafe fn get_attribute(
246 &self,
247 attr: i32,
248 value_out: *mut core::ffi::c_void,
249 ) -> Result<()> { unsafe {
250 let d = driver()?;
251 let cu = d.cu_stream_get_attribute()?;
252 check(cu(self.inner.handle, attr, value_out))
253 }}
254
255 pub unsafe fn set_attribute(
263 &self,
264 attr: i32,
265 value: *const core::ffi::c_void,
266 ) -> Result<()> { unsafe {
267 let d = driver()?;
268 let cu = d.cu_stream_set_attribute()?;
269 check(cu(self.inner.handle, attr, value))
270 }}
271
272 pub fn attach_mem_async(
275 &self,
276 dptr: baracuda_cuda_sys::CUdeviceptr,
277 length: usize,
278 flags: u32,
279 ) -> Result<()> {
280 let d = driver()?;
281 let cu = d.cu_stream_attach_mem_async()?;
282 check(unsafe { cu(self.inner.handle, dptr, length, flags) })
283 }
284
285 pub fn write_value_32(
291 &self,
292 addr: baracuda_cuda_sys::CUdeviceptr,
293 value: u32,
294 flags: u32,
295 ) -> Result<()> {
296 let d = driver()?;
297 let cu = d.cu_stream_write_value_32()?;
298 check(unsafe { cu(self.inner.handle, addr, value, flags) })
299 }
300
301 pub fn write_value_64(
302 &self,
303 addr: baracuda_cuda_sys::CUdeviceptr,
304 value: u64,
305 flags: u32,
306 ) -> Result<()> {
307 let d = driver()?;
308 let cu = d.cu_stream_write_value_64()?;
309 check(unsafe { cu(self.inner.handle, addr, value, flags) })
310 }
311
312 pub fn wait_value_32(
317 &self,
318 addr: baracuda_cuda_sys::CUdeviceptr,
319 value: u32,
320 flags: u32,
321 ) -> Result<()> {
322 let d = driver()?;
323 let cu = d.cu_stream_wait_value_32()?;
324 check(unsafe { cu(self.inner.handle, addr, value, flags) })
325 }
326
327 pub fn wait_value_64(
328 &self,
329 addr: baracuda_cuda_sys::CUdeviceptr,
330 value: u64,
331 flags: u32,
332 ) -> Result<()> {
333 let d = driver()?;
334 let cu = d.cu_stream_wait_value_64()?;
335 check(unsafe { cu(self.inner.handle, addr, value, flags) })
336 }
337
338 pub fn batch_mem_op(
343 &self,
344 ops: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
345 flags: u32,
346 ) -> Result<()> {
347 let d = driver()?;
348 let cu = d.cu_stream_batch_mem_op()?;
349 check(unsafe {
350 cu(
351 self.inner.handle,
352 ops.len() as core::ffi::c_uint,
353 ops.as_mut_ptr(),
354 flags,
355 )
356 })
357 }
358
359 pub fn capture_info(&self) -> Result<(bool, u64, baracuda_cuda_sys::CUgraph)> {
363 let d = driver()?;
364 let cu = d.cu_stream_get_capture_info()?;
365 let mut status: core::ffi::c_int = 0;
366 let mut id: u64 = 0;
367 let mut graph: baracuda_cuda_sys::CUgraph = core::ptr::null_mut();
368 let mut deps_ptr: *const baracuda_cuda_sys::CUgraphNode = core::ptr::null();
369 let mut num_deps: usize = 0;
370 check(unsafe {
371 cu(
372 self.inner.handle,
373 &mut status,
374 &mut id,
375 &mut graph,
376 &mut deps_ptr,
377 &mut num_deps,
378 )
379 })?;
380 Ok((status == 1, id, graph))
382 }
383}
384
385impl Drop for StreamInner {
386 fn drop(&mut self) {
387 if let Ok(d) = driver() {
388 if let Ok(cu) = d.cu_stream_destroy() {
389 let _ = unsafe { cu(self.handle) };
391 }
392 }
393 }
394}