use crate::device_context::with_deallocator_stream;
use cuda_core::free_async;
use cuda_core::sys::CUdeviceptr;
use std::marker::PhantomData;
#[derive(Debug, Copy, Clone)]
pub struct DevicePointer<T> {
dtype: PhantomData<T>,
pub dptr: CUdeviceptr,
}
unsafe impl<T> Send for DevicePointer<T> {}
impl<T> DevicePointer<T> {
pub fn cu_deviceptr(&self) -> CUdeviceptr {
self.dptr
}
pub unsafe fn from_cu_deviceptr(dptr: CUdeviceptr) -> Self {
Self {
dtype: PhantomData,
dptr,
}
}
}
#[derive(Debug)]
pub struct DeviceBuffer {
device_id: usize,
cudptr: CUdeviceptr,
len: usize,
}
unsafe impl Send for DeviceBuffer {}
unsafe impl Sync for DeviceBuffer {}
impl Drop for DeviceBuffer {
fn drop(&mut self) {
unsafe {
with_deallocator_stream(self.device_id, |stream| {
free_async(self.cudptr, stream);
})
.unwrap_or_else(|_| {
panic!(
"Failed to free device pointer on device_id={}",
self.device_id
)
})
}
}
}
impl DeviceBuffer {
pub unsafe fn from_raw_parts(dptr: CUdeviceptr, len_bytes: usize, device_id: usize) -> Self {
Self {
cudptr: dptr,
len: len_bytes,
device_id,
}
}
pub fn is_empty(&self) -> bool {
self.len_bytes() == 0
}
pub fn len_bytes(&self) -> usize {
self.len
}
pub fn cu_deviceptr(&self) -> CUdeviceptr {
self.cudptr
}
pub fn device_id(&self) -> usize {
self.device_id
}
}