baracuda_runtime/
stream.rs1use std::sync::Arc;
4
5use baracuda_cuda_sys::runtime::{cudaStream_t, runtime, types::cudaStreamFlags};
6
7use crate::device::Device;
8use crate::error::{check, Result};
9
10#[derive(Clone)]
12pub struct Stream {
13 inner: Arc<StreamInner>,
14}
15
16struct StreamInner {
17 handle: cudaStream_t,
18 device: Device,
19}
20
21unsafe impl Send for StreamInner {}
22unsafe impl Sync for StreamInner {}
23
24impl core::fmt::Debug for StreamInner {
25 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26 f.debug_struct("Stream")
27 .field("handle", &self.handle)
28 .field("device", &self.device)
29 .finish()
30 }
31}
32
33impl core::fmt::Debug for Stream {
34 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
35 self.inner.fmt(f)
36 }
37}
38
39impl Stream {
40 pub fn new() -> Result<Self> {
43 Self::with_flags(cudaStreamFlags::DEFAULT)
44 }
45
46 pub fn non_blocking() -> Result<Self> {
49 Self::with_flags(cudaStreamFlags::NON_BLOCKING)
50 }
51
52 pub unsafe fn from_raw(handle: cudaStream_t) -> Self {
60 let device = Device::current().unwrap_or(Device::from_ordinal(0));
61 Self {
62 inner: Arc::new(StreamInner { handle, device }),
63 }
64 }
65
66 pub fn with_flags(flags: u32) -> Result<Self> {
68 let r = runtime()?;
69 let cu = r.cuda_stream_create_with_flags()?;
70 let mut stream: cudaStream_t = core::ptr::null_mut();
71 check(unsafe { cu(&mut stream, flags) })?;
72 let device = Device::current()?;
73 Ok(Self {
74 inner: Arc::new(StreamInner {
75 handle: stream,
76 device,
77 }),
78 })
79 }
80
81 pub fn synchronize(&self) -> Result<()> {
83 let r = runtime()?;
84 let cu = r.cuda_stream_synchronize()?;
85 check(unsafe { cu(self.inner.handle) })
86 }
87
88 pub fn is_complete(&self) -> Result<bool> {
90 use baracuda_cuda_sys::runtime::cudaError_t;
91 let r = runtime()?;
92 let cu = r.cuda_stream_query()?;
93 match unsafe { cu(self.inner.handle) } {
94 cudaError_t::Success => Ok(true),
95 cudaError_t::NotReady => Ok(false),
96 other => Err(crate::error::Error::Status { status: other }),
97 }
98 }
99
100 #[inline]
102 pub fn device(&self) -> Device {
103 self.inner.device
104 }
105
106 #[inline]
108 pub fn as_raw(&self) -> cudaStream_t {
109 self.inner.handle
110 }
111
112 pub fn with_priority(flags: u32, priority: i32) -> Result<Self> {
116 let r = runtime()?;
117 let cu = r.cuda_stream_create_with_priority()?;
118 let mut stream: cudaStream_t = core::ptr::null_mut();
119 check(unsafe { cu(&mut stream, flags, priority) })?;
120 let device = Device::current()?;
121 Ok(Self {
122 inner: Arc::new(StreamInner {
123 handle: stream,
124 device,
125 }),
126 })
127 }
128
129 pub fn priority(&self) -> Result<i32> {
131 let r = runtime()?;
132 let cu = r.cuda_stream_get_priority()?;
133 let mut p: core::ffi::c_int = 0;
134 check(unsafe { cu(self.inner.handle, &mut p) })?;
135 Ok(p)
136 }
137
138 pub fn flags(&self) -> Result<u32> {
140 let r = runtime()?;
141 let cu = r.cuda_stream_get_flags()?;
142 let mut f: core::ffi::c_uint = 0;
143 check(unsafe { cu(self.inner.handle, &mut f) })?;
144 Ok(f)
145 }
146
147 pub fn wait_event(&self, event: &crate::Event, flags: u32) -> Result<()> {
150 let r = runtime()?;
151 let cu = r.cuda_stream_wait_event()?;
152 check(unsafe { cu(self.inner.handle, event.as_raw(), flags) })
153 }
154}
155
156pub fn stream_priority_range() -> Result<(i32, i32)> {
159 let r = runtime()?;
160 let cu = r.cuda_device_get_stream_priority_range()?;
161 let mut low: core::ffi::c_int = 0;
162 let mut high: core::ffi::c_int = 0;
163 check(unsafe { cu(&mut low, &mut high) })?;
164 Ok((low, high))
165}
166
167impl Stream {
168 pub fn launch_host_func<F>(&self, f: F) -> Result<()>
174 where
175 F: FnOnce() + Send + 'static,
176 {
177 use core::ffi::c_void;
178
179 let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(f));
180 let raw = Box::into_raw(boxed) as *mut c_void;
181
182 unsafe extern "C" fn trampoline(user_data: *mut c_void) {
183 let f: Box<Box<dyn FnOnce() + Send>> =
184 unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
185 (*f)();
186 }
187
188 let r = runtime()?;
189 let cu = r.cuda_launch_host_func()?;
190 let rc = unsafe { cu(self.inner.handle, Some(trampoline), raw) };
191 if rc != baracuda_cuda_sys::runtime::cudaError_t::Success {
192 drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
194 return Err(crate::error::Error::Status { status: rc });
195 }
196 Ok(())
197 }
198
199 pub unsafe fn write_value_32(
205 &self,
206 addr: *mut core::ffi::c_void,
207 value: u32,
208 flags: u32,
209 ) -> Result<()> { unsafe {
210 let r = runtime()?;
211 let cu = r.cuda_stream_write_value_32()?;
212 check(cu(self.inner.handle, addr, value, flags))
213 }}
214
215 pub unsafe fn write_value_64(
219 &self,
220 addr: *mut core::ffi::c_void,
221 value: u64,
222 flags: u32,
223 ) -> Result<()> { unsafe {
224 let r = runtime()?;
225 let cu = r.cuda_stream_write_value_64()?;
226 check(cu(self.inner.handle, addr, value, flags))
227 }}
228
229 pub unsafe fn wait_value_32(
237 &self,
238 addr: *mut core::ffi::c_void,
239 value: u32,
240 flags: u32,
241 ) -> Result<()> { unsafe {
242 let r = runtime()?;
243 let cu = r.cuda_stream_wait_value_32()?;
244 check(cu(self.inner.handle, addr, value, flags))
245 }}
246
247 pub unsafe fn wait_value_64(
251 &self,
252 addr: *mut core::ffi::c_void,
253 value: u64,
254 flags: u32,
255 ) -> Result<()> { unsafe {
256 let r = runtime()?;
257 let cu = r.cuda_stream_wait_value_64()?;
258 check(cu(self.inner.handle, addr, value, flags))
259 }}
260
261 pub unsafe fn attach_mem_async(
268 &self,
269 dev_ptr: *mut core::ffi::c_void,
270 length: usize,
271 flags: u32,
272 ) -> Result<()> { unsafe {
273 let r = runtime()?;
274 let cu = r.cuda_stream_attach_mem_async()?;
275 check(cu(self.inner.handle, dev_ptr, length, flags))
276 }}
277
278 pub fn copy_attributes_from(&self, src: &Stream) -> Result<()> {
281 let r = runtime()?;
282 let cu = r.cuda_stream_copy_attributes()?;
283 check(unsafe { cu(self.inner.handle, src.inner.handle) })
284 }
285
286 pub unsafe fn batch_mem_op(
297 &self,
298 params: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
299 flags: u32,
300 ) -> Result<()> { unsafe {
301 let r = runtime()?;
302 let cu = r.cuda_stream_batch_mem_op()?;
303 check(cu(
304 self.inner.handle,
305 params.len() as core::ffi::c_uint,
306 params.as_mut_ptr(),
307 flags,
308 ))
309 }}
310}
311
312impl Drop for StreamInner {
313 fn drop(&mut self) {
314 if let Ok(r) = runtime() {
315 if let Ok(cu) = r.cuda_stream_destroy() {
316 let _ = unsafe { cu(self.handle) };
317 }
318 }
319 }
320}