use cpp::cpp;
use crate::ffi::device::Device;
use crate::ffi::npp::result;
use crate::stream::Stream;
pub struct Context {
raw: *mut std::ffi::c_void,
pub stream: Stream,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
pub fn from_null_stream() -> Self {
let mut raw = std::ptr::null_mut();
let raw_ptr = std::ptr::addr_of_mut!(raw);
let ret = cpp!(unsafe [
raw_ptr as "void**"
] -> i32 as "std::int32_t" {
NppStreamContext* stream_context = new NppStreamContext();
NppStatus ret = nppGetStreamContext(stream_context);
if (ret == NPP_SUCCESS) {
stream_context->hStream = nullptr;
*raw_ptr = (void*) stream_context;
}
return ret;
});
match result!(ret) {
Ok(()) => Self {
raw,
stream: Stream::null(),
},
Err(err) => {
panic!("failed to get current NPP stream context: {err}")
}
}
}
pub fn from_stream(stream: Stream) -> Self {
let (ret, raw) = {
let mut raw = std::ptr::null_mut();
let raw_ptr = std::ptr::addr_of_mut!(raw);
let stream_ptr = stream.inner().as_internal().as_ptr();
let device_id = stream.inner().device();
let ret = cpp!(unsafe [
raw_ptr as "void**",
stream_ptr as "void*",
device_id as "int"
] -> i32 as "std::int32_t" {
NppStreamContext* stream_context = new NppStreamContext();
NppStatus ret = nppGetStreamContext(stream_context);
if (ret == NPP_SUCCESS) {
stream_context->hStream = (cudaStream_t) stream_ptr;
stream_context->nCudaDeviceId = device_id;
*raw_ptr = (void*) stream_context;
}
return ret;
});
(ret, raw)
};
match result!(ret) {
Ok(()) => Self { raw, stream },
Err(err) => {
panic!("failed to get current NPP stream context: {err}")
}
}
}
#[inline]
pub(crate) fn as_ptr(&self) -> *const std::ffi::c_void {
self.raw
}
pub unsafe fn delete(&mut self) {
if self.raw.is_null() {
return;
}
Device::set_or_panic(self.stream.inner().device());
let raw = self.raw;
self.raw = std::ptr::null_mut();
cpp!(unsafe [raw as "void*"] {
delete ((NppStreamContext*) raw);
});
}
}
impl Drop for Context {
#[inline]
fn drop(&mut self) {
unsafe {
self.delete();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_from_stream() {
let stream = Stream::new().await.unwrap();
let context = Context::from_stream(stream);
assert!(!context.as_ptr().is_null());
assert!(!context.stream.inner().as_internal().as_ptr().is_null());
}
#[test]
fn test_from_null_stream() {
let context = Context::from_null_stream();
assert!(!context.as_ptr().is_null());
assert!(context.stream.inner().as_internal().as_ptr().is_null());
}
}