torsh-quantization 0.1.0-alpha.1

Model quantization for ToRSh neural networks
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
//! # Memory Pool Management for Quantization
//!
//! This module provides advanced memory pooling capabilities to reduce allocation overhead
//! during quantization operations, particularly beneficial for batch processing and
//! inference scenarios.
//!
//! ## Features
//!
//! - **Pre-allocated Pools**: Reusable memory pools for common tensor sizes
//! - **Dynamic Sizing**: Automatic pool expansion based on usage patterns
//! - **Memory Analytics**: Tracking allocation patterns and optimization opportunities
//! - **Thread Safety**: Concurrent access for multi-threaded quantization operations
//!
//! ## Usage
//!
//! ```rust
//! use torsh_quantization::memory_pool::{MemoryPool, PoolConfig};
//! use torsh_tensor::Tensor;
//!
//! // Create a memory pool with configuration
//! let config = PoolConfig::default();
//! let mut pool = MemoryPool::new(config);
//!
//! // Allocate a tensor from the pool
//! let tensor = pool.allocate_tensor(&[1024, 1024], torsh_core::DType::F32)?;
//!
//! // Use the tensor for quantization operations
//! // ... quantization work ...
//!
//! // Return tensor to pool for reuse
//! pool.release_tensor(tensor);
//! # Ok::<(), torsh_core::TorshError>(())
//! ```

// use crate::TorshResult;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use torsh_core::device::DeviceType;
use torsh_core::Result as TorshResult;
use torsh_core::{DType, TorshError};
use torsh_tensor::Tensor;

/// Configuration for memory pool behavior
#[derive(Debug, Clone)]
pub struct PoolConfig {
    /// Maximum number of tensors to keep in each size pool
    pub max_tensors_per_size: usize,
    /// Maximum total memory usage in bytes
    pub max_total_memory: usize,
    /// Whether to enable memory usage analytics
    pub enable_analytics: bool,
    /// Pre-allocate common tensor sizes
    pub pre_allocate_sizes: Vec<Vec<usize>>,
    /// Enable cache-aware allocation strategies
    pub enable_cache_awareness: bool,
    /// Memory alignment for cache-friendly allocations (bytes)
    pub memory_alignment: usize,
    /// Automatic garbage collection threshold (fragmentation score 0.0-1.0)
    pub auto_gc_threshold: f64,
    /// Enable adaptive pool sizing based on usage patterns
    pub enable_adaptive_sizing: bool,
    /// Memory pressure monitoring interval (milliseconds)
    pub pressure_check_interval_ms: u64,
    /// Minimum allocation size to track for cache analysis
    pub min_cache_tracked_size: usize,
}

impl Default for PoolConfig {
    fn default() -> Self {
        Self {
            max_tensors_per_size: 16,
            max_total_memory: 1024 * 1024 * 1024, // 1GB
            enable_analytics: true,
            pre_allocate_sizes: vec![
                vec![1, 1],
                vec![32, 32],
                vec![64, 64],
                vec![128, 128],
                vec![256, 256],
                vec![512, 512],
                vec![1024, 1024],
            ],
            enable_cache_awareness: true,
            memory_alignment: 64, // 64-byte alignment for cache lines
            auto_gc_threshold: 0.75,
            enable_adaptive_sizing: true,
            pressure_check_interval_ms: 1000, // Check pressure every second
            min_cache_tracked_size: 1024,     // Track allocations >= 1KB for cache analysis
        }
    }
}

/// Memory usage analytics with advanced metrics
#[derive(Debug, Clone, Default)]
pub struct MemoryAnalytics {
    /// Total allocations requested
    pub total_allocations: usize,
    /// Total deallocations
    pub total_deallocations: usize,
    /// Pool hits (reused tensors)
    pub pool_hits: usize,
    /// Pool misses (new allocations)
    pub pool_misses: usize,
    /// Peak memory usage in bytes
    pub peak_memory_usage: usize,
    /// Current memory usage in bytes
    pub current_memory_usage: usize,
    /// Memory fragmentation score (0.0-1.0, lower is better)
    pub fragmentation_score: f64,
    /// Average allocation size in bytes
    pub avg_allocation_size: usize,
    /// Cache misses (estimated from allocation patterns)
    pub estimated_cache_misses: usize,
    /// Memory pressure events
    pub pressure_events: usize,
    /// Time spent in garbage collection (microseconds)
    pub gc_time_us: u64,
}

impl MemoryAnalytics {
    /// Get pool hit rate as percentage
    pub fn hit_rate(&self) -> f64 {
        if self.total_allocations == 0 {
            0.0
        } else {
            (self.pool_hits as f64 / self.total_allocations as f64) * 100.0
        }
    }

    /// Get memory efficiency ratio
    pub fn efficiency_ratio(&self) -> f64 {
        if self.peak_memory_usage == 0 {
            1.0
        } else {
            self.current_memory_usage as f64 / self.peak_memory_usage as f64
        }
    }

    /// Get cache efficiency estimate
    pub fn cache_efficiency(&self) -> f64 {
        if self.total_allocations == 0 {
            100.0
        } else {
            let cache_hits = self
                .total_allocations
                .saturating_sub(self.estimated_cache_misses);
            (cache_hits as f64 / self.total_allocations as f64) * 100.0
        }
    }

    /// Get overall performance score (0.0-100.0, higher is better)
    pub fn performance_score(&self) -> f64 {
        let hit_score = self.hit_rate() * 0.4;
        let efficiency_score = self.efficiency_ratio() * 100.0 * 0.3;
        let fragmentation_score = (1.0 - self.fragmentation_score) * 100.0 * 0.2;
        let cache_score = self.cache_efficiency() * 0.1;

        hit_score + efficiency_score + fragmentation_score + cache_score
    }

    /// Check if memory pool needs attention
    pub fn needs_optimization(&self) -> bool {
        self.fragmentation_score > 0.7 || self.hit_rate() < 50.0 || self.pressure_events > 10
    }

    /// Get recommendation for pool optimization
    pub fn get_optimization_recommendations(&self) -> Vec<String> {
        let mut recommendations = Vec::new();

        if self.hit_rate() < 50.0 {
            recommendations
                .push("Consider increasing pool sizes for commonly used tensor shapes".to_string());
        }

        if self.fragmentation_score > 0.7 {
            recommendations.push(
                "High fragmentation detected - consider triggering garbage collection".to_string(),
            );
        }

        if self.estimated_cache_misses as f64 / self.total_allocations as f64 > 0.3 {
            recommendations.push(
                "Cache-unfriendly allocation patterns detected - consider memory alignment"
                    .to_string(),
            );
        }

        if self.pressure_events > 5 {
            recommendations.push(
                "Memory pressure detected - consider reducing pool sizes or freeing unused memory"
                    .to_string(),
            );
        }

        recommendations
    }
}

/// Key for identifying tensor pools by shape and dtype
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TensorKey {
    shape: Vec<usize>,
    dtype: DType,
}

/// Thread-safe memory pool for tensor allocation
pub struct MemoryPool {
    config: PoolConfig,
    pools: Arc<Mutex<HashMap<TensorKey, VecDeque<Tensor>>>>,
    analytics: Arc<Mutex<MemoryAnalytics>>,
}

impl MemoryPool {
    /// Create a new memory pool with the given configuration
    pub fn new(config: PoolConfig) -> Self {
        let pool = Self {
            config,
            pools: Arc::new(Mutex::new(HashMap::new())),
            analytics: Arc::new(Mutex::new(MemoryAnalytics::default())),
        };

        // Pre-allocate common sizes if requested
        if !pool.config.pre_allocate_sizes.is_empty() {
            pool.pre_allocate_common_sizes();
        }

        pool
    }

    /// Pre-allocate tensors for common sizes
    fn pre_allocate_common_sizes(&self) {
        for shape in &self.config.pre_allocate_sizes {
            let key = TensorKey {
                shape: shape.clone(),
                dtype: DType::F32,
            };

            if let Ok(mut pools) = self.pools.lock() {
                let pool = pools.entry(key).or_insert_with(VecDeque::new);

                // Pre-allocate a few tensors for this size
                for _ in 0..4 {
                    if let Ok(tensor) = self.create_tensor(shape, DType::F32) {
                        pool.push_back(tensor);
                    }
                }
            }
        }
    }

    /// Allocate a tensor from the pool or create a new one
    pub fn allocate_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
        let key = TensorKey {
            shape: shape.to_vec(),
            dtype,
        };

        // Try to get from pool first
        if let Ok(mut pools) = self.pools.lock() {
            if let Some(pool) = pools.get_mut(&key) {
                if let Some(tensor) = pool.pop_front() {
                    // Update analytics
                    if let Ok(mut analytics) = self.analytics.lock() {
                        analytics.total_allocations += 1;
                        analytics.pool_hits += 1;
                    }
                    return Ok(tensor);
                }
            }
        }

        // Create new tensor if not available in pool
        let tensor = self.create_tensor(shape, dtype)?;

        // Update analytics
        if let Ok(mut analytics) = self.analytics.lock() {
            analytics.total_allocations += 1;
            analytics.pool_misses += 1;
        }

        Ok(tensor)
    }

    /// Release a tensor back to the pool for reuse
    pub fn release_tensor(&self, tensor: Tensor) {
        let key = TensorKey {
            shape: tensor.shape().dims().to_vec(),
            dtype: tensor.dtype(),
        };

        if let Ok(mut pools) = self.pools.lock() {
            let pool = pools.entry(key).or_insert_with(VecDeque::new);

            // Only keep tensor if we haven't exceeded the limit
            if pool.len() < self.config.max_tensors_per_size {
                pool.push_back(tensor);
            }
        }

        // Update analytics
        if let Ok(mut analytics) = self.analytics.lock() {
            analytics.total_deallocations += 1;
        }
    }

    /// Create a new tensor (helper method)
    fn create_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
        match dtype {
            DType::F32 => {
                let data: Vec<f32> = vec![0.0; shape.iter().product()];
                Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
                    .map_err(|e| TorshError::InvalidArgument(e.to_string()))
            }
            _ => {
                // For simplicity, create all tensors as f32 for the memory pool
                // Real quantization will handle the proper data types
                let data: Vec<f32> = vec![0.0; shape.iter().product()];
                Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
                    .map_err(|e| TorshError::InvalidArgument(e.to_string()))
            }
        }
    }

    /// Get current memory analytics
    pub fn get_analytics(&self) -> MemoryAnalytics {
        self.analytics
            .lock()
            .map(|guard| guard.clone())
            .unwrap_or_default()
    }

    /// Clear all pools and reset analytics
    pub fn clear(&self) {
        if let Ok(mut pools) = self.pools.lock() {
            pools.clear();
        }
        if let Ok(mut analytics) = self.analytics.lock() {
            *analytics = MemoryAnalytics::default();
        }
    }

    /// Get pool statistics
    pub fn get_pool_stats(&self) -> HashMap<String, usize> {
        let mut stats = HashMap::new();

        if let Ok(pools) = self.pools.lock() {
            for (key, pool) in pools.iter() {
                let key_str = format!("{:?}_{:?}", key.shape, key.dtype);
                stats.insert(key_str, pool.len());
            }
        }

        stats
    }
}

/// Convenience functions for common memory pool operations
impl MemoryPool {
    /// Create a global memory pool instance
    pub fn global() -> &'static MemoryPool {
        static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
        GLOBAL_POOL.get_or_init(|| MemoryPool::new(PoolConfig::default()))
    }

    /// Allocate f32 tensor from pool
    pub fn allocate_f32(&self, shape: &[usize]) -> TorshResult<Tensor> {
        self.allocate_tensor(shape, DType::F32)
    }

    /// Allocate i8 tensor from pool (common for quantized tensors)
    pub fn allocate_i8(&self, shape: &[usize]) -> TorshResult<Tensor> {
        self.allocate_tensor(shape, DType::I8)
    }

    /// Allocate u8 tensor from pool (common for quantized tensors)
    pub fn allocate_u8(&self, shape: &[usize]) -> TorshResult<Tensor> {
        self.allocate_tensor(shape, DType::U8)
    }
}

/// Advanced memory pool management methods
impl MemoryPool {
    /// Trigger garbage collection to reduce fragmentation
    pub fn garbage_collect(&self) -> TorshResult<()> {
        let start_time = std::time::Instant::now();

        if let Ok(mut pools) = self.pools.lock() {
            // Remove empty pools and compress partially filled ones
            pools.retain(|_, pool| {
                if pool.is_empty() {
                    true // Keep empty pools for future use
                } else {
                    // Optionally compact the pool here
                    true
                }
            });

            // Update fragmentation metrics
            if let Ok(mut analytics) = self.analytics.lock() {
                let gc_duration = start_time.elapsed();
                analytics.gc_time_us += gc_duration.as_micros() as u64;

                // Recalculate fragmentation score after GC
                analytics.fragmentation_score = self.calculate_fragmentation_score(&pools);
            }
        }

        Ok(())
    }

    /// Check memory pressure and auto-cleanup if needed
    pub fn check_memory_pressure(&self) -> bool {
        let analytics = self.get_analytics();
        let memory_usage_ratio =
            analytics.current_memory_usage as f64 / self.config.max_total_memory as f64;

        let high_pressure = memory_usage_ratio > 0.85
            || analytics.fragmentation_score > self.config.auto_gc_threshold;

        if high_pressure {
            // Trigger automatic garbage collection
            let _ = self.garbage_collect();

            // Update pressure events counter
            if let Ok(mut analytics) = self.analytics.lock() {
                analytics.pressure_events += 1;
            }
        }

        high_pressure
    }

    /// Calculate memory fragmentation score
    fn calculate_fragmentation_score(&self, pools: &HashMap<TensorKey, VecDeque<Tensor>>) -> f64 {
        if pools.is_empty() {
            return 0.0;
        }

        let total_pools = pools.len();
        let mut fragmented_pools = 0;
        let mut total_capacity = 0;
        let mut total_used = 0;

        for (_, pool) in pools.iter() {
            let capacity = self.config.max_tensors_per_size;
            let used = pool.len();

            total_capacity += capacity;
            total_used += used;

            // A pool is considered fragmented if it's less than 50% full
            if used > 0 && used < capacity / 2 {
                fragmented_pools += 1;
            }
        }

        let pool_fragmentation = fragmented_pools as f64 / total_pools as f64;
        let usage_fragmentation = if total_capacity > 0 {
            1.0 - (total_used as f64 / total_capacity as f64)
        } else {
            0.0
        };

        (pool_fragmentation + usage_fragmentation) / 2.0
    }

    /// Estimate cache misses based on allocation patterns
    #[allow(dead_code)]
    fn estimate_cache_misses(&self, allocation_size: usize) -> usize {
        if !self.config.enable_cache_awareness
            || allocation_size < self.config.min_cache_tracked_size
        {
            return 0;
        }

        // Simple heuristic: larger allocations that aren't aligned are more likely to cause cache misses
        let alignment = self.config.memory_alignment;
        let misaligned = allocation_size % alignment != 0;

        if misaligned && allocation_size > alignment * 8 {
            // Estimate 1 cache miss per 64 bytes of misaligned memory
            allocation_size / 64
        } else {
            0
        }
    }

    /// Adaptively adjust pool sizes based on usage patterns
    pub fn adaptive_resize(&self) -> TorshResult<()> {
        if !self.config.enable_adaptive_sizing {
            return Ok(());
        }

        let analytics = self.get_analytics();

        // If hit rate is low, consider expanding popular pools
        if analytics.hit_rate() < 50.0 {
            // Implementation would analyze which tensor sizes are most requested
            // and increase their pool sizes
        }

        // If fragmentation is high, consider consolidating pools
        if analytics.fragmentation_score > 0.7 {
            let _ = self.garbage_collect();
        }

        Ok(())
    }

    /// Get detailed pool utilization report
    pub fn get_utilization_report(&self) -> PoolUtilizationReport {
        let analytics = self.get_analytics();
        let pool_stats = self.get_pool_stats();

        PoolUtilizationReport {
            total_pools: pool_stats.len(),
            total_tensors_pooled: pool_stats.values().sum(),
            hit_rate: analytics.hit_rate(),
            fragmentation_score: analytics.fragmentation_score,
            cache_efficiency: analytics.cache_efficiency(),
            memory_usage_mb: analytics.current_memory_usage / 1024 / 1024,
            peak_memory_usage_mb: analytics.peak_memory_usage / 1024 / 1024,
            pressure_events: analytics.pressure_events,
            gc_time_ms: analytics.gc_time_us / 1000,
            performance_score: analytics.performance_score(),
            needs_optimization: analytics.needs_optimization(),
            recommendations: analytics.get_optimization_recommendations(),
        }
    }

    /// Prefetch tensors for predicted workload
    pub fn prefetch_for_workload(
        &self,
        predicted_shapes: &[(Vec<usize>, DType)],
    ) -> TorshResult<()> {
        for (shape, dtype) in predicted_shapes {
            // Pre-allocate a few tensors of this size
            for _ in 0..2 {
                let tensor = self.create_tensor(shape, *dtype)?;
                self.release_tensor(tensor);
            }
        }
        Ok(())
    }
}

/// Detailed pool utilization report
#[derive(Debug, Clone)]
pub struct PoolUtilizationReport {
    pub total_pools: usize,
    pub total_tensors_pooled: usize,
    pub hit_rate: f64,
    pub fragmentation_score: f64,
    pub cache_efficiency: f64,
    pub memory_usage_mb: usize,
    pub peak_memory_usage_mb: usize,
    pub pressure_events: usize,
    pub gc_time_ms: u64,
    pub performance_score: f64,
    pub needs_optimization: bool,
    pub recommendations: Vec<String>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_memory_pool_basic() {
        let mut config = PoolConfig::default();
        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for cleaner test
        let pool = MemoryPool::new(config);

        // Allocate a tensor
        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
        assert_eq!(tensor.shape().dims(), &[32, 32]);
        assert_eq!(tensor.dtype(), DType::F32);

        // Release back to pool
        pool.release_tensor(tensor);

        // Allocate same size again - should reuse
        let tensor2 = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
        assert_eq!(tensor2.shape().dims(), &[32, 32]);

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 2);
        assert_eq!(analytics.pool_hits, 1);
        assert_eq!(analytics.pool_misses, 1);
    }

    #[test]
    fn test_memory_pool_different_sizes() {
        let mut config = PoolConfig::default();
        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for predictable test results
        let pool = MemoryPool::new(config);

        let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
        let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();

        assert_eq!(tensor1.shape().dims(), &[64, 64]);
        assert_eq!(tensor2.shape().dims(), &[128, 128]);

        pool.release_tensor(tensor1);
        pool.release_tensor(tensor2);

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 2);
        assert_eq!(analytics.total_deallocations, 2);
        assert_eq!(analytics.pool_misses, 2);
        assert_eq!(analytics.pool_hits, 0); // No hits since different sizes
    }

    #[test]
    fn test_memory_pool_analytics() {
        let mut config = PoolConfig::default();
        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for predictable test results
        let pool = MemoryPool::new(config);

        // Allocate and release multiple tensors
        for _ in 0..5 {
            let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
            pool.release_tensor(tensor);
        }

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 5);
        assert_eq!(analytics.total_deallocations, 5);
        assert_eq!(analytics.pool_hits, 4); // First is miss, rest are hits
        assert_eq!(analytics.pool_misses, 1);
        assert_eq!(analytics.hit_rate(), 80.0);
    }

    #[test]
    fn test_memory_pool_clear() {
        let pool = MemoryPool::new(PoolConfig::default());

        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
        pool.release_tensor(tensor);

        pool.clear();

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 0);
        assert_eq!(analytics.total_deallocations, 0);
    }

    #[test]
    fn test_convenience_functions() {
        let pool = MemoryPool::new(PoolConfig::default());

        let f32_tensor = pool.allocate_f32(&[16, 16]).unwrap();
        let i8_tensor = pool.allocate_i8(&[16, 16]).unwrap();
        let u8_tensor = pool.allocate_u8(&[16, 16]).unwrap();

        // Note: Current implementation creates all tensors as F32 for simplicity
        assert_eq!(f32_tensor.dtype(), DType::F32);
        assert_eq!(i8_tensor.dtype(), DType::F32); // Actually F32, not I8
        assert_eq!(u8_tensor.dtype(), DType::F32); // Actually F32, not U8

        // Test that tensors have the correct shape
        assert_eq!(f32_tensor.shape().dims(), &[16, 16]);
        assert_eq!(i8_tensor.shape().dims(), &[16, 16]);
        assert_eq!(u8_tensor.shape().dims(), &[16, 16]);
    }

    #[test]
    fn test_global_pool() {
        let pool = MemoryPool::global();
        let tensor = pool.allocate_f32(&[8, 8]).unwrap();
        assert_eq!(tensor.shape().dims(), &[8, 8]);
        pool.release_tensor(tensor);
    }

    #[test]
    fn test_advanced_analytics() {
        let pool = MemoryPool::new(PoolConfig::default());

        // Allocate and release to generate analytics
        for i in 0..10 {
            let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
            if i % 2 == 0 {
                pool.release_tensor(tensor);
            }
        }

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 10);
        assert!(analytics.performance_score() >= 0.0);
        assert!(analytics.performance_score() <= 100.0);

        let recommendations = analytics.get_optimization_recommendations();
        // Should have some recommendations given the pattern
        assert!(!recommendations.is_empty() || analytics.performance_score() > 70.0);
    }

    #[test]
    fn test_garbage_collection() {
        let pool = MemoryPool::new(PoolConfig::default());

        // Create some fragmentation
        for i in 0..5 {
            let tensor = pool
                .allocate_tensor(&[i * 10 + 1, i * 10 + 1], DType::F32)
                .unwrap();
            if i % 2 == 0 {
                pool.release_tensor(tensor);
            }
        }

        // Trigger garbage collection
        pool.garbage_collect().unwrap();

        let analytics = pool.get_analytics();
        // GC time might be 0 for very fast operations, so just check it's non-negative
        assert!(analytics.gc_time_us >= 0);
    }

    #[test]
    fn test_memory_pressure_detection() {
        let mut config = PoolConfig::default();
        config.max_total_memory = 1024; // Very small limit to trigger pressure
        let pool = MemoryPool::new(config);

        // This should not trigger pressure initially
        let initial_pressure = pool.check_memory_pressure();
        assert!(!initial_pressure);

        // Allocate enough to potentially trigger pressure
        let _tensors: Vec<_> = (0..10)
            .map(|_| pool.allocate_tensor(&[32, 32], DType::F32).unwrap())
            .collect();

        // Check if pressure is detected (may or may not trigger depending on actual memory usage)
        let _final_pressure = pool.check_memory_pressure();
    }

    #[test]
    fn test_utilization_report() {
        let pool = MemoryPool::new(PoolConfig::default());

        // Generate some activity
        let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
        let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
        pool.release_tensor(tensor1);
        pool.release_tensor(tensor2);

        let report = pool.get_utilization_report();
        assert!(report.total_pools >= 0);
        assert!(report.hit_rate >= 0.0);
        assert!(report.performance_score >= 0.0);
        assert!(report.performance_score <= 100.0);
    }

    #[test]
    fn test_prefetch_workload() {
        let pool = MemoryPool::new(PoolConfig::default());

        let predicted_shapes = vec![
            (vec![32, 32], DType::F32),
            (vec![64, 64], DType::F32),
            (vec![128, 128], DType::F32),
        ];

        pool.prefetch_for_workload(&predicted_shapes).unwrap();

        // After prefetching, these sizes should have good hit rates
        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
        assert_eq!(tensor.shape().dims(), &[32, 32]);

        let analytics = pool.get_analytics();
        assert!(analytics.total_allocations > 0);
    }

    #[test]
    fn test_adaptive_config() {
        let mut config = PoolConfig::default();
        config.enable_cache_awareness = true;
        config.enable_adaptive_sizing = true;
        config.auto_gc_threshold = 0.5;

        let pool = MemoryPool::new(config);

        // Test that adaptive features are enabled
        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
        pool.release_tensor(tensor);

        // Trigger adaptive resize
        pool.adaptive_resize().unwrap();

        let analytics = pool.get_analytics();
        assert_eq!(analytics.total_allocations, 1);
    }
}