use std::ffi::c_void;
use std::ops::{Deref, DerefMut};
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::loader::try_driver;
pub struct PinnedBuffer<T: Copy> {
ptr: *mut T,
len: usize,
}
unsafe impl<T: Copy + Send> Send for PinnedBuffer<T> {}
unsafe impl<T: Copy + Sync> Sync for PinnedBuffer<T> {}
impl<T: Copy> PinnedBuffer<T> {
pub fn alloc(n: usize) -> CudaResult<Self> {
if n == 0 {
return Err(CudaError::InvalidValue);
}
let byte_size = n
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
let api = try_driver()?;
let mut raw_ptr: *mut c_void = std::ptr::null_mut();
let rc = unsafe { (api.cu_mem_alloc_host_v2)(&mut raw_ptr, byte_size) };
oxicuda_driver::check(rc)?;
Ok(Self {
ptr: raw_ptr.cast::<T>(),
len: n,
})
}
pub fn from_slice(data: &[T]) -> CudaResult<Self> {
let buf = Self::alloc(data.len())?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr, data.len());
}
Ok(buf)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T: Copy> Deref for PinnedBuffer<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T: Copy> DerefMut for PinnedBuffer<T> {
#[inline]
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
impl<T: Copy> Drop for PinnedBuffer<T> {
fn drop(&mut self) {
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_mem_free_host)(self.ptr.cast::<c_void>()) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
len = self.len,
"cuMemFreeHost failed during PinnedBuffer drop"
);
}
}
}
}