#[cfg(feature = "cuda")]
use cudarc::driver::CudaSlice;
#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::sync::Mutex;
#[cfg(feature = "cuda")]
use std::sync::OnceLock;
#[cfg(feature = "cuda")]
struct PooledBlock {
ptr: u64,
capacity: usize,
}
#[cfg(feature = "cuda")]
struct MemoryPoolInner {
free_lists: HashMap<usize, Vec<PooledBlock>>,
pooled_bytes: usize,
hits: usize,
misses: usize,
returns: usize,
}
#[cfg(feature = "cuda")]
pub struct CudaMemoryPool {
inner: Mutex<MemoryPoolInner>,
}
#[cfg(feature = "cuda")]
static CUDA_MEMORY_POOL: OnceLock<CudaMemoryPool> = OnceLock::new();
#[cfg(feature = "cuda")]
impl CudaMemoryPool {
fn new() -> Self {
Self {
inner: Mutex::new(MemoryPoolInner {
free_lists: HashMap::new(),
pooled_bytes: 0,
hits: 0,
misses: 0,
returns: 0,
}),
}
}
fn bucket_size(requested: usize) -> usize {
if requested <= 256 {
((requested + 63) / 64) * 64
} else {
requested.next_power_of_two()
}
}
fn try_acquire(&self, requested_elements: usize) -> Option<(u64, usize)> {
let bucket = Self::bucket_size(requested_elements);
let mut inner = self.inner.lock().unwrap();
if let Some(blocks) = inner.free_lists.get_mut(&bucket) {
if let Some(block) = blocks.pop() {
inner.pooled_bytes -= block.capacity * 4; inner.hits += 1;
return Some((block.ptr, block.capacity));
}
}
inner.misses += 1;
None
}
fn release(&self, ptr: u64, capacity: usize) {
let bucket = Self::bucket_size(capacity);
let mut inner = self.inner.lock().unwrap();
inner.pooled_bytes += capacity * 4;
inner.returns += 1;
let blocks = inner.free_lists.entry(bucket).or_default();
if blocks.len() < 64 {
blocks.push(PooledBlock { ptr, capacity });
} else {
inner.pooled_bytes -= capacity * 4;
if let Some(backend) = super::cuda::get_cuda_backend() {
unsafe {
let slice: CudaSlice<f32> = backend.stream().upgrade_device_ptr(ptr, capacity);
drop(slice); }
}
}
}
pub fn stats(&self) -> (usize, usize, usize, usize) {
let inner = self.inner.lock().unwrap();
(inner.hits, inner.misses, inner.returns, inner.pooled_bytes)
}
pub fn clear(&self) {
let mut inner = self.inner.lock().unwrap();
let backend = super::cuda::get_cuda_backend();
for (_bucket, blocks) in inner.free_lists.drain() {
for block in blocks {
if let Some(ref be) = backend {
unsafe {
let slice: CudaSlice<f32> =
be.stream().upgrade_device_ptr(block.ptr, block.capacity);
drop(slice);
}
}
}
}
inner.pooled_bytes = 0;
}
}
#[cfg(feature = "cuda")]
pub fn get_memory_pool() -> &'static CudaMemoryPool {
CUDA_MEMORY_POOL.get_or_init(CudaMemoryPool::new)
}
#[cfg(feature = "cuda")]
pub fn pool_alloc(len: usize) -> Result<CudaSlice<f32>, super::cuda::CudaError> {
let pool = get_memory_pool();
if let Some((ptr, capacity)) = pool.try_acquire(len) {
let backend =
super::cuda::get_cuda_backend().ok_or(super::cuda::CudaError::DeviceNotFound)?;
unsafe {
let mut slice: CudaSlice<f32> = backend.stream().upgrade_device_ptr(ptr, capacity);
backend
.stream()
.memset_zeros(&mut slice)
.map_err(super::cuda::CudaError::from)?;
Ok(slice)
}
} else {
let bucket = CudaMemoryPool::bucket_size(len);
let backend =
super::cuda::get_cuda_backend().ok_or(super::cuda::CudaError::DeviceNotFound)?;
backend
.stream()
.alloc_zeros(bucket)
.map_err(super::cuda::CudaError::from)
}
}
#[cfg(feature = "cuda")]
pub fn pool_free(slice: CudaSlice<f32>) {
let pool = get_memory_pool();
let capacity = slice.len();
let ptr = slice.leak();
pool.release(ptr, capacity);
}
#[cfg(feature = "cuda")]
pub fn print_pool_stats() {
let pool = get_memory_pool();
let (hits, misses, returns, pooled) = pool.stats();
eprintln!(
"[CudaPool] hits={}, misses={}, returns={}, pooled={:.1}MB",
hits,
misses,
returns,
pooled as f64 / (1024.0 * 1024.0)
);
}
#[cfg(feature = "cuda")]
pub fn clear_pool() {
get_memory_pool().clear();
}
#[cfg(not(feature = "cuda"))]
pub fn print_pool_stats() {}
#[cfg(not(feature = "cuda"))]
pub fn clear_pool() {}
#[cfg(test)]
mod tests {
#[cfg(feature = "cuda")]
use super::*;
#[test]
#[cfg(feature = "cuda")]
fn test_bucket_size_small() {
assert_eq!(CudaMemoryPool::bucket_size(1), 64);
assert_eq!(CudaMemoryPool::bucket_size(63), 64);
assert_eq!(CudaMemoryPool::bucket_size(64), 64);
assert_eq!(CudaMemoryPool::bucket_size(65), 128);
assert_eq!(CudaMemoryPool::bucket_size(128), 128);
assert_eq!(CudaMemoryPool::bucket_size(200), 256);
assert_eq!(CudaMemoryPool::bucket_size(256), 256);
}
#[test]
#[cfg(feature = "cuda")]
fn test_bucket_size_large() {
assert_eq!(CudaMemoryPool::bucket_size(257), 512);
assert_eq!(CudaMemoryPool::bucket_size(500), 512);
assert_eq!(CudaMemoryPool::bucket_size(512), 512);
assert_eq!(CudaMemoryPool::bucket_size(513), 1024);
assert_eq!(CudaMemoryPool::bucket_size(1_000_000), 1_048_576);
}
#[test]
#[cfg(feature = "cuda")]
fn test_bucket_size_zero() {
assert_eq!(CudaMemoryPool::bucket_size(0), 0);
}
#[test]
#[cfg(feature = "cuda")]
fn test_pool_alloc_and_free() {
if super::super::cuda::get_cuda_backend().is_none() {
return;
}
let slice = pool_alloc(1024).expect("pool_alloc failed");
assert!(slice.len() >= 1024);
pool_free(slice);
let pool = get_memory_pool();
let (hits_before, _, _, _) = pool.stats();
let slice2 = pool_alloc(1024).expect("second pool_alloc failed");
let (hits_after, _, _, _) = pool.stats();
assert!(
hits_after > hits_before,
"Expected pool hit on second alloc"
);
pool_free(slice2);
}
#[test]
#[cfg(feature = "cuda")]
fn test_pool_stats() {
if super::super::cuda::get_cuda_backend().is_none() {
return;
}
let pool = get_memory_pool();
let (hits, misses, returns, _pooled) = pool.stats();
assert!(hits + misses + returns >= 0);
}
#[test]
#[cfg(feature = "cuda")]
fn test_pool_clear() {
if super::super::cuda::get_cuda_backend().is_none() {
return;
}
let slice = pool_alloc(512).expect("alloc failed");
pool_free(slice);
clear_pool();
let pool = get_memory_pool();
let (_, _, _, pooled_bytes) = pool.stats();
assert_eq!(pooled_bytes, 0, "Pool should be empty after clear");
}
#[test]
#[cfg(feature = "cuda")]
fn test_pool_different_sizes() {
if super::super::cuda::get_cuda_backend().is_none() {
return;
}
let s1 = pool_alloc(100).expect("alloc 100 failed");
let s2 = pool_alloc(1000).expect("alloc 1000 failed");
let s3 = pool_alloc(10000).expect("alloc 10000 failed");
pool_free(s1);
pool_free(s2);
pool_free(s3);
let pool = get_memory_pool();
let (hits_before, _, _, _) = pool.stats();
let s4 = pool_alloc(100).expect("re-alloc 100 failed");
let (hits_after, _, _, _) = pool.stats();
assert!(hits_after > hits_before);
pool_free(s4);
}
#[test]
#[cfg(feature = "cuda")]
fn test_pool_zeroed_on_reuse() {
if super::super::cuda::get_cuda_backend().is_none() {
return;
}
let slice = pool_alloc(64).expect("alloc failed");
pool_free(slice);
let slice2 = pool_alloc(64).expect("re-alloc failed");
let host_data = super::super::cuda::get_cuda_backend()
.unwrap()
.stream()
.memcpy_dtoh(&slice2);
if let Ok(data) = host_data {
for &val in &data {
assert_eq!(val, 0.0, "Pool-reused memory should be zeroed");
}
}
pool_free(slice2);
}
}