use std::{
fmt::{self, Display, Formatter},
marker::PhantomData,
mem::{self, MaybeUninit},
ptr,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda_sys::{driver, runtime};
use crate::{
error::{Error, Result},
ipc::IpcMemoryHandle,
stream::{Stream, StreamScope},
try_cuda,
types::DevicePtr,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemoryCopyKind {
HostToHost = runtime::cudaMemcpyKind::cudaMemcpyHostToHost as _,
HostToDevice = runtime::cudaMemcpyKind::cudaMemcpyHostToDevice as _,
DeviceToHost = runtime::cudaMemcpyKind::cudaMemcpyDeviceToHost as _,
DeviceToDevice = runtime::cudaMemcpyKind::cudaMemcpyDeviceToDevice as _,
Default = runtime::cudaMemcpyKind::cudaMemcpyDefault as _,
}
impl_enum_conversion!(runtime::cudaMemcpyKind, MemoryCopyKind);
impl Display for MemoryCopyKind {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::HostToHost => write!(f, "cudaMemcpyHostToHost"),
Self::HostToDevice => write!(f, "cudaMemcpyHostToDevice"),
Self::DeviceToHost => write!(f, "cudaMemcpyDeviceToHost"),
Self::DeviceToDevice => write!(f, "cudaMemcpyDeviceToDevice"),
Self::Default => write!(f, "cudaMemcpyDefault"),
}
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MemoryAttachFlags: u32 {
const GLOBAL = driver::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL as _;
const HOST = driver::CUmemAttach_flags::CU_MEM_ATTACH_HOST as _;
const SINGLE = driver::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE as _;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemAllocationType {
Invalid = driver::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_INVALID as _,
Pinned = driver::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED as _,
Managed = driver::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_MANAGED as _,
Max = driver::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_MAX as _,
}
impl_enum_conversion!(u32, driver::CUmemAllocationType, MemAllocationType);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemAllocationHandleType {
None = driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_NONE as _,
PosixFileDescriptor =
driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR as _,
Win32 = driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_WIN32 as _,
Win32Kmt = driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_WIN32_KMT as _,
Fabric = driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_FABRIC as _,
Max = driver::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_MAX as _,
}
impl_enum_conversion!(
u32,
driver::CUmemAllocationHandleType,
MemAllocationHandleType
);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemAccessFlag {
None = driver::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_NONE as _,
Read = driver::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READ as _,
ReadWrite = driver::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE as _,
Max = driver::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_MAX as _,
}
impl_enum_conversion!(u32, driver::CUmemAccess_flags, MemAccessFlag);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemoryPoolAttribute {
ReuseFollowEventDependencies =
driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES as _,
ReuseAllowOpportunistic =
driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC as _,
ReuseAllowInternalDependencies =
driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES as _,
ReleaseThreshold = driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD as _,
ReservedMemoryCurrent = driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT as _,
ReservedMemoryHigh = driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH as _,
UsedMemoryCurrent = driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_USED_MEM_CURRENT as _,
UsedMemoryHigh = driver::CUmemPool_attribute::CU_MEMPOOL_ATTR_USED_MEM_HIGH as _,
}
impl_enum_conversion!(u32, driver::CUmemPool_attribute, MemoryPoolAttribute);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryPoolAttributeValue {
Bool(bool),
Bytes(u64),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MemAccessDescriptor {
pub location: MemoryLocation,
pub flags: MemAccessFlag,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MemoryPoolProps {
pub alloc_type: MemAllocationType,
pub handle_type: MemAllocationHandleType,
pub location: MemoryLocation,
pub max_size: usize,
pub usage: u16,
}
#[derive(Debug)]
pub struct MemoryPool {
handle: driver::CUmemoryPool,
}
impl From<MemAccessDescriptor> for driver::CUmemAccessDesc {
fn from(value: MemAccessDescriptor) -> Self {
Self {
location: value.location.into(),
flags: value.flags.into(),
}
}
}
impl From<MemoryPoolProps> for driver::CUmemPoolProps {
fn from(value: MemoryPoolProps) -> Self {
Self {
allocType: value.alloc_type.into(),
handleTypes: value.handle_type.into(),
location: value.location.into(),
win32SecurityAttributes: ptr::null_mut(),
maxSize: value.max_size as _,
usage: value.usage,
reserved: [0; 54],
}
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct HostAllocationFlags: u32 {
const DEFAULT = runtime::cudaHostAllocDefault;
const PORTABLE = runtime::cudaHostAllocPortable;
const MAPPED = runtime::cudaHostAllocMapped;
const WRITE_COMBINED = runtime::cudaHostAllocWriteCombined;
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct HostRegisterFlags: u32 {
const DEFAULT = runtime::cudaHostRegisterDefault;
const PORTABLE = runtime::cudaHostRegisterPortable;
const MAPPED = runtime::cudaHostRegisterMapped;
const IO_MEMORY = runtime::cudaHostRegisterIoMemory;
const READ_ONLY = runtime::cudaHostRegisterReadOnly;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum MemoryType {
Unregistered = runtime::cudaMemoryType::cudaMemoryTypeUnregistered as _,
Host = runtime::cudaMemoryType::cudaMemoryTypeHost as _,
Device = runtime::cudaMemoryType::cudaMemoryTypeDevice as _,
Managed = runtime::cudaMemoryType::cudaMemoryTypeManaged as _,
}
impl_enum_conversion!(runtime::cudaMemoryType, MemoryType);
impl Display for MemoryType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Unregistered => write!(f, "cudaMemoryTypeUnregistered"),
Self::Host => write!(f, "cudaMemoryTypeHost"),
Self::Device => write!(f, "cudaMemoryTypeDevice"),
Self::Managed => write!(f, "cudaMemoryTypeManaged"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PointerAttributes {
pub memory_type: MemoryType,
pub device: i32,
pub device_pointer: DevicePtr,
pub host_pointer: *mut (),
}
impl From<runtime::cudaPointerAttributes> for PointerAttributes {
fn from(attr: runtime::cudaPointerAttributes) -> Self {
Self {
memory_type: attr.type_.into(),
device: attr.device,
device_pointer: DevicePtr::from(attr.devicePointer),
host_pointer: attr.hostPointer.cast(),
}
}
}
#[repr(u32)]
#[derive(
Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq, TryFromPrimitive, IntoPrimitive,
)]
pub enum MemoryLocationKind {
Invalid = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_INVALID as _,
Device = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_DEVICE as _,
Host = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_HOST as _,
Numa = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_HOST_NUMA as _,
NumaCurrent = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT as _,
Max = driver::CUmemLocationType_enum::CU_MEM_LOCATION_TYPE_MAX as _,
}
impl_enum_conversion!(driver::CUmemLocationType_enum, MemoryLocationKind);
impl Display for MemoryLocationKind {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Invalid => write!(f, "CU_MEM_LOCATION_TYPE_INVALID"),
Self::Device => write!(f, "CU_MEM_LOCATION_TYPE_DEVICE"),
Self::Host => write!(f, "CU_MEM_LOCATION_TYPE_HOST"),
Self::Numa => write!(f, "CU_MEM_LOCATION_TYPE_HOST_NUMA"),
Self::NumaCurrent => {
write!(f, "CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT")
}
Self::Max => write!(f, "CU_MEM_LOCATION_TYPE_MAX"),
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub struct MemoryLocation {
pub kind: MemoryLocationKind,
pub id: i32,
}
impl From<driver::CUmemLocation_st> for MemoryLocation {
fn from(s: driver::CUmemLocation_st) -> Self {
Self {
kind: s.type_.into(),
id: unsafe { s.__bindgen_anon_1.id },
}
}
}
impl From<MemoryLocation> for driver::CUmemLocation_st {
fn from(m: MemoryLocation) -> Self {
Self {
type_: m.kind.into(),
__bindgen_anon_1: driver::CUmemLocation_st__bindgen_ty_1 { id: m.id as _ },
}
}
}
impl Default for MemoryLocation {
fn default() -> Self {
driver::CUmemLocation_st::default().into()
}
}
impl MemoryPool {
pub fn create(props: MemoryPoolProps) -> Result<Self> {
let mut handle = ptr::null_mut();
let props = driver::CUmemPoolProps::from(props);
unsafe {
try_cuda!(driver::cuMemPoolCreate(&raw mut handle, &raw const props))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self { handle })
}
pub fn set_attribute(
&mut self,
attribute: MemoryPoolAttribute,
value: MemoryPoolAttributeValue,
) -> Result<()> {
unsafe {
match (attribute, value) {
(
MemoryPoolAttribute::ReuseFollowEventDependencies
| MemoryPoolAttribute::ReuseAllowOpportunistic
| MemoryPoolAttribute::ReuseAllowInternalDependencies,
MemoryPoolAttributeValue::Bool(value),
) => {
let mut value = u32::from(value);
try_cuda!(driver::cuMemPoolSetAttribute(
self.handle,
attribute.into(),
ptr::from_mut(&mut value).cast(),
))?;
}
(
MemoryPoolAttribute::ReleaseThreshold
| MemoryPoolAttribute::ReservedMemoryCurrent
| MemoryPoolAttribute::ReservedMemoryHigh
| MemoryPoolAttribute::UsedMemoryCurrent
| MemoryPoolAttribute::UsedMemoryHigh,
MemoryPoolAttributeValue::Bytes(value),
) => {
let mut value = value;
try_cuda!(driver::cuMemPoolSetAttribute(
self.handle,
attribute.into(),
ptr::from_mut(&mut value).cast(),
))?;
}
_ => return Err(Error::InvalidValue),
}
}
Ok(())
}
pub fn attribute(&self, attribute: MemoryPoolAttribute) -> Result<MemoryPoolAttributeValue> {
unsafe {
match attribute {
MemoryPoolAttribute::ReuseFollowEventDependencies
| MemoryPoolAttribute::ReuseAllowOpportunistic
| MemoryPoolAttribute::ReuseAllowInternalDependencies => {
let mut value = 0u32;
try_cuda!(driver::cuMemPoolGetAttribute(
self.handle,
attribute.into(),
ptr::from_mut(&mut value).cast(),
))?;
Ok(MemoryPoolAttributeValue::Bool(value != 0))
}
MemoryPoolAttribute::ReleaseThreshold
| MemoryPoolAttribute::ReservedMemoryCurrent
| MemoryPoolAttribute::ReservedMemoryHigh
| MemoryPoolAttribute::UsedMemoryCurrent
| MemoryPoolAttribute::UsedMemoryHigh => {
let mut value = 0u64;
try_cuda!(driver::cuMemPoolGetAttribute(
self.handle,
attribute.into(),
ptr::from_mut(&mut value).cast(),
))?;
Ok(MemoryPoolAttributeValue::Bytes(value))
}
}
}
}
pub fn set_access(&mut self, access_descs: &[MemAccessDescriptor]) -> Result<()> {
let access_descs: Vec<_> = access_descs.iter().copied().map(Into::into).collect();
unsafe {
try_cuda!(driver::cuMemPoolSetAccess(
self.handle,
access_descs.as_ptr(),
access_descs.len() as _,
))?;
}
Ok(())
}
pub fn access(&self, location: MemoryLocation) -> Result<MemAccessFlag> {
let mut flags = driver::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_NONE;
let mut location = driver::CUmemLocation_st::from(location);
unsafe {
try_cuda!(driver::cuMemPoolGetAccess(
&raw mut flags,
self.handle,
&raw mut location,
))?;
}
Ok(flags.into())
}
pub fn trim_to(&mut self, min_bytes_to_keep: usize) -> Result<()> {
unsafe {
try_cuda!(driver::cuMemPoolTrimTo(self.handle, min_bytes_to_keep as _))?;
}
Ok(())
}
pub const unsafe fn as_raw(&self) -> driver::CUmemoryPool {
self.handle
}
}
impl Drop for MemoryPool {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_cuda!(driver::cuMemPoolDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cuda memory pool: {err}");
}
}
}
}
#[derive(Debug)]
pub struct DeviceMemory<T> {
ptr: *mut T,
length: usize,
_phantom: PhantomData<T>,
}
impl<T> DeviceMemory<T> {
pub unsafe fn alloc(count: usize) -> Result<*mut T> {
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
let mut p = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaMalloc(&raw mut p, bytes as _))?;
}
Ok(p.cast())
}
pub unsafe fn alloc_managed(count: usize, flags: MemoryAttachFlags) -> Result<*mut T> {
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
if bytes == 0 {
return Ok(ptr::null_mut());
}
let mut p = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaMallocManaged(
&raw mut p,
bytes as _,
flags.bits(),
))?;
}
Ok(p.cast::<T>())
}
pub unsafe fn free(ptr: *mut T) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaFree(ptr.cast()))?;
}
Ok(())
}
pub unsafe fn copy(
dst: *mut T,
src: *const T,
count: usize,
kind: MemoryCopyKind,
) -> Result<()> {
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
unsafe {
try_cuda!(runtime::cudaMemcpy(
dst.cast(),
src.cast(),
bytes as _,
kind.into(),
))?;
}
Ok(())
}
pub unsafe fn set(dst: *mut T, value: u8, count: usize) -> Result<()> {
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
unsafe {
try_cuda!(runtime::cudaMemset(dst.cast(), value.into(), bytes as _))?;
}
Ok(())
}
pub unsafe fn alloc_host(size: usize) -> Result<*mut ()> {
let mut ptr = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaMallocHost(
&raw mut ptr,
size as runtime::size_t
))?;
}
Ok(ptr.cast())
}
pub unsafe fn free_host(ptr: *mut ()) -> Result<()> {
unsafe { try_cuda!(runtime::cudaFreeHost(ptr.cast())) }
}
pub unsafe fn alloc_pinned(size: usize, flags: HostAllocationFlags) -> Result<*mut ()> {
let mut ptr = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaHostAlloc(
&raw mut ptr,
size as _,
flags.bits()
))?;
}
Ok(ptr.cast())
}
pub unsafe fn register_host(ptr: *mut (), size: usize, flags: HostRegisterFlags) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaHostRegister(
ptr.cast(),
size as _,
flags.bits()
))?;
}
Ok(())
}
pub unsafe fn unregister_host(ptr: *mut ()) -> Result<()> {
unsafe { try_cuda!(runtime::cudaHostUnregister(ptr.cast())) }
}
pub fn memory_info() -> Result<(usize, usize)> {
let mut free: runtime::size_t = 0;
let mut total: runtime::size_t = 0;
unsafe {
try_cuda!(runtime::cudaMemGetInfo(&raw mut free, &raw mut total))?;
}
Ok((free as usize, total as usize))
}
pub fn pointer_attributes(ptr: *const T) -> Result<PointerAttributes> {
let mut attr_ffi = MaybeUninit::<runtime::cudaPointerAttributes>::uninit();
unsafe {
try_cuda!(runtime::cudaPointerGetAttributes(
attr_ffi.as_mut_ptr(),
ptr.cast(),
))?;
Ok(attr_ffi.assume_init().into())
}
}
pub unsafe fn alloc_async(count: usize, stream: &Stream) -> Result<*mut T> {
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
if bytes == 0 {
return Ok(ptr::null_mut());
}
let mut p = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaMallocAsync(
&raw mut p,
bytes as _,
stream.as_raw()
))?;
}
Ok(p.cast::<T>())
}
pub unsafe fn free_async(ptr: *mut T, stream: &Stream) -> Result<()> {
if ptr.is_null() {
return Ok(());
}
unsafe { try_cuda!(runtime::cudaFreeAsync(ptr.cast(), stream.as_raw())) }
}
pub unsafe fn copy_async(
dst: *mut T,
src: *const T,
count: usize,
kind: MemoryCopyKind,
stream: &Stream,
) -> Result<()> {
if count == 0 {
return Ok(());
}
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
unsafe {
try_cuda!(runtime::cudaMemcpyAsync(
dst.cast(),
src.cast(),
bytes as _,
kind.into(),
stream.as_raw(),
))?;
}
Ok(())
}
pub unsafe fn set_async(dst: *mut T, value: u8, count: usize, stream: &Stream) -> Result<()> {
if count == 0 {
return Ok(());
}
let Some(bytes) = count.checked_mul(size_of::<T>()) else {
return Err(Error::InvalidMemoryAllocationRequest);
};
unsafe {
try_cuda!(runtime::cudaMemsetAsync(
dst.cast(),
value.into(),
bytes as _,
stream.as_raw(),
))?;
}
Ok(())
}
pub fn prefetch_async(
ptr: DevicePtr,
count: usize,
location: MemoryLocation,
stream: &Stream,
) -> Result<()> {
if count == 0 {
return Ok(());
}
unsafe {
try_cuda!(runtime::cudaMemPrefetchAsync(
ptr.as_ptr() as _,
count as _,
location.into(),
0, stream.as_raw()
))?;
}
Ok(())
}
}
unsafe impl<T: Send> Send for DeviceMemory<T> {}
unsafe impl<T: Sync> Sync for DeviceMemory<T> {}
impl<T> DeviceMemory<T> {
pub unsafe fn from_raw_parts(ptr: *mut T, length: usize) -> Self {
Self {
ptr,
length,
_phantom: PhantomData,
}
}
pub fn into_raw_parts(self) -> (*mut T, usize) {
let ptr = self.ptr;
let length = self.length;
mem::forget(self);
(ptr, length)
}
pub fn create(length: usize) -> Result<Self> {
let size_t = size_of::<T>();
if size_t == 0 {
if length == 0 {
return Ok(Self {
ptr: ptr::null_mut(), length: 0,
_phantom: PhantomData,
});
}
return Err(Error::InvalidMemoryAllocationRequest);
}
if length > (usize::MAX / size_t) {
return Err(Error::InvalidMemoryAllocationRequest);
}
if length == 0 {
Ok(Self {
ptr: ptr::null_mut(),
length: 0,
_phantom: PhantomData,
})
} else {
let device_ptr = unsafe { Self::alloc(length)? };
Ok(Self {
ptr: device_ptr,
length,
_phantom: PhantomData,
})
}
}
pub fn zeroes(length: usize) -> Result<Self> {
let mut mem = Self::create(length)?;
mem.set_zeroes()?;
Ok(mem)
}
pub fn from_slice(v: &[T]) -> Result<Self> {
let mut mem = Self::create(v.len())?;
mem.copy_from_host(v)?;
Ok(mem)
}
pub unsafe fn from_slice_async(v: &[T], stream: &Stream) -> Result<Self> {
let mut mem = Self::create(v.len())?;
unsafe {
mem.copy_from_host_async_unchecked(v, stream)?;
}
Ok(mem)
}
pub const fn len(&self) -> usize {
self.length
}
pub const fn is_empty(&self) -> bool {
self.length == 0
}
pub const fn size(&self) -> usize {
self.length.saturating_mul(size_of::<T>())
}
pub const fn as_ptr(&self) -> *const T {
self.ptr
}
pub const fn as_mut_ptr(&self) -> *mut T {
self.ptr
}
pub fn copy_from_host(&mut self, host_slice: &[T]) -> Result<()> {
if host_slice.len() != self.length {
return Err(Error::InvalidMemoryAccess);
}
if self.length == 0 {
return Ok(());
}
unsafe {
Self::copy(
self.ptr,
host_slice.as_ptr(),
self.length,
MemoryCopyKind::HostToDevice,
)
}
}
pub fn copy_from_host_async<'scope, 'env>(
&mut self,
host_slice: &'env [T],
stream: &StreamScope<'scope, 'env>,
) -> Result<()> {
unsafe { self.copy_from_host_async_unchecked(host_slice, stream.stream()) }
}
pub unsafe fn copy_from_host_async_unchecked(
&mut self,
host_slice: &[T],
stream: &Stream,
) -> Result<()> {
if host_slice.len() != self.len() {
return Err(Error::InvalidMemoryAccess);
}
if self.is_empty() {
return Ok(());
}
unsafe {
Self::copy_async(
self.as_mut_ptr(),
host_slice.as_ptr(),
self.len(),
MemoryCopyKind::HostToDevice,
stream,
)
}
}
pub fn copy_to_host(&self, host_slice: &mut [T]) -> Result<()> {
if host_slice.len() != self.length {
return Err(Error::InvalidMemoryAccess);
}
if self.length == 0 {
return Ok(());
}
unsafe {
Self::copy(
host_slice.as_mut_ptr(),
self.ptr,
self.length,
MemoryCopyKind::DeviceToHost,
)
}
}
pub fn copy_to_host_async<'scope, 'env>(
&self,
host_slice: &'env mut [T],
stream: &StreamScope<'scope, 'env>,
) -> Result<()> {
unsafe { self.copy_to_host_async_unchecked(host_slice, stream.stream()) }
}
pub unsafe fn copy_to_host_async_unchecked(
&self,
host_slice: &mut [T],
stream: &Stream,
) -> Result<()> {
if host_slice.len() != self.len() {
return Err(Error::InvalidMemoryAccess);
}
if self.is_empty() {
return Ok(());
}
unsafe {
Self::copy_async(
host_slice.as_mut_ptr(),
self.as_ptr(),
self.len(),
MemoryCopyKind::DeviceToHost,
stream,
)
}
}
pub fn copy_to_host_vec(&self) -> Result<Vec<T>> {
if size_of::<T>() == 0 {
return Err(Error::InvalidMemoryAllocationRequest);
}
if self.length == 0 {
return Ok(Vec::new());
}
let mut host_vec = Vec::<T>::with_capacity(self.length);
unsafe {
Self::copy(
host_vec.as_mut_ptr(),
self.ptr,
self.length,
MemoryCopyKind::DeviceToHost,
)?;
host_vec.set_len(self.length);
}
Ok(host_vec)
}
pub fn copy_from_device(&mut self, src: &Self) -> Result<()> {
if src.len() != self.length {
return Err(Error::InvalidMemoryAccess);
}
if self.length == 0 {
return Ok(());
}
unsafe {
Self::copy(
self.ptr,
src.as_ptr(),
self.length,
MemoryCopyKind::DeviceToDevice,
)
}
}
pub fn copy_from_device_async<'scope, 'env>(
&mut self,
src: &Self,
stream: &StreamScope<'scope, 'env>,
) -> Result<()> {
unsafe { self.copy_from_device_async_unchecked(src, stream.stream()) }
}
pub unsafe fn copy_from_device_async_unchecked(
&mut self,
src: &Self,
stream: &Stream,
) -> Result<()> {
if src.len() != self.len() {
return Err(Error::InvalidMemoryAccess);
}
if self.is_empty() {
return Ok(());
}
unsafe {
Self::copy_async(
self.as_mut_ptr(),
src.as_ptr(),
self.len(),
MemoryCopyKind::DeviceToDevice,
stream,
)
}
}
pub fn set_zeroes(&mut self) -> Result<()> {
if self.length == 0 {
return Ok(());
}
unsafe { Self::set(self.ptr, 0, self.length) }
}
pub fn set_value(&mut self, value: u8) -> Result<()> {
if self.length == 0 {
return Ok(());
}
unsafe { Self::set(self.ptr, value, self.length) }
}
pub fn set_value_async<'scope, 'env>(
&mut self,
value: u8,
stream: &StreamScope<'scope, 'env>,
) -> Result<()> {
unsafe { self.set_value_async_unchecked(value, stream.stream()) }
}
pub unsafe fn set_value_async_unchecked(&mut self, value: u8, stream: &Stream) -> Result<()> {
if self.is_empty() {
return Ok(());
}
unsafe { Self::set_async(self.as_mut_ptr(), value, self.len(), stream) }
}
pub fn ipc_handle(&self) -> Result<IpcMemoryHandle> {
if self.is_empty() {
return Err(Error::InvalidMemoryAccess);
}
let mut handle = MaybeUninit::uninit();
unsafe {
try_cuda!(runtime::cudaIpcGetMemHandle(
handle.as_mut_ptr(),
self.as_ptr().cast_mut().cast(),
))?;
Ok(IpcMemoryHandle::from_raw(handle.assume_init()))
}
}
pub fn try_clone(&self) -> Result<Self> {
if self.length == 0 || size_of::<T>() == 0 {
return Ok(Self {
ptr: ptr::null_mut(),
length: self.length,
_phantom: PhantomData,
});
}
let new_mem = Self::create(self.length)?;
unsafe {
Self::copy(
new_mem.as_mut_ptr(),
self.as_ptr(),
self.length,
MemoryCopyKind::DeviceToDevice,
)?;
}
Ok(new_mem)
}
}
impl<T> Clone for DeviceMemory<T> {
fn clone(&self) -> Self {
match self.try_clone() {
Ok(new_mem) => new_mem,
Err(err) => {
#[cfg(debug_assertions)]
eprintln!("device memory clone failed: {err}");
Self {
ptr: ptr::null_mut(),
length: 0,
_phantom: PhantomData,
}
}
}
}
}
impl<T> Drop for DeviceMemory<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Err(err) = unsafe { Self::free(self.ptr) } {
#[cfg(debug_assertions)]
eprintln!("failed to free device memory: {err}");
return;
}
self.ptr = ptr::null_mut();
self.length = 0;
}
}
#[cfg(all(test, feature = "testing"))]
mod tests {
use super::*;
use crate::{context::Context, testing};
#[test]
fn it_works() -> Result<()> {
unsafe {
let host_in = [1, 2, 3];
let device_ptr = match DeviceMemory::alloc(3) {
Ok(device_ptr) => device_ptr,
Err(error) if testing::is_stub_library(&error) => return Ok(()),
Err(error) => return Err(error),
};
DeviceMemory::copy(
device_ptr,
host_in.as_ptr(),
3,
MemoryCopyKind::HostToDevice,
)?;
let mut host_out = [0, 0, 0];
DeviceMemory::copy(
host_out.as_mut_ptr(),
device_ptr,
3,
MemoryCopyKind::DeviceToHost,
)?;
assert_eq!(host_out, host_in);
DeviceMemory::free(device_ptr)?;
}
Ok(())
}
#[test]
fn test_scoped_async_copy_round_trip() -> Result<()> {
let _lock = testing::device_lock(0)?;
let ctx = match Context::create() {
Ok(ctx) => ctx,
Err(error) if testing::is_stub_library(&error) => return Ok(()),
Err(error) => return Err(error),
};
let stream = ctx.create_stream()?;
let host_in = [4_i32, 5, 6];
let mut device = DeviceMemory::create(host_in.len())?;
let mut host_out = [0_i32; 3];
stream.scope(|scope| {
device.copy_from_host_async(&host_in, scope)?;
device.copy_to_host_async(&mut host_out, scope)
})?;
assert_eq!(host_out, host_in);
Ok(())
}
}