use std::ffi::c_void;
use std::marker::PhantomData;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::stream::Stream;
pub struct DeviceBuffer<T: Copy> {
ptr: CUdeviceptr,
len: usize,
_phantom: PhantomData<T>,
}
unsafe impl<T: Copy + Send> Send for DeviceBuffer<T> {}
unsafe impl<T: Copy + Sync> Sync for DeviceBuffer<T> {}
impl<T: Copy> DeviceBuffer<T> {
pub fn alloc(n: usize) -> CudaResult<Self> {
if n == 0 {
return Err(CudaError::InvalidValue);
}
let byte_size = n
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
let api = try_driver()?;
let mut ptr: CUdeviceptr = 0;
let rc = unsafe { (api.cu_mem_alloc_v2)(&mut ptr, byte_size) };
oxicuda_driver::check(rc)?;
Ok(Self {
ptr,
len: n,
_phantom: PhantomData,
})
}
pub fn zeroed(n: usize) -> CudaResult<Self> {
let buf = Self::alloc(n)?;
let api = try_driver()?;
let rc = unsafe { (api.cu_memset_d8_v2)(buf.ptr, 0, buf.byte_size()) };
oxicuda_driver::check(rc)?;
Ok(buf)
}
pub fn from_host(data: &[T]) -> CudaResult<Self> {
let mut buf = Self::alloc(data.len())?;
buf.copy_from_host(data)?;
Ok(buf)
}
pub fn copy_from_host(&mut self, src: &[T]) -> CudaResult<()> {
if src.len() != self.len {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_htod_v2)(self.ptr, src.as_ptr().cast::<c_void>(), self.byte_size())
};
oxicuda_driver::check(rc)
}
pub fn copy_to_host(&self, dst: &mut [T]) -> CudaResult<()> {
if dst.len() != self.len {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_dtoh_v2)(
dst.as_mut_ptr().cast::<c_void>(),
self.ptr,
self.byte_size(),
)
};
oxicuda_driver::check(rc)
}
pub fn copy_from_device(&mut self, src: &DeviceBuffer<T>) -> CudaResult<()> {
if src.len != self.len {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let rc = unsafe { (api.cu_memcpy_dtod_v2)(self.ptr, src.ptr, self.byte_size()) };
oxicuda_driver::check(rc)
}
pub fn copy_from_host_async(&mut self, src: &[T], stream: &Stream) -> CudaResult<()> {
if src.len() != self.len {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_htod_async_v2)(
self.ptr,
src.as_ptr().cast::<c_void>(),
self.byte_size(),
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> CudaResult<()> {
if dst.len() != self.len {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_dtoh_async_v2)(
dst.as_mut_ptr().cast::<c_void>(),
self.ptr,
self.byte_size(),
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
#[inline]
pub fn as_device_ptr(&self) -> CUdeviceptr {
self.ptr
}
pub fn slice(&self, offset: usize, len: usize) -> CudaResult<DeviceSlice<'_, T>> {
let end = offset.checked_add(len).ok_or(CudaError::InvalidValue)?;
if end > self.len {
return Err(CudaError::InvalidValue);
}
let byte_offset = offset
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
Ok(DeviceSlice {
ptr: self.ptr + byte_offset as u64,
len,
_phantom: PhantomData,
})
}
}
impl<T: Copy> Drop for DeviceBuffer<T> {
fn drop(&mut self) {
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_mem_free_v2)(self.ptr) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
ptr = self.ptr,
len = self.len,
"cuMemFree_v2 failed during DeviceBuffer drop"
);
}
}
}
}
pub struct DeviceSlice<'a, T: Copy> {
ptr: CUdeviceptr,
len: usize,
_phantom: PhantomData<&'a T>,
}
impl<T: Copy> DeviceSlice<'_, T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
#[inline]
pub fn as_device_ptr(&self) -> CUdeviceptr {
self.ptr
}
}