use core::ffi::c_void;
use std::sync::Arc;
use baracuda_cuda_sys::types::{
CUmemAccessDesc, CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation,
CUmemLocationType, CUmemPoolProps, CUmemPoolPtrExportData, CUmemPool_attribute,
};
use baracuda_cuda_sys::{driver, CUdeviceptr, CUmemoryPool};
use crate::context::Context;
use crate::device::Device;
use crate::error::{check, Result};
use crate::stream::Stream;
use crate::vmm::AccessFlags;
#[derive(Clone)]
pub struct MemoryPool {
inner: Arc<MemoryPoolInner>,
}
struct MemoryPoolInner {
handle: CUmemoryPool,
owned: bool,
#[allow(dead_code)]
context: Context,
}
unsafe impl Send for MemoryPoolInner {}
unsafe impl Sync for MemoryPoolInner {}
impl core::fmt::Debug for MemoryPoolInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("MemoryPool")
.field("handle", &self.handle)
.field("owned", &self.owned)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for MemoryPool {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl MemoryPool {
pub fn new(context: &Context, device: &Device) -> Result<Self> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_mem_pool_create()?;
let props = CUmemPoolProps {
alloc_type: CUmemAllocationType::PINNED,
handle_types: CUmemAllocationHandleType::NONE,
location: CUmemLocation {
type_: CUmemLocationType::DEVICE,
id: device.as_raw().0,
},
..Default::default()
};
let mut handle: CUmemoryPool = core::ptr::null_mut();
check(unsafe { cu(&mut handle, &props) })?;
Ok(Self {
inner: Arc::new(MemoryPoolInner {
handle,
owned: true,
context: context.clone(),
}),
})
}
pub unsafe fn from_borrowed(context: &Context, handle: CUmemoryPool) -> Self {
Self {
inner: Arc::new(MemoryPoolInner {
handle,
owned: false,
context: context.clone(),
}),
}
}
#[inline]
pub fn as_raw(&self) -> CUmemoryPool {
self.inner.handle
}
pub fn set_release_threshold(&self, bytes: u64) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_pool_set_attribute()?;
let mut v = bytes;
check(unsafe {
cu(
self.inner.handle,
CUmemPool_attribute::RELEASE_THRESHOLD,
&mut v as *mut u64 as *mut c_void,
)
})
}
pub fn release_threshold(&self) -> Result<u64> {
let d = driver()?;
let cu = d.cu_mem_pool_get_attribute()?;
let mut v: u64 = 0;
check(unsafe {
cu(
self.inner.handle,
CUmemPool_attribute::RELEASE_THRESHOLD,
&mut v as *mut u64 as *mut c_void,
)
})?;
Ok(v)
}
pub fn used_bytes(&self) -> Result<u64> {
let d = driver()?;
let cu = d.cu_mem_pool_get_attribute()?;
let mut v: u64 = 0;
check(unsafe {
cu(
self.inner.handle,
CUmemPool_attribute::USED_MEM_CURRENT,
&mut v as *mut u64 as *mut c_void,
)
})?;
Ok(v)
}
pub fn reserved_bytes(&self) -> Result<u64> {
let d = driver()?;
let cu = d.cu_mem_pool_get_attribute()?;
let mut v: u64 = 0;
check(unsafe {
cu(
self.inner.handle,
CUmemPool_attribute::RESERVED_MEM_CURRENT,
&mut v as *mut u64 as *mut c_void,
)
})?;
Ok(v)
}
pub fn trim_to(&self, min_bytes_to_keep: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_pool_trim_to()?;
check(unsafe { cu(self.inner.handle, min_bytes_to_keep) })
}
pub fn set_access(&self, device: &Device, flags: AccessFlags) -> Result<()> {
let d = driver()?;
let cu = d.cu_mem_pool_set_access()?;
let desc = CUmemAccessDesc {
location: CUmemLocation {
type_: CUmemLocationType::DEVICE,
id: device.as_raw().0,
},
flags: flags.raw(),
};
check(unsafe { cu(self.inner.handle, &desc, 1) })
}
pub fn access(&self, device: &Device) -> Result<AccessFlags> {
let d = driver()?;
let cu = d.cu_mem_pool_get_access()?;
let mut loc = CUmemLocation {
type_: CUmemLocationType::DEVICE,
id: device.as_raw().0,
};
let mut flags: core::ffi::c_int = 0;
check(unsafe { cu(&mut flags, self.inner.handle, &mut loc) })?;
Ok(AccessFlags::from_raw(flags))
}
pub fn alloc_async(&self, bytes: usize, stream: &Stream) -> Result<CUdeviceptr> {
let d = driver()?;
let cu = d.cu_mem_alloc_from_pool_async()?;
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, bytes, self.inner.handle, stream.as_raw()) })?;
Ok(ptr)
}
pub fn export_pointer(&self, ptr: CUdeviceptr) -> Result<CUmemPoolPtrExportData> {
let d = driver()?;
let cu = d.cu_mem_pool_export_pointer()?;
let mut data = CUmemPoolPtrExportData::default();
check(unsafe { cu(&mut data, ptr) })?;
Ok(data)
}
pub fn import_pointer(&self, mut data: CUmemPoolPtrExportData) -> Result<CUdeviceptr> {
let d = driver()?;
let cu = d.cu_mem_pool_import_pointer()?;
let mut ptr = CUdeviceptr(0);
check(unsafe { cu(&mut ptr, self.inner.handle, &mut data) })?;
Ok(ptr)
}
}
impl AccessFlags {
#[inline]
fn from_raw(raw: core::ffi::c_int) -> Self {
use baracuda_cuda_sys::types::CUmemAccess_flags;
match raw {
x if x == CUmemAccess_flags::READ => AccessFlags::Read,
x if x == CUmemAccess_flags::READWRITE => AccessFlags::ReadWrite,
_ => AccessFlags::None,
}
}
}
impl Drop for MemoryPoolInner {
fn drop(&mut self) {
if !self.owned || self.handle.is_null() {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_pool_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub fn default_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_device_get_default_mem_pool()?;
let mut handle: CUmemoryPool = core::ptr::null_mut();
check(unsafe { cu(&mut handle, device.as_raw()) })?;
Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
}
pub fn current_pool(context: &Context, device: &Device) -> Result<MemoryPool> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_device_get_mem_pool()?;
let mut handle: CUmemoryPool = core::ptr::null_mut();
check(unsafe { cu(&mut handle, device.as_raw()) })?;
Ok(unsafe { MemoryPool::from_borrowed(context, handle) })
}
pub fn set_current_pool(device: &Device, pool: &MemoryPool) -> Result<()> {
let d = driver()?;
let cu = d.cu_device_set_mem_pool()?;
check(unsafe { cu(device.as_raw(), pool.as_raw()) })
}