use std::fmt;
use oxicuda_driver::error::{CudaError, CudaResult, check};
use oxicuda_driver::ffi::{
CUdeviceptr, CUmemAccessDesc, CUmemAllocationHandleType, CUmemAllocationProp,
CUmemAllocationType, CUmemGenericAllocationHandle, CUmemLocation, CUmemLocationType,
};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum AccessFlags {
#[default]
None,
Read,
ReadWrite,
}
impl fmt::Display for AccessFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "None"),
Self::Read => write!(f, "Read"),
Self::ReadWrite => write!(f, "ReadWrite"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VirtualAddressRange {
base: u64,
size: usize,
alignment: usize,
}
impl VirtualAddressRange {
#[inline]
pub fn base(&self) -> u64 {
self.base
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
#[inline]
pub fn alignment(&self) -> usize {
self.alignment
}
pub fn contains(&self, addr: u64) -> bool {
addr >= self.base && addr < self.base.saturating_add(self.size as u64)
}
#[inline]
pub fn end(&self) -> u64 {
self.base.saturating_add(self.size as u64)
}
}
impl fmt::Display for VirtualAddressRange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"VA[0x{:016x}..0x{:016x}, {} bytes, align={}]",
self.base,
self.end(),
self.size,
self.alignment,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PhysicalAllocation {
handle: u64,
size: usize,
device_ordinal: i32,
}
impl PhysicalAllocation {
#[inline]
pub fn handle(&self) -> u64 {
self.handle
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
#[inline]
pub fn device_ordinal(&self) -> i32 {
self.device_ordinal
}
}
impl fmt::Display for PhysicalAllocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"PhysAlloc[handle=0x{:016x}, {} bytes, dev={}]",
self.handle, self.size, self.device_ordinal,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MappingRecord {
pub va_offset: usize,
pub size: usize,
pub phys_handle: u64,
pub access: AccessFlags,
}
pub struct VirtualMemoryManager;
impl VirtualMemoryManager {
pub fn reserve(size: usize, alignment: usize) -> CudaResult<VirtualAddressRange> {
if size == 0 {
return Err(CudaError::InvalidValue);
}
if alignment == 0 || !alignment.is_power_of_two() {
return Err(CudaError::InvalidValue);
}
if size % alignment != 0 {
return Err(CudaError::InvalidValue);
}
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_address_reserve.ok_or(CudaError::NotSupported)?;
let mut base: CUdeviceptr = 0;
check(unsafe { f(&mut base, size, alignment, 0, 0) })?;
Ok(VirtualAddressRange {
base,
size,
alignment,
})
}
pub fn release(va: VirtualAddressRange) -> CudaResult<()> {
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_address_free.ok_or(CudaError::NotSupported)?;
check(unsafe { f(va.base, va.size) })
}
pub fn alloc_physical(size: usize, device_ordinal: i32) -> CudaResult<PhysicalAllocation> {
if size == 0 {
return Err(CudaError::InvalidValue);
}
if device_ordinal < 0 {
return Err(CudaError::InvalidValue);
}
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_create.ok_or(CudaError::NotSupported)?;
let prop = CUmemAllocationProp {
alloc_type: CUmemAllocationType::Pinned as u32,
requested_handle_types: CUmemAllocationHandleType::None as u32,
location: CUmemLocation {
loc_type: CUmemLocationType::Device as u32,
id: device_ordinal,
},
..CUmemAllocationProp::default()
};
let mut handle: CUmemGenericAllocationHandle = 0;
check(unsafe { f(&mut handle, size, &prop, 0) })?;
Ok(PhysicalAllocation {
handle,
size,
device_ordinal,
})
}
pub fn free_physical(phys: PhysicalAllocation) -> CudaResult<()> {
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_release.ok_or(CudaError::NotSupported)?;
check(unsafe { f(phys.handle) })
}
pub fn map(
va: &VirtualAddressRange,
phys: &PhysicalAllocation,
offset: usize,
) -> CudaResult<()> {
if va.alignment > 0 && offset % va.alignment != 0 {
return Err(CudaError::InvalidValue);
}
let end = offset
.checked_add(phys.size)
.ok_or(CudaError::InvalidValue)?;
if end > va.size {
return Err(CudaError::InvalidValue);
}
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_map.ok_or(CudaError::NotSupported)?;
let target_va: CUdeviceptr = va.base.saturating_add(offset as u64);
check(unsafe { f(target_va, phys.size, 0, phys.handle, 0) })
}
pub fn unmap(va: &VirtualAddressRange, offset: usize, size: usize) -> CudaResult<()> {
let end = offset.checked_add(size).ok_or(CudaError::InvalidValue)?;
if end > va.size {
return Err(CudaError::InvalidValue);
}
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_unmap.ok_or(CudaError::NotSupported)?;
let target_va: CUdeviceptr = va.base.saturating_add(offset as u64);
check(unsafe { f(target_va, size) })
}
pub fn set_access(
va: &VirtualAddressRange,
device_ordinal: i32,
flags: AccessFlags,
) -> CudaResult<()> {
let api = oxicuda_driver::loader::try_driver()?;
let f = api.cu_mem_set_access.ok_or(CudaError::NotSupported)?;
let desc = CUmemAccessDesc {
location: CUmemLocation {
loc_type: CUmemLocationType::Device as u32,
id: device_ordinal,
},
flags: match flags {
AccessFlags::None => 0,
AccessFlags::Read => 1,
AccessFlags::ReadWrite => 3,
},
};
check(unsafe { f(va.base, va.size, &desc, 1) })
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_driver_unavailable(err: &CudaError) -> bool {
matches!(
err,
CudaError::NotInitialized
| CudaError::NotSupported
| CudaError::InvalidValue
| CudaError::InvalidDevice
| CudaError::NoDevice
| CudaError::InvalidContext
)
}
#[test]
fn reserve_zero_size_fails() {
let result = VirtualMemoryManager::reserve(0, 4096);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn reserve_zero_alignment_fails() {
let result = VirtualMemoryManager::reserve(4096, 0);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn reserve_non_power_of_two_alignment_fails() {
let result = VirtualMemoryManager::reserve(4096, 3);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn reserve_misaligned_size_fails() {
let result = VirtualMemoryManager::reserve(4097, 4096);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn reserve_no_driver_returns_driver_unavailable() {
let result = VirtualMemoryManager::reserve(4096, 4096);
match result {
Ok(va) => {
assert_eq!(va.size(), 4096);
assert_eq!(va.alignment(), 4096);
}
Err(e) => assert!(
is_driver_unavailable(&e),
"unexpected error from reserve: {e:?}"
),
}
}
#[test]
fn virtual_address_range_contains_synthetic() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 8192,
alignment: 4096,
};
assert!(va.contains(va.base()));
assert!(va.contains(va.base() + 1));
assert!(va.contains(va.base() + 8191));
assert!(!va.contains(va.end()));
assert!(!va.contains(va.base().wrapping_sub(1)));
}
#[test]
fn virtual_address_range_end_synthetic() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
assert_eq!(va.end(), va.base() + 4096);
}
#[test]
fn virtual_address_range_display_synthetic() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
let disp = format!("{va}");
assert!(disp.contains("VA["));
assert!(disp.contains("4096 bytes"));
}
#[test]
fn alloc_physical_zero_size_fails() {
let result = VirtualMemoryManager::alloc_physical(0, 0);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn alloc_physical_negative_device_fails() {
let result = VirtualMemoryManager::alloc_physical(4096, -1);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn alloc_physical_no_driver_returns_driver_unavailable() {
let result = VirtualMemoryManager::alloc_physical(4096, 0);
if let Err(e) = result {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn release_no_driver_returns_driver_unavailable() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
if let Err(e) = VirtualMemoryManager::release(va) {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn free_physical_no_driver_returns_driver_unavailable() {
if oxicuda_driver::loader::try_driver().is_ok() {
return;
}
let phys = PhysicalAllocation {
handle: 1,
size: 4096,
device_ordinal: 0,
};
if let Err(e) = VirtualMemoryManager::free_physical(phys) {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn map_validates_alignment() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 8192,
alignment: 4096,
};
let phys = PhysicalAllocation {
handle: 1,
size: 4096,
device_ordinal: 0,
};
let result = VirtualMemoryManager::map(&va, &phys, 1);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn map_validates_bounds() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
let phys = PhysicalAllocation {
handle: 1,
size: 8192, device_ordinal: 0,
};
let result = VirtualMemoryManager::map(&va, &phys, 0);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn map_no_driver_returns_driver_unavailable() {
if oxicuda_driver::loader::try_driver().is_ok() {
return;
}
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 8192,
alignment: 4096,
};
let phys = PhysicalAllocation {
handle: 1,
size: 4096,
device_ordinal: 0,
};
if let Err(e) = VirtualMemoryManager::map(&va, &phys, 0) {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn unmap_validates_bounds() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
let result = VirtualMemoryManager::unmap(&va, 0, 8192);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn unmap_no_driver_returns_driver_unavailable() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
if let Err(e) = VirtualMemoryManager::unmap(&va, 0, 4096) {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn set_access_no_driver_returns_driver_unavailable() {
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
if let Err(e) = VirtualMemoryManager::set_access(&va, 0, AccessFlags::ReadWrite) {
assert!(
is_driver_unavailable(&e),
"expected driver-unavailable error, got {e:?}"
);
}
}
#[test]
fn access_flags_default() {
assert_eq!(AccessFlags::default(), AccessFlags::None);
}
#[test]
fn access_flags_display() {
assert_eq!(format!("{}", AccessFlags::None), "None");
assert_eq!(format!("{}", AccessFlags::Read), "Read");
assert_eq!(format!("{}", AccessFlags::ReadWrite), "ReadWrite");
}
#[test]
fn physical_allocation_display() {
let phys = PhysicalAllocation {
handle: 0x1234,
size: 4096,
device_ordinal: 0,
};
let disp = format!("{phys}");
assert!(disp.contains("4096 bytes"));
assert!(disp.contains("dev=0"));
}
#[test]
fn mapping_record_fields() {
let record = MappingRecord {
va_offset: 0,
size: 4096,
phys_handle: 42,
access: AccessFlags::ReadWrite,
};
assert_eq!(record.va_offset, 0);
assert_eq!(record.size, 4096);
assert_eq!(record.phys_handle, 42);
assert_eq!(record.access, AccessFlags::ReadWrite);
}
#[cfg(target_os = "macos")]
#[test]
fn macos_paths_return_not_initialized() {
assert_eq!(
VirtualMemoryManager::reserve(4096, 4096),
Err(CudaError::NotInitialized)
);
assert_eq!(
VirtualMemoryManager::alloc_physical(4096, 0),
Err(CudaError::NotInitialized)
);
let phys = PhysicalAllocation {
handle: 1,
size: 4096,
device_ordinal: 0,
};
assert_eq!(
VirtualMemoryManager::free_physical(phys.clone()),
Err(CudaError::NotInitialized)
);
let va = VirtualAddressRange {
base: 0x1_0000_0000,
size: 4096,
alignment: 4096,
};
assert_eq!(
VirtualMemoryManager::release(va.clone()),
Err(CudaError::NotInitialized)
);
assert_eq!(
VirtualMemoryManager::map(&va, &phys, 0),
Err(CudaError::NotInitialized)
);
assert_eq!(
VirtualMemoryManager::unmap(&va, 0, 4096),
Err(CudaError::NotInitialized)
);
assert_eq!(
VirtualMemoryManager::set_access(&va, 0, AccessFlags::ReadWrite),
Err(CudaError::NotInitialized)
);
}
}