use std::{mem, ptr};
use singe_cuda_sys::{driver, runtime};
use crate::{
error::{Error, Result},
try_cuda,
};
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpcMemoryFlags: u32 {
const LAZY_ENABLE_PEER_ACCESS = driver::CUipcMem_flags::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS as _;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpcEventHandle(runtime::cudaIpcEventHandle_t);
impl IpcEventHandle {
pub const unsafe fn from_raw(handle: runtime::cudaIpcEventHandle_t) -> Self {
Self(handle)
}
pub const fn zeroed() -> Self {
unsafe { mem::zeroed() }
}
pub const unsafe fn as_raw(&self) -> runtime::cudaIpcEventHandle_t {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpcMemoryHandle(runtime::cudaIpcMemHandle_t);
impl IpcMemoryHandle {
pub const unsafe fn from_raw(handle: runtime::cudaIpcMemHandle_t) -> Self {
Self(handle)
}
pub const fn zeroed() -> Self {
unsafe { mem::zeroed() }
}
pub const unsafe fn as_raw(&self) -> runtime::cudaIpcMemHandle_t {
self.0
}
pub fn open<T>(self, flags: IpcMemoryFlags) -> Result<OpenedIpcMemory<T>> {
let mut dev_ptr = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaIpcOpenMemHandle(
&raw mut dev_ptr,
self.as_raw(),
flags.bits()
))?;
}
if dev_ptr.is_null() {
return Err(Error::NullHandle);
}
Ok(OpenedIpcMemory {
ptr: dev_ptr.cast(),
})
}
}
#[derive(Debug)]
pub struct OpenedIpcMemory<T> {
ptr: *mut T,
}
impl<T> OpenedIpcMemory<T> {
pub const fn as_ptr(&self) -> *mut T {
self.ptr
}
pub fn close(self) -> Result<()> {
let ptr: *mut () = self.ptr.cast();
mem::forget(self);
if ptr.is_null() {
return Ok(());
}
unsafe { try_cuda!(runtime::cudaIpcCloseMemHandle(ptr as _)) }
}
}
impl<T> Drop for OpenedIpcMemory<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
unsafe {
if let Err(err) = try_cuda!(runtime::cudaIpcCloseMemHandle(self.ptr.cast())) {
#[cfg(debug_assertions)]
eprintln!("failed to close cuda ipc memory handle: {err}");
}
}
}
}