async_cuda_npp/ffi/context.rs
1use cpp::cpp;
2
3use async_cuda_core::Stream;
4
5use crate::ffi::result;
6
7/// NPP stream context structure.
8///
9/// [NPP documentation](https://docs.nvidia.com/cuda/npp/struct_npp_stream_context.html)
10pub struct Context {
11 raw: *mut std::ffi::c_void,
12 pub stream: Stream,
13}
14
15/// Implements [`Send`] for [`Context`].
16///
17/// # Safety
18///
19/// This is safe because the way we use the underlying `NppStreamContext` object is thread-safe.
20unsafe impl Send for Context {}
21
22/// Implements [`Sync`] for [`Context`].
23///
24/// # Safety
25///
26/// This is safe because the way we use the underlying `NppStreamContext` object is thread-safe.
27unsafe impl Sync for Context {}
28
29impl Context {
30 /// Create context on null stream.
31 ///
32 /// This creates a context that can be passed to NPP functions. Any functions using this context
33 /// will be executed on the null stream.
34 pub fn from_null_stream() -> Self {
35 let mut raw = std::ptr::null_mut();
36 let raw_ptr = std::ptr::addr_of_mut!(raw);
37 // SAFETY:
38 // * Must call this function on runtime since `nppGetStreamContext` needs the correct thread
39 // locals to determine current device and other context settings.
40 // * We can store a reference to the stream in `NppStreamContext` as long as we make sure
41 // `NppStreamContext` cannot outlive the stream, which we can guarantee because we take
42 // ownership of the stream.
43 let ret = cpp!(unsafe [
44 raw_ptr as "void**"
45 ] -> i32 as "std::int32_t" {
46 NppStreamContext* stream_context = new NppStreamContext();
47 NppStatus ret = nppGetStreamContext(stream_context);
48 if (ret == NPP_SUCCESS) {
49 stream_context->hStream = nullptr;
50 *raw_ptr = (void*) stream_context;
51 }
52 return ret;
53 });
54 match result!(ret) {
55 Ok(()) => Self {
56 raw,
57 stream: Stream::null(),
58 },
59 Err(err) => {
60 panic!("failed to get current NPP stream context: {err}")
61 }
62 }
63 }
64
65 /// Create context.
66 ///
67 /// This creates an NPP context object. It can be passed to NPP functions, and they will execute
68 /// on the associated stream.
69 ///
70 /// # Arguments
71 ///
72 /// * `stream` - Stream to associate with context.
73 pub fn from_stream(stream: Stream) -> Self {
74 let (ret, raw) = {
75 let mut raw = std::ptr::null_mut();
76 let raw_ptr = std::ptr::addr_of_mut!(raw);
77 let stream_ptr = stream.inner().as_internal().as_ptr();
78 // SAFETY:
79 // * Must call this function on runtime since `nppGetStreamContext` needs the correct
80 // thread locals to determine current device and other context settings.
81 // * We can store a reference to the stream in `NppStreamContext` as long as we make
82 // sure `NppStreamContext` cannot outlive the stream, which we can guarantee because
83 // we take ownership of the stream.
84 let ret = cpp!(unsafe [
85 raw_ptr as "void**",
86 stream_ptr as "void*"
87 ] -> i32 as "std::int32_t" {
88 NppStreamContext* stream_context = new NppStreamContext();
89 NppStatus ret = nppGetStreamContext(stream_context);
90 if (ret == NPP_SUCCESS) {
91 stream_context->hStream = (cudaStream_t) stream_ptr;
92 *raw_ptr = (void*) stream_context;
93 }
94 return ret;
95 });
96 (ret, raw)
97 };
98 match result!(ret) {
99 Ok(()) => Self { raw, stream },
100 Err(err) => {
101 panic!("failed to get current NPP stream context: {err}")
102 }
103 }
104 }
105
106 /// Get internal readonly pointer.
107 #[inline]
108 pub(crate) fn as_ptr(&self) -> *const std::ffi::c_void {
109 self.raw
110 }
111}
112
113impl Drop for Context {
114 fn drop(&mut self) {
115 let raw = self.raw;
116 cpp!(unsafe [raw as "void*"] {
117 delete ((NppStreamContext*) raw);
118 });
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[tokio::test]
127 async fn test_from_stream() {
128 let stream = Stream::new().await.unwrap();
129 let context = Context::from_stream(stream);
130 assert!(!context.as_ptr().is_null());
131 assert!(!context.stream.inner().as_internal().as_ptr().is_null());
132 }
133
134 #[test]
135 fn test_from_null_stream() {
136 let context = Context::from_null_stream();
137 assert!(!context.as_ptr().is_null());
138 assert!(context.stream.inner().as_internal().as_ptr().is_null());
139 }
140}