use crate::ndarray_ext::NdArray;
use crate::Float;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::any::TypeId;
use std::collections::HashMap;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PoolStats {
pub n_acquired: u64,
pub n_released: u64,
pub n_allocated: u64,
pub n_reused: u64,
pub pool_bytes: u64,
pub n_buckets: u64,
pub n_pooled_buffers: u64,
}
impl fmt::Display for PoolStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"PoolStats {{ acquired: {}, released: {}, allocated: {}, reused: {}, \
pool_bytes: {}, buckets: {}, pooled_buffers: {} }}",
self.n_acquired,
self.n_released,
self.n_allocated,
self.n_reused,
self.pool_bytes,
self.n_buckets,
self.n_pooled_buffers,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct BucketKey {
shape: Vec<usize>,
type_id: TypeId,
}
pub struct TensorPool {
inner: Arc<TensorPoolInner>,
}
struct TensorPoolInner {
buckets: Mutex<HashMap<BucketKey, Vec<ErasedArray>>>,
n_acquired: AtomicU64,
n_released: AtomicU64,
n_allocated: AtomicU64,
n_reused: AtomicU64,
max_per_bucket: usize,
}
struct ErasedArray {
data: Vec<u8>,
shape: Vec<usize>,
elem_size: usize,
}
impl ErasedArray {
fn byte_size(&self) -> usize {
self.data.len()
}
}
impl Clone for TensorPool {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
unsafe impl Send for TensorPoolInner {}
unsafe impl Sync for TensorPoolInner {}
impl fmt::Debug for TensorPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let stats = self.stats();
f.debug_struct("TensorPool").field("stats", &stats).finish()
}
}
impl Default for TensorPool {
fn default() -> Self {
Self::new()
}
}
impl TensorPool {
pub fn new() -> Self {
Self::with_max_per_bucket(0)
}
pub fn with_max_per_bucket(max: usize) -> Self {
Self {
inner: Arc::new(TensorPoolInner {
buckets: Mutex::new(HashMap::new()),
n_acquired: AtomicU64::new(0),
n_released: AtomicU64::new(0),
n_allocated: AtomicU64::new(0),
n_reused: AtomicU64::new(0),
max_per_bucket: max,
}),
}
}
pub fn acquire<F: Float>(&self, shape: &[usize]) -> PooledArray<F> {
self.inner.n_acquired.fetch_add(1, Ordering::Relaxed);
let key = BucketKey {
shape: shape.to_vec(),
type_id: TypeId::of::<F>(),
};
let maybe_erased = {
let mut buckets = self.inner.buckets.lock();
buckets.get_mut(&key).and_then(|v| v.pop())
};
let array = if let Some(erased) = maybe_erased {
self.inner.n_reused.fetch_add(1, Ordering::Relaxed);
erased_to_ndarray::<F>(erased)
} else {
self.inner.n_allocated.fetch_add(1, Ordering::Relaxed);
NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(shape))
};
PooledArray {
array: Some(array),
pool: self.clone(),
}
}
pub fn release<F: Float>(&self, array: NdArray<F>) {
self.inner.n_released.fetch_add(1, Ordering::Relaxed);
self.release_inner::<F>(array);
}
fn release_inner<F: Float>(&self, array: NdArray<F>) {
let key = BucketKey {
shape: array.shape().to_vec(),
type_id: TypeId::of::<F>(),
};
let erased = ndarray_to_erased(array);
let mut buckets = self.inner.buckets.lock();
let bucket = buckets.entry(key).or_default();
if self.inner.max_per_bucket == 0 || bucket.len() < self.inner.max_per_bucket {
bucket.push(erased);
}
}
pub fn clear(&self) {
let mut buckets = self.inner.buckets.lock();
buckets.clear();
}
pub fn stats(&self) -> PoolStats {
let buckets = self.inner.buckets.lock();
let mut pool_bytes: u64 = 0;
let mut n_pooled_buffers: u64 = 0;
for bucket in buckets.values() {
for erased in bucket {
pool_bytes = pool_bytes.saturating_add(erased.byte_size() as u64);
}
n_pooled_buffers = n_pooled_buffers.saturating_add(bucket.len() as u64);
}
PoolStats {
n_acquired: self.inner.n_acquired.load(Ordering::Relaxed),
n_released: self.inner.n_released.load(Ordering::Relaxed),
n_allocated: self.inner.n_allocated.load(Ordering::Relaxed),
n_reused: self.inner.n_reused.load(Ordering::Relaxed),
pool_bytes,
n_buckets: buckets.len() as u64,
n_pooled_buffers,
}
}
pub fn reset_stats(&self) {
self.inner.n_acquired.store(0, Ordering::Relaxed);
self.inner.n_released.store(0, Ordering::Relaxed);
self.inner.n_allocated.store(0, Ordering::Relaxed);
self.inner.n_reused.store(0, Ordering::Relaxed);
}
}
fn ndarray_to_erased<F: Float>(array: NdArray<F>) -> ErasedArray {
let shape = array.shape().to_vec();
let elem_size = std::mem::size_of::<F>();
let vec_f: Vec<F> = array.into_raw_vec_and_offset().0;
let len = vec_f.len();
let cap = vec_f.capacity();
let ptr = vec_f.as_ptr();
std::mem::forget(vec_f);
let data = unsafe { Vec::from_raw_parts(ptr as *mut u8, len * elem_size, cap * elem_size) };
ErasedArray {
data,
shape,
elem_size,
}
}
fn erased_to_ndarray<F: Float>(erased: ErasedArray) -> NdArray<F> {
let elem_size = std::mem::size_of::<F>();
debug_assert_eq!(erased.elem_size, elem_size);
let byte_len = erased.data.len();
let byte_cap = erased.data.capacity();
let ptr = erased.data.as_ptr();
std::mem::forget(erased.data);
let f_len = byte_len / elem_size;
let f_cap = byte_cap / elem_size;
let mut vec_f: Vec<F> = unsafe { Vec::from_raw_parts(ptr as *mut F, f_len, f_cap) };
for v in vec_f.iter_mut() {
*v = F::zero();
}
NdArray::<F>::from_shape_vec(scirs2_core::ndarray::IxDyn(&erased.shape), vec_f).unwrap_or_else(
|_| {
NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(&erased.shape))
},
)
}
pub struct PooledArray<F: Float> {
array: Option<NdArray<F>>,
pool: TensorPool,
}
impl<F: Float> PooledArray<F> {
pub fn into_inner(mut self) -> NdArray<F> {
self.array
.take()
.expect("PooledArray inner array already taken")
}
pub fn shape(&self) -> &[usize] {
match &self.array {
Some(a) => a.shape(),
None => &[],
}
}
}
impl<F: Float> Deref for PooledArray<F> {
type Target = NdArray<F>;
fn deref(&self) -> &Self::Target {
self.array
.as_ref()
.expect("PooledArray inner array already taken")
}
}
impl<F: Float> DerefMut for PooledArray<F> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.array
.as_mut()
.expect("PooledArray inner array already taken")
}
}
impl<F: Float> Drop for PooledArray<F> {
fn drop(&mut self) {
if let Some(array) = self.array.take() {
self.pool.inner.n_released.fetch_add(1, Ordering::Relaxed);
self.pool.release_inner::<F>(array);
}
}
}
impl<F: Float> fmt::Debug for PooledArray<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.array {
Some(a) => write!(f, "PooledArray(shape={:?})", a.shape()),
None => write!(f, "PooledArray(<taken>)"),
}
}
}
static GLOBAL_POOL: Lazy<TensorPool> = Lazy::new(TensorPool::new);
pub fn global_pool() -> &'static TensorPool {
&GLOBAL_POOL
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_acquire_returns_zero_array() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[3, 4]);
assert_eq!(buf.shape(), &[3, 4]);
for &v in buf.iter() {
assert!((v - 0.0).abs() < f64::EPSILON);
}
}
#[test]
fn test_acquire_release_reuse_cycle() {
let pool = TensorPool::new();
let buf1: PooledArray<f64> = pool.acquire(&[8, 16]);
let stats1 = pool.stats();
assert_eq!(stats1.n_acquired, 1);
assert_eq!(stats1.n_allocated, 1);
assert_eq!(stats1.n_reused, 0);
drop(buf1);
let stats2 = pool.stats();
assert_eq!(stats2.n_released, 1);
assert_eq!(stats2.n_pooled_buffers, 1);
let buf2: PooledArray<f64> = pool.acquire(&[8, 16]);
let stats3 = pool.stats();
assert_eq!(stats3.n_acquired, 2);
assert_eq!(stats3.n_allocated, 1); assert_eq!(stats3.n_reused, 1);
for &v in buf2.iter() {
assert!((v - 0.0).abs() < f64::EPSILON);
}
}
#[test]
fn test_different_shapes_get_different_buckets() {
let pool = TensorPool::new();
let a: PooledArray<f64> = pool.acquire(&[2, 3]);
let b: PooledArray<f64> = pool.acquire(&[3, 2]);
drop(a);
drop(b);
let stats = pool.stats();
assert_eq!(stats.n_buckets, 2);
assert_eq!(stats.n_pooled_buffers, 2);
}
#[test]
fn test_different_types_get_different_buckets() {
let pool = TensorPool::new();
let a: PooledArray<f32> = pool.acquire(&[4, 4]);
let b: PooledArray<f64> = pool.acquire(&[4, 4]);
drop(a);
drop(b);
let stats = pool.stats();
assert_eq!(stats.n_buckets, 2);
}
#[test]
fn test_manual_release() {
let pool = TensorPool::new();
let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[5, 5]));
pool.release(arr);
let stats = pool.stats();
assert_eq!(stats.n_released, 1);
assert_eq!(stats.n_pooled_buffers, 1);
let buf: PooledArray<f64> = pool.acquire(&[5, 5]);
let stats2 = pool.stats();
assert_eq!(stats2.n_reused, 1);
assert_eq!(stats2.n_allocated, 0);
drop(buf);
}
#[test]
fn test_clear_empties_pool() {
let pool = TensorPool::new();
let a: PooledArray<f64> = pool.acquire(&[10, 10]);
drop(a);
assert_eq!(pool.stats().n_pooled_buffers, 1);
pool.clear();
assert_eq!(pool.stats().n_pooled_buffers, 0);
assert_eq!(pool.stats().n_buckets, 0);
}
#[test]
fn test_into_inner_does_not_return_to_pool() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[3, 3]);
let _arr: NdArray<f64> = buf.into_inner();
let stats = pool.stats();
assert_eq!(stats.n_released, 0);
assert_eq!(stats.n_pooled_buffers, 0);
}
#[test]
fn test_stats_display() {
let pool = TensorPool::new();
let _a: PooledArray<f64> = pool.acquire(&[2]);
let display = format!("{}", pool.stats());
assert!(display.contains("acquired: 1"));
}
#[test]
fn test_pool_stats_pool_bytes() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[100]);
drop(buf);
let stats = pool.stats();
assert_eq!(stats.pool_bytes, 800);
}
#[test]
fn test_reset_stats() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[4]);
drop(buf);
pool.reset_stats();
let stats = pool.stats();
assert_eq!(stats.n_acquired, 0);
assert_eq!(stats.n_released, 0);
assert_eq!(stats.n_allocated, 0);
assert_eq!(stats.n_reused, 0);
assert_eq!(stats.n_pooled_buffers, 1);
}
#[test]
fn test_max_per_bucket() {
let pool = TensorPool::with_max_per_bucket(2);
for _ in 0..5 {
let buf: PooledArray<f64> = pool.acquire(&[10]);
drop(buf);
}
assert!(pool.stats().n_pooled_buffers <= 2);
}
#[test]
fn test_global_pool_accessible() {
let pool = global_pool();
let _buf: PooledArray<f64> = pool.acquire(&[1]);
}
#[test]
fn test_deref_mut() {
let pool = TensorPool::new();
let mut buf: PooledArray<f64> = pool.acquire(&[3]);
buf[[0]] = 42.0;
assert!((buf[[0]] - 42.0).abs() < f64::EPSILON);
}
#[test]
fn test_debug_format() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[2, 3]);
let dbg = format!("{:?}", buf);
assert!(dbg.contains("PooledArray"));
assert!(dbg.contains("[2, 3]"));
}
#[test]
fn test_pool_debug_format() {
let pool = TensorPool::new();
let dbg = format!("{:?}", pool);
assert!(dbg.contains("TensorPool"));
}
#[test]
fn test_pool_clone_shares_state() {
let pool1 = TensorPool::new();
let pool2 = pool1.clone();
let buf: PooledArray<f64> = pool1.acquire(&[4]);
drop(buf);
let stats = pool2.stats();
assert_eq!(stats.n_acquired, 1);
assert_eq!(stats.n_released, 1);
assert_eq!(stats.n_pooled_buffers, 1);
}
#[test]
fn test_scalar_shape() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[]);
assert_eq!(buf.shape(), &[] as &[usize]);
drop(buf);
let buf2: PooledArray<f64> = pool.acquire(&[]);
assert_eq!(pool.stats().n_reused, 1);
drop(buf2);
}
#[test]
fn test_f32_pool() {
let pool = TensorPool::new();
let buf: PooledArray<f32> = pool.acquire(&[5, 5]);
assert_eq!(buf.shape(), &[5, 5]);
for &v in buf.iter() {
assert!((v - 0.0f32).abs() < f32::EPSILON);
}
drop(buf);
let stats = pool.stats();
assert_eq!(stats.pool_bytes, 100); }
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
use std::thread;
let pool = Arc::new(TensorPool::new());
let n_threads = 8;
let n_ops_per_thread = 100;
let mut handles = Vec::with_capacity(n_threads);
for _ in 0..n_threads {
let pool = Arc::clone(&pool);
handles.push(thread::spawn(move || {
for i in 0..n_ops_per_thread {
let shape = match i % 3 {
0 => vec![16, 32],
1 => vec![32, 16],
_ => vec![64],
};
let mut buf: PooledArray<f64> = pool.acquire(&shape);
if let Some(v) = buf.iter_mut().next() {
*v = 1.0;
}
drop(buf);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let stats = pool.stats();
assert_eq!(stats.n_acquired, (n_threads * n_ops_per_thread) as u64,);
assert_eq!(stats.n_acquired, stats.n_allocated + stats.n_reused);
assert_eq!(stats.n_released, stats.n_acquired);
}
#[test]
fn test_concurrent_mixed_types() {
use std::sync::Arc;
use std::thread;
let pool = Arc::new(TensorPool::new());
let n_threads = 4;
let n_ops = 50;
let mut handles = Vec::with_capacity(n_threads * 2);
for _ in 0..n_threads {
let pool = Arc::clone(&pool);
handles.push(thread::spawn(move || {
for _ in 0..n_ops {
let buf: PooledArray<f64> = pool.acquire(&[8, 8]);
drop(buf);
}
}));
}
for _ in 0..n_threads {
let pool = Arc::clone(&pool);
handles.push(thread::spawn(move || {
for _ in 0..n_ops {
let buf: PooledArray<f32> = pool.acquire(&[8, 8]);
drop(buf);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let stats = pool.stats();
let total_ops = (n_threads * 2 * n_ops) as u64;
assert_eq!(stats.n_acquired, total_ops);
}
#[test]
fn test_large_shape() {
let pool = TensorPool::new();
let buf: PooledArray<f64> = pool.acquire(&[256, 256]);
assert_eq!(buf.shape(), &[256, 256]);
assert_eq!(buf.len(), 256 * 256);
drop(buf);
let stats = pool.stats();
assert_eq!(stats.pool_bytes, (256 * 256 * 8) as u64);
}
#[test]
fn test_reused_buffer_is_zeroed() {
let pool = TensorPool::new();
let mut buf: PooledArray<f64> = pool.acquire(&[4]);
buf[[0]] = 99.0;
buf[[1]] = 88.0;
buf[[2]] = 77.0;
buf[[3]] = 66.0;
drop(buf);
let buf2: PooledArray<f64> = pool.acquire(&[4]);
for &v in buf2.iter() {
assert!((v - 0.0).abs() < f64::EPSILON, "expected zero, got {}", v);
}
}
#[test]
fn test_multiple_buffers_same_shape() {
let pool = TensorPool::new();
for _ in 0..5 {
let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[3]));
pool.release(arr);
}
assert_eq!(pool.stats().n_pooled_buffers, 5);
let mut held: Vec<PooledArray<f64>> = Vec::with_capacity(5);
for i in 0..5 {
held.push(pool.acquire(&[3]));
assert_eq!(pool.stats().n_pooled_buffers, 4 - i as u64);
}
drop(held);
assert_eq!(pool.stats().n_pooled_buffers, 5);
}
}