use core::ffi::c_void;
use core::marker::PhantomData;
use core::mem::size_of;
use core::ops::Range;
use baracuda_cuda_sys::{driver, CUdeviceptr};
use baracuda_types::{DeviceRepr, KernelArg};
use crate::context::Context;
use crate::error::{check, Result};
use crate::stream::Stream;
pub struct DeviceBuffer<T: DeviceRepr> {
ptr: CUdeviceptr,
len: usize,
context: Context,
_marker: PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DeviceBuffer")
.field("ptr", &format_args!("{:#x}", self.ptr.0))
.field("len", &self.len)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
impl<T: DeviceRepr> DeviceBuffer<T> {
pub fn new(context: &Context, len: usize) -> Result<Self> {
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
if bytes == 0 {
return Ok(Self {
ptr: CUdeviceptr(0),
len,
context: context.clone(),
_marker: PhantomData,
});
}
context.set_current()?;
let d = driver()?;
let cu = d.cu_mem_alloc()?;
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, bytes) })?;
Ok(Self {
ptr,
len,
context: context.clone(),
_marker: PhantomData,
})
}
pub fn new_async(context: &Context, len: usize, stream: &Stream) -> Result<Self> {
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
if bytes == 0 {
return Ok(Self {
ptr: CUdeviceptr(0),
len,
context: context.clone(),
_marker: PhantomData,
});
}
context.set_current()?;
let d = driver()?;
let cu = d.cu_mem_alloc_async()?;
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
Ok(Self {
ptr,
len,
context: context.clone(),
_marker: PhantomData,
})
}
pub fn free_async(mut self, stream: &Stream) -> Result<()> {
let ptr = core::mem::replace(&mut self.ptr, CUdeviceptr(0));
if ptr.0 == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_mem_free_async()?;
check(unsafe { cu(ptr, stream.as_raw()) })
}
pub fn zeros(context: &Context, len: usize) -> Result<Self> {
let buf = Self::new(context, len)?;
let bytes = len * size_of::<T>();
if bytes == 0 {
return Ok(buf);
}
let d = driver()?;
let cu = d.cu_memset_d8()?;
check(unsafe { cu(buf.ptr, 0, bytes) })?;
Ok(buf)
}
pub fn zero(&self) -> Result<()> {
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memset_d8()?;
check(unsafe { cu(self.ptr, 0, bytes) })
}
pub fn zero_async(&self, stream: &Stream) -> Result<()> {
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memset_d8_async()?;
check(unsafe { cu(self.ptr, 0, bytes, stream.as_raw()) })
}
pub fn from_slice(context: &Context, src: &[T]) -> Result<Self> {
let buf = Self::new(context, src.len())?;
buf.copy_from_host(src)?;
Ok(buf)
}
pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
assert_eq!(
src.len(),
self.len,
"copy_from_host: source length {} != buffer length {}",
src.len(),
self.len
);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_htod()?;
check(unsafe { cu(self.ptr, src.as_ptr() as *const c_void, bytes) })
}
pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
assert_eq!(
dst.len(),
self.len,
"copy_to_host: destination length {} != buffer length {}",
dst.len(),
self.len
);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_dtoh()?;
check(unsafe { cu(dst.as_mut_ptr() as *mut c_void, self.ptr, bytes) })
}
pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
assert_eq!(src.len(), self.len);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_htod_async()?;
check(unsafe {
cu(
self.ptr,
src.as_ptr() as *const c_void,
bytes,
stream.as_raw(),
)
})
}
pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
assert_eq!(dst.len(), self.len);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_dtoh_async()?;
check(unsafe {
cu(
dst.as_mut_ptr() as *mut c_void,
self.ptr,
bytes,
stream.as_raw(),
)
})
}
pub fn copy_to_device(&self, dst: &DeviceBuffer<T>) -> Result<()> {
assert_eq!(dst.len, self.len);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_dtod()?;
check(unsafe { cu(dst.ptr, self.ptr, bytes) })
}
pub fn copy_to_device_async(&self, dst: &DeviceBuffer<T>, stream: &Stream) -> Result<()> {
assert_eq!(dst.len, self.len);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_dtod_async()?;
check(unsafe { cu(dst.ptr, self.ptr, bytes, stream.as_raw()) })
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn byte_size(&self) -> usize {
self.len * size_of::<T>()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn context(&self) -> &Context {
&self.context
}
#[inline]
pub fn as_raw(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> DeviceSlice<'_, T> {
DeviceSlice {
ptr: self.ptr,
len: self.len,
_marker: PhantomData,
}
}
#[inline]
pub fn as_slice_mut(&mut self) -> DeviceSliceMut<'_, T> {
DeviceSliceMut {
ptr: self.ptr,
len: self.len,
_marker: PhantomData,
}
}
#[inline]
pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
assert!(
range.start <= range.end && range.end <= self.len,
"DeviceBuffer::slice({}..{}) out of bounds for len {}",
range.start,
range.end,
self.len,
);
DeviceSlice {
ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
len: range.end - range.start,
_marker: PhantomData,
}
}
#[inline]
pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
assert!(
range.start <= range.end && range.end <= self.len,
"DeviceBuffer::slice_mut({}..{}) out of bounds for len {}",
range.start,
range.end,
self.len,
);
DeviceSliceMut {
ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
len: range.end - range.start,
_marker: PhantomData,
}
}
}
impl DeviceBuffer<u8> {
#[inline]
pub fn view_as<U: DeviceRepr>(&self) -> DeviceSlice<'_, U> {
let elem = size_of::<U>();
if elem == 0 {
return DeviceSlice {
ptr: self.ptr,
len: 0,
_marker: PhantomData,
};
}
assert!(
self.len % elem == 0,
"DeviceBuffer<u8>::view_as: byte length {} not divisible by size_of::<{}>() = {}",
self.len,
core::any::type_name::<U>(),
elem,
);
DeviceSlice {
ptr: self.ptr,
len: self.len / elem,
_marker: PhantomData,
}
}
#[inline]
pub fn view_as_mut<U: DeviceRepr>(&mut self) -> DeviceSliceMut<'_, U> {
let elem = size_of::<U>();
if elem == 0 {
return DeviceSliceMut {
ptr: self.ptr,
len: 0,
_marker: PhantomData,
};
}
assert!(
self.len % elem == 0,
"DeviceBuffer<u8>::view_as_mut: byte length {} not divisible by size_of::<{}>() = {}",
self.len,
core::any::type_name::<U>(),
elem,
);
DeviceSliceMut {
ptr: self.ptr,
len: self.len / elem,
_marker: PhantomData,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
pub enum ManagedAttach {
#[default]
Global,
Host,
Single,
}
impl ManagedAttach {
#[inline]
fn raw(self) -> u32 {
use baracuda_cuda_sys::types::CUmemAttach_flags as F;
match self {
ManagedAttach::Global => F::GLOBAL,
ManagedAttach::Host => F::HOST,
ManagedAttach::Single => F::SINGLE,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum MemAdvise {
SetReadMostly,
UnsetReadMostly,
SetPreferredLocation,
UnsetPreferredLocation,
SetAccessedBy,
UnsetAccessedBy,
}
impl MemAdvise {
#[inline]
fn raw(self) -> i32 {
use baracuda_cuda_sys::types::CUmem_advise as A;
match self {
MemAdvise::SetReadMostly => A::SET_READ_MOSTLY,
MemAdvise::UnsetReadMostly => A::UNSET_READ_MOSTLY,
MemAdvise::SetPreferredLocation => A::SET_PREFERRED_LOCATION,
MemAdvise::UnsetPreferredLocation => A::UNSET_PREFERRED_LOCATION,
MemAdvise::SetAccessedBy => A::SET_ACCESSED_BY,
MemAdvise::UnsetAccessedBy => A::UNSET_ACCESSED_BY,
}
}
}
pub struct ManagedBuffer<T: DeviceRepr> {
ptr: CUdeviceptr,
len: usize,
context: Context,
_marker: PhantomData<T>,
}
unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ManagedBuffer")
.field("ptr", &format_args!("{:#x}", self.ptr.0))
.field("len", &self.len)
.field("type", &core::any::type_name::<T>())
.finish()
}
}
impl<T: DeviceRepr> ManagedBuffer<T> {
pub fn new(context: &Context, len: usize) -> Result<Self> {
Self::new_with_flags(context, len, ManagedAttach::Global)
}
pub fn new_with_flags(context: &Context, len: usize, attach: ManagedAttach) -> Result<Self> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_mem_alloc_managed()?;
let bytes = len
.checked_mul(size_of::<T>())
.expect("overflow computing allocation size");
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, bytes, attach.raw()) })?;
Ok(Self {
ptr,
len,
context: context.clone(),
_marker: PhantomData,
})
}
pub fn advise(&self, advice: MemAdvise, device: &crate::Device) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_advise()?;
let bytes = self.len * size_of::<T>();
check(unsafe { cu(self.ptr, bytes, advice.raw(), device.as_raw()) })
}
pub fn prefetch_async(&self, device: &crate::Device, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_prefetch_async()?;
let bytes = self.len * size_of::<T>();
check(unsafe { cu(self.ptr, bytes, device.as_raw(), stream.as_raw()) })
}
pub unsafe fn as_host_slice(&self) -> &[T] { unsafe {
core::slice::from_raw_parts(self.ptr.0 as *const T, self.len)
}}
pub unsafe fn as_host_slice_mut(&mut self) -> &mut [T] { unsafe {
core::slice::from_raw_parts_mut(self.ptr.0 as *mut T, self.len)
}}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_raw(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub fn context(&self) -> &Context {
&self.context
}
}
impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
fn drop(&mut self) {
if self.ptr.0 == 0 {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_free() {
let _ = unsafe { cu(self.ptr) };
}
}
}
}
pub fn mem_get_info() -> Result<(u64, u64)> {
let d = driver()?;
let cu = d.cu_mem_get_info()?;
let mut free: usize = 0;
let mut total: usize = 0;
check(unsafe { cu(&mut free, &mut total) })?;
Ok((free as u64, total as u64))
}
pub fn memcpy_peer<T: DeviceRepr>(
dst: &DeviceBuffer<T>,
dst_ctx: &Context,
src: &DeviceBuffer<T>,
src_ctx: &Context,
) -> Result<()> {
assert_eq!(dst.len(), src.len());
let d = driver()?;
let cu = d.cu_memcpy_peer()?;
let bytes = src.len() * size_of::<T>();
check(unsafe {
cu(
dst.as_raw(),
dst_ctx.as_raw(),
src.as_raw(),
src_ctx.as_raw(),
bytes,
)
})
}
pub fn memcpy_peer_async<T: DeviceRepr>(
dst: &DeviceBuffer<T>,
dst_ctx: &Context,
src: &DeviceBuffer<T>,
src_ctx: &Context,
stream: &Stream,
) -> Result<()> {
assert_eq!(dst.len(), src.len());
let d = driver()?;
let cu = d.cu_memcpy_peer_async()?;
let bytes = src.len() * size_of::<T>();
check(unsafe {
cu(
dst.as_raw(),
dst_ctx.as_raw(),
src.as_raw(),
src_ctx.as_raw(),
bytes,
stream.as_raw(),
)
})
}
pub fn memset_u16(dst: CUdeviceptr, value: u16, count: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d16()?;
check(unsafe { cu(dst, value, count) })
}
pub fn memset_u16_async(dst: CUdeviceptr, value: u16, count: usize, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d16_async()?;
check(unsafe { cu(dst, value, count, stream.as_raw()) })
}
pub fn memset_u8_async(dst: CUdeviceptr, value: u8, count: usize, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d8_async()?;
check(unsafe { cu(dst, value, count, stream.as_raw()) })
}
pub fn memset_u32_async(dst: CUdeviceptr, value: u32, count: usize, stream: &Stream) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d32_async()?;
check(unsafe { cu(dst, value, count, stream.as_raw()) })
}
pub fn memset_u32(dst: CUdeviceptr, value: u32, count: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d32()?;
check(unsafe { cu(dst, value, count) })
}
pub fn memset_2d_u8(
dst: CUdeviceptr,
pitch: usize,
value: u8,
width: usize,
height: usize,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d2d8()?;
check(unsafe { cu(dst, pitch, value, width, height) })
}
pub fn memset_2d_u16(
dst: CUdeviceptr,
pitch: usize,
value: u16,
width: usize,
height: usize,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d2d16()?;
check(unsafe { cu(dst, pitch, value, width, height) })
}
pub fn memset_2d_u32(
dst: CUdeviceptr,
pitch: usize,
value: u32,
width: usize,
height: usize,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_memset_d2d32()?;
check(unsafe { cu(dst, pitch, value, width, height) })
}
pub unsafe fn memcpy(dst: CUdeviceptr, src: CUdeviceptr, bytes: usize) -> Result<()> { unsafe {
let d = driver()?;
let cu = d.cu_memcpy()?;
check(cu(dst, src, bytes))
}}
pub unsafe fn memcpy_async(
dst: CUdeviceptr,
src: CUdeviceptr,
bytes: usize,
stream: &Stream,
) -> Result<()> { unsafe {
let d = driver()?;
let cu = d.cu_memcpy_async()?;
check(cu(dst, src, bytes, stream.as_raw()))
}}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PrefetchTarget {
Device(i32),
Host,
HostNuma(i32),
HostNumaCurrent,
}
impl PrefetchTarget {
fn as_location(self) -> baracuda_cuda_sys::types::CUmemLocation {
use baracuda_cuda_sys::types::CUmemLocationType;
let (type_, id) = match self {
PrefetchTarget::Device(i) => (CUmemLocationType::DEVICE, i),
PrefetchTarget::Host => (CUmemLocationType::HOST, 0),
PrefetchTarget::HostNuma(n) => (CUmemLocationType::HOST_NUMA, n),
PrefetchTarget::HostNumaCurrent => (CUmemLocationType::HOST_NUMA_CURRENT, 0),
};
baracuda_cuda_sys::types::CUmemLocation { type_, id }
}
}
pub fn mem_prefetch_v2(
dptr: CUdeviceptr,
count: usize,
target: PrefetchTarget,
stream: &Stream,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_prefetch_async_v2()?;
check(unsafe { cu(dptr, count, target.as_location(), 0, stream.as_raw()) })
}
pub fn mem_advise_v2(
dptr: CUdeviceptr,
count: usize,
advice: i32,
target: PrefetchTarget,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_advise_v2()?;
check(unsafe { cu(dptr, count, advice, target.as_location()) })
}
pub fn retain_allocation_handle(
addr: CUdeviceptr,
) -> Result<baracuda_cuda_sys::CUmemGenericAllocationHandle> {
let d = driver()?;
let cu = d.cu_mem_retain_allocation_handle()?;
let mut h: baracuda_cuda_sys::CUmemGenericAllocationHandle = 0;
check(unsafe { cu(&mut h, addr.0 as *mut core::ffi::c_void) })?;
Ok(h)
}
pub fn allocation_properties_from_handle(
handle: baracuda_cuda_sys::CUmemGenericAllocationHandle,
) -> Result<baracuda_cuda_sys::types::CUmemAllocationProp> {
let d = driver()?;
let cu = d.cu_mem_get_allocation_properties_from_handle()?;
let mut prop = baracuda_cuda_sys::types::CUmemAllocationProp::default();
check(unsafe { cu(&mut prop, handle) })?;
Ok(prop)
}
pub unsafe fn get_handle_for_address_range(
handle_out: *mut core::ffi::c_void,
dptr: CUdeviceptr,
size: usize,
handle_type: i32,
) -> Result<()> { unsafe {
let d = driver()?;
let cu = d.cu_mem_get_handle_for_address_range()?;
check(cu(handle_out, dptr, size, handle_type, 0))
}}
impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
fn drop(&mut self) {
if self.ptr.0 == 0 {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_free() {
let _ = unsafe { cu(self.ptr) };
}
}
}
}
#[derive(Copy, Clone)]
pub struct DeviceSlice<'a, T: DeviceRepr> {
pub(crate) ptr: CUdeviceptr,
pub(crate) len: usize,
pub(crate) _marker: PhantomData<&'a T>,
}
impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSlice<'a, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DeviceSlice")
.field("ptr", &format_args!("{:#x}", self.ptr.0))
.field("len", &self.len)
.finish()
}
}
impl<'a, T: DeviceRepr> DeviceSlice<'a, T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_raw(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSlice<'b, T> {
DeviceSlice {
ptr,
len,
_marker: PhantomData,
}
}
#[inline]
pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
assert!(
range.start <= range.end && range.end <= self.len,
"DeviceSlice::slice({}..{}) out of bounds for len {}",
range.start,
range.end,
self.len,
);
DeviceSlice {
ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
len: range.end - range.start,
_marker: PhantomData,
}
}
}
pub struct DeviceSliceMut<'a, T: DeviceRepr> {
pub(crate) ptr: CUdeviceptr,
pub(crate) len: usize,
pub(crate) _marker: PhantomData<&'a mut T>,
}
impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSliceMut<'a, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DeviceSliceMut")
.field("ptr", &format_args!("{:#x}", self.ptr.0))
.field("len", &self.len)
.finish()
}
}
impl<'a, T: DeviceRepr> DeviceSliceMut<'a, T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_raw(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSliceMut<'b, T> {
DeviceSliceMut {
ptr,
len,
_marker: PhantomData,
}
}
#[inline]
pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
assert!(
range.start <= range.end && range.end <= self.len,
"DeviceSliceMut::slice({}..{}) out of bounds for len {}",
range.start,
range.end,
self.len,
);
DeviceSlice {
ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
len: range.end - range.start,
_marker: PhantomData,
}
}
#[inline]
pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
assert!(
range.start <= range.end && range.end <= self.len,
"DeviceSliceMut::slice_mut({}..{}) out of bounds for len {}",
range.start,
range.end,
self.len,
);
DeviceSliceMut {
ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
len: range.end - range.start,
_marker: PhantomData,
}
}
pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
assert_eq!(src.len(), self.len);
let bytes = self.len * size_of::<T>();
if bytes == 0 {
return Ok(());
}
let d = driver()?;
let cu = d.cu_memcpy_htod_async()?;
check(unsafe {
cu(
self.ptr,
src.as_ptr() as *const c_void,
bytes,
stream.as_raw(),
)
})
}
}
pub unsafe trait DevicePtr<T: DeviceRepr> {
fn device_ptr(&self) -> CUdeviceptr;
fn len(&self) -> usize;
#[inline]
fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn byte_size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
}
pub unsafe trait DevicePtrMut<T: DeviceRepr>: DevicePtr<T> {
fn device_ptr_mut(&mut self) -> CUdeviceptr;
}
unsafe impl<T: DeviceRepr> DevicePtr<T> for DeviceBuffer<T> {
#[inline]
fn device_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
fn len(&self) -> usize {
self.len
}
}
unsafe impl<T: DeviceRepr> DevicePtrMut<T> for DeviceBuffer<T> {
#[inline]
fn device_ptr_mut(&mut self) -> CUdeviceptr {
self.ptr
}
}
unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSlice<'a, T> {
#[inline]
fn device_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
fn len(&self) -> usize {
self.len
}
}
unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSliceMut<'a, T> {
#[inline]
fn device_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
fn len(&self) -> usize {
self.len
}
}
unsafe impl<'a, T: DeviceRepr> DevicePtrMut<T> for DeviceSliceMut<'a, T> {
#[inline]
fn device_ptr_mut(&mut self) -> CUdeviceptr {
self.ptr
}
}
unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &P {
#[inline]
fn device_ptr(&self) -> CUdeviceptr {
(**self).device_ptr()
}
#[inline]
fn len(&self) -> usize {
(**self).len()
}
}
unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &mut P {
#[inline]
fn device_ptr(&self) -> CUdeviceptr {
(**self).device_ptr()
}
#[inline]
fn len(&self) -> usize {
(**self).len()
}
}
unsafe impl<T: DeviceRepr, P: DevicePtrMut<T> + ?Sized> DevicePtrMut<T> for &mut P {
#[inline]
fn device_ptr_mut(&mut self) -> CUdeviceptr {
(**self).device_ptr_mut()
}
}
unsafe impl<T: DeviceRepr> KernelArg for &DeviceBuffer<T> {
#[inline]
fn as_kernel_arg_ptr(&self) -> *mut c_void {
&self.ptr as *const CUdeviceptr as *mut c_void
}
}
unsafe impl<T: DeviceRepr> KernelArg for &mut DeviceBuffer<T> {
#[inline]
fn as_kernel_arg_ptr(&self) -> *mut c_void {
&self.ptr as *const CUdeviceptr as *mut c_void
}
}
unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSlice<'a, T> {
#[inline]
fn as_kernel_arg_ptr(&self) -> *mut c_void {
&self.ptr as *const CUdeviceptr as *mut c_void
}
}
unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSliceMut<'a, T> {
#[inline]
fn as_kernel_arg_ptr(&self) -> *mut c_void {
&self.ptr as *const CUdeviceptr as *mut c_void
}
}
unsafe impl<'a, T: DeviceRepr> KernelArg for &mut DeviceSliceMut<'a, T> {
#[inline]
fn as_kernel_arg_ptr(&self) -> *mut c_void {
&self.ptr as *const CUdeviceptr as *mut c_void
}
}
#[cfg(test)]
mod slice_tests {
use super::*;
fn fake_slice<T: DeviceRepr>(ptr: u64, len: usize) -> DeviceSlice<'static, T> {
DeviceSlice {
ptr: CUdeviceptr(ptr),
len,
_marker: PhantomData,
}
}
#[test]
fn slice_offsets_ptr_by_element_bytes() {
let s: DeviceSlice<'_, f32> = fake_slice(0x1000, 16);
let sub = s.slice(4..12);
assert_eq!(sub.len(), 8);
assert_eq!(sub.as_raw().0, 0x1000 + 4 * 4); }
#[test]
fn slice_of_slice_stays_correct() {
let s: DeviceSlice<'_, f64> = fake_slice(0x2000, 100);
let mid = s.slice(10..90);
let inner = mid.slice(5..15);
assert_eq!(inner.len(), 10);
assert_eq!(inner.as_raw().0, 0x2000 + 15 * 8);
}
#[test]
#[should_panic(expected = "out of bounds")]
fn slice_end_past_len_panics() {
let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
let _ = s.slice(0..11);
}
#[test]
#[should_panic(expected = "out of bounds")]
#[allow(clippy::reversed_empty_ranges)]
fn slice_inverted_range_panics() {
let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
let _ = s.slice(5..3);
}
#[test]
fn from_raw_parts_preserves_ptr_and_len() {
let s: DeviceSlice<'static, f32> =
unsafe { DeviceSlice::from_raw_parts(CUdeviceptr(0x4000), 32) };
assert_eq!(s.as_raw().0, 0x4000);
assert_eq!(s.len(), 32);
}
#[test]
fn from_raw_parts_mut_preserves_ptr_and_len() {
let s: DeviceSliceMut<'static, u32> =
unsafe { DeviceSliceMut::from_raw_parts(CUdeviceptr(0x8000), 64) };
assert_eq!(s.as_raw().0, 0x8000);
assert_eq!(s.len(), 64);
}
}
#[cfg(test)]
mod kernel_arg_tests {
use super::*;
use core::mem::size_of;
#[test]
fn slice_kernel_arg_points_at_ptr_field() {
let slice: DeviceSlice<'_, f32> = DeviceSlice {
ptr: CUdeviceptr(0xDEAD_BEEF_u64),
len: 42,
_marker: PhantomData,
};
let kernel_arg = (&slice).as_kernel_arg_ptr();
unsafe {
let as_u64 = *(kernel_arg as *const u64);
assert_eq!(as_u64, 0xDEAD_BEEF);
}
let slice_start = &slice as *const _ as usize;
let slice_end = slice_start + size_of::<DeviceSlice<'_, f32>>();
let arg_addr = kernel_arg as usize;
assert!((slice_start..slice_end).contains(&arg_addr));
}
}