use core::fmt;
use core::marker::PhantomData;
#[cfg(feature = "tensor-pool")]
use crate::tensor::error::TensorError;
#[cfg(feature = "tensor-pool")]
use crate::tensor::traits::{TensorBase, TensorOps};
#[cfg(feature = "tensor-pool")]
use crate::tensor::dense::DenseTensor;
#[cfg(feature = "tensor-pool")]
use smallvec::SmallVec;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub initial_capacity: usize,
pub max_capacity: usize,
pub preallocate: bool,
pub alignment: usize,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
initial_capacity: 16,
max_capacity: 1024,
preallocate: false,
alignment: 64,
}
}
}
impl PoolConfig {
pub fn new(initial_capacity: usize, max_capacity: usize) -> Self {
Self {
initial_capacity,
max_capacity,
..Default::default()
}
}
pub fn with_preallocate(mut self, preallocate: bool) -> Self {
self.preallocate = preallocate;
self
}
pub fn with_alignment(mut self, alignment: usize) -> Self {
self.alignment = alignment;
self
}
}
#[cfg(feature = "tensor-pool")]
pub struct TensorPool {
free_list: Vec<DenseTensor>,
allocated: bitvec::vec::BitVec,
config: PoolConfig,
stats: PoolStats,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub total_allocations: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub current_used: usize,
pub peak_used: usize,
}
impl PoolStats {
pub fn hit_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
self.pool_hits as f64 / self.total_allocations as f64
}
}
pub fn miss_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
self.pool_misses as f64 / self.total_allocations as f64
}
}
pub fn allocation_reduction(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
self.pool_hits as f64 / self.total_allocations as f64 * 100.0
}
}
}
#[cfg(feature = "tensor-pool")]
impl TensorPool {
pub fn new(config: PoolConfig) -> Self {
let preallocate = config.preallocate;
let mut pool = Self {
free_list: Vec::with_capacity(config.initial_capacity),
allocated: bitvec::vec::BitVec::new(),
config,
stats: PoolStats::default(),
};
if preallocate {
pool.preallocate();
}
pool
}
pub fn preallocate(&mut self) {
for _ in 0..self.config.initial_capacity {
self.free_list.push(DenseTensor::zeros(vec![1]));
}
}
pub fn acquire(&mut self, shape: Vec<usize>) -> PooledTensor<'_> {
self.stats.total_allocations += 1;
if let Some(mut tensor) = self.free_list.pop() {
if tensor.numel() >= shape.iter().product::<usize>() {
tensor = tensor.reshape(&shape);
self.stats.pool_hits += 1;
} else {
self.stats.pool_misses += 1;
tensor = DenseTensor::zeros(shape);
}
self.stats.current_used += 1;
if self.stats.current_used > self.stats.peak_used {
self.stats.peak_used = self.stats.current_used;
}
PooledTensor::new(tensor, self)
} else {
self.stats.pool_misses += 1;
self.stats.current_used += 1;
if self.stats.current_used > self.stats.peak_used {
self.stats.peak_used = self.stats.current_used;
}
PooledTensor::new(DenseTensor::zeros(shape), self)
}
}
fn recycle(&mut self, mut tensor: DenseTensor) {
if self.free_list.len() < self.config.max_capacity {
for val in tensor.data_mut() {
*val = 0.0;
}
self.free_list.push(tensor);
}
self.stats.current_used = self.stats.current_used.saturating_sub(1);
}
pub fn stats(&self) -> &PoolStats {
&self.stats
}
pub fn clear(&mut self) {
self.free_list.clear();
self.allocated.clear();
self.stats = PoolStats::default();
}
pub fn utilization(&self) -> f64 {
if self.config.max_capacity == 0 {
0.0
} else {
self.free_list.len() as f64 / self.config.max_capacity as f64
}
}
}
#[cfg(feature = "tensor-pool")]
impl fmt::Debug for TensorPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorPool")
.field("free_count", &self.free_list.len())
.field("config", &self.config)
.field("stats", &self.stats)
.finish()
}
}
#[cfg(feature = "tensor-pool")]
pub struct PooledTensor<'pool> {
tensor: DenseTensor,
pool: *mut TensorPool,
_marker: PhantomData<&'pool mut TensorPool>,
}
#[cfg(feature = "tensor-pool")]
unsafe impl<'pool> Send for PooledTensor<'pool> {}
#[cfg(feature = "tensor-pool")]
unsafe impl<'pool> Sync for PooledTensor<'pool> {}
#[cfg(feature = "tensor-pool")]
impl<'pool> PooledTensor<'pool> {
fn new(tensor: DenseTensor, pool: &'pool mut TensorPool) -> Self {
Self {
tensor,
pool: pool as *mut TensorPool,
_marker: PhantomData,
}
}
pub fn tensor(&self) -> &DenseTensor {
&self.tensor
}
pub fn tensor_mut(&mut self) -> &mut DenseTensor {
&mut self.tensor
}
pub fn into_inner(mut self) -> DenseTensor {
let tensor = core::mem::take(&mut self.tensor);
core::mem::forget(self); tensor
}
}
#[cfg(feature = "tensor-pool")]
impl<'pool> core::ops::Deref for PooledTensor<'pool> {
type Target = DenseTensor;
fn deref(&self) -> &Self::Target {
&self.tensor
}
}
#[cfg(feature = "tensor-pool")]
impl<'pool> core::ops::DerefMut for PooledTensor<'pool> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.tensor
}
}
#[cfg(feature = "tensor-pool")]
impl<'pool> Drop for PooledTensor<'pool> {
fn drop(&mut self) {
unsafe {
if let Some(pool) = self.pool.as_mut() {
pool.recycle(core::mem::take(&mut self.tensor));
}
}
}
}
#[cfg(feature = "tensor-pool")]
impl<'pool> Clone for PooledTensor<'pool> {
fn clone(&self) -> Self {
PooledTensor::new(self.tensor.clone(), unsafe { &mut *self.pool })
}
}
#[cfg(feature = "tensor-autograd")]
pub struct GradientCheckpoint {
saved_tensors: std::collections::HashMap<usize, DenseTensor>,
max_saved: usize,
memory_used: usize,
memory_budget: usize,
}
#[cfg(feature = "tensor-autograd")]
impl GradientCheckpoint {
pub fn new(memory_budget: usize) -> Self {
Self {
saved_tensors: std::collections::HashMap::new(),
max_saved: 100,
memory_used: 0,
memory_budget,
}
}
pub fn save(&mut self, id: usize, tensor: DenseTensor) {
let size = tensor.nbytes();
if self.memory_used + size > self.memory_budget {
self.evict_oldest();
}
if self.saved_tensors.len() < self.max_saved {
self.memory_used += size;
self.saved_tensors.insert(id, tensor);
}
}
pub fn get(&self, id: usize) -> Result<&DenseTensor, TensorError> {
self.saved_tensors.get(&id).ok_or_else(|| TensorError::MatrixError {
message: format!("Tensor with id {} not found in pool", id),
})
}
pub fn take(&mut self, id: usize) -> Result<DenseTensor, TensorError> {
self.saved_tensors.remove(&id).ok_or_else(|| TensorError::MatrixError {
message: format!("Tensor with id {} not found in pool", id),
}).inspect(|tensor| {
self.memory_used -= tensor.nbytes();
})
}
pub fn clear(&mut self) {
self.saved_tensors.clear();
self.memory_used = 0;
}
pub fn memory_used(&self) -> usize {
self.memory_used
}
pub fn len(&self) -> usize {
self.saved_tensors.len()
}
pub fn is_empty(&self) -> bool {
self.saved_tensors.is_empty()
}
fn evict_oldest(&mut self) {
if let Some((&id, _)) = self.saved_tensors.iter().next() {
let _ = self.take(id);
}
}
}
#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn test_pool_creation() {
let config = PoolConfig::new(8, 64);
let pool = TensorPool::new(config);
assert_eq!(pool.free_list.len(), 0);
assert_eq!(pool.stats.total_allocations, 0);
}
#[test]
fn test_pool_acquire() {
let config = PoolConfig::new(4, 16);
let mut pool = TensorPool::new(config);
{
let tensor = pool.acquire(vec![10]);
assert_eq!(tensor.shape(), &[10]);
}
assert_eq!(pool.free_list.len(), 1);
assert_eq!(pool.stats.total_allocations, 1);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ShapeKey {
shape: SmallVec<[usize; 4]>,
ndim: usize,
}
impl ShapeKey {
fn new(shape: &[usize]) -> Self {
Self {
shape: shape.into(),
ndim: shape.len(),
}
}
}
#[derive(Clone)]
struct ArenaSlice {
ptr: *mut f64,
#[allow(dead_code)]
len: usize,
shape: SmallVec<[usize; 4]>,
borrowed: bool,
}
pub struct ArenaTensor {
ptr: *mut f64,
len: usize,
shape: SmallVec<[usize; 4]>,
borrowed: bool,
}
#[cfg(feature = "tensor-pool")]
pub struct TensorArena {
arena: bumpalo::Bump,
free_lists: std::collections::HashMap<ShapeKey, Vec<ArenaSlice>>,
stats: ArenaStats,
capacity: usize,
}
#[derive(Debug, Clone, Default)]
pub struct ArenaStats {
pub allocation_count: usize,
pub deallocation_count: usize,
pub reuse_count: usize,
pub total_bytes_allocated: usize,
pub bytes_in_use: usize,
pub peak_bytes_in_use: usize,
}
impl ArenaStats {
pub fn reuse_ratio(&self) -> f64 {
if self.allocation_count == 0 {
0.0
} else {
self.reuse_count as f64 / self.allocation_count as f64
}
}
pub fn memory_efficiency(&self) -> f64 {
if self.total_bytes_allocated == 0 {
0.0
} else {
self.peak_bytes_in_use as f64 / self.total_bytes_allocated as f64
}
}
}
#[cfg(feature = "tensor-pool")]
impl TensorArena {
pub fn new() -> Self {
Self::with_capacity(16 * 1024 * 1024)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
arena: bumpalo::Bump::with_capacity(capacity),
free_lists: std::collections::HashMap::new(),
stats: ArenaStats::default(),
capacity,
}
}
pub fn allocate(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError>
{
let key = ShapeKey::new(shape);
let size = shape.iter().product::<usize>();
if let Some(slices) = self.free_lists.get_mut(&key) {
if let Some(mut slice) = slices.pop() {
self.stats.reuse_count += 1;
self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
self.update_peak();
slice.borrowed = true;
return Ok(ArenaTensor {
ptr: slice.ptr,
len: size,
shape: slice.shape.clone(),
borrowed: true,
});
}
}
let layout = std::alloc::Layout::from_size_align(
size * core::mem::size_of::<f64>(),
64, ).map_err(|e| crate::tensor::error::TensorError::AllocationError {
message: format!("Failed to create layout: {}", e),
})?;
let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
self.stats.allocation_count += 1;
self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
self.update_peak();
Ok(ArenaTensor {
ptr,
len: size,
shape: key.shape,
borrowed: true,
})
}
pub fn deallocate(&mut self, mut tensor: ArenaTensor) {
if tensor.borrowed {
tensor.borrowed = false;
}
let key = ShapeKey::new(&tensor.shape);
let slice = ArenaSlice {
ptr: tensor.ptr,
len: tensor.len,
shape: tensor.shape.clone(),
borrowed: false,
};
self.free_lists
.entry(key)
.or_default()
.push(slice);
self.stats.deallocation_count += 1;
self.stats.bytes_in_use -= tensor.len * core::mem::size_of::<f64>();
}
pub fn reset(&mut self) {
self.arena.reset();
self.free_lists.clear();
self.stats = ArenaStats::default();
self.stats.bytes_in_use = 0;
}
pub fn stats(&self) -> &ArenaStats {
&self.stats
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn bytes_in_use(&self) -> usize {
self.stats.bytes_in_use
}
fn update_peak(&mut self) {
if self.stats.bytes_in_use > self.stats.peak_bytes_in_use {
self.stats.peak_bytes_in_use = self.stats.bytes_in_use;
}
}
pub fn allocate_fresh(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError> {
let size = shape.iter().product::<usize>();
let layout = std::alloc::Layout::from_size_align(
size * core::mem::size_of::<f64>(),
64,
).map_err(|e| crate::tensor::error::TensorError::AllocationError {
message: format!("Failed to create layout: {}", e),
})?;
let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
self.stats.allocation_count += 1;
self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
self.update_peak();
Ok(ArenaTensor {
ptr,
len: size,
shape: shape.into(),
borrowed: true,
})
}
}
#[cfg(feature = "tensor-pool")]
impl Default for TensorArena {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "tensor-pool")]
impl fmt::Debug for TensorArena {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorArena")
.field("capacity", &self.capacity)
.field("free_lists_count", &self.free_lists.len())
.field("stats", &self.stats)
.finish()
}
}
impl ArenaTensor {
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_ptr(&self) -> *const f64 {
self.ptr
}
pub fn as_mut_ptr(&mut self) -> *mut f64 {
self.ptr
}
pub unsafe fn as_slice(&self) -> &[f64] {
std::slice::from_raw_parts(self.ptr, self.len)
}
pub unsafe fn as_mut_slice(&mut self) -> &mut [f64] {
std::slice::from_raw_parts_mut(self.ptr, self.len)
}
pub unsafe fn zero(&mut self) {
std::ptr::write_bytes(self.ptr, 0, self.len);
}
}
impl Clone for ArenaTensor {
fn clone(&self) -> Self {
unsafe {
let layout = std::alloc::Layout::from_size_align(
self.len * core::mem::size_of::<f64>(),
64,
).unwrap();
let new_ptr = std::alloc::alloc(layout) as *mut f64;
std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
ArenaTensor {
ptr: new_ptr,
len: self.len,
shape: self.shape.clone(),
borrowed: false, }
}
}
}
impl Drop for ArenaTensor {
fn drop(&mut self) {
}
}
#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
mod arena_tests {
use super::*;
#[test]
fn test_arena_creation() {
let arena = TensorArena::with_capacity(1024 * 1024);
assert_eq!(arena.capacity(), 1024 * 1024);
assert_eq!(arena.bytes_in_use(), 0);
}
#[test]
fn test_arena_allocate() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let shape = vec![10, 10];
let tensor = arena.allocate(&shape).unwrap();
assert_eq!(tensor.shape(), &[10, 10]);
assert_eq!(tensor.len(), 100);
let stats = arena.stats();
assert_eq!(stats.allocation_count, 1);
assert_eq!(stats.reuse_count, 0);
}
#[test]
fn test_arena_reuse() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let shape = vec![5, 5];
let tensor1 = arena.allocate(&shape).unwrap();
let stats_after_alloc = arena.stats().allocation_count;
arena.deallocate(tensor1);
let _tensor2 = arena.allocate(&shape).unwrap();
let stats = arena.stats();
assert_eq!(stats.allocation_count, 1);
assert_eq!(stats.reuse_count, 1);
}
#[test]
fn test_arena_different_shapes() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let t1 = arena.allocate(&[10]).unwrap();
let t2 = arena.allocate(&[20]).unwrap();
let shape1 = t1.shape().to_vec();
let shape2 = t2.shape().to_vec();
arena.deallocate(t1);
arena.deallocate(t2);
let t3 = arena.allocate(&[10]).unwrap();
assert_eq!(shape1, vec![10]);
assert_eq!(shape2, vec![20]);
assert_eq!(t3.shape(), &[10]);
let stats = arena.stats();
assert_eq!(stats.allocation_count, 2); assert_eq!(stats.reuse_count, 1); }
#[test]
fn test_arena_reset() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let _t1 = arena.allocate(&[100]).unwrap();
let _t2 = arena.allocate(&[200]).unwrap();
arena.reset();
assert_eq!(arena.bytes_in_use(), 0);
assert_eq!(arena.stats().allocation_count, 0);
assert_eq!(arena.stats().reuse_count, 0);
}
#[test]
fn test_arena_stats() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let shape = vec![10, 10];
let size_bytes = 100 * core::mem::size_of::<f64>();
let t1 = arena.allocate(&shape).unwrap();
arena.deallocate(t1);
let _t2 = arena.allocate(&shape).unwrap();
let stats = arena.stats();
assert_eq!(stats.total_bytes_allocated, size_bytes);
assert_eq!(stats.allocation_count, 1);
assert_eq!(stats.reuse_count, 1);
assert_eq!(stats.reuse_ratio(), 1.0);
}
#[test]
fn test_arena_tensor_zero() {
let mut arena = TensorArena::with_capacity(1024 * 1024);
let mut tensor = arena.allocate(&[10]).unwrap();
unsafe {
tensor.zero();
let slice = tensor.as_slice();
for &val in slice {
assert_eq!(val, 0.0);
}
}
}
}