async_cuda_npp/
stream.rs

1use std::sync::Arc;
2
3use async_cuda_core::runtime::Future;
4
5use crate::ffi::context::Context;
6
7/// Represents an NPP stream.
8///
9/// An NPP stream is a thin wrapper around a normal CUDA stream ([`async_cuda_core::Stream`]). It
10/// manages some additional context information required in NPP to statelessly execute on a
11/// user-provided stream.
12///
13/// This struct implements `Deref` such that it can be used as a normal [`async_cuda_core::Stream`]
14/// as well.
15///
16/// # Usage
17///
18/// If the caller wants to use a stream context for mixed NPP and non-NPP operations, they should
19/// create an NPP stream and pass it as CUDA stream when desired. This should work out-of-the-box
20/// since [`Stream`] dereferences to [`async_cuda_core::Stream`].
21pub struct Stream {
22    context: Arc<Context>,
23}
24
25impl Stream {
26    /// Create an NPP [`Stream`] that represent the default stream, also known as the null stream.
27    ///
28    /// This type is a wrapper around the actual CUDA stream type: [`async_cuda_core::Stream`].
29    #[inline]
30    pub async fn null() -> Self {
31        let context = Future::new(Context::from_null_stream).await;
32        Self {
33            context: Arc::new(context),
34        }
35    }
36
37    /// Create a new [`Stream`] for use with NPP.
38    ///
39    /// This type is a wrapper around the actual CUDA stream type: [`async_cuda_core::Stream`].
40    #[inline]
41    pub async fn new() -> std::result::Result<Self, async_cuda_core::Error> {
42        let stream = async_cuda_core::Stream::new().await?;
43        let context = Future::new(move || Context::from_stream(stream)).await;
44        Ok(Self {
45            context: Arc::new(context),
46        })
47    }
48
49    /// Acquire shared access to the underlying NPP context object.
50    ///
51    /// This NPP object can be safetly sent to the runtime thread so it can be used as a context.
52    ///
53    /// # Safety
54    ///
55    /// The [`Context`] object may only be *used* from the runtime thread.
56    pub(crate) fn to_context(&self) -> Arc<Context> {
57        self.context.clone()
58    }
59}
60
61impl std::ops::Deref for Stream {
62    type Target = async_cuda_core::Stream;
63
64    fn deref(&self) -> &Self::Target {
65        &self.context.stream
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    #[tokio::test]
74    async fn test_new() {
75        let stream = Stream::new().await.unwrap();
76        assert!(!stream.to_context().as_ptr().is_null());
77        // SAFETY: This works because we know that the first field of the underlying
78        // `NppStreamContext` struct used internally is `hStream`, which should refer to the wrapped
79        // stream or it was not initalized correctly.
80        assert_eq!(
81            unsafe { *(stream.to_context().as_ptr() as *const *const std::ffi::c_void) },
82            stream.inner().as_internal().as_ptr(),
83        );
84    }
85
86    #[tokio::test]
87    async fn test_null() {
88        let stream = Stream::null().await;
89        assert!(!stream.to_context().as_ptr().is_null());
90        // SAFETY: This works because we know that the first field of the underlying
91        // `NppStreamContext` struct used internally is `hStream`, which should refer to the wrapped
92        // stream, which is the null stream in this case.
93        assert!(
94            unsafe { *(stream.to_context().as_ptr() as *const *const std::ffi::c_void) }.is_null()
95        );
96    }
97}