use std::cell::RefCell;
const MAX_CACHED_PER_SIZE: usize = 4;
const MAX_TOTAL_CACHED: usize = 64;
struct TensorPool {
buffers: Vec<Vec<f64>>,
}
impl TensorPool {
fn new() -> Self {
TensorPool {
buffers: Vec::new(),
}
}
fn acquire(&mut self, size: usize) -> Vec<f64> {
if let Some(pos) = self.buffers.iter().position(|b| b.capacity() == size) {
let mut buf = self.buffers.swap_remove(pos);
buf.clear();
buf.resize(size, 0.0);
return buf;
}
vec![0.0f64; size]
}
fn recycle(&mut self, buf: Vec<f64>) {
if self.buffers.len() >= MAX_TOTAL_CACHED {
return; }
let cap = buf.capacity();
let same_size_count = self.buffers.iter().filter(|b| b.capacity() == cap).count();
if same_size_count >= MAX_CACHED_PER_SIZE {
return;
}
self.buffers.push(buf);
}
}
thread_local! {
static POOL: RefCell<TensorPool> = RefCell::new(TensorPool::new());
}
pub fn acquire(size: usize) -> Vec<f64> {
POOL.with(|pool| pool.borrow_mut().acquire(size))
}
pub fn recycle(buf: Vec<f64>) {
POOL.with(|pool| pool.borrow_mut().recycle(buf));
}
#[allow(dead_code)]
pub fn pool_size() -> usize {
POOL.with(|pool| pool.borrow().buffers.len())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_acquire_returns_correct_size() {
let buf = acquire(100);
assert_eq!(buf.len(), 100);
assert!(buf.iter().all(|&x| x == 0.0));
}
#[test]
fn test_recycle_and_reuse() {
let buf = acquire(256);
assert_eq!(pool_size(), 0);
recycle(buf);
assert_eq!(pool_size(), 1);
let buf2 = acquire(256);
assert_eq!(buf2.len(), 256);
assert_eq!(pool_size(), 0); }
#[test]
fn test_pool_max_per_size() {
for _ in 0..10 {
let buf = acquire(64);
recycle(buf);
}
assert!(pool_size() <= MAX_CACHED_PER_SIZE);
}
#[test]
fn test_pool_total_limit() {
for size in 0..100 {
let buf = acquire(size + 1);
recycle(buf);
}
assert!(pool_size() <= MAX_TOTAL_CACHED);
}
#[test]
fn test_acquired_buffer_is_zeroed() {
let mut buf = acquire(10);
for x in buf.iter_mut() {
*x = 42.0; }
recycle(buf);
let buf2 = acquire(10);
assert!(buf2.iter().all(|&x| x == 0.0), "Recycled buffer must be zeroed");
}
}