1use core::fmt;
7use core::marker::PhantomData;
8
9#[cfg(feature = "tensor-pool")]
10use crate::tensor::error::TensorError;
11
12#[cfg(feature = "tensor-pool")]
13use crate::tensor::traits::{TensorBase, TensorOps};
14
15#[cfg(feature = "tensor-pool")]
16use crate::tensor::dense::DenseTensor;
17
18#[cfg(feature = "tensor-pool")]
19use smallvec::SmallVec;
20
21#[derive(Debug, Clone)]
23pub struct PoolConfig {
24 pub initial_capacity: usize,
26 pub max_capacity: usize,
28 pub preallocate: bool,
30 pub alignment: usize,
32}
33
34impl Default for PoolConfig {
35 fn default() -> Self {
36 Self {
37 initial_capacity: 16,
38 max_capacity: 1024,
39 preallocate: false,
40 alignment: 64,
41 }
42 }
43}
44
45impl PoolConfig {
46 pub fn new(initial_capacity: usize, max_capacity: usize) -> Self {
48 Self {
49 initial_capacity,
50 max_capacity,
51 ..Default::default()
52 }
53 }
54
55 pub fn with_preallocate(mut self, preallocate: bool) -> Self {
57 self.preallocate = preallocate;
58 self
59 }
60
61 pub fn with_alignment(mut self, alignment: usize) -> Self {
63 self.alignment = alignment;
64 self
65 }
66}
67
68#[cfg(feature = "tensor-pool")]
72pub struct TensorPool {
73 free_list: Vec<DenseTensor>,
75 allocated: bitvec::vec::BitVec,
77 config: PoolConfig,
79 stats: PoolStats,
81}
82
83#[derive(Debug, Clone, Default)]
85pub struct PoolStats {
86 pub total_allocations: usize,
88 pub pool_hits: usize,
90 pub pool_misses: usize,
92 pub current_used: usize,
94 pub peak_used: usize,
96}
97
98impl PoolStats {
99 pub fn hit_rate(&self) -> f64 {
101 if self.total_allocations == 0 {
102 0.0
103 } else {
104 self.pool_hits as f64 / self.total_allocations as f64
105 }
106 }
107
108 pub fn miss_rate(&self) -> f64 {
110 if self.total_allocations == 0 {
111 0.0
112 } else {
113 self.pool_misses as f64 / self.total_allocations as f64
114 }
115 }
116
117 pub fn allocation_reduction(&self) -> f64 {
119 if self.total_allocations == 0 {
120 0.0
121 } else {
122 self.pool_hits as f64 / self.total_allocations as f64 * 100.0
123 }
124 }
125}
126
127#[cfg(feature = "tensor-pool")]
128impl TensorPool {
129 pub fn new(config: PoolConfig) -> Self {
131 let preallocate = config.preallocate;
132 let mut pool = Self {
133 free_list: Vec::with_capacity(config.initial_capacity),
134 allocated: bitvec::vec::BitVec::new(),
135 config,
136 stats: PoolStats::default(),
137 };
138
139 if preallocate {
140 pool.preallocate();
141 }
142
143 pool
144 }
145
146 pub fn preallocate(&mut self) {
148 for _ in 0..self.config.initial_capacity {
149 self.free_list.push(DenseTensor::zeros(vec![1]));
150 }
151 }
152
153 pub fn acquire(&mut self, shape: Vec<usize>) -> PooledTensor<'_> {
155 self.stats.total_allocations += 1;
156
157 if let Some(mut tensor) = self.free_list.pop() {
159 if tensor.numel() >= shape.iter().product::<usize>() {
161 tensor = tensor.reshape(&shape);
162 self.stats.pool_hits += 1;
163 } else {
164 self.stats.pool_misses += 1;
166 tensor = DenseTensor::zeros(shape);
167 }
168
169 self.stats.current_used += 1;
170 if self.stats.current_used > self.stats.peak_used {
171 self.stats.peak_used = self.stats.current_used;
172 }
173
174 PooledTensor::new(tensor, self)
175 } else {
176 self.stats.pool_misses += 1;
178 self.stats.current_used += 1;
179 if self.stats.current_used > self.stats.peak_used {
180 self.stats.peak_used = self.stats.current_used;
181 }
182
183 PooledTensor::new(DenseTensor::zeros(shape), self)
184 }
185 }
186
187 fn recycle(&mut self, mut tensor: DenseTensor) {
189 if self.free_list.len() < self.config.max_capacity {
190 for val in tensor.data_mut() {
192 *val = 0.0;
193 }
194 self.free_list.push(tensor);
195 }
196 self.stats.current_used = self.stats.current_used.saturating_sub(1);
199 }
200
201 pub fn stats(&self) -> &PoolStats {
203 &self.stats
204 }
205
206 pub fn clear(&mut self) {
208 self.free_list.clear();
209 self.allocated.clear();
210 self.stats = PoolStats::default();
211 }
212
213 pub fn utilization(&self) -> f64 {
215 if self.config.max_capacity == 0 {
216 0.0
217 } else {
218 self.free_list.len() as f64 / self.config.max_capacity as f64
219 }
220 }
221}
222
223#[cfg(feature = "tensor-pool")]
224impl fmt::Debug for TensorPool {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("TensorPool")
227 .field("free_count", &self.free_list.len())
228 .field("config", &self.config)
229 .field("stats", &self.stats)
230 .finish()
231 }
232}
233
234#[cfg(feature = "tensor-pool")]
236pub struct PooledTensor<'pool> {
237 tensor: DenseTensor,
239 pool: *mut TensorPool,
247 _marker: PhantomData<&'pool mut TensorPool>,
249}
250
251#[cfg(feature = "tensor-pool")]
258unsafe impl<'pool> Send for PooledTensor<'pool> {}
259
260#[cfg(feature = "tensor-pool")]
267unsafe impl<'pool> Sync for PooledTensor<'pool> {}
268
269#[cfg(feature = "tensor-pool")]
270impl<'pool> PooledTensor<'pool> {
271 fn new(tensor: DenseTensor, pool: &'pool mut TensorPool) -> Self {
273 Self {
274 tensor,
275 pool: pool as *mut TensorPool,
276 _marker: PhantomData,
277 }
278 }
279
280 pub fn tensor(&self) -> &DenseTensor {
282 &self.tensor
283 }
284
285 pub fn tensor_mut(&mut self) -> &mut DenseTensor {
287 &mut self.tensor
288 }
289
290 pub fn into_inner(mut self) -> DenseTensor {
292 let tensor = core::mem::take(&mut self.tensor);
293 core::mem::forget(self); tensor
295 }
296}
297
298#[cfg(feature = "tensor-pool")]
299impl<'pool> core::ops::Deref for PooledTensor<'pool> {
300 type Target = DenseTensor;
301
302 fn deref(&self) -> &Self::Target {
303 &self.tensor
304 }
305}
306
307#[cfg(feature = "tensor-pool")]
308impl<'pool> core::ops::DerefMut for PooledTensor<'pool> {
309 fn deref_mut(&mut self) -> &mut Self::Target {
310 &mut self.tensor
311 }
312}
313
314#[cfg(feature = "tensor-pool")]
315impl<'pool> Drop for PooledTensor<'pool> {
316 fn drop(&mut self) {
317 unsafe {
320 if let Some(pool) = self.pool.as_mut() {
321 pool.recycle(core::mem::take(&mut self.tensor));
322 }
323 }
324 }
325}
326
327#[cfg(feature = "tensor-pool")]
328impl<'pool> Clone for PooledTensor<'pool> {
329 fn clone(&self) -> Self {
330 PooledTensor::new(self.tensor.clone(), unsafe { &mut *self.pool })
332 }
333}
334
335#[cfg(feature = "tensor-autograd")]
337pub struct GradientCheckpoint {
338 saved_tensors: std::collections::HashMap<usize, DenseTensor>,
340 max_saved: usize,
342 memory_used: usize,
344 memory_budget: usize,
346}
347
348#[cfg(feature = "tensor-autograd")]
349impl GradientCheckpoint {
350 pub fn new(memory_budget: usize) -> Self {
352 Self {
353 saved_tensors: std::collections::HashMap::new(),
354 max_saved: 100,
355 memory_used: 0,
356 memory_budget,
357 }
358 }
359
360 pub fn save(&mut self, id: usize, tensor: DenseTensor) {
362 let size = tensor.nbytes();
363
364 if self.memory_used + size > self.memory_budget {
366 self.evict_oldest();
368 }
369
370 if self.saved_tensors.len() < self.max_saved {
371 self.memory_used += size;
372 self.saved_tensors.insert(id, tensor);
373 }
374 }
375
376 pub fn get(&self, id: usize) -> Result<&DenseTensor, TensorError> {
386 self.saved_tensors.get(&id).ok_or_else(|| TensorError::MatrixError {
387 message: format!("Tensor with id {} not found in pool", id),
388 })
389 }
390
391 pub fn take(&mut self, id: usize) -> Result<DenseTensor, TensorError> {
401 self.saved_tensors.remove(&id).ok_or_else(|| TensorError::MatrixError {
402 message: format!("Tensor with id {} not found in pool", id),
403 }).inspect(|tensor| {
404 self.memory_used -= tensor.nbytes();
405 })
406 }
407
408 pub fn clear(&mut self) {
410 self.saved_tensors.clear();
411 self.memory_used = 0;
412 }
413
414 pub fn memory_used(&self) -> usize {
416 self.memory_used
417 }
418
419 pub fn len(&self) -> usize {
421 self.saved_tensors.len()
422 }
423
424 pub fn is_empty(&self) -> bool {
426 self.saved_tensors.is_empty()
427 }
428
429 fn evict_oldest(&mut self) {
431 if let Some((&id, _)) = self.saved_tensors.iter().next() {
432 let _ = self.take(id);
433 }
434 }
435}
436
437#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_pool_creation() {
443 let config = PoolConfig::new(8, 64);
444 let pool = TensorPool::new(config);
445
446 assert_eq!(pool.free_list.len(), 0);
447 assert_eq!(pool.stats.total_allocations, 0);
448 }
449
450 #[test]
451 fn test_pool_acquire() {
452 let config = PoolConfig::new(4, 16);
453 let mut pool = TensorPool::new(config);
454
455 {
457 let tensor = pool.acquire(vec![10]);
458 assert_eq!(tensor.shape(), &[10]);
459 } assert_eq!(pool.free_list.len(), 1);
463 assert_eq!(pool.stats.total_allocations, 1);
464 }
465}
466
467#[derive(Debug, Clone, PartialEq, Eq, Hash)]
473struct ShapeKey {
474 shape: SmallVec<[usize; 4]>,
475 ndim: usize,
476}
477
478impl ShapeKey {
479 fn new(shape: &[usize]) -> Self {
480 Self {
481 shape: shape.into(),
482 ndim: shape.len(),
483 }
484 }
485}
486
487#[derive(Clone)]
489struct ArenaSlice {
490 ptr: *mut f64,
492 #[allow(dead_code)]
494 len: usize,
495 shape: SmallVec<[usize; 4]>,
497 borrowed: bool,
499}
500
501pub struct ArenaTensor {
503 ptr: *mut f64,
505 len: usize,
507 shape: SmallVec<[usize; 4]>,
509 borrowed: bool,
511}
512
513#[cfg(feature = "tensor-pool")]
519pub struct TensorArena {
520 arena: bumpalo::Bump,
522 free_lists: std::collections::HashMap<ShapeKey, Vec<ArenaSlice>>,
524 stats: ArenaStats,
526 capacity: usize,
528}
529
530#[derive(Debug, Clone, Default)]
532pub struct ArenaStats {
533 pub allocation_count: usize,
535 pub deallocation_count: usize,
537 pub reuse_count: usize,
539 pub total_bytes_allocated: usize,
541 pub bytes_in_use: usize,
543 pub peak_bytes_in_use: usize,
545}
546
547impl ArenaStats {
548 pub fn reuse_ratio(&self) -> f64 {
550 if self.allocation_count == 0 {
551 0.0
552 } else {
553 self.reuse_count as f64 / self.allocation_count as f64
554 }
555 }
556
557 pub fn memory_efficiency(&self) -> f64 {
559 if self.total_bytes_allocated == 0 {
560 0.0
561 } else {
562 self.peak_bytes_in_use as f64 / self.total_bytes_allocated as f64
563 }
564 }
565}
566
567#[cfg(feature = "tensor-pool")]
568impl TensorArena {
569 pub fn new() -> Self {
571 Self::with_capacity(16 * 1024 * 1024)
572 }
573
574 pub fn with_capacity(capacity: usize) -> Self {
576 Self {
577 arena: bumpalo::Bump::with_capacity(capacity),
578 free_lists: std::collections::HashMap::new(),
579 stats: ArenaStats::default(),
580 capacity,
581 }
582 }
583
584 pub fn allocate(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError>
589 {
590 let key = ShapeKey::new(shape);
591 let size = shape.iter().product::<usize>();
592
593 if let Some(slices) = self.free_lists.get_mut(&key) {
595 if let Some(mut slice) = slices.pop() {
596 self.stats.reuse_count += 1;
597 self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
598 self.update_peak();
599
600 slice.borrowed = true;
601 return Ok(ArenaTensor {
602 ptr: slice.ptr,
603 len: size,
604 shape: slice.shape.clone(),
605 borrowed: true,
606 });
607 }
608 }
609
610 let layout = std::alloc::Layout::from_size_align(
612 size * core::mem::size_of::<f64>(),
613 64, ).map_err(|e| crate::tensor::error::TensorError::AllocationError {
615 message: format!("Failed to create layout: {}", e),
616 })?;
617
618 let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
619
620 self.stats.allocation_count += 1;
621 self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
622 self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
623 self.update_peak();
624
625 Ok(ArenaTensor {
626 ptr,
627 len: size,
628 shape: key.shape,
629 borrowed: true,
630 })
631 }
632
633 pub fn deallocate(&mut self, mut tensor: ArenaTensor) {
635 if tensor.borrowed {
636 tensor.borrowed = false;
637 }
638
639 let key = ShapeKey::new(&tensor.shape);
640 let slice = ArenaSlice {
641 ptr: tensor.ptr,
642 len: tensor.len,
643 shape: tensor.shape.clone(),
644 borrowed: false,
645 };
646
647 self.free_lists
648 .entry(key)
649 .or_default()
650 .push(slice);
651
652 self.stats.deallocation_count += 1;
653 self.stats.bytes_in_use -= tensor.len * core::mem::size_of::<f64>();
654 }
655
656 pub fn reset(&mut self) {
660 self.arena.reset();
661 self.free_lists.clear();
662 self.stats = ArenaStats::default();
663 self.stats.bytes_in_use = 0;
664 }
665
666 pub fn stats(&self) -> &ArenaStats {
668 &self.stats
669 }
670
671 pub fn capacity(&self) -> usize {
673 self.capacity
674 }
675
676 pub fn bytes_in_use(&self) -> usize {
678 self.stats.bytes_in_use
679 }
680
681 fn update_peak(&mut self) {
683 if self.stats.bytes_in_use > self.stats.peak_bytes_in_use {
684 self.stats.peak_bytes_in_use = self.stats.bytes_in_use;
685 }
686 }
687
688 pub fn allocate_fresh(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError> {
690 let size = shape.iter().product::<usize>();
691
692 let layout = std::alloc::Layout::from_size_align(
693 size * core::mem::size_of::<f64>(),
694 64,
695 ).map_err(|e| crate::tensor::error::TensorError::AllocationError {
696 message: format!("Failed to create layout: {}", e),
697 })?;
698
699 let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
700
701 self.stats.allocation_count += 1;
702 self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
703 self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
704 self.update_peak();
705
706 Ok(ArenaTensor {
707 ptr,
708 len: size,
709 shape: shape.into(),
710 borrowed: true,
711 })
712 }
713}
714
715#[cfg(feature = "tensor-pool")]
716impl Default for TensorArena {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[cfg(feature = "tensor-pool")]
723impl fmt::Debug for TensorArena {
724 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725 f.debug_struct("TensorArena")
726 .field("capacity", &self.capacity)
727 .field("free_lists_count", &self.free_lists.len())
728 .field("stats", &self.stats)
729 .finish()
730 }
731}
732
733impl ArenaTensor {
735 pub fn shape(&self) -> &[usize] {
737 &self.shape
738 }
739
740 pub fn len(&self) -> usize {
742 self.len
743 }
744
745 pub fn is_empty(&self) -> bool {
747 self.len == 0
748 }
749
750 pub fn as_ptr(&self) -> *const f64 {
752 self.ptr
753 }
754
755 pub fn as_mut_ptr(&mut self) -> *mut f64 {
757 self.ptr
758 }
759
760 pub unsafe fn as_slice(&self) -> &[f64] {
765 std::slice::from_raw_parts(self.ptr, self.len)
766 }
767
768 pub unsafe fn as_mut_slice(&mut self) -> &mut [f64] {
773 std::slice::from_raw_parts_mut(self.ptr, self.len)
774 }
775
776 pub unsafe fn zero(&mut self) {
781 std::ptr::write_bytes(self.ptr, 0, self.len);
782 }
783}
784
785impl Clone for ArenaTensor {
786 fn clone(&self) -> Self {
787 unsafe {
789 let layout = std::alloc::Layout::from_size_align(
790 self.len * core::mem::size_of::<f64>(),
791 64,
792 ).unwrap();
793 let new_ptr = std::alloc::alloc(layout) as *mut f64;
794 std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
795
796 ArenaTensor {
797 ptr: new_ptr,
798 len: self.len,
799 shape: self.shape.clone(),
800 borrowed: false, }
802 }
803 }
804}
805
806impl Drop for ArenaTensor {
807 fn drop(&mut self) {
808 }
811}
812
813#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
814mod arena_tests {
815 use super::*;
816
817 #[test]
818 fn test_arena_creation() {
819 let arena = TensorArena::with_capacity(1024 * 1024);
820 assert_eq!(arena.capacity(), 1024 * 1024);
821 assert_eq!(arena.bytes_in_use(), 0);
822 }
823
824 #[test]
825 fn test_arena_allocate() {
826 let mut arena = TensorArena::with_capacity(1024 * 1024);
827 let shape = vec![10, 10];
828
829 let tensor = arena.allocate(&shape).unwrap();
830 assert_eq!(tensor.shape(), &[10, 10]);
831 assert_eq!(tensor.len(), 100);
832
833 let stats = arena.stats();
834 assert_eq!(stats.allocation_count, 1);
835 assert_eq!(stats.reuse_count, 0);
836 }
837
838 #[test]
839 fn test_arena_reuse() {
840 let mut arena = TensorArena::with_capacity(1024 * 1024);
841 let shape = vec![5, 5];
842
843 let tensor1 = arena.allocate(&shape).unwrap();
845 let stats_after_alloc = arena.stats().allocation_count;
846
847 arena.deallocate(tensor1);
849
850 let _tensor2 = arena.allocate(&shape).unwrap();
852
853 let stats = arena.stats();
854 assert_eq!(stats.allocation_count, 1);
856 assert_eq!(stats.reuse_count, 1);
857 }
858
859 #[test]
860 fn test_arena_different_shapes() {
861 let mut arena = TensorArena::with_capacity(1024 * 1024);
862
863 let t1 = arena.allocate(&[10]).unwrap();
864 let t2 = arena.allocate(&[20]).unwrap();
865 let shape1 = t1.shape().to_vec();
866 let shape2 = t2.shape().to_vec();
867 arena.deallocate(t1);
868 arena.deallocate(t2);
869 let t3 = arena.allocate(&[10]).unwrap();
870
871 assert_eq!(shape1, vec![10]);
873 assert_eq!(shape2, vec![20]);
874 assert_eq!(t3.shape(), &[10]);
875
876 let stats = arena.stats();
877 assert_eq!(stats.allocation_count, 2); assert_eq!(stats.reuse_count, 1); }
880
881 #[test]
882 fn test_arena_reset() {
883 let mut arena = TensorArena::with_capacity(1024 * 1024);
884
885 let _t1 = arena.allocate(&[100]).unwrap();
886 let _t2 = arena.allocate(&[200]).unwrap();
887
888 arena.reset();
889
890 assert_eq!(arena.bytes_in_use(), 0);
891 assert_eq!(arena.stats().allocation_count, 0);
892 assert_eq!(arena.stats().reuse_count, 0);
893 }
894
895 #[test]
896 fn test_arena_stats() {
897 let mut arena = TensorArena::with_capacity(1024 * 1024);
898
899 let shape = vec![10, 10];
900 let size_bytes = 100 * core::mem::size_of::<f64>();
901
902 let t1 = arena.allocate(&shape).unwrap();
903 arena.deallocate(t1);
904 let _t2 = arena.allocate(&shape).unwrap();
905
906 let stats = arena.stats();
907 assert_eq!(stats.total_bytes_allocated, size_bytes);
909 assert_eq!(stats.allocation_count, 1);
910 assert_eq!(stats.reuse_count, 1);
911 assert_eq!(stats.reuse_ratio(), 1.0);
913 }
914
915 #[test]
916 fn test_arena_tensor_zero() {
917 let mut arena = TensorArena::with_capacity(1024 * 1024);
918 let mut tensor = arena.allocate(&[10]).unwrap();
919
920 unsafe {
921 tensor.zero();
922 let slice = tensor.as_slice();
923 for &val in slice {
924 assert_eq!(val, 0.0);
925 }
926 }
927 }
928}