use scirs2_core::ndarray::Array1;
use std::cell::RefCell;
use std::collections::VecDeque;
#[derive(Debug)]
pub struct ArrayPool {
array_size: usize,
max_capacity: usize,
pool: RefCell<VecDeque<Array1<f32>>>,
stats: RefCell<PoolStats>,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PoolStats {
pub hits: u64,
pub misses: u64,
pub returns: u64,
pub drops: u64,
}
impl PoolStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64 * 100.0
}
}
}
impl ArrayPool {
pub fn new(array_size: usize, max_capacity: usize) -> Self {
Self {
array_size,
max_capacity,
pool: RefCell::new(VecDeque::with_capacity(max_capacity)),
stats: RefCell::new(PoolStats::default()),
}
}
pub fn array_size(&self) -> usize {
self.array_size
}
pub fn size(&self) -> usize {
self.pool.borrow().len()
}
pub fn stats(&self) -> PoolStats {
*self.stats.borrow()
}
pub fn acquire(&self) -> Array1<f32> {
let mut pool = self.pool.borrow_mut();
let mut stats = self.stats.borrow_mut();
if let Some(array) = pool.pop_front() {
stats.hits += 1;
array
} else {
stats.misses += 1;
Array1::zeros(self.array_size)
}
}
pub fn acquire_filled(&self, value: f32) -> Array1<f32> {
let mut array = self.acquire();
array.fill(value);
array
}
pub fn acquire_zeros(&self) -> Array1<f32> {
let mut array = self.acquire();
array.fill(0.0);
array
}
pub fn release(&self, array: Array1<f32>) {
if array.len() != self.array_size {
return;
}
let mut pool = self.pool.borrow_mut();
let mut stats = self.stats.borrow_mut();
if pool.len() < self.max_capacity {
pool.push_back(array);
stats.returns += 1;
} else {
stats.drops += 1;
}
}
pub fn clear(&self) {
self.pool.borrow_mut().clear();
}
pub fn warm(&self) {
let mut pool = self.pool.borrow_mut();
while pool.len() < self.max_capacity {
pool.push_back(Array1::zeros(self.array_size));
}
}
}
pub struct PooledArray<'a> {
array: Option<Array1<f32>>,
pool: &'a ArrayPool,
}
impl<'a> PooledArray<'a> {
pub fn new(pool: &'a ArrayPool) -> Self {
Self {
array: Some(pool.acquire()),
pool,
}
}
pub fn zeros(pool: &'a ArrayPool) -> Self {
Self {
array: Some(pool.acquire_zeros()),
pool,
}
}
pub fn as_array(&self) -> &Array1<f32> {
self.array.as_ref().unwrap()
}
pub fn as_array_mut(&mut self) -> &mut Array1<f32> {
self.array.as_mut().unwrap()
}
pub fn take(mut self) -> Array1<f32> {
self.array.take().unwrap()
}
}
impl Drop for PooledArray<'_> {
fn drop(&mut self) {
if let Some(array) = self.array.take() {
self.pool.release(array);
}
}
}
impl std::ops::Deref for PooledArray<'_> {
type Target = Array1<f32>;
fn deref(&self) -> &Self::Target {
self.as_array()
}
}
impl std::ops::DerefMut for PooledArray<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_array_mut()
}
}
#[derive(Debug)]
pub struct MultiArrayPool {
pools: Vec<ArrayPool>,
sizes: Vec<usize>,
}
impl MultiArrayPool {
pub fn new() -> Self {
Self::with_sizes(&[32, 64, 128, 256, 512, 1024, 2048, 4096], 8)
}
pub fn with_sizes(sizes: &[usize], capacity_per_size: usize) -> Self {
let mut sorted_sizes: Vec<usize> = sizes.to_vec();
sorted_sizes.sort_unstable();
let pools = sorted_sizes
.iter()
.map(|&size| ArrayPool::new(size, capacity_per_size))
.collect();
Self {
pools,
sizes: sorted_sizes,
}
}
pub fn acquire(&self, min_size: usize) -> Array1<f32> {
if let Some(idx) = self.sizes.iter().position(|&s| s >= min_size) {
self.pools[idx].acquire()
} else {
Array1::zeros(min_size)
}
}
pub fn acquire_zeros(&self, min_size: usize) -> Array1<f32> {
let mut arr = self.acquire(min_size);
arr.fill(0.0);
arr
}
pub fn release(&self, array: Array1<f32>) {
let size = array.len();
if let Some(idx) = self.sizes.iter().position(|&s| s == size) {
self.pools[idx].release(array);
}
}
pub fn stats(&self) -> PoolStats {
let mut total = PoolStats::default();
for pool in &self.pools {
let s = pool.stats();
total.hits += s.hits;
total.misses += s.misses;
total.returns += s.returns;
total.drops += s.drops;
}
total
}
pub fn warm(&self) {
for pool in &self.pools {
pool.warm();
}
}
pub fn clear(&self) {
for pool in &self.pools {
pool.clear();
}
}
}
impl Default for MultiArrayPool {
fn default() -> Self {
Self::new()
}
}
thread_local! {
static LOCAL_POOL: MultiArrayPool = MultiArrayPool::new();
}
pub fn tl_acquire(min_size: usize) -> Array1<f32> {
LOCAL_POOL.with(|pool| pool.acquire(min_size))
}
pub fn tl_acquire_zeros(min_size: usize) -> Array1<f32> {
LOCAL_POOL.with(|pool| pool.acquire_zeros(min_size))
}
pub fn tl_release(array: Array1<f32>) {
LOCAL_POOL.with(|pool| pool.release(array));
}
pub fn tl_stats() -> PoolStats {
LOCAL_POOL.with(|pool| pool.stats())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_array_pool_basic() {
let pool = ArrayPool::new(64, 4);
let arr1 = pool.acquire();
assert_eq!(arr1.len(), 64);
pool.release(arr1);
assert_eq!(pool.size(), 1);
let arr2 = pool.acquire();
assert_eq!(arr2.len(), 64);
assert_eq!(pool.size(), 0);
let stats = pool.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_array_pool_capacity() {
let pool = ArrayPool::new(32, 2);
let a1 = pool.acquire();
let a2 = pool.acquire();
let a3 = pool.acquire();
pool.release(a1);
pool.release(a2);
pool.release(a3);
let stats = pool.stats();
assert_eq!(stats.returns, 2);
assert_eq!(stats.drops, 1);
}
#[test]
fn test_pooled_array_scope() {
let pool = ArrayPool::new(32, 4);
{
let mut arr = PooledArray::zeros(&pool);
arr[0] = 1.0;
assert_eq!(arr.len(), 32);
}
assert_eq!(pool.size(), 1);
}
#[test]
fn test_pooled_array_take() {
let pool = ArrayPool::new(32, 4);
let owned = {
let arr = PooledArray::zeros(&pool);
arr.take() };
assert_eq!(owned.len(), 32);
assert_eq!(pool.size(), 0); }
#[test]
fn test_multi_pool() {
let pool = MultiArrayPool::with_sizes(&[32, 64, 128], 4);
let arr = pool.acquire(50);
assert_eq!(arr.len(), 64);
pool.release(arr);
assert_eq!(pool.stats().returns, 1);
}
#[test]
fn test_pool_warm() {
let pool = ArrayPool::new(64, 4);
pool.warm();
assert_eq!(pool.size(), 4);
for _ in 0..4 {
let _ = pool.acquire();
}
let stats = pool.stats();
assert_eq!(stats.hits, 4);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_hit_rate() {
let stats = PoolStats {
hits: 80,
misses: 20,
returns: 0,
drops: 0,
};
assert!((stats.hit_rate() - 80.0).abs() < 0.01);
}
#[test]
fn test_thread_local_pool() {
let arr = tl_acquire_zeros(100);
assert!(arr.len() >= 100);
tl_release(arr);
let stats = tl_stats();
assert!(stats.misses >= 1);
}
}