use std::any::Any;
use std::collections::{BTreeMap, VecDeque};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PoolKey {
pub backend_id: u32,
pub rounded_size: usize,
pub kind: BufferKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum BufferKind {
Input,
Output,
Scratch,
}
pub struct PooledBuffer {
raw: Box<dyn Any + Send>,
pub size: usize,
pub last_touched: Instant,
}
impl std::fmt::Debug for PooledBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledBuffer")
.field("size", &self.size)
.field("last_touched", &self.last_touched)
.finish_non_exhaustive()
}
}
impl PooledBuffer {
#[must_use]
pub fn new(raw: Box<dyn Any + Send>, size: usize) -> Self {
Self {
raw,
size,
last_touched: Instant::now(),
}
}
pub fn downcast<T: Any + Send>(self) -> Result<Box<T>, Self> {
if self.raw.is::<T>() {
let raw_ptr = Box::into_raw(self.raw) as *mut T;
Ok(unsafe { Box::from_raw(raw_ptr) })
} else {
Err(self)
}
}
#[must_use]
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
self.raw.downcast_ref::<T>()
}
}
pub struct GpuBufferPool {
inner: Mutex<BTreeMap<PoolKey, VecDeque<PooledBuffer>>>,
max_bytes: usize,
current_bytes: AtomicUsize,
}
unsafe impl Send for GpuBufferPool {}
unsafe impl Sync for GpuBufferPool {}
impl GpuBufferPool {
#[must_use]
pub fn new(max_bytes: usize) -> Self {
Self {
inner: Mutex::new(BTreeMap::new()),
max_bytes,
current_bytes: AtomicUsize::new(0),
}
}
#[must_use]
pub fn acquire(
&self,
key: PoolKey,
allocate: impl FnOnce(usize) -> Option<PooledBuffer>,
) -> Option<PooledBuffer> {
{
let mut pool = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(queue) = pool.get_mut(&key) {
if let Some(buf) = queue.pop_front() {
self.current_bytes.fetch_sub(buf.size, Ordering::Relaxed);
return Some(buf);
}
}
}
allocate(key.rounded_size)
}
pub fn release(&self, key: PoolKey, mut buf: PooledBuffer) {
buf.last_touched = Instant::now();
let size = buf.size;
if self.max_bytes == 0 {
return;
}
let mut pool = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let current = self.current_bytes.load(Ordering::Relaxed);
if current + size > self.max_bytes {
self.evict_lru_locked(&mut pool, size);
}
pool.entry(key).or_default().push_back(buf);
self.current_bytes.fetch_add(size, Ordering::Relaxed);
}
fn evict_lru_locked(
&self,
pool: &mut BTreeMap<PoolKey, VecDeque<PooledBuffer>>,
needed: usize,
) {
let mut freed = 0usize;
while freed < needed {
let oldest_key = pool
.iter()
.filter(|(_, q)| !q.is_empty())
.min_by_key(|(_, q)| q.front().map(|b| b.last_touched))
.map(|(k, _)| *k);
match oldest_key {
Some(key) => {
if let Some(q) = pool.get_mut(&key) {
if let Some(buf) = q.pop_front() {
freed += buf.size;
self.current_bytes.fetch_sub(buf.size, Ordering::Relaxed);
}
}
}
None => break, }
}
}
pub fn clear(&self) {
let mut pool = self.inner.lock().unwrap_or_else(|e| e.into_inner());
pool.clear();
self.current_bytes.store(0, Ordering::Relaxed);
}
#[must_use]
pub fn current_bytes(&self) -> usize {
self.current_bytes.load(Ordering::Relaxed)
}
#[must_use]
pub fn max_bytes(&self) -> usize {
self.max_bytes
}
}
impl std::fmt::Debug for GpuBufferPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuBufferPool")
.field("max_bytes", &self.max_bytes)
.field("current_bytes", &self.current_bytes.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
#[must_use]
pub fn round_pool_size(size: usize) -> usize {
const STEP_64MIB: usize = 64 * 1024 * 1024;
if size == 0 {
return 0;
}
if size <= STEP_64MIB {
size.next_power_of_two()
} else {
size.div_ceil(STEP_64MIB) * STEP_64MIB
}
}
#[must_use]
pub fn global_pool() -> &'static GpuBufferPool {
static POOL: std::sync::OnceLock<GpuBufferPool> = std::sync::OnceLock::new();
POOL.get_or_init(|| GpuBufferPool::new(256 * 1024 * 1024))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
struct MockBuffer(Vec<u8>);
fn mock_pooled(size: usize) -> PooledBuffer {
PooledBuffer::new(Box::new(MockBuffer(vec![0u8; size])), size)
}
#[test]
fn pool_reuse_same_size() {
let pool = GpuBufferPool::new(64 * 1024 * 1024);
let key = PoolKey {
backend_id: 0,
rounded_size: 4096,
kind: BufferKind::Scratch,
};
let mut alloc_count = 0usize;
let buf = pool
.acquire(key, |size| {
alloc_count += 1;
Some(mock_pooled(size))
})
.expect("allocation failed");
pool.release(key, buf);
let _buf2 = pool.acquire(key, |_size| {
alloc_count += 1; None
});
assert_eq!(alloc_count, 1, "second acquire should reuse from pool");
}
#[test]
fn pool_current_bytes_tracks_release_and_acquire() {
let pool = GpuBufferPool::new(64 * 1024 * 1024);
let key = PoolKey {
backend_id: 0,
rounded_size: 1024,
kind: BufferKind::Input,
};
assert_eq!(pool.current_bytes(), 0);
let buf = pool
.acquire(key, |size| Some(mock_pooled(size)))
.expect("allocation failed");
assert_eq!(pool.current_bytes(), 0, "in-flight buffer not counted");
pool.release(key, buf);
assert_eq!(
pool.current_bytes(),
1024,
"released buffer should be counted"
);
let buf2 = pool.acquire(key, |_| None).expect("should reuse from pool");
assert_eq!(pool.current_bytes(), 0, "acquired buffer leaves pool");
drop(buf2);
}
#[test]
fn pool_lru_eviction() {
let pool = GpuBufferPool::new(1536);
let key_a = PoolKey {
backend_id: 0,
rounded_size: 512,
kind: BufferKind::Input,
};
let key_b = PoolKey {
backend_id: 0,
rounded_size: 512,
kind: BufferKind::Output,
};
let key_c = PoolKey {
backend_id: 0,
rounded_size: 512,
kind: BufferKind::Scratch,
};
pool.release(key_a, mock_pooled(512));
pool.release(key_b, mock_pooled(512));
pool.release(key_c, mock_pooled(512));
assert_eq!(pool.current_bytes(), 1536);
pool.release(key_a, mock_pooled(512));
assert!(
pool.current_bytes() <= 1536,
"pool exceeded budget after release: {} > 1536",
pool.current_bytes()
);
}
#[test]
fn pool_clear_resets_budget() {
let pool = GpuBufferPool::new(64 * 1024 * 1024);
let key = PoolKey {
backend_id: 0,
rounded_size: 4096,
kind: BufferKind::Output,
};
pool.release(key, mock_pooled(4096));
assert_eq!(pool.current_bytes(), 4096);
pool.clear();
assert_eq!(pool.current_bytes(), 0);
let mut alloc_count = 0usize;
let _buf = pool.acquire(key, |size| {
alloc_count += 1;
Some(mock_pooled(size))
});
assert_eq!(alloc_count, 1, "pool should be empty after clear");
}
#[test]
fn pool_thread_safety() {
let pool = Arc::new(GpuBufferPool::new(64 * 1024 * 1024));
let barrier = Arc::new(Barrier::new(8));
let mut handles = vec![];
for i in 0..8_u32 {
let pool = Arc::clone(&pool);
let barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier.wait();
let key = PoolKey {
backend_id: i % 2,
rounded_size: 4096,
kind: BufferKind::Scratch,
};
for _ in 0..100 {
let buf = pool.acquire(key, |size| Some(mock_pooled(size)));
if let Some(b) = buf {
pool.release(key, b);
}
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
}
#[test]
fn pooled_buffer_downcast_roundtrip() {
let buf = PooledBuffer::new(Box::new(MockBuffer(vec![42u8; 8])), 8);
let inner = buf.downcast::<MockBuffer>().expect("downcast failed");
assert_eq!(inner.0[0], 42);
}
#[test]
fn pooled_buffer_downcast_wrong_type_returns_err() {
let buf = PooledBuffer::new(Box::new(MockBuffer(vec![0u8; 8])), 8);
let result = buf.downcast::<Vec<u8>>();
assert!(result.is_err(), "downcast to wrong type must fail");
}
#[test]
fn round_pool_size_zero() {
assert_eq!(round_pool_size(0), 0);
}
#[test]
fn round_pool_size_power_of_two_passthrough() {
assert_eq!(round_pool_size(1024), 1024);
assert_eq!(round_pool_size(4096), 4096);
}
#[test]
fn round_pool_size_rounds_up_to_next_pow2() {
assert_eq!(round_pool_size(1000), 1024);
assert_eq!(round_pool_size(3000), 4096);
}
#[test]
fn round_pool_size_above_64mib_aligned_to_64mib() {
const MIB64: usize = 64 * 1024 * 1024;
assert_eq!(round_pool_size(MIB64 + 1), 2 * MIB64);
assert_eq!(round_pool_size(2 * MIB64), 2 * MIB64);
assert_eq!(round_pool_size(2 * MIB64 + 1), 3 * MIB64);
}
#[test]
fn pool_reuse_avoids_realloc() {
let pool = GpuBufferPool::new(64 * 1024 * 1024);
let key = PoolKey {
backend_id: 0,
rounded_size: 4096,
kind: BufferKind::Scratch,
};
let mut alloc_count = 0usize;
let buf = pool
.acquire(key, |size| {
alloc_count += 1;
Some(mock_pooled(size))
})
.expect("first alloc");
pool.release(key, buf);
let _buf2 = pool.acquire(key, |_size| {
alloc_count += 1;
None
});
assert_eq!(alloc_count, 1, "second acquire must reuse from pool");
}
#[test]
fn pool_lru_eviction_under_budget() {
let pool = GpuBufferPool::new(8192);
let key = PoolKey {
backend_id: 0,
rounded_size: 8192,
kind: BufferKind::Scratch,
};
let buf = pool
.acquire(key, |size| Some(mock_pooled(size)))
.expect("first alloc");
pool.release(key, buf);
assert_eq!(pool.current_bytes(), 8192);
let key2 = PoolKey {
backend_id: 0,
rounded_size: 4096,
kind: BufferKind::Input,
};
let buf2 = pool
.acquire(key2, |size| Some(mock_pooled(size)))
.expect("second alloc");
pool.release(key2, buf2);
assert!(
pool.current_bytes() <= 8192 + 4096,
"pool exceeded budget: {} > {}",
pool.current_bytes(),
8192 + 4096,
);
}
}