use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{LazyLock, Mutex};
static POOL_HITS: AtomicUsize = AtomicUsize::new(0);
static POOL_MISSES: AtomicUsize = AtomicUsize::new(0);
static POOL_RETURNS: AtomicUsize = AtomicUsize::new(0);
pub fn pool_stats() -> (usize, usize, usize) {
(
POOL_HITS.load(Ordering::Relaxed),
POOL_MISSES.load(Ordering::Relaxed),
POOL_RETURNS.load(Ordering::Relaxed),
)
}
pub fn reset_pool_stats() {
POOL_HITS.store(0, Ordering::Relaxed);
POOL_MISSES.store(0, Ordering::Relaxed);
POOL_RETURNS.store(0, Ordering::Relaxed);
}
type PoolKey = (usize, usize, TypeId);
struct PoolState {
free: HashMap<PoolKey, Vec<Box<dyn Any + Send + Sync>>>,
cached_bytes: usize,
}
impl PoolState {
fn new() -> Self {
Self {
free: HashMap::new(),
cached_bytes: 0,
}
}
}
static POOL: LazyLock<Mutex<PoolState>> = LazyLock::new(|| Mutex::new(PoolState::new()));
pub fn pool_take<T: Any + Send + Sync>(
device_ordinal: usize,
len: usize,
elem_size: usize,
) -> Option<T> {
let key = (device_ordinal, len, TypeId::of::<T>());
let mut pool = POOL.lock().ok()?;
let bucket = pool.free.get_mut(&key)?;
let boxed = bucket.pop()?;
let is_empty = bucket.is_empty();
if is_empty {
pool.free.remove(&key);
}
pool.cached_bytes = pool.cached_bytes.saturating_sub(len * elem_size);
POOL_HITS.fetch_add(1, Ordering::Relaxed);
Some(*boxed.downcast::<T>().expect("pool type mismatch"))
}
pub fn pool_return<T: Any + Send + Sync>(
device_ordinal: usize,
len: usize,
elem_size: usize,
value: T,
) {
let key = (device_ordinal, len, TypeId::of::<T>());
let Ok(mut pool) = POOL.lock() else { return };
POOL_RETURNS.fetch_add(1, Ordering::Relaxed);
pool.cached_bytes += len * elem_size;
pool.free.entry(key).or_default().push(Box::new(value));
}
pub fn empty_cache(device_ordinal: usize) {
let Ok(mut pool) = POOL.lock() else { return };
pool.free.retain(|&(dev, _, _), _| dev != device_ordinal);
let remaining: usize = pool.free.iter()
.map(|((_, len, _), bucket)| len * bucket.len() * 4) .sum();
pool.cached_bytes = remaining;
}
pub fn empty_cache_all() {
let Ok(mut pool) = POOL.lock() else { return };
pool.free.clear();
pool.cached_bytes = 0;
}
pub fn cached_bytes(_device_ordinal: usize) -> usize {
POOL.lock()
.ok()
.map(|p| p.cached_bytes)
.unwrap_or(0)
}