Skip to main content

god_graph/tensor/
pool.rs

1//! Tensor 内存池:优化张量分配性能
2//!
3//! 通过复用已分配的内存减少分配开销,特别适用于
4//! 迭代算法(如 PageRank、GNN 训练)中的临时张量
5
6use core::fmt;
7use core::marker::PhantomData;
8
9#[cfg(feature = "tensor-pool")]
10use crate::tensor::error::TensorError;
11
12#[cfg(feature = "tensor-pool")]
13use crate::tensor::traits::{TensorBase, TensorOps};
14
15#[cfg(feature = "tensor-pool")]
16use crate::tensor::dense::DenseTensor;
17
18#[cfg(feature = "tensor-pool")]
19use smallvec::SmallVec;
20
21/// Tensor 内存池配置
22#[derive(Debug, Clone)]
23pub struct PoolConfig {
24    /// 初始容量
25    pub initial_capacity: usize,
26    /// 最大容量
27    pub max_capacity: usize,
28    /// 是否预分配
29    pub preallocate: bool,
30    /// 对齐字节数
31    pub alignment: usize,
32}
33
34impl Default for PoolConfig {
35    fn default() -> Self {
36        Self {
37            initial_capacity: 16,
38            max_capacity: 1024,
39            preallocate: false,
40            alignment: 64,
41        }
42    }
43}
44
45impl PoolConfig {
46    /// 创建新的池配置
47    pub fn new(initial_capacity: usize, max_capacity: usize) -> Self {
48        Self {
49            initial_capacity,
50            max_capacity,
51            ..Default::default()
52        }
53    }
54
55    /// 设置预分配
56    pub fn with_preallocate(mut self, preallocate: bool) -> Self {
57        self.preallocate = preallocate;
58        self
59    }
60
61    /// 设置对齐
62    pub fn with_alignment(mut self, alignment: usize) -> Self {
63        self.alignment = alignment;
64        self
65    }
66}
67
68/// Tensor 内存池
69///
70/// 提供高效的张量分配和回收机制
71#[cfg(feature = "tensor-pool")]
72pub struct TensorPool {
73    /// 空闲张量列表
74    free_list: Vec<DenseTensor>,
75    /// 已分配的位图
76    allocated: bitvec::vec::BitVec,
77    /// 池配置
78    config: PoolConfig,
79    /// 统计信息
80    stats: PoolStats,
81}
82
83/// 池统计信息
84#[derive(Debug, Clone, Default)]
85pub struct PoolStats {
86    /// 总分配次数
87    pub total_allocations: usize,
88    /// 池命中次数
89    pub pool_hits: usize,
90    /// 池未命中次数
91    pub pool_misses: usize,
92    /// 当前已使用数量
93    pub current_used: usize,
94    /// 峰值使用数量
95    pub peak_used: usize,
96}
97
98impl PoolStats {
99    /// Compute the pool hit rate (ratio of reused allocations)
100    pub fn hit_rate(&self) -> f64 {
101        if self.total_allocations == 0 {
102            0.0
103        } else {
104            self.pool_hits as f64 / self.total_allocations as f64
105        }
106    }
107
108    /// Compute the pool miss rate (ratio of new allocations)
109    pub fn miss_rate(&self) -> f64 {
110        if self.total_allocations == 0 {
111            0.0
112        } else {
113            self.pool_misses as f64 / self.total_allocations as f64
114        }
115    }
116
117    /// Compute the allocation reduction percentage
118    pub fn allocation_reduction(&self) -> f64 {
119        if self.total_allocations == 0 {
120            0.0
121        } else {
122            self.pool_hits as f64 / self.total_allocations as f64 * 100.0
123        }
124    }
125}
126
127#[cfg(feature = "tensor-pool")]
128impl TensorPool {
129    /// 创建新的 tensor 池
130    pub fn new(config: PoolConfig) -> Self {
131        let preallocate = config.preallocate;
132        let mut pool = Self {
133            free_list: Vec::with_capacity(config.initial_capacity),
134            allocated: bitvec::vec::BitVec::new(),
135            config,
136            stats: PoolStats::default(),
137        };
138
139        if preallocate {
140            pool.preallocate();
141        }
142
143        pool
144    }
145
146    /// 预分配池容量
147    pub fn preallocate(&mut self) {
148        for _ in 0..self.config.initial_capacity {
149            self.free_list.push(DenseTensor::zeros(vec![1]));
150        }
151    }
152
153    /// 从池中获取张量
154    pub fn acquire(&mut self, shape: Vec<usize>) -> PooledTensor<'_> {
155        self.stats.total_allocations += 1;
156
157        // 尝试从空闲列表复用
158        if let Some(mut tensor) = self.free_list.pop() {
159            // 重塑为所需形状
160            if tensor.numel() >= shape.iter().product::<usize>() {
161                tensor = tensor.reshape(&shape);
162                self.stats.pool_hits += 1;
163            } else {
164                // 容量不足,重新分配
165                self.stats.pool_misses += 1;
166                tensor = DenseTensor::zeros(shape);
167            }
168
169            self.stats.current_used += 1;
170            if self.stats.current_used > self.stats.peak_used {
171                self.stats.peak_used = self.stats.current_used;
172            }
173
174            PooledTensor::new(tensor, self)
175        } else {
176            // 池为空,直接分配
177            self.stats.pool_misses += 1;
178            self.stats.current_used += 1;
179            if self.stats.current_used > self.stats.peak_used {
180                self.stats.peak_used = self.stats.current_used;
181            }
182
183            PooledTensor::new(DenseTensor::zeros(shape), self)
184        }
185    }
186
187    /// 回收张量到池中
188    fn recycle(&mut self, mut tensor: DenseTensor) {
189        if self.free_list.len() < self.config.max_capacity {
190            // 清零数据
191            for val in tensor.data_mut() {
192                *val = 0.0;
193            }
194            self.free_list.push(tensor);
195        }
196        // 否则让 tensor 自然 Drop
197
198        self.stats.current_used = self.stats.current_used.saturating_sub(1);
199    }
200
201    /// 获取统计信息
202    pub fn stats(&self) -> &PoolStats {
203        &self.stats
204    }
205
206    /// 清空池
207    pub fn clear(&mut self) {
208        self.free_list.clear();
209        self.allocated.clear();
210        self.stats = PoolStats::default();
211    }
212
213    /// 获取池使用率
214    pub fn utilization(&self) -> f64 {
215        if self.config.max_capacity == 0 {
216            0.0
217        } else {
218            self.free_list.len() as f64 / self.config.max_capacity as f64
219        }
220    }
221}
222
223#[cfg(feature = "tensor-pool")]
224impl fmt::Debug for TensorPool {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("TensorPool")
227            .field("free_count", &self.free_list.len())
228            .field("config", &self.config)
229            .field("stats", &self.stats)
230            .finish()
231    }
232}
233
234/// 池化张量:带有自动回收功能的张量包装器
235#[cfg(feature = "tensor-pool")]
236pub struct PooledTensor<'pool> {
237    /// 内部张量
238    tensor: DenseTensor,
239    /// 指向父池的引用
240    ///
241    /// # Safety
242    ///
243    /// 此原始指针对 `PooledTensor` 拥有可变借用,但不拥有所有权。
244    /// 池的生命周期 `'pool` 必须长于 `PooledTensor`,确保指针不会悬垂。
245    /// 池本身必须是线程安全的(`TensorPool: Send + Sync`)。
246    pool: *mut TensorPool,
247    /// 生命周期标记
248    _marker: PhantomData<&'pool mut TensorPool>,
249}
250
251/// # Safety
252///
253/// `PooledTensor` 可以安全地发送到其他线程,因为:
254/// 1. 内部 `tensor: DenseTensor` 已实现 `Send`
255/// 2. `pool` 指针仅用于在 `Drop` 时回收张量,不在线程间共享状态
256/// 3. 生命周期 `'pool` 确保指针有效性
257#[cfg(feature = "tensor-pool")]
258unsafe impl<'pool> Send for PooledTensor<'pool> {}
259
260/// # Safety
261///
262/// `PooledTensor` 可以安全地在线程间共享,因为:
263/// 1. 内部 `tensor: DenseTensor` 已实现 `Sync`
264/// 2. `pool` 指针仅在 `Drop` 时访问,且 `TensorPool` 本身是 `Sync` 的
265/// 3. 所有可变操作都通过 `&mut self` 方法(如 `tensor_mut()`)进行,由 Rust 借用规则保证安全
266#[cfg(feature = "tensor-pool")]
267unsafe impl<'pool> Sync for PooledTensor<'pool> {}
268
269#[cfg(feature = "tensor-pool")]
270impl<'pool> PooledTensor<'pool> {
271    /// 创建新的池化张量
272    fn new(tensor: DenseTensor, pool: &'pool mut TensorPool) -> Self {
273        Self {
274            tensor,
275            pool: pool as *mut TensorPool,
276            _marker: PhantomData,
277        }
278    }
279
280    /// 获取内部张量引用
281    pub fn tensor(&self) -> &DenseTensor {
282        &self.tensor
283    }
284
285    /// 获取内部张量可变引用
286    pub fn tensor_mut(&mut self) -> &mut DenseTensor {
287        &mut self.tensor
288    }
289
290    /// 消耗包装器并返回内部张量(不回收)
291    pub fn into_inner(mut self) -> DenseTensor {
292        let tensor = core::mem::take(&mut self.tensor);
293        core::mem::forget(self); // 防止 drop
294        tensor
295    }
296}
297
298#[cfg(feature = "tensor-pool")]
299impl<'pool> core::ops::Deref for PooledTensor<'pool> {
300    type Target = DenseTensor;
301
302    fn deref(&self) -> &Self::Target {
303        &self.tensor
304    }
305}
306
307#[cfg(feature = "tensor-pool")]
308impl<'pool> core::ops::DerefMut for PooledTensor<'pool> {
309    fn deref_mut(&mut self) -> &mut Self::Target {
310        &mut self.tensor
311    }
312}
313
314#[cfg(feature = "tensor-pool")]
315impl<'pool> Drop for PooledTensor<'pool> {
316    fn drop(&mut self) {
317        // SAFETY: pool 指针在 PooledTensor 创建时保证有效,
318        // 且生命周期 'pool 保证 pool 比 PooledTensor 活得长
319        unsafe {
320            if let Some(pool) = self.pool.as_mut() {
321                pool.recycle(core::mem::take(&mut self.tensor));
322            }
323        }
324    }
325}
326
327#[cfg(feature = "tensor-pool")]
328impl<'pool> Clone for PooledTensor<'pool> {
329    fn clone(&self) -> Self {
330        // Clone 不涉及池,直接克隆内部 tensor
331        PooledTensor::new(self.tensor.clone(), unsafe { &mut *self.pool })
332    }
333}
334
335/// 梯度检查点:用于减少反向传播的内存占用
336#[cfg(feature = "tensor-autograd")]
337pub struct GradientCheckpoint {
338    /// 保存的张量
339    saved_tensors: std::collections::HashMap<usize, DenseTensor>,
340    /// 最大保存数量
341    max_saved: usize,
342    /// 当前内存使用
343    memory_used: usize,
344    /// 内存预算(字节)
345    memory_budget: usize,
346}
347
348#[cfg(feature = "tensor-autograd")]
349impl GradientCheckpoint {
350    /// 创建新的梯度检查点
351    pub fn new(memory_budget: usize) -> Self {
352        Self {
353            saved_tensors: std::collections::HashMap::new(),
354            max_saved: 100,
355            memory_used: 0,
356            memory_budget,
357        }
358    }
359
360    /// 保存张量
361    pub fn save(&mut self, id: usize, tensor: DenseTensor) {
362        let size = tensor.nbytes();
363
364        // 检查内存预算
365        if self.memory_used + size > self.memory_budget {
366            // 触发重新计算策略(简化:移除最旧的)
367            self.evict_oldest();
368        }
369
370        if self.saved_tensors.len() < self.max_saved {
371            self.memory_used += size;
372            self.saved_tensors.insert(id, tensor);
373        }
374    }
375
376    /// 获取保存的张量
377    ///
378    /// # Arguments
379    ///
380    /// * `id` - 张量 ID
381    ///
382    /// # Returns
383    ///
384    /// 如果张量存在,返回引用;否则返回错误
385    pub fn get(&self, id: usize) -> Result<&DenseTensor, TensorError> {
386        self.saved_tensors.get(&id).ok_or_else(|| TensorError::MatrixError {
387            message: format!("Tensor with id {} not found in pool", id),
388        })
389    }
390
391    /// 移除并返回张量
392    ///
393    /// # Arguments
394    ///
395    /// * `id` - 张量 ID
396    ///
397    /// # Returns
398    ///
399    /// 如果张量存在,返回张量;否则返回错误
400    pub fn take(&mut self, id: usize) -> Result<DenseTensor, TensorError> {
401        self.saved_tensors.remove(&id).ok_or_else(|| TensorError::MatrixError {
402            message: format!("Tensor with id {} not found in pool", id),
403        }).inspect(|tensor| {
404            self.memory_used -= tensor.nbytes();
405        })
406    }
407
408    /// 清除所有保存的张量
409    pub fn clear(&mut self) {
410        self.saved_tensors.clear();
411        self.memory_used = 0;
412    }
413
414    /// 获取内存使用量
415    pub fn memory_used(&self) -> usize {
416        self.memory_used
417    }
418
419    /// 获取保存的张量数量
420    pub fn len(&self) -> usize {
421        self.saved_tensors.len()
422    }
423
424    /// 检查是否为空
425    pub fn is_empty(&self) -> bool {
426        self.saved_tensors.is_empty()
427    }
428
429    /// 移除最旧的张量(简化实现:随机移除)
430    fn evict_oldest(&mut self) {
431        if let Some((&id, _)) = self.saved_tensors.iter().next() {
432            let _ = self.take(id);
433        }
434    }
435}
436
437#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_pool_creation() {
443        let config = PoolConfig::new(8, 64);
444        let pool = TensorPool::new(config);
445
446        assert_eq!(pool.free_list.len(), 0);
447        assert_eq!(pool.stats.total_allocations, 0);
448    }
449
450    #[test]
451    fn test_pool_acquire() {
452        let config = PoolConfig::new(4, 16);
453        let mut pool = TensorPool::new(config);
454
455        // 获取张量
456        {
457            let tensor = pool.acquire(vec![10]);
458            assert_eq!(tensor.shape(), &[10]);
459        } // tensor 在这里被 drop 并回收
460
461        // 池中应该有 1 个回收的张量
462        assert_eq!(pool.free_list.len(), 1);
463        assert_eq!(pool.stats.total_allocations, 1);
464    }
465}
466
467// ============================================================================
468// TensorArena: Bump Allocator for High-Performance Tensor Allocation
469// ============================================================================
470
471/// Shape key for memory reuse matching
472#[derive(Debug, Clone, PartialEq, Eq, Hash)]
473struct ShapeKey {
474    shape: SmallVec<[usize; 4]>,
475    ndim: usize,
476}
477
478impl ShapeKey {
479    fn new(shape: &[usize]) -> Self {
480        Self {
481            shape: shape.into(),
482            ndim: shape.len(),
483        }
484    }
485}
486
487/// Memory slice from the arena
488#[derive(Clone)]
489struct ArenaSlice {
490    /// Start pointer (raw pointer into arena)
491    ptr: *mut f64,
492    /// Number of elements
493    #[allow(dead_code)]
494    len: usize,
495    /// Shape
496    shape: SmallVec<[usize; 4]>,
497    /// Whether borrowed (prevents double-free)
498    borrowed: bool,
499}
500
501/// Arena-allocated tensor wrapper
502pub struct ArenaTensor {
503    /// Data pointer
504    ptr: *mut f64,
505    /// Number of elements
506    len: usize,
507    /// Shape
508    shape: SmallVec<[usize; 4]>,
509    /// Whether borrowed (prevents double-free on drop)
510    borrowed: bool,
511}
512
513/// Tensor Arena allocator using bumpalo
514///
515/// Provides shape-aware memory reuse with bump allocation strategy.
516/// Memory is allocated from the arena and can be reused for tensors
517/// with the same shape, avoiding repeated allocations.
518#[cfg(feature = "tensor-pool")]
519pub struct TensorArena {
520    /// Underlying bump arena
521    arena: bumpalo::Bump,
522    /// Shape-aware free lists for reuse
523    free_lists: std::collections::HashMap<ShapeKey, Vec<ArenaSlice>>,
524    /// Allocation statistics
525    stats: ArenaStats,
526    /// Total capacity in bytes
527    capacity: usize,
528}
529
530/// Arena statistics
531#[derive(Debug, Clone, Default)]
532pub struct ArenaStats {
533    /// Total allocations
534    pub allocation_count: usize,
535    /// Total deallocations
536    pub deallocation_count: usize,
537    /// Reuse count (from free list)
538    pub reuse_count: usize,
539    /// Total bytes allocated
540    pub total_bytes_allocated: usize,
541    /// Current bytes in use
542    pub bytes_in_use: usize,
543    /// Peak bytes in use
544    pub peak_bytes_in_use: usize,
545}
546
547impl ArenaStats {
548    /// Reuse ratio (reuse / total allocations)
549    pub fn reuse_ratio(&self) -> f64 {
550        if self.allocation_count == 0 {
551            0.0
552        } else {
553            self.reuse_count as f64 / self.allocation_count as f64
554        }
555    }
556
557    /// Memory efficiency (peak use / total allocated)
558    pub fn memory_efficiency(&self) -> f64 {
559        if self.total_bytes_allocated == 0 {
560            0.0
561        } else {
562            self.peak_bytes_in_use as f64 / self.total_bytes_allocated as f64
563        }
564    }
565}
566
567#[cfg(feature = "tensor-pool")]
568impl TensorArena {
569    /// Create a new tensor arena with default capacity (16 MB)
570    pub fn new() -> Self {
571        Self::with_capacity(16 * 1024 * 1024)
572    }
573
574    /// Create a new tensor arena with specified capacity
575    pub fn with_capacity(capacity: usize) -> Self {
576        Self {
577            arena: bumpalo::Bump::with_capacity(capacity),
578            free_lists: std::collections::HashMap::new(),
579            stats: ArenaStats::default(),
580            capacity,
581        }
582    }
583
584    /// Allocate a tensor with the given shape
585    ///
586    /// Tries to reuse memory from the free list if a matching shape exists.
587    /// Otherwise allocates new memory from the bump arena.
588    pub fn allocate(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError>
589    {
590        let key = ShapeKey::new(shape);
591        let size = shape.iter().product::<usize>();
592
593        // Try to reuse from free list
594        if let Some(slices) = self.free_lists.get_mut(&key) {
595            if let Some(mut slice) = slices.pop() {
596                self.stats.reuse_count += 1;
597                self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
598                self.update_peak();
599                
600                slice.borrowed = true;
601                return Ok(ArenaTensor {
602                    ptr: slice.ptr,
603                    len: size,
604                    shape: slice.shape.clone(),
605                    borrowed: true,
606                });
607            }
608        }
609
610        // Allocate new memory from bump arena
611        let layout = std::alloc::Layout::from_size_align(
612            size * core::mem::size_of::<f64>(),
613            64, // 64-byte alignment for SIMD
614        ).map_err(|e| crate::tensor::error::TensorError::AllocationError {
615            message: format!("Failed to create layout: {}", e),
616        })?;
617
618        let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
619        
620        self.stats.allocation_count += 1;
621        self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
622        self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
623        self.update_peak();
624
625        Ok(ArenaTensor {
626            ptr,
627            len: size,
628            shape: key.shape,
629            borrowed: true,
630        })
631    }
632
633    /// Deallocate a tensor and return its memory to the free list
634    pub fn deallocate(&mut self, mut tensor: ArenaTensor) {
635        if tensor.borrowed {
636            tensor.borrowed = false;
637        }
638
639        let key = ShapeKey::new(&tensor.shape);
640        let slice = ArenaSlice {
641            ptr: tensor.ptr,
642            len: tensor.len,
643            shape: tensor.shape.clone(),
644            borrowed: false,
645        };
646
647        self.free_lists
648            .entry(key)
649            .or_default()
650            .push(slice);
651
652        self.stats.deallocation_count += 1;
653        self.stats.bytes_in_use -= tensor.len * core::mem::size_of::<f64>();
654    }
655
656    /// Reset the arena, clearing all free lists
657    ///
658    /// This releases all memory back to the system.
659    pub fn reset(&mut self) {
660        self.arena.reset();
661        self.free_lists.clear();
662        self.stats = ArenaStats::default();
663        self.stats.bytes_in_use = 0;
664    }
665
666    /// Get statistics
667    pub fn stats(&self) -> &ArenaStats {
668        &self.stats
669    }
670
671    /// Get the current capacity in bytes
672    pub fn capacity(&self) -> usize {
673        self.capacity
674    }
675
676    /// Get the current bytes in use
677    pub fn bytes_in_use(&self) -> usize {
678        self.stats.bytes_in_use
679    }
680
681    /// Update peak memory usage
682    fn update_peak(&mut self) {
683        if self.stats.bytes_in_use > self.stats.peak_bytes_in_use {
684            self.stats.peak_bytes_in_use = self.stats.bytes_in_use;
685        }
686    }
687
688    /// Force allocate without reuse (for benchmarking)
689    pub fn allocate_fresh(&mut self, shape: &[usize]) -> Result<ArenaTensor, crate::tensor::error::TensorError> {
690        let size = shape.iter().product::<usize>();
691        
692        let layout = std::alloc::Layout::from_size_align(
693            size * core::mem::size_of::<f64>(),
694            64,
695        ).map_err(|e| crate::tensor::error::TensorError::AllocationError {
696            message: format!("Failed to create layout: {}", e),
697        })?;
698
699        let ptr = self.arena.alloc_layout(layout).as_ptr() as *mut f64;
700        
701        self.stats.allocation_count += 1;
702        self.stats.total_bytes_allocated += size * core::mem::size_of::<f64>();
703        self.stats.bytes_in_use += size * core::mem::size_of::<f64>();
704        self.update_peak();
705
706        Ok(ArenaTensor {
707            ptr,
708            len: size,
709            shape: shape.into(),
710            borrowed: true,
711        })
712    }
713}
714
715#[cfg(feature = "tensor-pool")]
716impl Default for TensorArena {
717    fn default() -> Self {
718        Self::new()
719    }
720}
721
722#[cfg(feature = "tensor-pool")]
723impl fmt::Debug for TensorArena {
724    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725        f.debug_struct("TensorArena")
726            .field("capacity", &self.capacity)
727            .field("free_lists_count", &self.free_lists.len())
728            .field("stats", &self.stats)
729            .finish()
730    }
731}
732
733// ArenaTensor implementation
734impl ArenaTensor {
735    /// Get the shape
736    pub fn shape(&self) -> &[usize] {
737        &self.shape
738    }
739
740    /// Get the number of elements
741    pub fn len(&self) -> usize {
742        self.len
743    }
744
745    /// Check if empty
746    pub fn is_empty(&self) -> bool {
747        self.len == 0
748    }
749
750    /// Get raw pointer (unsafe)
751    pub fn as_ptr(&self) -> *const f64 {
752        self.ptr
753    }
754
755    /// Get mutable raw pointer (unsafe)
756    pub fn as_mut_ptr(&mut self) -> *mut f64 {
757        self.ptr
758    }
759
760    /// Get slice (unsafe, for reading)
761    ///
762    /// # Safety
763    /// Caller must ensure no other mutable references exist
764    pub unsafe fn as_slice(&self) -> &[f64] {
765        std::slice::from_raw_parts(self.ptr, self.len)
766    }
767
768    /// Get mutable slice (unsafe, for writing)
769    ///
770    /// # Safety
771    /// Caller must ensure no other references exist
772    pub unsafe fn as_mut_slice(&mut self) -> &mut [f64] {
773        std::slice::from_raw_parts_mut(self.ptr, self.len)
774    }
775
776    /// Zero out the data
777    ///
778    /// # Safety
779    /// Caller must ensure exclusive access
780    pub unsafe fn zero(&mut self) {
781        std::ptr::write_bytes(self.ptr, 0, self.len);
782    }
783}
784
785impl Clone for ArenaTensor {
786    fn clone(&self) -> Self {
787        // Clone creates a new allocation, not a reference
788        unsafe {
789            let layout = std::alloc::Layout::from_size_align(
790                self.len * core::mem::size_of::<f64>(),
791                64,
792            ).unwrap();
793            let new_ptr = std::alloc::alloc(layout) as *mut f64;
794            std::ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
795            
796            ArenaTensor {
797                ptr: new_ptr,
798                len: self.len,
799                shape: self.shape.clone(),
800                borrowed: false, // Not managed by arena
801            }
802        }
803    }
804}
805
806impl Drop for ArenaTensor {
807    fn drop(&mut self) {
808        // Memory is managed by the arena, don't free here
809        // The borrowed flag prevents issues if manually dropped
810    }
811}
812
813#[cfg(all(feature = "tensor-pool", test, feature = "std"))]
814mod arena_tests {
815    use super::*;
816
817    #[test]
818    fn test_arena_creation() {
819        let arena = TensorArena::with_capacity(1024 * 1024);
820        assert_eq!(arena.capacity(), 1024 * 1024);
821        assert_eq!(arena.bytes_in_use(), 0);
822    }
823
824    #[test]
825    fn test_arena_allocate() {
826        let mut arena = TensorArena::with_capacity(1024 * 1024);
827        let shape = vec![10, 10];
828        
829        let tensor = arena.allocate(&shape).unwrap();
830        assert_eq!(tensor.shape(), &[10, 10]);
831        assert_eq!(tensor.len(), 100);
832        
833        let stats = arena.stats();
834        assert_eq!(stats.allocation_count, 1);
835        assert_eq!(stats.reuse_count, 0);
836    }
837
838    #[test]
839    fn test_arena_reuse() {
840        let mut arena = TensorArena::with_capacity(1024 * 1024);
841        let shape = vec![5, 5];
842        
843        // Allocate
844        let tensor1 = arena.allocate(&shape).unwrap();
845        let stats_after_alloc = arena.stats().allocation_count;
846        
847        // Explicitly deallocate to return to free list
848        arena.deallocate(tensor1);
849        
850        // Allocate again with same shape - should reuse from free list
851        let _tensor2 = arena.allocate(&shape).unwrap();
852        
853        let stats = arena.stats();
854        // Should have 1 allocation (first one) and 1 reuse (second one)
855        assert_eq!(stats.allocation_count, 1);
856        assert_eq!(stats.reuse_count, 1);
857    }
858
859    #[test]
860    fn test_arena_different_shapes() {
861        let mut arena = TensorArena::with_capacity(1024 * 1024);
862        
863        let t1 = arena.allocate(&[10]).unwrap();
864        let t2 = arena.allocate(&[20]).unwrap();
865        let shape1 = t1.shape().to_vec();
866        let shape2 = t2.shape().to_vec();
867        arena.deallocate(t1);
868        arena.deallocate(t2);
869        let t3 = arena.allocate(&[10]).unwrap();
870        
871        // t3 should reuse t1's memory from free list
872        assert_eq!(shape1, vec![10]);
873        assert_eq!(shape2, vec![20]);
874        assert_eq!(t3.shape(), &[10]);
875        
876        let stats = arena.stats();
877        assert_eq!(stats.allocation_count, 2); // t1 and t2
878        assert_eq!(stats.reuse_count, 1); // t3 reused t1
879    }
880
881    #[test]
882    fn test_arena_reset() {
883        let mut arena = TensorArena::with_capacity(1024 * 1024);
884        
885        let _t1 = arena.allocate(&[100]).unwrap();
886        let _t2 = arena.allocate(&[200]).unwrap();
887        
888        arena.reset();
889        
890        assert_eq!(arena.bytes_in_use(), 0);
891        assert_eq!(arena.stats().allocation_count, 0);
892        assert_eq!(arena.stats().reuse_count, 0);
893    }
894
895    #[test]
896    fn test_arena_stats() {
897        let mut arena = TensorArena::with_capacity(1024 * 1024);
898        
899        let shape = vec![10, 10];
900        let size_bytes = 100 * core::mem::size_of::<f64>();
901        
902        let t1 = arena.allocate(&shape).unwrap();
903        arena.deallocate(t1);
904        let _t2 = arena.allocate(&shape).unwrap();
905        
906        let stats = arena.stats();
907        // First allocation + first reuse
908        assert_eq!(stats.total_bytes_allocated, size_bytes);
909        assert_eq!(stats.allocation_count, 1);
910        assert_eq!(stats.reuse_count, 1);
911        // reuse_ratio = reuse_count / allocation_count = 1/1 = 1.0 (100% reuse)
912        assert_eq!(stats.reuse_ratio(), 1.0);
913    }
914
915    #[test]
916    fn test_arena_tensor_zero() {
917        let mut arena = TensorArena::with_capacity(1024 * 1024);
918        let mut tensor = arena.allocate(&[10]).unwrap();
919        
920        unsafe {
921            tensor.zero();
922            let slice = tensor.as_slice();
923            for &val in slice {
924                assert_eq!(val, 0.0);
925            }
926        }
927    }
928}