use std::sync::{Arc, Mutex};
pub struct TensorPool {
pool: Arc<Mutex<Vec<Vec<f64>>>>,
max_size: usize,
stats: Arc<Mutex<PoolStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct PoolStats {
pub hits: usize,
pub misses: usize,
pub returns: usize,
pub current_size: usize,
}
impl TensorPool {
pub fn new(max_size: usize) -> Self {
Self {
pool: Arc::new(Mutex::new(Vec::with_capacity(max_size))),
max_size,
stats: Arc::new(Mutex::new(PoolStats::default())),
}
}
pub fn get(&self, size: usize) -> Vec<f64> {
let mut pool = self.pool.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
if let Some(mut buffer) = pool.pop() {
stats.hits += 1;
stats.current_size = pool.len();
drop(pool);
drop(stats);
buffer.clear();
buffer.resize(size, 0.0);
buffer
} else {
stats.misses += 1;
drop(pool);
drop(stats);
vec![0.0; size]
}
}
pub fn return_buffer(&self, buffer: Vec<f64>) {
let mut pool = self.pool.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
if pool.len() < self.max_size {
pool.push(buffer);
stats.returns += 1;
stats.current_size = pool.len();
}
}
pub fn stats(&self) -> PoolStats {
self.stats.lock().unwrap().clone()
}
pub fn clear(&self) {
let mut pool = self.pool.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
pool.clear();
stats.current_size = 0;
}
pub fn hit_rate(&self) -> f64 {
let stats = self.stats.lock().unwrap();
let total = stats.hits + stats.misses;
if total == 0 {
0.0
} else {
stats.hits as f64 / total as f64
}
}
pub fn size(&self) -> usize {
self.pool.lock().unwrap().len()
}
}
impl Default for TensorPool {
fn default() -> Self {
Self::new(32)
}
}
pub struct PooledBuffer {
buffer: Option<Vec<f64>>,
pool: Arc<Mutex<Vec<Vec<f64>>>>,
}
impl PooledBuffer {
pub fn new(buffer: Vec<f64>, pool: Arc<Mutex<Vec<Vec<f64>>>>) -> Self {
Self {
buffer: Some(buffer),
pool,
}
}
pub fn as_slice(&self) -> &[f64] {
self.buffer.as_ref().unwrap().as_slice()
}
pub fn as_mut_slice(&mut self) -> &mut [f64] {
self.buffer.as_mut().unwrap().as_mut_slice()
}
pub fn take(mut self) -> Vec<f64> {
self.buffer.take().unwrap()
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
let mut pool = self.pool.lock().unwrap();
pool.push(buffer);
}
}
}
impl std::ops::Deref for PooledBuffer {
type Target = [f64];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl std::ops::DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
pub struct SmallBuffer<const N: usize> {
data: smallvec::SmallVec<[f64; N]>,
}
impl<const N: usize> SmallBuffer<N> {
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: smallvec::SmallVec::with_capacity(capacity),
}
}
pub fn from_slice(slice: &[f64]) -> Self {
Self {
data: smallvec::SmallVec::from_slice(slice),
}
}
pub fn as_slice(&self) -> &[f64] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn to_vec(self) -> Vec<f64> {
self.data.to_vec()
}
pub fn push(&mut self, value: f64) {
self.data.push(value);
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn is_inline(&self) -> bool {
self.data.spilled()
}
}
impl<const N: usize> Default for SmallBuffer<N> {
fn default() -> Self {
Self {
data: smallvec::SmallVec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_reuse() {
let pool = TensorPool::new(10);
let buf1 = pool.get(100);
assert_eq!(buf1.len(), 100);
let stats = pool.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
pool.return_buffer(buf1);
assert_eq!(pool.size(), 1);
let buf2 = pool.get(100);
assert_eq!(buf2.len(), 100);
let stats = pool.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_pool_max_size() {
let pool = TensorPool::new(2);
pool.return_buffer(vec![0.0; 10]);
pool.return_buffer(vec![0.0; 10]);
pool.return_buffer(vec![0.0; 10]);
assert_eq!(pool.size(), 2);
}
#[test]
fn test_pool_hit_rate() {
let pool = TensorPool::new(10);
pool.return_buffer(vec![0.0; 100]);
pool.get(100); pool.get(100);
let hit_rate = pool.hit_rate();
assert!((hit_rate - 0.5).abs() < 0.01);
}
#[test]
fn test_pooled_buffer_auto_return() {
let pool = TensorPool::new(10);
let pool_ref = pool.pool.clone();
{
let buffer = pool.get(100);
let _pooled = PooledBuffer::new(buffer, pool_ref.clone());
}
assert_eq!(pool_ref.lock().unwrap().len(), 1);
}
#[test]
fn test_small_buffer_inline() {
let mut buf: SmallBuffer<32> = SmallBuffer::with_capacity(16);
for i in 0..16 {
buf.push(i as f64);
}
assert_eq!(buf.len(), 16);
assert!(!buf.is_inline()); }
#[test]
fn test_small_buffer_spills() {
let mut buf: SmallBuffer<8> = SmallBuffer::default();
for i in 0..16 {
buf.push(i as f64);
}
assert_eq!(buf.len(), 16);
assert!(buf.is_inline()); }
#[test]
fn test_pool_clear() {
let pool = TensorPool::new(10);
pool.return_buffer(vec![0.0; 100]);
pool.return_buffer(vec![0.0; 100]);
assert_eq!(pool.size(), 2);
pool.clear();
assert_eq!(pool.size(), 0);
}
}