rustacuda/
stream.rs

1//! Streams of work for the device to perform.
2//!
3//! In CUDA, most work is performed asynchronously. Even tasks such as memory copying can be
4//! scheduled by the host and performed when ready. Scheduling this work is done using a Stream.
5//!
6//! A stream is required for all asynchronous tasks in CUDA, such as kernel launches and
7//! asynchronous memory copying. Each task in a stream is performed in the order it was scheduled,
8//! and tasks within a stream cannot overlap. Tasks scheduled in multiple streams may interleave or
9//! execute concurrently. Sequencing between multiple streams can be achieved using events, which
10//! are not currently supported by RustaCUDA. Finally, the host can wait for all work scheduled in
11//! a stream to be completed.
12
13use crate::error::{CudaResult, DropResult, ToResult};
14use crate::event::Event;
15use crate::function::{BlockSize, Function, GridSize};
16use cuda_driver_sys::{cudaError_enum, CUstream};
17use std::ffi::c_void;
18use std::mem;
19use std::panic;
20use std::ptr;
21
22bitflags! {
23    /// Bit flags for configuring a CUDA Stream.
24    pub struct StreamFlags: u32 {
25        /// No flags set.
26        const DEFAULT = 0x00;
27
28        /// This stream does not synchronize with the NULL stream.
29        ///
30        /// Note that the name is chosen to correspond to CUDA documentation, but is nevertheless
31        /// misleading. All work within a single stream is ordered and asynchronous regardless
32        /// of whether this flag is set. All streams in RustaCUDA may execute work concurrently,
33        /// regardless of the flag. However, for legacy reasons, CUDA has a notion of a NULL stream,
34        /// which is used as the default when no other stream is provided. Work on other streams
35        /// may not be executed concurrently with work on the NULL stream unless this flag is set.
36        /// Since RustaCUDA does not provide access to the NULL stream, this flag has no effect in
37        /// most circumstances. However, it is recommended to use it anyway, as some other crate
38        /// in this binary may be using the NULL stream directly.
39        const NON_BLOCKING = 0x01;
40    }
41}
42
43bitflags! {
44    /// Bit flags for configuring a CUDA Stream waiting on an CUDA Event.
45    ///
46    /// Current versions of CUDA support only the default flag.
47    pub struct StreamWaitEventFlags: u32 {
48        /// No flags set.
49        const DEFAULT = 0x0;
50    }
51}
52
53/// A stream of work for the device to perform.
54///
55/// See the module-level documentation for more information.
56#[derive(Debug)]
57pub struct Stream {
58    inner: CUstream,
59}
60impl Stream {
61    /// Create a new stream with the given flags and optional priority.
62    ///
63    /// By convention, `priority` follows a convention where lower numbers represent greater
64    /// priorities. That is, work in a stream with a lower priority number may pre-empt work in
65    /// a stream with a higher priority number. `Context::get_stream_priority_range` can be used
66    /// to get the range of valid priority values; if priority is set outside that range, it will
67    /// be automatically clamped to the lowest or highest number in the range.
68    ///
69    /// # Examples
70    ///
71    /// ```
72    /// # use rustacuda::*;
73    /// # use std::error::Error;
74    /// # fn main() -> Result<(), Box<dyn Error>> {
75    /// # let _ctx = quick_init()?;
76    /// use rustacuda::stream::{Stream, StreamFlags};
77    ///
78    /// // With default priority
79    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
80    ///
81    /// // With specific priority
82    /// let priority = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
83    /// # Ok(())
84    /// # }
85    /// ```
86    pub fn new(flags: StreamFlags, priority: Option<i32>) -> CudaResult<Self> {
87        unsafe {
88            let mut stream = Stream {
89                inner: ptr::null_mut(),
90            };
91            cuda_driver_sys::cuStreamCreateWithPriority(
92                &mut stream.inner as *mut CUstream,
93                flags.bits(),
94                priority.unwrap_or(0),
95            )
96            .to_result()?;
97            Ok(stream)
98        }
99    }
100
101    /// Return the flags which were used to create this stream.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// # use rustacuda::*;
107    /// # use std::error::Error;
108    /// # fn main() -> Result<(), Box<dyn Error>> {
109    /// # let _ctx = quick_init()?;
110    /// use rustacuda::stream::{Stream, StreamFlags};
111    ///
112    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
113    /// assert_eq!(StreamFlags::NON_BLOCKING, stream.get_flags().unwrap());
114    /// # Ok(())
115    /// # }
116    /// ```
117    pub fn get_flags(&self) -> CudaResult<StreamFlags> {
118        unsafe {
119            let mut bits = 0u32;
120            cuda_driver_sys::cuStreamGetFlags(self.inner, &mut bits as *mut u32).to_result()?;
121            Ok(StreamFlags::from_bits_truncate(bits))
122        }
123    }
124
125    /// Return the priority of this stream.
126    ///
127    /// If this stream was created without a priority, returns the default priority.
128    /// If the stream was created with a priority outside the valid range, returns the clamped
129    /// priority.
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// # use rustacuda::*;
135    /// # use std::error::Error;
136    /// # fn main() -> Result<(), Box<dyn Error>> {
137    /// # let _ctx = quick_init()?;
138    /// use rustacuda::stream::{Stream, StreamFlags};
139    ///
140    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
141    /// println!("{}", stream.get_priority()?);
142    /// # Ok(())
143    /// # }
144    /// ```
145    pub fn get_priority(&self) -> CudaResult<i32> {
146        unsafe {
147            let mut priority = 0i32;
148            cuda_driver_sys::cuStreamGetPriority(self.inner, &mut priority as *mut i32)
149                .to_result()?;
150            Ok(priority)
151        }
152    }
153
154    /// Add a callback to a stream.
155    ///
156    /// The callback will be executed after all previously queued
157    /// items in the stream have been completed. Subsequently queued
158    /// items will not execute until the callback is finished.
159    ///
160    /// Callbacks must not make any CUDA API calls.
161    ///
162    /// The callback will be passed a `CudaResult<()>` indicating the
163    /// current state of the device with `Ok(())` denoting normal operation.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// # use rustacuda::*;
169    /// # use std::error::Error;
170    /// # fn main() -> Result<(), Box<dyn Error>> {
171    /// # let _ctx = quick_init()?;
172    /// use rustacuda::stream::{Stream, StreamFlags};
173    ///
174    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
175    ///
176    /// // ... queue up some work on the stream
177    ///
178    /// stream.add_callback(Box::new(|status| {
179    ///     println!("Device status is {:?}", status);
180    /// }));
181    ///
182    /// // ... queue up some more work on the stream
183    /// # Ok(())
184    /// # }
185    pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
186    where
187        T: FnOnce(CudaResult<()>) + Send,
188    {
189        unsafe {
190            cuda_driver_sys::cuStreamAddCallback(
191                self.inner,
192                Some(callback_wrapper::<T>),
193                Box::into_raw(callback) as *mut c_void,
194                0,
195            )
196            .to_result()
197        }
198    }
199
200    /// Wait until a stream's tasks are completed.
201    ///
202    /// Waits until the device has completed all operations scheduled for this stream.
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// # use rustacuda::*;
208    /// # use std::error::Error;
209    /// # fn main() -> Result<(), Box<dyn Error>> {
210    /// # let _ctx = quick_init()?;
211    /// use rustacuda::stream::{Stream, StreamFlags};
212    ///
213    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
214    ///
215    /// // ... queue up some work on the stream
216    ///
217    /// // Wait for the work to be completed.
218    /// stream.synchronize()?;
219    /// # Ok(())
220    /// # }
221    /// ```
222    pub fn synchronize(&self) -> CudaResult<()> {
223        unsafe { cuda_driver_sys::cuStreamSynchronize(self.inner).to_result() }
224    }
225
226    /// Make the stream wait on an event.
227    ///
228    /// All future work submitted to the stream will wait for the event to
229    /// complete. Synchronization is performed on the device, if possible. The
230    /// event may originate from different context or device than the stream.
231    ///
232    /// # Example
233    ///
234    /// ```
235    /// # use rustacuda::quick_init;
236    /// # use std::error::Error;
237    /// # fn main() -> Result<(), Box<dyn Error>> {
238    /// # let _context = quick_init()?;
239    /// use rustacuda::stream::{Stream, StreamFlags, StreamWaitEventFlags};
240    /// use rustacuda::event::{Event, EventFlags};
241    ///
242    /// let stream_0 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
243    /// let stream_1 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
244    /// let event = Event::new(EventFlags::DEFAULT)?;
245    ///
246    /// // do some work on stream_0 ...
247    ///
248    /// // record an event
249    /// event.record(&stream_0)?;
250    ///
251    /// // wait until the work on stream_0 is finished before continuing stream_1
252    /// stream_1.wait_event(event, StreamWaitEventFlags::DEFAULT)?;
253    /// # Ok(())
254    /// }
255    /// ```
256    pub fn wait_event(&self, event: Event, flags: StreamWaitEventFlags) -> CudaResult<()> {
257        unsafe {
258            cuda_driver_sys::cuStreamWaitEvent(self.inner, event.as_inner(), flags.bits())
259                .to_result()
260        }
261    }
262
263    // Hidden implementation detail function. Highly unsafe. Use the `launch!` macro instead.
264    #[doc(hidden)]
265    pub unsafe fn launch<G, B>(
266        &self,
267        func: &Function,
268        grid_size: G,
269        block_size: B,
270        shared_mem_bytes: u32,
271        args: &[*mut c_void],
272    ) -> CudaResult<()>
273    where
274        G: Into<GridSize>,
275        B: Into<BlockSize>,
276    {
277        let grid_size: GridSize = grid_size.into();
278        let block_size: BlockSize = block_size.into();
279
280        cuda_driver_sys::cuLaunchKernel(
281            func.to_inner(),
282            grid_size.x,
283            grid_size.y,
284            grid_size.z,
285            block_size.x,
286            block_size.y,
287            block_size.z,
288            shared_mem_bytes,
289            self.inner,
290            args.as_ptr() as *mut _,
291            ptr::null_mut(),
292        )
293        .to_result()
294    }
295
296    // Get the inner `CUstream` from the `Stream`.
297    //
298    // Necessary for certain CUDA functions outside of this
299    // module that expect a bare `CUstream`.
300    pub(crate) fn as_inner(&self) -> CUstream {
301        self.inner
302    }
303
304    /// Destroy a `Stream`, returning an error.
305    ///
306    /// Destroying a stream can return errors from previous asynchronous work. This function
307    /// destroys the given stream and returns the error and the un-destroyed stream on failure.
308    ///
309    /// # Example
310    ///
311    /// ```
312    /// # use rustacuda::*;
313    /// # use std::error::Error;
314    /// # fn main() -> Result<(), Box<dyn Error>> {
315    /// # let _ctx = quick_init()?;
316    /// use rustacuda::stream::{Stream, StreamFlags};
317    ///
318    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
319    /// match Stream::drop(stream) {
320    ///     Ok(()) => println!("Successfully destroyed"),
321    ///     Err((e, stream)) => {
322    ///         println!("Failed to destroy stream: {:?}", e);
323    ///         // Do something with stream
324    ///     },
325    /// }
326    /// # Ok(())
327    /// # }
328    /// ```
329    pub fn drop(mut stream: Stream) -> DropResult<Stream> {
330        if stream.inner.is_null() {
331            return Ok(());
332        }
333
334        unsafe {
335            let inner = mem::replace(&mut stream.inner, ptr::null_mut());
336            match cuda_driver_sys::cuStreamDestroy_v2(inner).to_result() {
337                Ok(()) => {
338                    mem::forget(stream);
339                    Ok(())
340                }
341                Err(e) => Err((e, Stream { inner })),
342            }
343        }
344    }
345}
346impl Drop for Stream {
347    fn drop(&mut self) {
348        if self.inner.is_null() {
349            return;
350        }
351
352        unsafe {
353            let inner = mem::replace(&mut self.inner, ptr::null_mut());
354            // No choice but to panic here.
355            cuda_driver_sys::cuStreamDestroy_v2(inner)
356                .to_result()
357                .expect("Failed to destroy CUDA stream.");
358        }
359    }
360}
361unsafe extern "C" fn callback_wrapper<T>(
362    _stream: CUstream,
363    status: cudaError_enum,
364    callback: *mut c_void,
365) where
366    T: FnOnce(CudaResult<()>) + Send,
367{
368    // Stop panics from unwinding across the FFI
369    let _ = panic::catch_unwind(|| {
370        let callback: Box<T> = Box::from_raw(callback as *mut T);
371        callback(status.to_result());
372    });
373}