#![cfg(feature = "cuda-ipc")]
use cudarc::driver::sys as driver_sys;
use crate::error::GpuError;
use crate::sys::cuda_driver;
#[derive(Clone, Copy)]
pub struct IpcMemHandle {
pub(crate) raw: driver_sys::CUipcMemHandle,
}
impl IpcMemHandle {
pub fn as_bytes(&self) -> [u8; 64] {
unsafe { std::mem::transmute::<[std::ffi::c_char; 64], [u8; 64]>(self.raw.reserved) }
}
pub fn from_bytes(bytes: [u8; 64]) -> Self {
let raw = driver_sys::CUipcMemHandle_st {
reserved: unsafe { std::mem::transmute::<[u8; 64], [std::ffi::c_char; 64]>(bytes) },
};
Self { raw }
}
}
impl std::fmt::Debug for IpcMemHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IpcMemHandle").finish()
}
}
unsafe impl Send for IpcMemHandle {}
unsafe impl Sync for IpcMemHandle {}
#[derive(Debug)]
pub struct OpenedMem {
dev_ptr: driver_sys::CUdeviceptr,
bytes: usize,
}
impl OpenedMem {
pub fn dev_ptr(&self) -> driver_sys::CUdeviceptr {
self.dev_ptr
}
pub fn bytes(&self) -> usize {
self.bytes
}
}
impl Drop for OpenedMem {
fn drop(&mut self) {
if self.dev_ptr != 0 {
let _ = cuda_driver::ipc_close_mem_handle(self.dev_ptr);
}
}
}
unsafe impl Send for OpenedMem {}
unsafe impl Sync for OpenedMem {}
pub fn get_mem_handle(dev_ptr: driver_sys::CUdeviceptr) -> Result<IpcMemHandle, GpuError> {
cuda_driver::ipc_get_mem_handle(dev_ptr).map(|raw| IpcMemHandle { raw })
}
pub fn open_mem_handle(handle: IpcMemHandle, bytes: usize) -> Result<OpenedMem, GpuError> {
let dev_ptr = cuda_driver::ipc_open_mem_handle_v2(
handle.raw, 1,
)?;
Ok(OpenedMem { dev_ptr, bytes })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handle_round_trip() {
let bytes: [u8; 64] = std::array::from_fn(|i| (i * 3) as u8 ^ 0x55);
let h = IpcMemHandle::from_bytes(bytes);
let round = h.as_bytes();
assert_eq!(round, bytes);
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<IpcMemHandle>();
assert_send_sync::<OpenedMem>();
}
#[test]
fn open_returns_typed_error_on_no_driver() {
let h = IpcMemHandle::from_bytes([0u8; 64]);
let r = open_mem_handle(h, 0);
match r {
Ok(_) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
}
}