use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{LazyLock, Mutex};
use crate::allocator::StreamId;
static POOL_HITS: AtomicUsize = AtomicUsize::new(0);
static POOL_MISSES: AtomicUsize = AtomicUsize::new(0);
static POOL_RETURNS: AtomicUsize = AtomicUsize::new(0);
pub fn pool_stats() -> (usize, usize, usize) {
(
POOL_HITS.load(Ordering::Relaxed),
POOL_MISSES.load(Ordering::Relaxed),
POOL_RETURNS.load(Ordering::Relaxed),
)
}
pub fn reset_pool_stats() {
POOL_HITS.store(0, Ordering::Relaxed);
POOL_MISSES.store(0, Ordering::Relaxed);
POOL_RETURNS.store(0, Ordering::Relaxed);
}
type PoolKey = (usize, usize, TypeId);
struct CachedEntry {
data: Box<dyn Any + Send + Sync>,
alloc_stream: StreamId,
stream_uses: Vec<StreamId>,
}
struct PoolState {
free: HashMap<PoolKey, Vec<CachedEntry>>,
cached_bytes: usize,
}
impl PoolState {
fn new() -> Self {
Self {
free: HashMap::new(),
cached_bytes: 0,
}
}
}
static POOL: LazyLock<Mutex<PoolState>> = LazyLock::new(|| Mutex::new(PoolState::new()));
const ROUND_ELEMENTS: usize = 256;
pub fn round_len(len: usize) -> usize {
if len == 0 {
return 0;
}
let remainder = len % ROUND_ELEMENTS;
if remainder == 0 {
return len;
}
len.saturating_add(ROUND_ELEMENTS - remainder)
}
pub fn pool_take<T: Any + Send + Sync>(
device_ordinal: usize,
rounded_len: usize,
elem_size: usize,
) -> Option<T> {
let key = (device_ordinal, rounded_len, TypeId::of::<T>());
let mut pool = POOL.lock().ok()?;
let bucket = pool.free.get_mut(&key)?;
let entry = bucket.pop()?;
let is_empty = bucket.is_empty();
if is_empty {
pool.free.remove(&key);
}
pool.cached_bytes = pool.cached_bytes.saturating_sub(rounded_len * elem_size);
POOL_HITS.fetch_add(1, Ordering::Relaxed);
Some(*entry.data.downcast::<T>().expect("pool type mismatch"))
}
pub fn pool_take_stream<T: Any + Send + Sync>(
device_ordinal: usize,
rounded_len: usize,
elem_size: usize,
stream: StreamId,
) -> Option<T> {
let key = (device_ordinal, rounded_len, TypeId::of::<T>());
let mut pool = POOL.lock().ok()?;
let bucket = pool.free.get_mut(&key)?;
let pos = bucket
.iter()
.rposition(|entry| entry.alloc_stream == stream && entry.stream_uses.is_empty())?;
let entry = bucket.swap_remove(pos);
if bucket.is_empty() {
pool.free.remove(&key);
}
pool.cached_bytes = pool.cached_bytes.saturating_sub(rounded_len * elem_size);
POOL_HITS.fetch_add(1, Ordering::Relaxed);
Some(*entry.data.downcast::<T>().expect("pool type mismatch"))
}
pub fn pool_return<T: Any + Send + Sync>(
device_ordinal: usize,
rounded_len: usize,
elem_size: usize,
value: T,
) {
pool_return_with_stream(device_ordinal, rounded_len, elem_size, value, StreamId(0))
}
pub fn pool_return_with_stream<T: Any + Send + Sync>(
device_ordinal: usize,
rounded_len: usize,
elem_size: usize,
value: T,
alloc_stream: StreamId,
) {
let key = (device_ordinal, rounded_len, TypeId::of::<T>());
let Ok(mut pool) = POOL.lock() else { return };
pool.cached_bytes += rounded_len * elem_size;
let entry = CachedEntry {
data: Box::new(value),
alloc_stream,
stream_uses: Vec::new(),
};
pool.free.entry(key).or_default().push(entry);
POOL_RETURNS.fetch_add(1, Ordering::Relaxed);
}
pub fn record_stream<T: Any + Send + Sync>(
device_ordinal: usize,
rounded_len: usize,
stream: StreamId,
) {
let key = (device_ordinal, rounded_len, TypeId::of::<T>());
let Ok(mut pool) = POOL.lock() else { return };
if let Some(bucket) = pool.free.get_mut(&key) {
for entry in bucket.iter_mut() {
entry.stream_uses.push(stream);
}
}
}
#[cfg(feature = "cuda")]
pub fn record_stream_on_buffer(
device_ordinal: usize,
rounded_len: usize,
type_id: TypeId,
stream: StreamId,
) {
let key = (device_ordinal, rounded_len, type_id);
let Ok(mut pool) = POOL.lock() else { return };
if let Some(bucket) = pool.free.get_mut(&key) {
for entry in bucket.iter_mut() {
entry.stream_uses.push(stream);
}
}
}
pub fn empty_cache(device_ordinal: usize) {
let Ok(mut pool) = POOL.lock() else { return };
pool.free.retain(|&(dev, _, _), _| dev != device_ordinal);
pool.cached_bytes = 0;
}
pub fn empty_cache_all() {
let Ok(mut pool) = POOL.lock() else { return };
pool.free.clear();
pool.cached_bytes = 0;
}
pub fn cached_bytes(_device_ordinal: usize) -> usize {
POOL.lock().ok().map(|p| p.cached_bytes).unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_len_zero() {
assert_eq!(round_len(0), 0);
}
#[test]
fn round_len_exact_multiple() {
assert_eq!(round_len(256), 256);
assert_eq!(round_len(512), 512);
}
#[test]
fn round_len_rounds_up() {
assert_eq!(round_len(1), 256);
assert_eq!(round_len(255), 256);
assert_eq!(round_len(257), 512);
}
#[test]
fn pool_take_miss_returns_none() {
let result = pool_take::<u64>(99, 256, 8);
assert!(result.is_none());
}
#[test]
fn pool_return_then_take() {
let value: u64 = 12345;
pool_return::<u64>(99, 256, 8, value);
let taken = pool_take::<u64>(99, 256, 8);
assert_eq!(taken, Some(12345u64));
}
#[test]
fn pool_stats_tracking() {
reset_pool_stats();
let (h, _m, r) = pool_stats();
assert_eq!(h, 0);
assert_eq!(r, 0);
pool_return::<u32>(98, 256, 4, 42u32);
let (_, _, r) = pool_stats();
assert!(r >= 1);
let _ = pool_take::<u32>(98, 256, 4);
let (h, _, _) = pool_stats();
assert!(h >= 1);
}
#[test]
fn stream_aware_take() {
let stream_a = StreamId(100);
let stream_b = StreamId(200);
pool_return_with_stream::<u64>(97, 256, 8, 777u64, stream_a);
let taken = pool_take_stream::<u64>(97, 256, 8, stream_b);
assert!(taken.is_none());
let taken = pool_take_stream::<u64>(97, 256, 8, stream_a);
assert_eq!(taken, Some(777u64));
}
#[test]
fn record_stream_prevents_reuse() {
let stream_a = StreamId(300);
let stream_b = StreamId(400);
pool_return_with_stream::<u64>(96, 256, 8, 888u64, stream_a);
record_stream::<u64>(96, 256, stream_b);
let taken = pool_take_stream::<u64>(96, 256, 8, stream_a);
assert!(taken.is_none());
let taken = pool_take::<u64>(96, 256, 8);
assert_eq!(taken, Some(888u64));
}
#[test]
fn empty_cache_clears_device() {
pool_return::<u32>(95, 256, 4, 11u32);
pool_return::<u32>(94, 256, 4, 22u32);
empty_cache(95);
assert!(pool_take::<u32>(95, 256, 4).is_none());
assert_eq!(pool_take::<u32>(94, 256, 4), Some(22u32));
}
#[test]
fn empty_cache_all_clears_everything() {
pool_return::<u32>(93, 256, 4, 33u32);
pool_return::<u32>(92, 256, 4, 44u32);
empty_cache_all();
assert!(pool_take::<u32>(93, 256, 4).is_none());
assert!(pool_take::<u32>(92, 256, 4).is_none());
}
}