use crate::error::{Result, VmmError};
use cudarc::driver::sys;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryLocationType {
Device,
}
#[derive(Debug, Clone)]
pub struct AllocationProp {
pub location_type: MemoryLocationType,
pub device_ordinal: i32,
}
impl AllocationProp {
pub fn device(device_ordinal: i32) -> Self {
Self {
location_type: MemoryLocationType::Device,
device_ordinal,
}
}
fn to_cuda_prop(&self) -> sys::CUmemAllocationProp {
sys::CUmemAllocationProp {
type_: sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED,
requestedHandleTypes: sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE,
location: sys::CUmemLocation {
type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id: self.device_ordinal,
},
win32HandleMetaData: std::ptr::null_mut(),
allocFlags: unsafe { std::mem::zeroed() },
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessFlags {
None = 0,
Read = 1,
ReadWrite = 3,
}
impl AccessFlags {
fn to_cuda_flags(&self) -> sys::CUmemAccess_flags {
match self {
AccessFlags::None => sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_NONE,
AccessFlags::Read => sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READ,
AccessFlags::ReadWrite => sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE,
}
}
}
pub type MemGenericAllocationHandle = sys::CUmemGenericAllocationHandle;
pub type DevicePtr = sys::CUdeviceptr;
pub unsafe fn mem_create(size: usize, prop: &AllocationProp) -> Result<MemGenericAllocationHandle> {
let mut handle: MemGenericAllocationHandle = 0;
let cuda_prop = prop.to_cuda_prop();
let result = unsafe {
sys::cuMemCreate(
&mut handle,
size,
&cuda_prop,
0, )
};
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemCreate failed with code {:?}",
result
)));
}
Ok(handle)
}
pub unsafe fn mem_release(handle: MemGenericAllocationHandle) -> Result<()> {
let result = unsafe { sys::cuMemRelease(handle) };
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemRelease failed with code {:?}",
result
)));
}
Ok(())
}
pub unsafe fn mem_address_reserve(
size: usize,
alignment: usize,
addr: DevicePtr,
) -> Result<DevicePtr> {
let mut ptr: DevicePtr = 0;
let result = unsafe {
sys::cuMemAddressReserve(
&mut ptr, size, alignment, addr, 0, )
};
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemAddressReserve failed with code {:?}",
result
)));
}
Ok(ptr)
}
pub unsafe fn mem_address_free(ptr: DevicePtr, size: usize) -> Result<()> {
let result = unsafe { sys::cuMemAddressFree(ptr, size) };
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemAddressFree failed with code {:?}",
result
)));
}
Ok(())
}
pub unsafe fn mem_map(
ptr: DevicePtr,
size: usize,
offset: usize,
handle: MemGenericAllocationHandle,
) -> Result<()> {
let result = unsafe {
sys::cuMemMap(
ptr, size, offset, handle, 0, )
};
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::MappingFailed(format!(
"cuMemMap failed with code {:?}",
result
)));
}
Ok(())
}
pub unsafe fn mem_unmap(ptr: DevicePtr, size: usize) -> Result<()> {
let result = unsafe { sys::cuMemUnmap(ptr, size) };
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::UnmappingFailed(format!(
"cuMemUnmap failed with code {:?}",
result
)));
}
Ok(())
}
pub unsafe fn mem_set_access(
ptr: DevicePtr,
size: usize,
device_ordinal: i32,
flags: AccessFlags,
) -> Result<()> {
let access_desc = sys::CUmemAccessDesc {
location: sys::CUmemLocation {
type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
id: device_ordinal,
},
flags: flags.to_cuda_flags(),
};
let result = unsafe { sys::cuMemSetAccess(ptr, size, &access_desc, 1) };
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemSetAccess failed with code {:?}",
result
)));
}
Ok(())
}
pub unsafe fn mem_get_allocation_granularity(
prop: &AllocationProp,
option: sys::CUmemAllocationGranularity_flags,
) -> Result<usize> {
let mut granularity: usize = 0;
let cuda_prop = prop.to_cuda_prop();
let result =
unsafe { sys::cuMemGetAllocationGranularity(&mut granularity, &cuda_prop, option) };
if result != sys::cudaError_enum::CUDA_SUCCESS {
return Err(VmmError::cuda(format!(
"cuMemGetAllocationGranularity failed with code {:?}",
result
)));
}
Ok(granularity)
}
pub fn get_recommended_granularity(device_ordinal: i32) -> Result<usize> {
let prop = AllocationProp::device(device_ordinal);
unsafe {
mem_get_allocation_granularity(
&prop,
sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
)
}
}
pub fn get_minimum_granularity(device_ordinal: i32) -> Result<usize> {
let prop = AllocationProp::device(device_ordinal);
unsafe {
mem_get_allocation_granularity(
&prop,
sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_prop_creation() {
let prop = AllocationProp::device(0);
assert_eq!(prop.location_type, MemoryLocationType::Device);
assert_eq!(prop.device_ordinal, 0);
}
#[test]
fn test_access_flags() {
assert_eq!(AccessFlags::None as i32, 0);
assert_eq!(AccessFlags::Read as i32, 1);
assert_eq!(AccessFlags::ReadWrite as i32, 3);
}
}