use crate::tensor::CudaBuffer;
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, Mutex};
pub struct CudaBufferPool {
pub free_buffers: HashMap<usize, Vec<Arc<CudaBuffer>>>,
pub hits: u64,
pub misses: u64,
}
impl CudaBufferPool {
pub fn new() -> Self {
CudaBufferPool {
free_buffers: HashMap::new(),
hits: 0,
misses: 0,
}
}
pub fn acquire(&mut self, size: usize) -> Option<Arc<CudaBuffer>> {
if let Some(list) = self.free_buffers.get_mut(&size) {
if let Some(buffer) = list.pop() {
self.hits += 1;
return Some(buffer);
}
}
self.misses += 1;
None
}
pub fn release(&mut self, buffer: Arc<CudaBuffer>) {
const MAX_BUFFERS_PER_SIZE: usize = 4;
let size = buffer.size();
let list = self.free_buffers.entry(size).or_default();
if list.len() < MAX_BUFFERS_PER_SIZE {
list.push(buffer);
}
}
pub fn free_count(&self) -> usize {
self.free_buffers.values().map(|v| v.len()).sum()
}
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn dump_stats(&self) {
eprintln!("=== CudaBufferPool Stats ===");
eprintln!("Hits: {}, Misses: {}", self.hits, self.misses);
eprintln!("Hit rate: {:.2}%", self.hit_rate() * 100.0);
eprintln!("Free buffers: {}", self.free_count());
eprintln!("============================");
}
}
impl Default for CudaBufferPool {
fn default() -> Self {
Self::new()
}
}
pub static BUFFER_POOL: LazyLock<Mutex<CudaBufferPool>> =
LazyLock::new(|| Mutex::new(CudaBufferPool::new()));
pub fn pool_acquire(size: usize) -> Option<Arc<CudaBuffer>> {
crate::stream::sync_stream();
BUFFER_POOL.lock().ok()?.acquire(size)
}
pub fn pool_release(buffer: Arc<CudaBuffer>) {
if let Ok(mut pool) = BUFFER_POOL.lock() {
pool.release(buffer);
}
}