fil_rustacuda/stream.rs
1//! Streams of work for the device to perform.
2//!
3//! In CUDA, most work is performed asynchronously. Even tasks such as memory copying can be
4//! scheduled by the host and performed when ready. Scheduling this work is done using a Stream.
5//!
6//! A stream is required for all asynchronous tasks in CUDA, such as kernel launches and
7//! asynchronous memory copying. Each task in a stream is performed in the order it was scheduled,
8//! and tasks within a stream cannot overlap. Tasks scheduled in multiple streams may interleave or
9//! execute concurrently. Sequencing between multiple streams can be achieved using events, which
10//! are not currently supported by RustaCUDA. Finally, the host can wait for all work scheduled in
11//! a stream to be completed.
12
13use crate::error::{CudaResult, DropResult, ToResult};
14use crate::event::Event;
15use crate::function::{BlockSize, Function, GridSize};
16use cuda_driver_sys::{cudaError_enum, CUstream};
17use std::ffi::c_void;
18use std::mem;
19use std::panic;
20use std::ptr;
21
22bitflags! {
23 /// Bit flags for configuring a CUDA Stream.
24 pub struct StreamFlags: u32 {
25 /// No flags set.
26 const DEFAULT = 0x00;
27
28 /// This stream does not synchronize with the NULL stream.
29 ///
30 /// Note that the name is chosen to correspond to CUDA documentation, but is nevertheless
31 /// misleading. All work within a single stream is ordered and asynchronous regardless
32 /// of whether this flag is set. All streams in RustaCUDA may execute work concurrently,
33 /// regardless of the flag. However, for legacy reasons, CUDA has a notion of a NULL stream,
34 /// which is used as the default when no other stream is provided. Work on other streams
35 /// may not be executed concurrently with work on the NULL stream unless this flag is set.
36 /// Since RustaCUDA does not provide access to the NULL stream, this flag has no effect in
37 /// most circumstances. However, it is recommended to use it anyway, as some other crate
38 /// in this binary may be using the NULL stream directly.
39 const NON_BLOCKING = 0x01;
40 }
41}
42
43bitflags! {
44 /// Bit flags for configuring a CUDA Stream waiting on an CUDA Event.
45 ///
46 /// Current versions of CUDA support only the default flag.
47 pub struct StreamWaitEventFlags: u32 {
48 /// No flags set.
49 const DEFAULT = 0x0;
50 }
51}
52
53/// A stream of work for the device to perform.
54///
55/// See the module-level documentation for more information.
56#[derive(Debug)]
57pub struct Stream {
58 inner: CUstream,
59}
60impl Stream {
61 /// Create a new stream with the given flags and optional priority.
62 ///
63 /// By convention, `priority` follows a convention where lower numbers represent greater
64 /// priorities. That is, work in a stream with a lower priority number may pre-empt work in
65 /// a stream with a higher priority number. `Context::get_stream_priority_range` can be used
66 /// to get the range of valid priority values; if priority is set outside that range, it will
67 /// be automatically clamped to the lowest or highest number in the range.
68 ///
69 /// # Examples
70 ///
71 /// ```
72 /// # use rustacuda::*;
73 /// # use std::error::Error;
74 /// # fn main() -> Result<(), Box<dyn Error>> {
75 /// # let _ctx = quick_init()?;
76 /// use rustacuda::stream::{Stream, StreamFlags};
77 ///
78 /// // With default priority
79 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
80 ///
81 /// // With specific priority
82 /// let priority = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
83 /// # Ok(())
84 /// # }
85 /// ```
86 pub fn new(flags: StreamFlags, priority: Option<i32>) -> CudaResult<Self> {
87 unsafe {
88 let mut stream = Stream {
89 inner: ptr::null_mut(),
90 };
91 cuda_driver_sys::cuStreamCreateWithPriority(
92 &mut stream.inner as *mut CUstream,
93 flags.bits(),
94 priority.unwrap_or(0),
95 )
96 .to_result()?;
97 Ok(stream)
98 }
99 }
100
101 /// Return the flags which were used to create this stream.
102 ///
103 /// # Examples
104 ///
105 /// ```
106 /// # use rustacuda::*;
107 /// # use std::error::Error;
108 /// # fn main() -> Result<(), Box<dyn Error>> {
109 /// # let _ctx = quick_init()?;
110 /// use rustacuda::stream::{Stream, StreamFlags};
111 ///
112 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
113 /// assert_eq!(StreamFlags::NON_BLOCKING, stream.get_flags().unwrap());
114 /// # Ok(())
115 /// # }
116 /// ```
117 pub fn get_flags(&self) -> CudaResult<StreamFlags> {
118 unsafe {
119 let mut bits = 0u32;
120 cuda_driver_sys::cuStreamGetFlags(self.inner, &mut bits as *mut u32).to_result()?;
121 Ok(StreamFlags::from_bits_truncate(bits))
122 }
123 }
124
125 /// Return the priority of this stream.
126 ///
127 /// If this stream was created without a priority, returns the default priority.
128 /// If the stream was created with a priority outside the valid range, returns the clamped
129 /// priority.
130 ///
131 /// # Examples
132 ///
133 /// ```
134 /// # use rustacuda::*;
135 /// # use std::error::Error;
136 /// # fn main() -> Result<(), Box<dyn Error>> {
137 /// # let _ctx = quick_init()?;
138 /// use rustacuda::stream::{Stream, StreamFlags};
139 ///
140 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
141 /// println!("{}", stream.get_priority()?);
142 /// # Ok(())
143 /// # }
144 /// ```
145 pub fn get_priority(&self) -> CudaResult<i32> {
146 unsafe {
147 let mut priority = 0i32;
148 cuda_driver_sys::cuStreamGetPriority(self.inner, &mut priority as *mut i32)
149 .to_result()?;
150 Ok(priority)
151 }
152 }
153
154 /// Add a callback to a stream.
155 ///
156 /// The callback will be executed after all previously queued
157 /// items in the stream have been completed. Subsequently queued
158 /// items will not execute until the callback is finished.
159 ///
160 /// Callbacks must not make any CUDA API calls.
161 ///
162 /// The callback will be passed a `CudaResult<()>` indicating the
163 /// current state of the device with `Ok(())` denoting normal operation.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// # use rustacuda::*;
169 /// # use std::error::Error;
170 /// # fn main() -> Result<(), Box<dyn Error>> {
171 /// # let _ctx = quick_init()?;
172 /// use rustacuda::stream::{Stream, StreamFlags};
173 ///
174 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
175 ///
176 /// // ... queue up some work on the stream
177 ///
178 /// stream.add_callback(Box::new(|status| {
179 /// println!("Device status is {:?}", status);
180 /// }));
181 ///
182 /// // ... queue up some more work on the stream
183 /// # Ok(())
184 /// # }
185 pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
186 where
187 T: FnOnce(CudaResult<()>) + Send,
188 {
189 unsafe {
190 cuda_driver_sys::cuStreamAddCallback(
191 self.inner,
192 Some(callback_wrapper::<T>),
193 Box::into_raw(callback) as *mut c_void,
194 0,
195 )
196 .to_result()
197 }
198 }
199
200 /// Wait until a stream's tasks are completed.
201 ///
202 /// Waits until the device has completed all operations scheduled for this stream.
203 ///
204 /// # Examples
205 ///
206 /// ```
207 /// # use rustacuda::*;
208 /// # use std::error::Error;
209 /// # fn main() -> Result<(), Box<dyn Error>> {
210 /// # let _ctx = quick_init()?;
211 /// use rustacuda::stream::{Stream, StreamFlags};
212 ///
213 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
214 ///
215 /// // ... queue up some work on the stream
216 ///
217 /// // Wait for the work to be completed.
218 /// stream.synchronize()?;
219 /// # Ok(())
220 /// # }
221 /// ```
222 pub fn synchronize(&self) -> CudaResult<()> {
223 unsafe { cuda_driver_sys::cuStreamSynchronize(self.inner).to_result() }
224 }
225
226 /// Make the stream wait on an event.
227 ///
228 /// All future work submitted to the stream will wait for the event to
229 /// complete. Synchronization is performed on the device, if possible. The
230 /// event may originate from different context or device than the stream.
231 ///
232 /// # Example
233 ///
234 /// ```
235 /// # use rustacuda::quick_init;
236 /// # use std::error::Error;
237 /// # fn main() -> Result<(), Box<dyn Error>> {
238 /// # let _context = quick_init()?;
239 /// use rustacuda::stream::{Stream, StreamFlags, StreamWaitEventFlags};
240 /// use rustacuda::event::{Event, EventFlags};
241 ///
242 /// let stream_0 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
243 /// let stream_1 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
244 /// let event = Event::new(EventFlags::DEFAULT)?;
245 ///
246 /// // do some work on stream_0 ...
247 ///
248 /// // record an event
249 /// event.record(&stream_0)?;
250 ///
251 /// // wait until the work on stream_0 is finished before continuing stream_1
252 /// stream_1.wait_event(event, StreamWaitEventFlags::DEFAULT)?;
253 /// # Ok(())
254 /// }
255 /// ```
256 pub fn wait_event(&self, event: Event, flags: StreamWaitEventFlags) -> CudaResult<()> {
257 unsafe {
258 cuda_driver_sys::cuStreamWaitEvent(self.inner, event.as_inner(), flags.bits())
259 .to_result()
260 }
261 }
262
263 // Hidden implementation detail function. Highly unsafe. Use the `launch!` macro instead.
264 #[doc(hidden)]
265 pub unsafe fn launch<G, B>(
266 &self,
267 func: &Function,
268 grid_size: G,
269 block_size: B,
270 shared_mem_bytes: u32,
271 args: &[*mut c_void],
272 ) -> CudaResult<()>
273 where
274 G: Into<GridSize>,
275 B: Into<BlockSize>,
276 {
277 let grid_size: GridSize = grid_size.into();
278 let block_size: BlockSize = block_size.into();
279
280 cuda_driver_sys::cuLaunchKernel(
281 func.to_inner(),
282 grid_size.x,
283 grid_size.y,
284 grid_size.z,
285 block_size.x,
286 block_size.y,
287 block_size.z,
288 shared_mem_bytes,
289 self.inner,
290 args.as_ptr() as *mut _,
291 ptr::null_mut(),
292 )
293 .to_result()
294 }
295
296 // Get the inner `CUstream` from the `Stream`.
297 //
298 // Necessary for certain CUDA functions outside of this
299 // module that expect a bare `CUstream`.
300 pub(crate) fn as_inner(&self) -> CUstream {
301 self.inner
302 }
303
304 /// Destroy a `Stream`, returning an error.
305 ///
306 /// Destroying a stream can return errors from previous asynchronous work. This function
307 /// destroys the given stream and returns the error and the un-destroyed stream on failure.
308 ///
309 /// # Example
310 ///
311 /// ```
312 /// # use rustacuda::*;
313 /// # use std::error::Error;
314 /// # fn main() -> Result<(), Box<dyn Error>> {
315 /// # let _ctx = quick_init()?;
316 /// use rustacuda::stream::{Stream, StreamFlags};
317 ///
318 /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
319 /// match Stream::drop(stream) {
320 /// Ok(()) => println!("Successfully destroyed"),
321 /// Err((e, stream)) => {
322 /// println!("Failed to destroy stream: {:?}", e);
323 /// // Do something with stream
324 /// },
325 /// }
326 /// # Ok(())
327 /// # }
328 /// ```
329 pub fn drop(mut stream: Stream) -> DropResult<Stream> {
330 if stream.inner.is_null() {
331 return Ok(());
332 }
333
334 unsafe {
335 let inner = mem::replace(&mut stream.inner, ptr::null_mut());
336 match cuda_driver_sys::cuStreamDestroy_v2(inner).to_result() {
337 Ok(()) => {
338 mem::forget(stream);
339 Ok(())
340 }
341 Err(e) => Err((e, Stream { inner })),
342 }
343 }
344 }
345}
346impl Drop for Stream {
347 fn drop(&mut self) {
348 if self.inner.is_null() {
349 return;
350 }
351
352 unsafe {
353 let inner = mem::replace(&mut self.inner, ptr::null_mut());
354 // No choice but to panic here.
355 cuda_driver_sys::cuStreamDestroy_v2(inner)
356 .to_result()
357 .expect("Failed to destroy CUDA stream.");
358 }
359 }
360}
361unsafe extern "C" fn callback_wrapper<T>(
362 _stream: CUstream,
363 status: cudaError_enum,
364 callback: *mut c_void,
365) where
366 T: FnOnce(CudaResult<()>) + Send,
367{
368 // Stop panics from unwinding across the FFI
369 let _ = panic::catch_unwind(|| {
370 let callback: Box<T> = Box::from_raw(callback as *mut T);
371 callback(status.to_result());
372 });
373}