use core::ffi::c_void;
use baracuda_cuda_sys::runtime::runtime;
use baracuda_cuda_sys::runtime::types::{
cudaMemAccessDesc, cudaMemAllocationHandleType, cudaMemAllocationProp, cudaMemAllocationType,
cudaMemGenericAllocationHandle_t, cudaMemLocation, cudaMemLocationType,
};
use crate::device::Device;
use crate::error::{check, Result};
use crate::mempool::AccessFlags;
pub fn address_reserve(size: usize, alignment: usize, flags: u64) -> Result<*mut c_void> {
let r = runtime()?;
let cu = r.cuda_mem_address_reserve()?;
let mut ptr: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut ptr, size, alignment, core::ptr::null_mut(), flags) })?;
Ok(ptr)
}
pub unsafe fn address_free(ptr: *mut c_void, size: usize) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_address_free()?;
check(cu(ptr, size))
}
pub fn device_alloc_prop(device: &Device) -> cudaMemAllocationProp {
cudaMemAllocationProp {
alloc_type: cudaMemAllocationType::PINNED,
requested_handle_types: cudaMemAllocationHandleType::NONE,
location: cudaMemLocation {
type_: cudaMemLocationType::DEVICE,
id: device.ordinal(),
},
win32_handle_meta_data: core::ptr::null_mut(),
allocation_flags: [0; 32],
}
}
pub fn allocation_granularity(prop: &cudaMemAllocationProp, option: i32) -> Result<usize> {
let r = runtime()?;
let cu = r.cuda_mem_get_allocation_granularity()?;
let mut g: usize = 0;
check(unsafe {
cu(
&mut g,
prop as *const cudaMemAllocationProp as *const c_void,
option,
)
})?;
Ok(g)
}
#[derive(Debug)]
pub struct MemHandle {
handle: cudaMemGenericAllocationHandle_t,
}
impl MemHandle {
pub fn new(size: usize, prop: &cudaMemAllocationProp, flags: u64) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_mem_create()?;
let mut h: cudaMemGenericAllocationHandle_t = 0;
check(unsafe {
cu(
&mut h,
size,
prop as *const cudaMemAllocationProp as *const c_void,
flags,
)
})?;
Ok(Self { handle: h })
}
pub unsafe fn from_raw(handle: cudaMemGenericAllocationHandle_t) -> Self {
Self { handle }
}
#[inline]
pub fn as_raw(&self) -> cudaMemGenericAllocationHandle_t {
self.handle
}
pub unsafe fn retain(addr: *mut c_void) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_mem_retain_allocation_handle()?;
let mut h: cudaMemGenericAllocationHandle_t = 0;
check(cu(&mut h, addr))?;
Ok(Self { handle: h })
}
pub fn properties(&self) -> Result<cudaMemAllocationProp> {
let r = runtime()?;
let cu = r.cuda_mem_get_allocation_properties_from_handle()?;
let mut prop = cudaMemAllocationProp::default();
check(unsafe {
cu(
&mut prop as *mut cudaMemAllocationProp as *mut c_void,
self.handle,
)
})?;
Ok(prop)
}
}
impl Drop for MemHandle {
fn drop(&mut self) {
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_mem_release() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub unsafe fn map(
ptr: *mut c_void,
size: usize,
offset: usize,
handle: &MemHandle,
flags: u64,
) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_map()?;
check(cu(ptr, size, offset, handle.as_raw(), flags))
}
pub unsafe fn unmap(ptr: *mut c_void, size: usize) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_unmap()?;
check(cu(ptr, size))
}
pub unsafe fn set_access(
ptr: *mut c_void,
size: usize,
device: &Device,
flags: AccessFlags,
) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_set_access()?;
let flags_raw = match flags {
AccessFlags::None => 0,
AccessFlags::Read => 1,
AccessFlags::ReadWrite => 3,
};
let desc = cudaMemAccessDesc {
location: cudaMemLocation {
type_: cudaMemLocationType::DEVICE,
id: device.ordinal(),
},
flags: flags_raw,
};
check(cu(
ptr,
size,
&desc as *const cudaMemAccessDesc as *const c_void,
1,
))
}
pub unsafe fn get_access(ptr: *mut c_void, device: &Device) -> Result<u64> {
let r = runtime()?;
let cu = r.cuda_mem_get_access()?;
let loc = cudaMemLocation {
type_: cudaMemLocationType::DEVICE,
id: device.ordinal(),
};
let mut flags: u64 = 0;
check(cu(
&mut flags,
&loc as *const cudaMemLocation as *const c_void,
ptr,
))?;
Ok(flags)
}