#[cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceSlice};
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::Mutex;
#[cfg(feature = "cuda")]
use std::sync::OnceLock;
#[cfg(feature = "cuda")]
struct PooledBlock {
ptr: u64,
capacity: usize,
}
#[cfg(feature = "cuda")]
struct MemoryPoolInner {
free_lists: HashMap<usize, Vec<PooledBlock>>,
pooled_bytes: usize,
hits: usize,
misses: usize,
returns: usize,
}
#[cfg(feature = "cuda")]
pub struct CudaMemoryPool {
inner: Mutex<MemoryPoolInner>,
}
#[cfg(feature = "cuda")]
static CUDA_MEMORY_POOL: OnceLock<CudaMemoryPool> = OnceLock::new();
#[cfg(feature = "cuda")]
impl CudaMemoryPool {
fn new() -> Self {
Self {
inner: Mutex::new(MemoryPoolInner {
free_lists: HashMap::new(),
pooled_bytes: 0,
hits: 0,
misses: 0,
returns: 0,
}),
}
}
fn bucket_size(requested: usize) -> usize {
if requested <= 256 {
((requested + 63) / 64) * 64
} else {
requested.next_power_of_two()
}
}
fn try_acquire(&self, requested_elements: usize) -> Option<(u64, usize)> {
let bucket = Self::bucket_size(requested_elements);
let mut inner = self.inner.lock().unwrap();
if let Some(blocks) = inner.free_lists.get_mut(&bucket) {
if let Some(block) = blocks.pop() {
inner.pooled_bytes -= block.capacity * 4; inner.hits += 1;
return Some((block.ptr, block.capacity));
}
}
inner.misses += 1;
None
}
fn release(&self, ptr: u64, capacity: usize) {
let bucket = Self::bucket_size(capacity);
let mut inner = self.inner.lock().unwrap();
inner.pooled_bytes += capacity * 4;
inner.returns += 1;
let blocks = inner.free_lists.entry(bucket).or_default();
if blocks.len() < 64 {
blocks.push(PooledBlock { ptr, capacity });
} else {
inner.pooled_bytes -= capacity * 4;
if let Some(backend) = super::cuda::get_cuda_backend() {
unsafe {
let slice: CudaSlice<f32> = backend.stream().upgrade_device_ptr(ptr, capacity);
drop(slice); }
}
}
}
pub fn stats(&self) -> (usize, usize, usize, usize) {
let inner = self.inner.lock().unwrap();
(inner.hits, inner.misses, inner.returns, inner.pooled_bytes)
}
pub fn clear(&self) {
let mut inner = self.inner.lock().unwrap();
let backend = super::cuda::get_cuda_backend();
for (_bucket, blocks) in inner.free_lists.drain() {
for block in blocks {
if let Some(ref be) = backend {
unsafe {
let slice: CudaSlice<f32> =
be.stream().upgrade_device_ptr(block.ptr, block.capacity);
drop(slice);
}
}
}
}
inner.pooled_bytes = 0;
}
}
#[cfg(feature = "cuda")]
pub fn get_memory_pool() -> &'static CudaMemoryPool {
CUDA_MEMORY_POOL.get_or_init(CudaMemoryPool::new)
}
#[cfg(feature = "cuda")]
pub fn pool_alloc(len: usize) -> Result<CudaSlice<f32>, super::cuda::CudaError> {
let pool = get_memory_pool();
if let Some((ptr, capacity)) = pool.try_acquire(len) {
let backend =
super::cuda::get_cuda_backend().ok_or(super::cuda::CudaError::DeviceNotFound)?;
unsafe {
let mut slice: CudaSlice<f32> = backend.stream().upgrade_device_ptr(ptr, capacity);
backend
.stream()
.memset_zeros(&mut slice)
.map_err(super::cuda::CudaError::from)?;
Ok(slice)
}
} else {
let bucket = CudaMemoryPool::bucket_size(len);
let backend =
super::cuda::get_cuda_backend().ok_or(super::cuda::CudaError::DeviceNotFound)?;
backend
.stream()
.alloc_zeros(bucket)
.map_err(super::cuda::CudaError::from)
}
}
#[cfg(feature = "cuda")]
pub fn pool_free(slice: CudaSlice<f32>) {
let pool = get_memory_pool();
let capacity = slice.len();
let ptr = slice.leak();
pool.release(ptr, capacity);
}
#[cfg(feature = "cuda")]
pub fn print_pool_stats() {
let pool = get_memory_pool();
let (hits, misses, returns, pooled) = pool.stats();
eprintln!(
"[CudaPool] hits={}, misses={}, returns={}, pooled={:.1}MB",
hits,
misses,
returns,
pooled as f64 / (1024.0 * 1024.0)
);
}
#[cfg(feature = "cuda")]
pub fn clear_pool() {
get_memory_pool().clear();
}
#[cfg(not(feature = "cuda"))]
pub fn print_pool_stats() {}
#[cfg(not(feature = "cuda"))]
pub fn clear_pool() {}