use core::ffi::c_void;
use core::marker::PhantomData;
use core::mem::size_of;
use baracuda_cuda_sys::runtime::{cudaMemcpyKind, runtime};
use baracuda_types::DeviceRepr;
use crate::error::{check, Result};
use crate::stream::Stream;
pub struct DeviceBuffer<T: DeviceRepr> {
ptr: *mut c_void,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DeviceBuffer")
.field("ptr", &self.ptr)
.field("len", &self.len)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
impl<T: DeviceRepr> DeviceBuffer<T> {
pub fn new(len: usize) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_malloc()?;
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
let mut ptr: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut ptr, bytes) })?;
Ok(Self {
ptr,
len,
_marker: PhantomData,
})
}
pub fn zeros(len: usize) -> Result<Self> {
let buf = Self::new(len)?;
let r = runtime()?;
let cu = r.cuda_memset()?;
let bytes = len * size_of::<T>();
check(unsafe { cu(buf.ptr, 0, bytes) })?;
Ok(buf)
}
pub fn from_slice(src: &[T]) -> Result<Self> {
let buf = Self::new(src.len())?;
buf.copy_from_host(src)?;
Ok(buf)
}
pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
assert_eq!(src.len(), self.len);
let r = runtime()?;
let cu = r.cuda_memcpy()?;
let bytes = self.len * size_of::<T>();
check(unsafe {
cu(
self.ptr,
src.as_ptr() as *const c_void,
bytes,
cudaMemcpyKind::HostToDevice,
)
})
}
pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
assert_eq!(dst.len(), self.len);
let r = runtime()?;
let cu = r.cuda_memcpy()?;
let bytes = self.len * size_of::<T>();
check(unsafe {
cu(
dst.as_mut_ptr() as *mut c_void,
self.ptr,
bytes,
cudaMemcpyKind::DeviceToHost,
)
})
}
pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
assert_eq!(src.len(), self.len);
let r = runtime()?;
let cu = r.cuda_memcpy_async()?;
let bytes = self.len * size_of::<T>();
check(unsafe {
cu(
self.ptr,
src.as_ptr() as *const c_void,
bytes,
cudaMemcpyKind::HostToDevice,
stream.as_raw(),
)
})
}
pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
assert_eq!(dst.len(), self.len);
let r = runtime()?;
let cu = r.cuda_memcpy_async()?;
let bytes = self.len * size_of::<T>();
check(unsafe {
cu(
dst.as_mut_ptr() as *mut c_void,
self.ptr,
bytes,
cudaMemcpyKind::DeviceToHost,
stream.as_raw(),
)
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * size_of::<T>()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_raw(&self) -> *mut c_void {
self.ptr
}
#[inline]
pub fn as_device_ptr(&self) -> u64 {
self.ptr as u64
}
}
impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_free() {
let _ = unsafe { cu(self.ptr) };
}
}
}
}
pub fn mem_get_info() -> Result<(u64, u64)> {
let r = runtime()?;
let cu = r.cuda_mem_get_info()?;
let mut free: usize = 0;
let mut total: usize = 0;
check(unsafe { cu(&mut free, &mut total) })?;
Ok((free as u64, total as u64))
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PrefetchTarget {
Device(i32),
Host,
}
impl PrefetchTarget {
#[inline]
fn as_raw(self) -> i32 {
match self {
PrefetchTarget::Device(i) => i,
PrefetchTarget::Host => -1, }
}
}
pub unsafe fn mem_prefetch_async(
dev_ptr: *const core::ffi::c_void,
count: usize,
target: PrefetchTarget,
stream: &Stream,
) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_prefetch_async()?;
check(cu(dev_ptr, count, target.as_raw(), stream.as_raw()))
}
pub unsafe fn mem_advise(
dev_ptr: *const core::ffi::c_void,
count: usize,
advice: i32,
target: PrefetchTarget,
) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_mem_advise()?;
check(cu(dev_ptr, count, advice, target.as_raw()))
}
pub struct ManagedBuffer<T: DeviceRepr> {
ptr: *mut T,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
unsafe impl<T: DeviceRepr + Sync> Sync for ManagedBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ManagedBuffer")
.field("ptr", &self.ptr)
.field("len", &self.len)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
impl<T: DeviceRepr> ManagedBuffer<T> {
pub fn new(len: usize) -> Result<Self> {
use baracuda_cuda_sys::runtime::types::cudaMemAttach;
Self::with_flags(len, cudaMemAttach::GLOBAL)
}
pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_malloc_managed()?;
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
let mut ptr: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut ptr, bytes, flags) })?;
Ok(Self {
ptr: ptr as *mut T,
len,
_marker: PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
pub fn as_slice(&self) -> &[T] {
unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_free() {
let _ = unsafe { cu(self.ptr as *mut c_void) };
}
}
}
}
pub mod pinned_flags {
pub use baracuda_cuda_sys::runtime::types::cudaHostAllocFlags::*;
}
pub struct PinnedHostBuffer<T: DeviceRepr> {
ptr: *mut T,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for PinnedHostBuffer<T> {}
unsafe impl<T: DeviceRepr + Sync> Sync for PinnedHostBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for PinnedHostBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PinnedHostBuffer")
.field("ptr", &self.ptr)
.field("len", &self.len)
.finish()
}
}
impl<T: DeviceRepr> PinnedHostBuffer<T> {
pub fn new(len: usize) -> Result<Self> {
Self::with_flags(len, 0)
}
pub fn with_flags(len: usize, flags: u32) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_host_alloc()?;
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
let mut ptr: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut ptr, bytes, flags) })?;
Ok(Self {
ptr: ptr as *mut T,
len,
_marker: PhantomData,
})
}
pub fn device_ptr(&self) -> Result<*mut c_void> {
let r = runtime()?;
let cu = r.cuda_host_get_device_pointer()?;
let mut dev: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut dev, self.ptr as *mut c_void, 0) })?;
Ok(dev)
}
pub fn flags(&self) -> Result<u32> {
let r = runtime()?;
let cu = r.cuda_host_get_flags()?;
let mut f: core::ffi::c_uint = 0;
check(unsafe { cu(&mut f, self.ptr as *mut c_void) })?;
Ok(f)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
}
impl<T: DeviceRepr> core::ops::Deref for PinnedHostBuffer<T> {
type Target = [T];
fn deref(&self) -> &[T] {
unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
}
}
impl<T: DeviceRepr> core::ops::DerefMut for PinnedHostBuffer<T> {
fn deref_mut(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T: DeviceRepr> Drop for PinnedHostBuffer<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_free_host() {
let _ = unsafe { cu(self.ptr as *mut c_void) };
}
}
}
}
pub struct PinnedRegistration<'a, T: DeviceRepr> {
ptr: *mut T,
len: usize,
_borrow: PhantomData<&'a mut [T]>,
}
unsafe impl<T: DeviceRepr + Send> Send for PinnedRegistration<'_, T> {}
impl<T: DeviceRepr> core::fmt::Debug for PinnedRegistration<'_, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PinnedRegistration")
.field("ptr", &self.ptr)
.field("len", &self.len)
.finish()
}
}
impl<'a, T: DeviceRepr> PinnedRegistration<'a, T> {
pub fn register(slice: &'a mut [T]) -> Result<Self> {
Self::register_with_flags(slice, 0)
}
pub fn register_with_flags(slice: &'a mut [T], flags: u32) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_host_register()?;
check(unsafe {
cu(
slice.as_mut_ptr() as *mut c_void,
core::mem::size_of_val(slice),
flags,
)
})?;
Ok(Self {
ptr: slice.as_mut_ptr(),
len: slice.len(),
_borrow: PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<T: DeviceRepr> Drop for PinnedRegistration<'_, T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_host_unregister() {
let _ = unsafe { cu(self.ptr as *mut c_void) };
}
}
}
}
impl<T: DeviceRepr> DeviceBuffer<T> {
pub fn new_async(len: usize, stream: &Stream) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_malloc_async()?;
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
let mut ptr: *mut c_void = core::ptr::null_mut();
check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
Ok(Self {
ptr,
len,
_marker: PhantomData,
})
}
pub fn free_async(mut self, stream: &Stream) -> Result<()> {
let ptr = core::mem::replace(&mut self.ptr, core::ptr::null_mut());
if ptr.is_null() {
return Ok(());
}
let r = runtime()?;
let cu = r.cuda_free_async()?;
check(unsafe { cu(ptr, stream.as_raw()) })
}
pub fn memset_async(&self, value: u8, stream: &Stream) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_memset_async()?;
let bytes = self.len * size_of::<T>();
check(unsafe { cu(self.ptr, value as core::ffi::c_int, bytes, stream.as_raw()) })
}
}
pub fn memcpy_peer<T: DeviceRepr>(
dst: &DeviceBuffer<T>,
dst_device: &crate::Device,
src: &DeviceBuffer<T>,
src_device: &crate::Device,
) -> Result<()> {
assert_eq!(dst.len(), src.len());
let r = runtime()?;
let cu = r.cuda_memcpy_peer()?;
let bytes = src.len() * size_of::<T>();
check(unsafe {
cu(
dst.as_raw(),
dst_device.ordinal(),
src.as_raw(),
src_device.ordinal(),
bytes,
)
})
}
pub fn memcpy_peer_async<T: DeviceRepr>(
dst: &DeviceBuffer<T>,
dst_device: &crate::Device,
src: &DeviceBuffer<T>,
src_device: &crate::Device,
stream: &Stream,
) -> Result<()> {
assert_eq!(dst.len(), src.len());
let r = runtime()?;
let cu = r.cuda_memcpy_peer_async()?;
let bytes = src.len() * size_of::<T>();
check(unsafe {
cu(
dst.as_raw(),
dst_device.ordinal(),
src.as_raw(),
src_device.ordinal(),
bytes,
stream.as_raw(),
)
})
}