use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use torsh_core::device::DeviceType;
use torsh_core::Result as TorshResult;
use torsh_core::{DType, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_tensors_per_size: usize,
pub max_total_memory: usize,
pub enable_analytics: bool,
pub pre_allocate_sizes: Vec<Vec<usize>>,
pub enable_cache_awareness: bool,
pub memory_alignment: usize,
pub auto_gc_threshold: f64,
pub enable_adaptive_sizing: bool,
pub pressure_check_interval_ms: u64,
pub min_cache_tracked_size: usize,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_tensors_per_size: 16,
max_total_memory: 1024 * 1024 * 1024, enable_analytics: true,
pre_allocate_sizes: vec![
vec![1, 1],
vec![32, 32],
vec![64, 64],
vec![128, 128],
vec![256, 256],
vec![512, 512],
vec![1024, 1024],
],
enable_cache_awareness: true,
memory_alignment: 64, auto_gc_threshold: 0.75,
enable_adaptive_sizing: true,
pressure_check_interval_ms: 1000, min_cache_tracked_size: 1024, }
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryAnalytics {
pub total_allocations: usize,
pub total_deallocations: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub peak_memory_usage: usize,
pub current_memory_usage: usize,
pub fragmentation_score: f64,
pub avg_allocation_size: usize,
pub estimated_cache_misses: usize,
pub pressure_events: usize,
pub gc_time_us: u64,
}
impl MemoryAnalytics {
pub fn hit_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
(self.pool_hits as f64 / self.total_allocations as f64) * 100.0
}
}
pub fn efficiency_ratio(&self) -> f64 {
if self.peak_memory_usage == 0 {
1.0
} else {
self.current_memory_usage as f64 / self.peak_memory_usage as f64
}
}
pub fn cache_efficiency(&self) -> f64 {
if self.total_allocations == 0 {
100.0
} else {
let cache_hits = self
.total_allocations
.saturating_sub(self.estimated_cache_misses);
(cache_hits as f64 / self.total_allocations as f64) * 100.0
}
}
pub fn performance_score(&self) -> f64 {
let hit_score = self.hit_rate() * 0.4;
let efficiency_score = self.efficiency_ratio() * 100.0 * 0.3;
let fragmentation_score = (1.0 - self.fragmentation_score) * 100.0 * 0.2;
let cache_score = self.cache_efficiency() * 0.1;
hit_score + efficiency_score + fragmentation_score + cache_score
}
pub fn needs_optimization(&self) -> bool {
self.fragmentation_score > 0.7 || self.hit_rate() < 50.0 || self.pressure_events > 10
}
pub fn get_optimization_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self.hit_rate() < 50.0 {
recommendations
.push("Consider increasing pool sizes for commonly used tensor shapes".to_string());
}
if self.fragmentation_score > 0.7 {
recommendations.push(
"High fragmentation detected - consider triggering garbage collection".to_string(),
);
}
if self.estimated_cache_misses as f64 / self.total_allocations as f64 > 0.3 {
recommendations.push(
"Cache-unfriendly allocation patterns detected - consider memory alignment"
.to_string(),
);
}
if self.pressure_events > 5 {
recommendations.push(
"Memory pressure detected - consider reducing pool sizes or freeing unused memory"
.to_string(),
);
}
recommendations
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TensorKey {
shape: Vec<usize>,
dtype: DType,
}
pub struct MemoryPool {
config: PoolConfig,
pools: Arc<Mutex<HashMap<TensorKey, VecDeque<Tensor>>>>,
analytics: Arc<Mutex<MemoryAnalytics>>,
}
impl MemoryPool {
pub fn new(config: PoolConfig) -> Self {
let pool = Self {
config,
pools: Arc::new(Mutex::new(HashMap::new())),
analytics: Arc::new(Mutex::new(MemoryAnalytics::default())),
};
if !pool.config.pre_allocate_sizes.is_empty() {
pool.pre_allocate_common_sizes();
}
pool
}
fn pre_allocate_common_sizes(&self) {
for shape in &self.config.pre_allocate_sizes {
let key = TensorKey {
shape: shape.clone(),
dtype: DType::F32,
};
if let Ok(mut pools) = self.pools.lock() {
let pool = pools.entry(key).or_insert_with(VecDeque::new);
for _ in 0..4 {
if let Ok(tensor) = self.create_tensor(shape, DType::F32) {
pool.push_back(tensor);
}
}
}
}
}
pub fn allocate_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
let key = TensorKey {
shape: shape.to_vec(),
dtype,
};
if let Ok(mut pools) = self.pools.lock() {
if let Some(pool) = pools.get_mut(&key) {
if let Some(tensor) = pool.pop_front() {
if let Ok(mut analytics) = self.analytics.lock() {
analytics.total_allocations += 1;
analytics.pool_hits += 1;
}
return Ok(tensor);
}
}
}
let tensor = self.create_tensor(shape, dtype)?;
if let Ok(mut analytics) = self.analytics.lock() {
analytics.total_allocations += 1;
analytics.pool_misses += 1;
}
Ok(tensor)
}
pub fn release_tensor(&self, tensor: Tensor) {
let key = TensorKey {
shape: tensor.shape().dims().to_vec(),
dtype: tensor.dtype(),
};
if let Ok(mut pools) = self.pools.lock() {
let pool = pools.entry(key).or_insert_with(VecDeque::new);
if pool.len() < self.config.max_tensors_per_size {
pool.push_back(tensor);
}
}
if let Ok(mut analytics) = self.analytics.lock() {
analytics.total_deallocations += 1;
}
}
fn create_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
match dtype {
DType::F32 => {
let data: Vec<f32> = vec![0.0; shape.iter().product()];
Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
.map_err(|e| TorshError::InvalidArgument(e.to_string()))
}
_ => {
let data: Vec<f32> = vec![0.0; shape.iter().product()];
Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
.map_err(|e| TorshError::InvalidArgument(e.to_string()))
}
}
}
pub fn get_analytics(&self) -> MemoryAnalytics {
self.analytics
.lock()
.map(|guard| guard.clone())
.unwrap_or_default()
}
pub fn clear(&self) {
if let Ok(mut pools) = self.pools.lock() {
pools.clear();
}
if let Ok(mut analytics) = self.analytics.lock() {
*analytics = MemoryAnalytics::default();
}
}
pub fn get_pool_stats(&self) -> HashMap<String, usize> {
let mut stats = HashMap::new();
if let Ok(pools) = self.pools.lock() {
for (key, pool) in pools.iter() {
let key_str = format!("{:?}_{:?}", key.shape, key.dtype);
stats.insert(key_str, pool.len());
}
}
stats
}
}
impl MemoryPool {
pub fn global() -> &'static MemoryPool {
static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
GLOBAL_POOL.get_or_init(|| MemoryPool::new(PoolConfig::default()))
}
pub fn allocate_f32(&self, shape: &[usize]) -> TorshResult<Tensor> {
self.allocate_tensor(shape, DType::F32)
}
pub fn allocate_i8(&self, shape: &[usize]) -> TorshResult<Tensor> {
self.allocate_tensor(shape, DType::I8)
}
pub fn allocate_u8(&self, shape: &[usize]) -> TorshResult<Tensor> {
self.allocate_tensor(shape, DType::U8)
}
}
impl MemoryPool {
pub fn garbage_collect(&self) -> TorshResult<()> {
let start_time = std::time::Instant::now();
if let Ok(mut pools) = self.pools.lock() {
pools.retain(|_, pool| {
if pool.is_empty() {
true } else {
true
}
});
if let Ok(mut analytics) = self.analytics.lock() {
let gc_duration = start_time.elapsed();
analytics.gc_time_us += gc_duration.as_micros() as u64;
analytics.fragmentation_score = self.calculate_fragmentation_score(&pools);
}
}
Ok(())
}
pub fn check_memory_pressure(&self) -> bool {
let analytics = self.get_analytics();
let memory_usage_ratio =
analytics.current_memory_usage as f64 / self.config.max_total_memory as f64;
let high_pressure = memory_usage_ratio > 0.85
|| analytics.fragmentation_score > self.config.auto_gc_threshold;
if high_pressure {
let _ = self.garbage_collect();
if let Ok(mut analytics) = self.analytics.lock() {
analytics.pressure_events += 1;
}
}
high_pressure
}
fn calculate_fragmentation_score(&self, pools: &HashMap<TensorKey, VecDeque<Tensor>>) -> f64 {
if pools.is_empty() {
return 0.0;
}
let total_pools = pools.len();
let mut fragmented_pools = 0;
let mut total_capacity = 0;
let mut total_used = 0;
for (_, pool) in pools.iter() {
let capacity = self.config.max_tensors_per_size;
let used = pool.len();
total_capacity += capacity;
total_used += used;
if used > 0 && used < capacity / 2 {
fragmented_pools += 1;
}
}
let pool_fragmentation = fragmented_pools as f64 / total_pools as f64;
let usage_fragmentation = if total_capacity > 0 {
1.0 - (total_used as f64 / total_capacity as f64)
} else {
0.0
};
(pool_fragmentation + usage_fragmentation) / 2.0
}
#[allow(dead_code)]
fn estimate_cache_misses(&self, allocation_size: usize) -> usize {
if !self.config.enable_cache_awareness
|| allocation_size < self.config.min_cache_tracked_size
{
return 0;
}
let alignment = self.config.memory_alignment;
let misaligned = allocation_size % alignment != 0;
if misaligned && allocation_size > alignment * 8 {
allocation_size / 64
} else {
0
}
}
pub fn adaptive_resize(&self) -> TorshResult<()> {
if !self.config.enable_adaptive_sizing {
return Ok(());
}
let analytics = self.get_analytics();
if analytics.hit_rate() < 50.0 {
}
if analytics.fragmentation_score > 0.7 {
let _ = self.garbage_collect();
}
Ok(())
}
pub fn get_utilization_report(&self) -> PoolUtilizationReport {
let analytics = self.get_analytics();
let pool_stats = self.get_pool_stats();
PoolUtilizationReport {
total_pools: pool_stats.len(),
total_tensors_pooled: pool_stats.values().sum(),
hit_rate: analytics.hit_rate(),
fragmentation_score: analytics.fragmentation_score,
cache_efficiency: analytics.cache_efficiency(),
memory_usage_mb: analytics.current_memory_usage / 1024 / 1024,
peak_memory_usage_mb: analytics.peak_memory_usage / 1024 / 1024,
pressure_events: analytics.pressure_events,
gc_time_ms: analytics.gc_time_us / 1000,
performance_score: analytics.performance_score(),
needs_optimization: analytics.needs_optimization(),
recommendations: analytics.get_optimization_recommendations(),
}
}
pub fn prefetch_for_workload(
&self,
predicted_shapes: &[(Vec<usize>, DType)],
) -> TorshResult<()> {
for (shape, dtype) in predicted_shapes {
for _ in 0..2 {
let tensor = self.create_tensor(shape, *dtype)?;
self.release_tensor(tensor);
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PoolUtilizationReport {
pub total_pools: usize,
pub total_tensors_pooled: usize,
pub hit_rate: f64,
pub fragmentation_score: f64,
pub cache_efficiency: f64,
pub memory_usage_mb: usize,
pub peak_memory_usage_mb: usize,
pub pressure_events: usize,
pub gc_time_ms: u64,
pub performance_score: f64,
pub needs_optimization: bool,
pub recommendations: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_basic() {
let mut config = PoolConfig::default();
config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
assert_eq!(tensor.shape().dims(), &[32, 32]);
assert_eq!(tensor.dtype(), DType::F32);
pool.release_tensor(tensor);
let tensor2 = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
assert_eq!(tensor2.shape().dims(), &[32, 32]);
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 2);
assert_eq!(analytics.pool_hits, 1);
assert_eq!(analytics.pool_misses, 1);
}
#[test]
fn test_memory_pool_different_sizes() {
let mut config = PoolConfig::default();
config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
assert_eq!(tensor1.shape().dims(), &[64, 64]);
assert_eq!(tensor2.shape().dims(), &[128, 128]);
pool.release_tensor(tensor1);
pool.release_tensor(tensor2);
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 2);
assert_eq!(analytics.total_deallocations, 2);
assert_eq!(analytics.pool_misses, 2);
assert_eq!(analytics.pool_hits, 0); }
#[test]
fn test_memory_pool_analytics() {
let mut config = PoolConfig::default();
config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
for _ in 0..5 {
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
pool.release_tensor(tensor);
}
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 5);
assert_eq!(analytics.total_deallocations, 5);
assert_eq!(analytics.pool_hits, 4); assert_eq!(analytics.pool_misses, 1);
assert_eq!(analytics.hit_rate(), 80.0);
}
#[test]
fn test_memory_pool_clear() {
let pool = MemoryPool::new(PoolConfig::default());
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
pool.release_tensor(tensor);
pool.clear();
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 0);
assert_eq!(analytics.total_deallocations, 0);
}
#[test]
fn test_convenience_functions() {
let pool = MemoryPool::new(PoolConfig::default());
let f32_tensor = pool.allocate_f32(&[16, 16]).unwrap();
let i8_tensor = pool.allocate_i8(&[16, 16]).unwrap();
let u8_tensor = pool.allocate_u8(&[16, 16]).unwrap();
assert_eq!(f32_tensor.dtype(), DType::F32);
assert_eq!(i8_tensor.dtype(), DType::F32); assert_eq!(u8_tensor.dtype(), DType::F32);
assert_eq!(f32_tensor.shape().dims(), &[16, 16]);
assert_eq!(i8_tensor.shape().dims(), &[16, 16]);
assert_eq!(u8_tensor.shape().dims(), &[16, 16]);
}
#[test]
fn test_global_pool() {
let pool = MemoryPool::global();
let tensor = pool.allocate_f32(&[8, 8]).unwrap();
assert_eq!(tensor.shape().dims(), &[8, 8]);
pool.release_tensor(tensor);
}
#[test]
fn test_advanced_analytics() {
let pool = MemoryPool::new(PoolConfig::default());
for i in 0..10 {
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
if i % 2 == 0 {
pool.release_tensor(tensor);
}
}
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 10);
assert!(analytics.performance_score() >= 0.0);
assert!(analytics.performance_score() <= 100.0);
let recommendations = analytics.get_optimization_recommendations();
assert!(!recommendations.is_empty() || analytics.performance_score() > 70.0);
}
#[test]
fn test_garbage_collection() {
let pool = MemoryPool::new(PoolConfig::default());
for i in 0..5 {
let tensor = pool
.allocate_tensor(&[i * 10 + 1, i * 10 + 1], DType::F32)
.unwrap();
if i % 2 == 0 {
pool.release_tensor(tensor);
}
}
pool.garbage_collect().unwrap();
let analytics = pool.get_analytics();
assert!(analytics.gc_time_us >= 0);
}
#[test]
fn test_memory_pressure_detection() {
let mut config = PoolConfig::default();
config.max_total_memory = 1024; let pool = MemoryPool::new(config);
let initial_pressure = pool.check_memory_pressure();
assert!(!initial_pressure);
let _tensors: Vec<_> = (0..10)
.map(|_| pool.allocate_tensor(&[32, 32], DType::F32).unwrap())
.collect();
let _final_pressure = pool.check_memory_pressure();
}
#[test]
fn test_utilization_report() {
let pool = MemoryPool::new(PoolConfig::default());
let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
pool.release_tensor(tensor1);
pool.release_tensor(tensor2);
let report = pool.get_utilization_report();
assert!(report.total_pools >= 0);
assert!(report.hit_rate >= 0.0);
assert!(report.performance_score >= 0.0);
assert!(report.performance_score <= 100.0);
}
#[test]
fn test_prefetch_workload() {
let pool = MemoryPool::new(PoolConfig::default());
let predicted_shapes = vec![
(vec![32, 32], DType::F32),
(vec![64, 64], DType::F32),
(vec![128, 128], DType::F32),
];
pool.prefetch_for_workload(&predicted_shapes).unwrap();
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
assert_eq!(tensor.shape().dims(), &[32, 32]);
let analytics = pool.get_analytics();
assert!(analytics.total_allocations > 0);
}
#[test]
fn test_adaptive_config() {
let mut config = PoolConfig::default();
config.enable_cache_awareness = true;
config.enable_adaptive_sizing = true;
config.auto_gc_threshold = 0.5;
let pool = MemoryPool::new(config);
let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
pool.release_tensor(tensor);
pool.adaptive_resize().unwrap();
let analytics = pool.get_analytics();
assert_eq!(analytics.total_allocations, 1);
}
}