use core::ffi::c_void;
use core::mem::size_of;
use core::ops::{Deref, DerefMut};
use baracuda_cuda_sys::{driver, CUdeviceptr};
use baracuda_types::DeviceRepr;
use crate::context::Context;
use crate::error::{check, Result};
#[allow(non_snake_case)]
pub mod flags {
pub const PORTABLE: u32 = 0x01;
pub const DEVICEMAP: u32 = 0x02;
pub const WRITECOMBINED: u32 = 0x04;
}
pub struct PinnedBuffer<T: DeviceRepr> {
ptr: *mut T,
len: usize,
#[allow(dead_code)]
context: Context,
_marker: core::marker::PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for PinnedBuffer<T> {}
unsafe impl<T: DeviceRepr + Sync> Sync for PinnedBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for PinnedBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PinnedBuffer")
.field("ptr", &self.ptr)
.field("len", &self.len)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
impl<T: DeviceRepr> PinnedBuffer<T> {
pub fn new(context: &Context, len: usize) -> Result<Self> {
Self::with_flags(context, len, 0)
}
pub fn with_flags(context: &Context, len: usize, flags: u32) -> Result<Self> {
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow in PinnedBuffer byte count");
if bytes == 0 {
return Ok(Self {
ptr: core::ptr::NonNull::<T>::dangling().as_ptr(),
len,
context: context.clone(),
_marker: core::marker::PhantomData,
});
}
context.set_current()?;
let d = driver()?;
let cu = d.cu_mem_host_alloc()?;
let mut p: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut p, bytes, flags) })?;
Ok(Self {
ptr: p as *mut T,
len,
context: context.clone(),
_marker: core::marker::PhantomData,
})
}
pub fn device_ptr(&self) -> Result<CUdeviceptr> {
if self.len == 0 {
return Ok(CUdeviceptr(0));
}
let d = driver()?;
let cu = d.cu_mem_host_get_device_pointer()?;
let mut dptr = CUdeviceptr(0);
check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
Ok(dptr)
}
pub fn flags(&self) -> Result<u32> {
if self.len == 0 {
return Ok(0);
}
let d = driver()?;
let cu = d.cu_mem_host_get_flags()?;
let mut flags: core::ffi::c_uint = 0;
check(unsafe { cu(&mut flags, self.ptr as *mut c_void) })?;
Ok(flags)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
}
impl<T: DeviceRepr> Deref for PinnedBuffer<T> {
type Target = [T];
fn deref(&self) -> &[T] {
unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
}
}
impl<T: DeviceRepr> DerefMut for PinnedBuffer<T> {
fn deref_mut(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T: DeviceRepr> Drop for PinnedBuffer<T> {
fn drop(&mut self) {
if self.len == 0 || self.ptr.is_null() {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_free_host() {
let _ = unsafe { cu(self.ptr as *mut c_void) };
}
}
}
}
pub struct PinnedRegistration<'a, T: DeviceRepr> {
ptr: *mut T,
len: usize,
_borrow: core::marker::PhantomData<&'a mut [T]>,
}
unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
unsafe impl<T: DeviceRepr + Sync> Sync for PinnedRegistration<'_, T> {}
impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
pub fn register(slice: &'a mut [T]) -> Result<Self> {
Self::register_with_flags(slice, 0)
}
pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
let d = driver()?;
let cu = d.cu_mem_host_register()?;
let bytes = core::mem::size_of_val(slice);
check(unsafe { cu(slice.as_mut_ptr() as *mut c_void, bytes, flags) })?;
Ok(Self {
ptr: slice.as_mut_ptr(),
len: slice.len(),
_borrow: core::marker::PhantomData,
})
}
pub fn device_ptr(&self) -> Result<CUdeviceptr> {
let d = driver()?;
let cu = d.cu_mem_host_get_device_pointer()?;
let mut dptr = CUdeviceptr(0);
check(unsafe { cu(&mut dptr, self.ptr as *mut c_void, 0) })?;
Ok(dptr)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PinnedRegistration")
.field("ptr", &self.ptr)
.field("len", &self.len)
.finish()
}
}
impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_host_unregister() {
let _ = unsafe { cu(self.ptr as *mut c_void) };
}
}
}
}