use std::ffi::c_void;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use crate::driver::context::{get_driver, CudaContext};
use crate::driver::sys::{CUdeviceptr, CudaDriver};
use crate::GpuError;
pub struct GpuBuffer<T> {
pub(super) ptr: CUdeviceptr,
pub(super) len: usize,
pub(super) _marker: PhantomData<T>,
}
unsafe impl<T: Send> Send for GpuBuffer<T> {}
unsafe impl<T: Sync> Sync for GpuBuffer<T> {}
impl<T> GpuBuffer<T> {
#[must_use]
pub unsafe fn from_raw_parts(ptr: CUdeviceptr, len: usize) -> Self {
Self { ptr, len, _marker: PhantomData }
}
pub fn new(_ctx: &CudaContext, len: usize) -> Result<Self, GpuError> {
if len == 0 {
return Ok(Self { ptr: 0, len: 0, _marker: PhantomData });
}
let driver = get_driver()?;
let size = len * mem::size_of::<T>();
let mut ptr: CUdeviceptr = 0;
let result = unsafe { (driver.cuMemAlloc)(&mut ptr, size) };
CudaDriver::check(result).map_err(|e| GpuError::MemoryAllocation(e.to_string()))?;
Ok(Self { ptr, len, _marker: PhantomData })
}
#[must_use]
pub fn as_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * mem::size_of::<T>()
}
#[must_use]
pub fn clone_metadata(&self) -> GpuBufferView<T> {
GpuBufferView { ptr: self.ptr, len: self.len, _marker: PhantomData }
}
}
pub struct GpuBufferView<T> {
ptr: CUdeviceptr,
len: usize,
_marker: PhantomData<T>,
}
impl<T> GpuBufferView<T> {
#[must_use]
pub fn as_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
}
impl<T> Drop for GpuBuffer<T> {
fn drop(&mut self) {
if self.ptr != 0 {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuMemFree)(self.ptr);
}
}
}
}
}
impl<T> GpuBuffer<T> {
#[must_use]
pub fn as_kernel_arg(&self) -> *mut c_void {
ptr::addr_of!(self.ptr) as *mut c_void
}
}