1use num_enum::TryFromPrimitive;
2use std::{ffi::c_void, marker::PhantomData, pin::Pin, ptr::null_mut, rc::Rc};
3
4use crate::*;
5
6pub struct Stream<'a> {
8 pub(crate) inner: *mut sys::CUstream_st,
9 pub(crate) pending_stores: Vec<Pin<Box<[u8]>>>,
10 _p: PhantomData<&'a ()>,
11}
12
13#[derive(Debug, Copy, Clone, TryFromPrimitive)]
15#[repr(u32)]
16pub enum WaitValueMode {
17 Geq = 0x0,
19 Eq = 0x1,
21 And = 0x2,
23 Nor = 0x3,
25}
26
27unsafe extern "C" fn host_callback(arg: *mut std::ffi::c_void) {
28 let closure: Box<Box<dyn FnOnce() + Send + Sync>> = Box::from_raw(arg as *mut _);
29 closure();
30}
31
32impl<'a> Stream<'a> {
33 pub fn new(_handle: &Rc<Handle<'a>>) -> CudaResult<Self> {
35 let mut out = null_mut();
36 cuda_error(unsafe {
37 sys::cuStreamCreate(
38 &mut out as *mut _,
39 sys::CUstream_flags_enum_CU_STREAM_NON_BLOCKING,
40 )
41 })?;
42 Ok(Self {
43 inner: out,
44 pending_stores: vec![],
45 _p: PhantomData,
46 })
47 }
48
49 pub fn sync(&mut self) -> CudaResult<()> {
51 cuda_error(unsafe { sys::cuStreamSynchronize(self.inner) })?;
52 self.pending_stores.clear();
53 Ok(())
54 }
55
56 pub fn is_synced(&self) -> CudaResult<bool> {
58 match cuda_error(unsafe { sys::cuStreamQuery(self.inner) }) {
59 Ok(()) => Ok(true),
60 Err(ErrorCode::NotReady) => Ok(false),
61 Err(e) => Err(e),
62 }
63 }
64
65 pub fn wait_32<'b>(
67 &'b mut self,
68 addr: &'b DevicePtr<'a>,
69 value: u32,
70 mode: WaitValueMode,
71 flush: bool,
72 ) -> CudaResult<()> {
73 if addr.len < 4 {
74 panic!("overflow in Stream::wait_32");
75 }
76 let flush = if flush { 1u32 << 30 } else { 0 };
77 cuda_error(unsafe {
78 sys::cuStreamWaitValue32(self.inner, addr.inner, value, mode as u32 | flush)
79 })
80 }
81
82 pub fn wait_64<'b>(
84 &mut self,
85 addr: &'b DevicePtr<'a>,
86 value: u64,
87 mode: WaitValueMode,
88 flush: bool,
89 ) -> CudaResult<()> {
90 if addr.len < 8 {
91 panic!("overflow in Stream::wait_64");
92 }
93 let flush = if flush { 1u32 << 30 } else { 0 };
94 cuda_error(unsafe {
95 sys::cuStreamWaitValue64(self.inner, addr.inner, value, mode as u32 | flush)
96 })
97 }
98
99 pub fn write_32<'b>(
101 &'b mut self,
102 addr: &'b DevicePtr<'a>,
103 value: u32,
104 no_memory_barrier: bool,
105 ) -> CudaResult<()> {
106 if addr.len < 4 {
107 panic!("overflow in Stream::write_32");
108 }
109 let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 };
110 cuda_error(unsafe {
111 sys::cuStreamWriteValue32(self.inner, addr.inner, value, no_memory_barrier)
112 })
113 }
114
115 pub fn write_64<'b>(
117 &'b mut self,
118 addr: &'b DevicePtr<'a>,
119 value: u64,
120 no_memory_barrier: bool,
121 ) -> CudaResult<()> {
122 if addr.len < 8 {
123 panic!("overflow in Stream::write_64");
124 }
125 let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 };
126 cuda_error(unsafe {
127 sys::cuStreamWriteValue64(self.inner, addr.inner, value, no_memory_barrier)
128 })
129 }
130
131 pub fn callback<F: FnOnce() + Send + Sync>(&mut self, callback: F) -> CudaResult<()> {
137 let callback: Box<Box<dyn FnOnce()>> = Box::new(Box::new(callback));
138 cuda_error(unsafe {
139 sys::cuLaunchHostFunc(
140 self.inner,
141 Some(host_callback),
142 Box::leak(callback) as *mut _ as *mut _,
143 )
144 })
145 }
146
147 pub unsafe fn launch<'b, D1: Into<Dim3>, D2: Into<Dim3>, K: KernelParameters>(
151 &mut self,
152 f: &Function<'a, 'b>,
153 grid_dim: D1,
154 block_dim: D2,
155 shared_mem_size: u32,
156 parameters: K,
157 ) -> CudaResult<()> {
158 let grid_dim = grid_dim.into().0;
159 let block_dim = block_dim.into().0;
160 let mut kernel_params = vec![];
161 parameters.params(&mut kernel_params);
162 let mut new_kernel_params = Vec::with_capacity(kernel_params.len());
163 for param in &kernel_params {
164 new_kernel_params.push(param.as_ptr() as *mut c_void);
165 }
166 cuda_error(sys::cuLaunchKernel(
167 f.inner,
168 grid_dim.0,
169 grid_dim.1,
170 grid_dim.2,
171 block_dim.0,
172 block_dim.1,
173 block_dim.2,
174 shared_mem_size,
175 self.inner,
176 new_kernel_params.as_mut_ptr(),
177 null_mut(),
178 ))
179 }
180}
181
182impl<'a> Drop for Stream<'a> {
183 fn drop(&mut self) {
184 if let Err(e) = cuda_error(unsafe { sys::cuStreamDestroy_v2(self.inner) }) {
185 eprintln!("CUDA: failed to drop stream: {:?}", e);
186 }
187 }
188}