use std::fmt;
use std::ops::{BitAnd, BitOr, Deref, DerefMut};
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::ffi::{
CU_MEMHOSTREGISTER_DEVICEMAP, CU_MEMHOSTREGISTER_IOMEMORY, CU_MEMHOSTREGISTER_PORTABLE,
CU_MEMHOSTREGISTER_READ_ONLY, CUdeviceptr,
};
#[cfg(not(target_os = "macos"))]
use oxicuda_driver::ffi;
#[cfg(not(target_os = "macos"))]
use oxicuda_driver::loader::try_driver;
#[cfg(not(target_os = "macos"))]
use std::ffi::c_void;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct RegisterFlags(u32);
impl RegisterFlags {
pub const PORTABLE: Self = Self(CU_MEMHOSTREGISTER_PORTABLE);
pub const DEVICE_MAP: Self = Self(CU_MEMHOSTREGISTER_DEVICEMAP);
pub const IO_MEMORY: Self = Self(CU_MEMHOSTREGISTER_IOMEMORY);
pub const READ_ONLY: Self = Self(CU_MEMHOSTREGISTER_READ_ONLY);
pub const DEFAULT: Self = Self(CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP);
pub const NONE: Self = Self(0);
#[inline]
pub const fn bits(self) -> u32 {
self.0
}
#[inline]
pub const fn from_bits(bits: u32) -> Self {
Self(bits)
}
#[inline]
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
}
impl BitOr for RegisterFlags {
type Output = Self;
#[inline]
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}
impl BitAnd for RegisterFlags {
type Output = Self;
#[inline]
fn bitand(self, rhs: Self) -> Self {
Self(self.0 & rhs.0)
}
}
impl fmt::Display for RegisterFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut parts = Vec::new();
if self.contains(Self::PORTABLE) {
parts.push("PORTABLE");
}
if self.contains(Self::DEVICE_MAP) {
parts.push("DEVICE_MAP");
}
if self.contains(Self::IO_MEMORY) {
parts.push("IO_MEMORY");
}
if self.contains(Self::READ_ONLY) {
parts.push("READ_ONLY");
}
if parts.is_empty() {
write!(f, "NONE")
} else {
write!(f, "{}", parts.join(" | "))
}
}
}
pub struct RegisteredMemory<T: Copy> {
ptr: *mut T,
len: usize,
flags: RegisterFlags,
device_ptr: CUdeviceptr,
}
unsafe impl<T: Copy + Send> Send for RegisteredMemory<T> {}
unsafe impl<T: Copy + Sync> Sync for RegisteredMemory<T> {}
impl<T: Copy> RegisteredMemory<T> {
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
#[inline]
pub fn device_ptr(&self) -> CUdeviceptr {
self.device_ptr
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn flags(&self) -> RegisterFlags {
self.flags
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T: Copy> Deref for RegisteredMemory<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T: Copy> DerefMut for RegisteredMemory<T> {
#[inline]
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
impl<T: Copy> Drop for RegisteredMemory<T> {
fn drop(&mut self) {
#[cfg(not(target_os = "macos"))]
{
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_mem_host_unregister)(self.ptr.cast::<c_void>()) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
len = self.len,
"cuMemHostUnregister failed during RegisteredMemory drop"
);
}
}
}
}
}
pub fn register<T: Copy>(
ptr: *mut T,
len: usize,
flags: RegisterFlags,
) -> CudaResult<RegisteredMemory<T>> {
if len == 0 {
return Err(CudaError::InvalidValue);
}
if ptr.is_null() {
return Err(CudaError::InvalidValue);
}
let byte_size = len
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
#[cfg(target_os = "macos")]
{
let _ = byte_size;
Ok(RegisteredMemory {
ptr,
len,
flags,
device_ptr: ptr as CUdeviceptr,
})
}
#[cfg(not(target_os = "macos"))]
{
let api = try_driver()?;
let rc =
unsafe { (api.cu_mem_host_register_v2)(ptr.cast::<c_void>(), byte_size, flags.bits()) };
oxicuda_driver::check(rc)?;
let device_ptr = if flags.contains(RegisterFlags::DEVICE_MAP) {
let mut dptr: CUdeviceptr = 0;
let rc2 = unsafe {
(api.cu_mem_host_get_device_pointer_v2)(&mut dptr, ptr.cast::<c_void>(), 0)
};
oxicuda_driver::check(rc2)?;
dptr
} else {
0
};
Ok(RegisteredMemory {
ptr,
len,
flags,
device_ptr,
})
}
}
pub fn register_slice<T: Copy>(
slice: &mut [T],
flags: RegisterFlags,
) -> CudaResult<RegisteredMemory<T>> {
register(slice.as_mut_ptr(), slice.len(), flags)
}
pub fn register_vec<T: Copy>(
vec: &mut Vec<T>,
flags: RegisterFlags,
) -> CudaResult<RegisteredMemory<T>> {
register(vec.as_mut_ptr(), vec.len(), flags)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RegisteredMemoryType {
Host,
Device,
Unified,
Unregistered,
}
impl fmt::Display for RegisteredMemoryType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Host => write!(f, "Host"),
Self::Device => write!(f, "Device"),
Self::Unified => write!(f, "Unified"),
Self::Unregistered => write!(f, "Unregistered"),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct RegisteredPointerInfo {
pub device_ptr: CUdeviceptr,
pub is_managed: bool,
pub memory_type: RegisteredMemoryType,
}
pub fn query_registered_pointer_info(ptr: *const u8) -> CudaResult<RegisteredPointerInfo> {
if ptr.is_null() {
return Err(CudaError::InvalidValue);
}
#[cfg(target_os = "macos")]
{
Ok(RegisteredPointerInfo {
device_ptr: ptr as CUdeviceptr,
is_managed: false,
memory_type: RegisteredMemoryType::Host,
})
}
#[cfg(not(target_os = "macos"))]
{
let api = try_driver()?;
let dev_ptr_val = ptr as CUdeviceptr;
let mut mem_type: u32 = 0;
let rc = unsafe {
(api.cu_pointer_get_attribute)(
(&mut mem_type as *mut u32).cast::<c_void>(),
ffi::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
dev_ptr_val,
)
};
let memory_type = if rc != 0 {
RegisteredMemoryType::Unregistered
} else {
match mem_type {
ffi::CU_MEMORYTYPE_HOST => RegisteredMemoryType::Host,
ffi::CU_MEMORYTYPE_DEVICE => RegisteredMemoryType::Device,
ffi::CU_MEMORYTYPE_UNIFIED => RegisteredMemoryType::Unified,
_ => RegisteredMemoryType::Unregistered,
}
};
let mut managed: u32 = 0;
let rc2 = unsafe {
(api.cu_pointer_get_attribute)(
(&mut managed as *mut u32).cast::<c_void>(),
ffi::CU_POINTER_ATTRIBUTE_IS_MANAGED,
dev_ptr_val,
)
};
let is_managed = rc2 == 0 && managed != 0;
let mut dptr: CUdeviceptr = 0;
let rc3 = unsafe {
(api.cu_pointer_get_attribute)(
(&mut dptr as *mut CUdeviceptr).cast::<c_void>(),
ffi::CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
dev_ptr_val,
)
};
if rc3 != 0 {
dptr = 0;
}
Ok(RegisteredPointerInfo {
device_ptr: dptr,
is_managed,
memory_type,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flags_default_contains_portable_and_device_map() {
assert!(RegisterFlags::DEFAULT.contains(RegisterFlags::PORTABLE));
assert!(RegisterFlags::DEFAULT.contains(RegisterFlags::DEVICE_MAP));
assert!(!RegisterFlags::DEFAULT.contains(RegisterFlags::IO_MEMORY));
assert!(!RegisterFlags::DEFAULT.contains(RegisterFlags::READ_ONLY));
}
#[test]
fn flags_bitor_combines() {
let combined = RegisterFlags::PORTABLE | RegisterFlags::READ_ONLY;
assert!(combined.contains(RegisterFlags::PORTABLE));
assert!(combined.contains(RegisterFlags::READ_ONLY));
assert!(!combined.contains(RegisterFlags::IO_MEMORY));
}
#[test]
fn flags_bitand_intersects() {
let a = RegisterFlags::PORTABLE | RegisterFlags::DEVICE_MAP;
let b = RegisterFlags::PORTABLE | RegisterFlags::READ_ONLY;
let intersected = a & b;
assert!(intersected.contains(RegisterFlags::PORTABLE));
assert!(!intersected.contains(RegisterFlags::DEVICE_MAP));
assert!(!intersected.contains(RegisterFlags::READ_ONLY));
}
#[test]
fn flags_display() {
assert_eq!(RegisterFlags::NONE.to_string(), "NONE");
assert_eq!(RegisterFlags::PORTABLE.to_string(), "PORTABLE");
let default_str = RegisterFlags::DEFAULT.to_string();
assert!(default_str.contains("PORTABLE"));
assert!(default_str.contains("DEVICE_MAP"));
}
#[test]
fn flags_bits_roundtrip() {
let flags = RegisterFlags::PORTABLE | RegisterFlags::IO_MEMORY;
let bits = flags.bits();
assert_eq!(RegisterFlags::from_bits(bits), flags);
}
#[test]
fn flags_none_is_zero() {
assert_eq!(RegisterFlags::NONE.bits(), 0);
}
#[test]
fn memory_type_display() {
assert_eq!(RegisteredMemoryType::Host.to_string(), "Host");
assert_eq!(RegisteredMemoryType::Device.to_string(), "Device");
assert_eq!(RegisteredMemoryType::Unified.to_string(), "Unified");
assert_eq!(
RegisteredMemoryType::Unregistered.to_string(),
"Unregistered"
);
}
#[test]
fn memory_type_equality() {
assert_eq!(RegisteredMemoryType::Host, RegisteredMemoryType::Host);
assert_ne!(RegisteredMemoryType::Host, RegisteredMemoryType::Device);
}
#[test]
fn register_zero_len_fails() {
let mut buf = [0u8; 16];
let result = register(buf.as_mut_ptr(), 0, RegisterFlags::DEFAULT);
assert!(matches!(result, Err(CudaError::InvalidValue)));
}
#[test]
fn register_null_ptr_fails() {
let result = register::<u8>(std::ptr::null_mut(), 10, RegisterFlags::DEFAULT);
assert!(matches!(result, Err(CudaError::InvalidValue)));
}
#[test]
fn register_slice_zero_len_fails() {
let mut empty: [f32; 0] = [];
let result = register_slice(&mut empty, RegisterFlags::DEFAULT);
assert!(matches!(result, Err(CudaError::InvalidValue)));
}
#[test]
fn register_vec_zero_len_fails() {
let mut v: Vec<i32> = Vec::new();
let result = register_vec(&mut v, RegisterFlags::DEFAULT);
assert!(matches!(result, Err(CudaError::InvalidValue)));
}
#[test]
fn query_null_ptr_fails() {
let result = query_registered_pointer_info(std::ptr::null());
assert!(matches!(result, Err(CudaError::InvalidValue)));
}
#[cfg(target_os = "macos")]
mod macos_tests {
use super::*;
#[test]
fn register_slice_succeeds_on_macos() {
let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
let reg = register_slice(data.as_mut_slice(), RegisterFlags::DEFAULT);
let reg = reg.ok();
assert!(reg.is_some());
let reg = reg.inspect(|r| {
assert_eq!(r.len(), 4);
assert!(!r.is_empty());
assert_eq!(r.flags(), RegisterFlags::DEFAULT);
assert_eq!(r.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
});
drop(reg);
}
#[test]
fn register_vec_succeeds_on_macos() {
let mut v = vec![10u32, 20, 30];
let reg = register_vec(&mut v, RegisterFlags::PORTABLE);
assert!(reg.is_ok());
if let Ok(r) = reg {
assert_eq!(r.len(), 3);
assert_eq!(r.flags(), RegisterFlags::PORTABLE);
assert_ne!(r.device_ptr(), 0);
}
}
#[test]
fn registered_memory_deref_works() {
let mut data = vec![100i64, 200, 300];
let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
assert!(reg.is_ok());
if let Ok(r) = reg {
let slice: &[i64] = &r;
assert_eq!(slice.len(), 3);
assert_eq!(slice[0], 100);
assert_eq!(slice[2], 300);
}
}
#[test]
fn registered_memory_deref_mut_works() {
let mut data = vec![1u8, 2, 3, 4, 5];
let reg = register_slice(&mut data, RegisterFlags::DEFAULT);
assert!(reg.is_ok());
if let Ok(mut r) = reg {
r[0] = 99;
assert_eq!(r[0], 99);
let mslice: &mut [u8] = &mut r;
mslice[4] = 88;
assert_eq!(mslice[4], 88);
}
}
#[test]
fn query_pointer_info_on_macos() {
let data = [42u8; 64];
let info = query_registered_pointer_info(data.as_ptr());
assert!(info.is_ok());
if let Ok(info) = info {
assert!(!info.is_managed);
assert_eq!(info.memory_type, RegisteredMemoryType::Host);
assert_ne!(info.device_ptr, 0);
}
}
#[test]
fn registered_memory_as_ptr_mut_ptr() {
let mut data = vec![5.0f64; 10];
let original_ptr = data.as_mut_ptr();
let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
assert!(reg.is_ok());
if let Ok(mut r) = reg {
assert_eq!(r.as_ptr(), original_ptr as *const f64);
assert_eq!(r.as_mut_ptr(), original_ptr);
}
}
}
#[cfg(feature = "gpu-tests")]
mod gpu_tests {
use super::*;
#[test]
fn register_and_unregister_on_gpu() {
if oxicuda_driver::init().is_err() || oxicuda_driver::Device::count().unwrap_or(0) == 0
{
return;
}
let Ok(dev) = oxicuda_driver::Device::get(0) else {
return;
};
let Ok(_ctx) = oxicuda_driver::Context::new(&dev) else {
return;
};
let mut data = vec![0.0f32; 4096];
let reg = register_vec(&mut data, RegisterFlags::DEFAULT);
assert!(reg.is_ok(), "registration failed: {:?}", reg.err());
if let Ok(r) = reg {
assert_eq!(r.len(), 4096);
assert!(r.device_ptr() != 0, "device_ptr should be non-zero");
}
}
}
}