use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
const MAX_PER_BUCKET: usize = 8;
const MAX_CACHED_BYTES: usize = 64 * 1024 * 1024;
static POOL_HITS: AtomicUsize = AtomicUsize::new(0);
static POOL_MISSES: AtomicUsize = AtomicUsize::new(0);
static POOL_RETURNS: AtomicUsize = AtomicUsize::new(0);
pub fn cpu_pool_stats() -> (usize, usize, usize) {
(
POOL_HITS.load(Ordering::Relaxed),
POOL_MISSES.load(Ordering::Relaxed),
POOL_RETURNS.load(Ordering::Relaxed),
)
}
pub fn reset_cpu_pool_stats() {
POOL_HITS.store(0, Ordering::Relaxed);
POOL_MISSES.store(0, Ordering::Relaxed);
POOL_RETURNS.store(0, Ordering::Relaxed);
}
type PoolKey = (usize, TypeId);
struct CpuPoolState {
free: HashMap<PoolKey, Vec<Box<dyn Any>>>,
cached_bytes: usize,
}
thread_local! {
static CPU_POOL: RefCell<CpuPoolState> = RefCell::new(CpuPoolState {
free: HashMap::new(),
cached_bytes: 0,
});
}
#[inline]
pub fn pool_alloc_cpu<T: Default + Clone + 'static>(len: usize) -> Vec<T> {
if len == 0 {
return Vec::new();
}
let key = (len, TypeId::of::<T>());
let elem_size = std::mem::size_of::<T>();
let maybe = CPU_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if let Some(bucket) = pool.free.get_mut(&key) {
if let Some(boxed) = bucket.pop() {
pool.cached_bytes -= len * elem_size;
return Some(boxed);
}
}
None
});
if let Some(boxed) = maybe {
POOL_HITS.fetch_add(1, Ordering::Relaxed);
let mut v: Vec<T> = *boxed.downcast::<Vec<T>>().unwrap();
v.fill(T::default());
debug_assert_eq!(v.len(), len);
v
} else {
POOL_MISSES.fetch_add(1, Ordering::Relaxed);
vec![T::default(); len]
}
}
#[inline]
pub fn pool_alloc_cpu_uninit_f32(len: usize) -> Vec<f32> {
if len == 0 {
return Vec::new();
}
let key = (len, TypeId::of::<f32>());
let maybe = CPU_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if let Some(bucket) = pool.free.get_mut(&key) {
if let Some(boxed) = bucket.pop() {
pool.cached_bytes -= len * 4;
return Some(boxed);
}
}
None
});
if let Some(boxed) = maybe {
POOL_HITS.fetch_add(1, Ordering::Relaxed);
let v: Vec<f32> = *boxed.downcast::<Vec<f32>>().unwrap();
debug_assert_eq!(v.len(), len);
v
} else {
POOL_MISSES.fetch_add(1, Ordering::Relaxed);
vec![0.0f32; len]
}
}
#[inline]
pub fn pool_alloc_cpu_uninit_f64(len: usize) -> Vec<f64> {
if len == 0 {
return Vec::new();
}
let key = (len, TypeId::of::<f64>());
let maybe = CPU_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if let Some(bucket) = pool.free.get_mut(&key) {
if let Some(boxed) = bucket.pop() {
pool.cached_bytes -= len * 8;
return Some(boxed);
}
}
None
});
if let Some(boxed) = maybe {
POOL_HITS.fetch_add(1, Ordering::Relaxed);
let v: Vec<f64> = *boxed.downcast::<Vec<f64>>().unwrap();
debug_assert_eq!(v.len(), len);
v
} else {
POOL_MISSES.fetch_add(1, Ordering::Relaxed);
vec![0.0f64; len]
}
}
pub fn pool_return_cpu<T: 'static>(mut v: Vec<T>) {
let len = v.len();
if len == 0 {
return;
}
let elem_size = std::mem::size_of::<T>();
let byte_size = len * elem_size;
let key = (len, TypeId::of::<T>());
CPU_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if pool.cached_bytes + byte_size > MAX_CACHED_BYTES {
return; }
let bucket = pool.free.entry(key).or_insert_with(Vec::new);
if bucket.len() >= MAX_PER_BUCKET {
return; }
unsafe { v.set_len(len) };
bucket.push(Box::new(v));
pool.cached_bytes += byte_size;
POOL_RETURNS.fetch_add(1, Ordering::Relaxed);
});
}
pub fn empty_cpu_pool() {
CPU_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
pool.free.clear();
pool.cached_bytes = 0;
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_miss_then_hit() {
reset_cpu_pool_stats();
let v: Vec<f32> = pool_alloc_cpu(1000);
assert_eq!(v.len(), 1000);
assert!(v.iter().all(|&x| x == 0.0));
let (_, misses_before, _) = cpu_pool_stats();
pool_return_cpu(v);
let v2: Vec<f32> = pool_alloc_cpu(1000);
assert_eq!(v2.len(), 1000);
assert!(v2.iter().all(|&x| x == 0.0));
let (hits_after, _, _) = cpu_pool_stats();
assert!(hits_after > 0, "expected at least 1 pool hit");
drop(v2);
}
#[test]
fn test_uninit_alloc() {
let v = pool_alloc_cpu_uninit_f32(500);
assert_eq!(v.len(), 500);
assert!(v.iter().all(|&x| x == 0.0));
let mut v = v;
v[0] = 42.0;
pool_return_cpu(v);
let v2 = pool_alloc_cpu_uninit_f32(500);
assert_eq!(v2.len(), 500);
}
#[test]
fn test_bucket_limit() {
empty_cpu_pool();
for _ in 0..MAX_PER_BUCKET + 5 {
let v: Vec<f32> = vec![0.0; 100];
pool_return_cpu(v);
}
let mut hits = 0;
for _ in 0..MAX_PER_BUCKET + 5 {
let v: Vec<f32> = pool_alloc_cpu(100);
if v.len() == 100 {
hits += 1;
}
}
assert!(hits >= MAX_PER_BUCKET);
}
#[test]
fn test_different_sizes_different_buckets() {
empty_cpu_pool();
let v1: Vec<f32> = vec![0.0; 100];
pool_return_cpu(v1);
let v2: Vec<f32> = pool_alloc_cpu(200);
assert_eq!(v2.len(), 200);
let v3: Vec<f32> = pool_alloc_cpu(100);
assert_eq!(v3.len(), 100);
}
}