async_cuda_npp/ffi/
context.rs

1use cpp::cpp;
2
3use async_cuda_core::Stream;
4
5use crate::ffi::result;
6
7/// NPP stream context structure.
8///
9/// [NPP documentation](https://docs.nvidia.com/cuda/npp/struct_npp_stream_context.html)
10pub struct Context {
11    raw: *mut std::ffi::c_void,
12    pub stream: Stream,
13}
14
15/// Implements [`Send`] for [`Context`].
16///
17/// # Safety
18///
19/// This is safe because the way we use the underlying `NppStreamContext` object is thread-safe.
20unsafe impl Send for Context {}
21
22/// Implements [`Sync`] for [`Context`].
23///
24/// # Safety
25///
26/// This is safe because the way we use the underlying `NppStreamContext` object is thread-safe.
27unsafe impl Sync for Context {}
28
29impl Context {
30    /// Create context on null stream.
31    ///
32    /// This creates a context that can be passed to NPP functions. Any functions using this context
33    /// will be executed on the null stream.
34    pub fn from_null_stream() -> Self {
35        let mut raw = std::ptr::null_mut();
36        let raw_ptr = std::ptr::addr_of_mut!(raw);
37        // SAFETY:
38        // * Must call this function on runtime since `nppGetStreamContext` needs the correct thread
39        //   locals to determine current device and other context settings.
40        // * We can store a reference to the stream in `NppStreamContext` as long as we make sure
41        //   `NppStreamContext` cannot outlive the stream, which we can guarantee because we take
42        //   ownership of the stream.
43        let ret = cpp!(unsafe [
44            raw_ptr as "void**"
45        ] -> i32 as "std::int32_t" {
46            NppStreamContext* stream_context = new NppStreamContext();
47            NppStatus ret = nppGetStreamContext(stream_context);
48            if (ret == NPP_SUCCESS) {
49                stream_context->hStream = nullptr;
50                *raw_ptr = (void*) stream_context;
51            }
52            return ret;
53        });
54        match result!(ret) {
55            Ok(()) => Self {
56                raw,
57                stream: Stream::null(),
58            },
59            Err(err) => {
60                panic!("failed to get current NPP stream context: {err}")
61            }
62        }
63    }
64
65    /// Create context.
66    ///
67    /// This creates an NPP context object. It can be passed to NPP functions, and they will execute
68    /// on the associated stream.
69    ///
70    /// # Arguments
71    ///
72    /// * `stream` - Stream to associate with context.
73    pub fn from_stream(stream: Stream) -> Self {
74        let (ret, raw) = {
75            let mut raw = std::ptr::null_mut();
76            let raw_ptr = std::ptr::addr_of_mut!(raw);
77            let stream_ptr = stream.inner().as_internal().as_ptr();
78            // SAFETY:
79            // * Must call this function on runtime since `nppGetStreamContext` needs the correct
80            //   thread locals to determine current device and other context settings.
81            // * We can store a reference to the stream in `NppStreamContext` as long as we make
82            //   sure `NppStreamContext` cannot outlive the stream, which we can guarantee because
83            //   we take ownership of the stream.
84            let ret = cpp!(unsafe [
85                raw_ptr as "void**",
86                stream_ptr as "void*"
87            ] -> i32 as "std::int32_t" {
88                NppStreamContext* stream_context = new NppStreamContext();
89                NppStatus ret = nppGetStreamContext(stream_context);
90                if (ret == NPP_SUCCESS) {
91                    stream_context->hStream = (cudaStream_t) stream_ptr;
92                    *raw_ptr = (void*) stream_context;
93                }
94                return ret;
95            });
96            (ret, raw)
97        };
98        match result!(ret) {
99            Ok(()) => Self { raw, stream },
100            Err(err) => {
101                panic!("failed to get current NPP stream context: {err}")
102            }
103        }
104    }
105
106    /// Get internal readonly pointer.
107    #[inline]
108    pub(crate) fn as_ptr(&self) -> *const std::ffi::c_void {
109        self.raw
110    }
111}
112
113impl Drop for Context {
114    fn drop(&mut self) {
115        let raw = self.raw;
116        cpp!(unsafe [raw as "void*"] {
117            delete ((NppStreamContext*) raw);
118        });
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[tokio::test]
127    async fn test_from_stream() {
128        let stream = Stream::new().await.unwrap();
129        let context = Context::from_stream(stream);
130        assert!(!context.as_ptr().is_null());
131        assert!(!context.stream.inner().as_internal().as_ptr().is_null());
132    }
133
134    #[test]
135    fn test_from_null_stream() {
136        let context = Context::from_null_stream();
137        assert!(!context.as_ptr().is_null());
138        assert!(context.stream.inner().as_internal().as_ptr().is_null());
139    }
140}