use std::ffi::c_void;
use std::marker::PhantomData;
use std::mem::size_of;
use oxicuda_driver::error::CudaResult;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_driver::loader::try_driver;
pub struct MappedBuffer<T: Copy> {
host_ptr: *mut T,
device_ptr: CUdeviceptr,
len: usize,
_phantom: PhantomData<T>,
}
unsafe impl<T: Copy + Send> Send for MappedBuffer<T> {}
unsafe impl<T: Copy + Sync> Sync for MappedBuffer<T> {}
impl<T: Copy> MappedBuffer<T> {
pub fn alloc(n: usize) -> CudaResult<Self> {
let api = try_driver()?;
let byte_size = n.saturating_mul(size_of::<T>());
let mut raw_ptr: *mut c_void = std::ptr::null_mut();
oxicuda_driver::error::check(unsafe {
(api.cu_mem_alloc_host_v2)(&mut raw_ptr, byte_size)
})?;
let host_ptr = raw_ptr.cast::<T>();
let mut device_ptr: CUdeviceptr = 0;
let result = oxicuda_driver::error::check(unsafe {
(api.cu_mem_host_get_device_pointer_v2)(&mut device_ptr, raw_ptr, 0)
});
if let Err(e) = result {
unsafe { (api.cu_mem_free_host)(raw_ptr) };
return Err(e);
}
Ok(Self {
host_ptr,
device_ptr,
len: n,
_phantom: PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * size_of::<T>()
}
#[inline]
pub fn as_device_ptr(&self) -> CUdeviceptr {
self.device_ptr
}
#[inline]
pub fn as_host_ptr(&self) -> *const T {
self.host_ptr
}
#[inline]
pub fn as_host_ptr_mut(&mut self) -> *mut T {
self.host_ptr
}
pub fn as_host_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.host_ptr, self.len) }
}
pub fn as_host_slice_mut(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.host_ptr, self.len) }
}
}
impl<T: Copy> Drop for MappedBuffer<T> {
fn drop(&mut self) {
if self.host_ptr.is_null() {
return;
}
if let Ok(api) = try_driver() {
unsafe { (api.cu_mem_free_host)(self.host_ptr.cast::<c_void>()) };
}
}
}