use crate::Scirs2Tensor;
use scirs2_core::ndarray::{ArrayD, IxDyn};
use std::collections::HashMap;
#[derive(Default)]
pub struct TensorPool {
available: HashMap<Vec<usize>, Vec<Scirs2Tensor>>,
pub(crate) allocations: usize,
pub(crate) reuses: usize,
}
impl TensorPool {
pub fn new() -> Self {
TensorPool {
available: HashMap::new(),
allocations: 0,
reuses: 0,
}
}
pub fn get(&mut self, shape: &[usize]) -> Scirs2Tensor {
let shape_key = shape.to_vec();
if let Some(tensors) = self.available.get_mut(&shape_key) {
if let Some(tensor) = tensors.pop() {
self.reuses += 1;
let mut tensor = tensor;
tensor.fill(0.0);
return tensor;
}
}
self.allocations += 1;
ArrayD::zeros(IxDyn(shape))
}
pub fn get_ones(&mut self, shape: &[usize]) -> Scirs2Tensor {
let mut tensor = self.get(shape);
tensor.fill(1.0);
tensor
}
pub fn return_tensor(&mut self, tensor: Scirs2Tensor) {
let shape = tensor.shape().to_vec();
self.available.entry(shape).or_default().push(tensor);
}
pub fn clear(&mut self) {
self.available.clear();
self.allocations = 0;
self.reuses = 0;
}
pub fn available_count(&self) -> usize {
self.available.values().map(|v| v.len()).sum()
}
pub fn stats(&self) -> PoolStats {
PoolStats {
allocations: self.allocations,
reuses: self.reuses,
available: self.available_count(),
reuse_rate: if self.allocations + self.reuses > 0 {
self.reuses as f64 / (self.allocations + self.reuses) as f64
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub allocations: usize,
pub reuses: usize,
pub available: usize,
pub reuse_rate: f64,
}
impl std::fmt::Display for PoolStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PoolStats {{ allocations: {}, reuses: {}, available: {}, reuse_rate: {:.2}% }}",
self.allocations,
self.reuses,
self.available,
self.reuse_rate * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_basic() {
let mut pool = TensorPool::new();
let t1 = pool.get(&[2, 3]);
assert_eq!(t1.shape(), &[2, 3]);
assert_eq!(pool.allocations, 1);
assert_eq!(pool.reuses, 0);
pool.return_tensor(t1);
assert_eq!(pool.available_count(), 1);
let t2 = pool.get(&[2, 3]);
assert_eq!(t2.shape(), &[2, 3]);
assert_eq!(pool.allocations, 1);
assert_eq!(pool.reuses, 1);
}
#[test]
fn test_pool_different_shapes() {
let mut pool = TensorPool::new();
let t1 = pool.get(&[2, 3]);
let t2 = pool.get(&[4, 5]);
pool.return_tensor(t1);
pool.return_tensor(t2);
assert_eq!(pool.available_count(), 2);
let t3 = pool.get(&[2, 3]);
assert_eq!(t3.shape(), &[2, 3]);
assert_eq!(pool.reuses, 1);
let t4 = pool.get(&[4, 5]);
assert_eq!(t4.shape(), &[4, 5]);
assert_eq!(pool.reuses, 2);
}
#[test]
fn test_pool_stats() {
let mut pool = TensorPool::new();
let t1 = pool.get(&[2, 2]);
let t2 = pool.get(&[2, 2]);
let t3 = pool.get(&[2, 2]);
pool.return_tensor(t1);
pool.return_tensor(t2);
pool.return_tensor(t3);
let _t4 = pool.get(&[2, 2]);
let _t5 = pool.get(&[2, 2]);
let stats = pool.stats();
assert_eq!(stats.allocations, 3);
assert_eq!(stats.reuses, 2);
assert_eq!(stats.available, 1);
assert!((stats.reuse_rate - 0.4).abs() < 1e-6); }
#[test]
fn test_get_ones() {
let mut pool = TensorPool::new();
let t = pool.get_ones(&[2, 2]);
assert_eq!(t.shape(), &[2, 2]);
assert_eq!(t[[0, 0]], 1.0);
assert_eq!(t[[1, 1]], 1.0);
}
#[test]
fn test_pool_clear() {
let mut pool = TensorPool::new();
let t1 = pool.get(&[2, 2]);
pool.return_tensor(t1);
assert_eq!(pool.available_count(), 1);
pool.clear();
assert_eq!(pool.available_count(), 0);
assert_eq!(pool.allocations, 0);
assert_eq!(pool.reuses, 0);
}
#[test]
fn test_pool_zeroing() {
let mut pool = TensorPool::new();
let mut t1 = pool.get(&[2, 2]);
t1.fill(5.0);
pool.return_tensor(t1);
let t2 = pool.get(&[2, 2]);
assert_eq!(t2[[0, 0]], 0.0);
assert_eq!(t2[[1, 1]], 0.0);
}
}