use std::ffi::c_void;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use crate::driver::context::{get_driver, CudaContext};
use crate::driver::sys::{CUcontext, CUdeviceptr, CudaDriver, CUDA_SUCCESS};
use crate::GpuError;
pub struct GpuBuffer<T> {
pub(super) ptr: CUdeviceptr,
pub(super) len: usize,
host_ptr: Option<*mut c_void>,
pub(crate) ctx: Option<CUcontext>,
pub(super) _marker: PhantomData<T>,
}
unsafe impl<T: Send> Send for GpuBuffer<T> {}
unsafe impl<T: Sync> Sync for GpuBuffer<T> {}
impl<T> GpuBuffer<T> {
#[must_use]
pub unsafe fn from_raw_parts(ptr: CUdeviceptr, len: usize) -> Self {
Self {
ptr,
len,
host_ptr: None,
ctx: None,
_marker: PhantomData,
}
}
pub fn new(ctx: &CudaContext, len: usize) -> Result<Self, GpuError> {
let ctx_handle = Some(ctx.raw());
if len == 0 {
return Ok(Self {
ptr: 0,
len: 0,
host_ptr: None,
ctx: ctx_handle,
_marker: PhantomData,
});
}
static USE_MANAGED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let managed =
*USE_MANAGED.get_or_init(|| std::env::var("MANAGED_MEMORY").as_deref() == Ok("1"));
if managed {
return Self::new_managed(ctx, len);
}
let driver = get_driver()?;
let size = len * mem::size_of::<T>();
let mut ptr: CUdeviceptr = 0;
let result = unsafe { (driver.cuMemAlloc)(&mut ptr, size) };
CudaDriver::check(result).map_err(|e| GpuError::MemoryAllocation(e.to_string()))?;
Ok(Self {
ptr,
len,
host_ptr: None,
ctx: ctx_handle,
_marker: PhantomData,
})
}
pub fn new_managed(ctx: &CudaContext, len: usize) -> Result<Self, GpuError> {
let ctx_handle = Some(ctx.raw());
if len == 0 {
return Ok(Self {
ptr: 0,
len: 0,
host_ptr: None,
ctx: ctx_handle,
_marker: PhantomData,
});
}
let driver = get_driver()?;
let size = len * mem::size_of::<T>();
let mut ptr: CUdeviceptr = 0;
const CU_MEM_ATTACH_GLOBAL: u32 = 1;
let result = unsafe { (driver.cuMemAllocManaged)(&mut ptr, size, CU_MEM_ATTACH_GLOBAL) };
CudaDriver::check(result).map_err(|e| {
GpuError::MemoryAllocation(format!("cuMemAllocManaged({} bytes): {}", size, e))
})?;
Ok(Self {
ptr,
len,
host_ptr: None,
ctx: ctx_handle,
_marker: PhantomData,
})
}
pub unsafe fn from_host_registered(host_ptr: *mut T, len: usize) -> Result<Self, GpuError> {
if len == 0 {
return Ok(Self {
ptr: 0,
len: 0,
host_ptr: None,
ctx: None,
_marker: PhantomData,
});
}
let driver = get_driver()?;
let size = len * mem::size_of::<T>();
const CU_MEMHOSTREGISTER_DEVICEMAP: u32 = 0x02;
let result = unsafe {
(driver.cuMemHostRegister)(host_ptr as *mut c_void, size, CU_MEMHOSTREGISTER_DEVICEMAP)
};
CudaDriver::check(result).map_err(|e| {
GpuError::MemoryAllocation(format!("cuMemHostRegister({} bytes): {}", size, e))
})?;
let mut dev_ptr: CUdeviceptr = 0;
let result =
unsafe { (driver.cuMemHostGetDevicePointer)(&mut dev_ptr, host_ptr as *mut c_void, 0) };
CudaDriver::check(result)
.map_err(|e| GpuError::MemoryAllocation(format!("cuMemHostGetDevicePointer: {}", e)))?;
Ok(Self {
ptr: dev_ptr,
len,
host_ptr: Some(host_ptr as *mut c_void),
ctx: None,
_marker: PhantomData,
})
}
pub fn zero_async(&mut self, stream: &crate::driver::CudaStream) -> Result<(), GpuError> {
if self.len == 0 {
return Ok(());
}
self.ensure_context()?;
let driver = get_driver()?;
let result = unsafe { (driver.cuMemsetD32Async)(self.ptr, 0, self.len, stream.raw()) };
if result != CUDA_SUCCESS {
return Err(GpuError::Transfer(format!(
"cuMemsetD32Async failed: {result}"
)));
}
Ok(())
}
#[must_use]
pub fn as_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn set_context(&mut self, ctx: &CudaContext) {
self.ctx = Some(ctx.raw());
}
pub(crate) fn ensure_context(&self) -> Result<(), GpuError> {
if let Some(ctx_handle) = self.ctx {
let driver = get_driver()?;
let result = unsafe { (driver.cuCtxSetCurrent)(ctx_handle) };
if result != CUDA_SUCCESS {
return Err(GpuError::DeviceInit(format!(
"PMAT-420: cuCtxSetCurrent failed with code {} — \
context may have been destroyed",
result
)));
}
}
Ok(())
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * mem::size_of::<T>()
}
#[must_use]
pub fn clone_metadata(&self) -> GpuBufferView<T> {
GpuBufferView {
ptr: self.ptr,
len: self.len,
_marker: PhantomData,
}
}
}
pub struct GpuBufferView<T> {
ptr: CUdeviceptr,
len: usize,
_marker: PhantomData<T>,
}
impl<T> GpuBufferView<T> {
#[must_use]
pub fn as_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
}
impl<T> Drop for GpuBuffer<T> {
fn drop(&mut self) {
if self.ptr != 0 {
if let Ok(driver) = get_driver() {
unsafe {
if let Some(host_ptr) = self.host_ptr {
let _ = (driver.cuMemHostUnregister)(host_ptr);
} else {
let _ = (driver.cuMemFree)(self.ptr);
}
}
}
}
}
}
impl<T> GpuBuffer<T> {
#[must_use]
pub fn as_kernel_arg(&self) -> *mut c_void {
ptr::addr_of!(self.ptr) as *mut c_void
}
}