use std::marker::PhantomData;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::ffi::{CU_MEM_ATTACH_GLOBAL, CUdeviceptr};
use oxicuda_driver::loader::try_driver;
pub struct UnifiedBuffer<T: Copy> {
ptr: CUdeviceptr,
host_ptr: *mut T,
len: usize,
_phantom: PhantomData<T>,
}
unsafe impl<T: Copy + Send> Send for UnifiedBuffer<T> {}
unsafe impl<T: Copy + Sync> Sync for UnifiedBuffer<T> {}
impl<T: Copy> UnifiedBuffer<T> {
pub fn alloc(n: usize) -> CudaResult<Self> {
if n == 0 {
return Err(CudaError::InvalidValue);
}
let byte_size = n
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
let api = try_driver()?;
let mut dev_ptr: CUdeviceptr = 0;
let rc =
unsafe { (api.cu_mem_alloc_managed)(&mut dev_ptr, byte_size, CU_MEM_ATTACH_GLOBAL) };
oxicuda_driver::check(rc)?;
let host_ptr = dev_ptr as *mut T;
Ok(Self {
ptr: dev_ptr,
host_ptr,
len: n,
_phantom: PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
#[inline]
pub fn as_device_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.host_ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.host_ptr, self.len) }
}
}
impl<T: Copy> Drop for UnifiedBuffer<T> {
fn drop(&mut self) {
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
ptr = self.ptr,
len = self.len,
"cuMemFree_v2 failed during UnifiedBuffer drop"
);
}
}
}
}