pub mod analytics;
pub mod enhanced_pool;
pub mod optimizer;
pub mod pressure_monitor;
pub use analytics::{AllocationRecord, AnalyticsConfig, MemoryAnalytics, MemoryReport};
pub use enhanced_pool::{AllocationStrategy, EnhancedMemoryPool, EnhancedPoolStats, PoolConfig};
pub use optimizer::{MemoryOptimizer, MemoryPrediction, OptimizationStrategy, OptimizerConfig};
pub use pressure_monitor::{AdaptivePressureMonitor, GcStrategy, MonitorConfig, PressureLevel};
use crate::error::{RusTorchError, RusTorchResult};
use lazy_static::lazy_static;
use ndarray::{ArrayD, IxDyn};
use num_traits::Float;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
pub struct MemoryPool<T: Float> {
pools: Vec<Arc<Mutex<VecDeque<ArrayD<T>>>>>,
max_pool_size: usize,
}
impl<T: Float + Clone + 'static> MemoryPool<T> {
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: Vec::new(),
max_pool_size,
}
}
fn get_pool_index(&self, total_elements: usize) -> usize {
if total_elements <= 64 {
0
} else if total_elements <= 256 {
1
} else if total_elements <= 1024 {
2
} else if total_elements <= 4096 {
3
} else if total_elements <= 16384 {
4
} else if total_elements <= 65536 {
5
} else {
6
}
}
fn ensure_pool(&mut self, index: usize) {
while self.pools.len() <= index {
self.pools.push(Arc::new(Mutex::new(VecDeque::new())));
}
}
pub fn allocate(&mut self, shape: &[usize]) -> ArrayD<T> {
let total_elements: usize = shape.iter().product();
let pool_index = self.get_pool_index(total_elements);
self.ensure_pool(pool_index);
if let Ok(mut pool) = self.pools[pool_index].lock() {
if let Some(mut array) = pool.pop_front() {
if array.shape() == shape {
array.fill(T::zero());
return array;
}
if array.len() >= total_elements {
let cloned_array = array.clone();
match cloned_array.into_shape_with_order(IxDyn(shape)) {
Ok(reshaped) => return reshaped,
Err(_) => {
pool.push_back(array);
}
}
} else {
pool.push_back(array);
}
}
}
ArrayD::zeros(IxDyn(shape))
}
pub fn deallocate(&mut self, array: ArrayD<T>) {
let total_elements = array.len();
let pool_index = self.get_pool_index(total_elements);
self.ensure_pool(pool_index);
if let Ok(mut pool) = self.pools[pool_index].lock() {
if pool.len() < self.max_pool_size {
pool.push_back(array);
}
}
}
pub fn stats(&self) -> PoolStats {
let mut total_cached = 0;
let mut pool_sizes = Vec::new();
for pool in &self.pools {
if let Ok(pool) = pool.lock() {
let size = pool.len();
pool_sizes.push(size);
total_cached += size;
}
}
PoolStats {
total_pools: self.pools.len(),
total_cached_arrays: total_cached,
pool_sizes,
max_pool_size: self.max_pool_size,
}
}
pub fn clear(&mut self) {
for pool in &self.pools {
if let Ok(mut pool) = pool.lock() {
pool.clear();
}
}
}
}
#[derive(Debug)]
pub struct PoolStats {
pub total_pools: usize,
pub total_cached_arrays: usize,
pub pool_sizes: Vec<usize>,
pub max_pool_size: usize,
}
impl std::fmt::Display for PoolStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Memory Pool Statistics:")?;
writeln!(f, " Total pools: {}", self.total_pools)?;
writeln!(f, " Total cached arrays: {}", self.total_cached_arrays)?;
writeln!(f, " Max pool size: {}", self.max_pool_size)?;
writeln!(f, " Pool sizes: {:?}", self.pool_sizes)?;
Ok(())
}
}
lazy_static! {
static ref GLOBAL_POOL_F32: Arc<Mutex<MemoryPool<f32>>> =
Arc::new(Mutex::new(MemoryPool::new(100)));
static ref GLOBAL_POOL_F64: Arc<Mutex<MemoryPool<f64>>> =
Arc::new(Mutex::new(MemoryPool::new(100)));
}
pub fn get_f32_pool() -> Arc<Mutex<MemoryPool<f32>>> {
GLOBAL_POOL_F32.clone()
}
pub fn get_f64_pool() -> Arc<Mutex<MemoryPool<f64>>> {
GLOBAL_POOL_F64.clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_creation() {
let pool: MemoryPool<f32> = MemoryPool::new(10);
let stats = pool.stats();
assert_eq!(stats.total_pools, 0);
assert_eq!(stats.total_cached_arrays, 0);
}
#[test]
fn test_allocate_and_deallocate() {
let mut pool: MemoryPool<f32> = MemoryPool::new(10);
let array1 = pool.allocate(&[2, 3]);
assert_eq!(array1.shape(), &[2, 3]);
pool.deallocate(array1);
let stats = pool.stats();
assert_eq!(stats.total_cached_arrays, 1);
}
#[test]
fn test_reuse_from_pool() {
let mut pool: MemoryPool<f32> = MemoryPool::new(10);
let array1 = pool.allocate(&[2, 3]);
pool.deallocate(array1);
let array2 = pool.allocate(&[2, 3]);
assert_eq!(array2.shape(), &[2, 3]);
let stats = pool.stats();
assert_eq!(stats.total_cached_arrays, 0); }
#[test]
fn test_pool_size_limit() {
let mut pool: MemoryPool<f32> = MemoryPool::new(2);
for _ in 0..5 {
let array = pool.allocate(&[2, 2]);
pool.deallocate(array);
}
let stats = pool.stats();
assert!(stats.total_cached_arrays <= 2);
}
#[test]
fn test_global_pools() {
let pool_f32 = get_f32_pool();
let pool_f64 = get_f64_pool();
assert!(pool_f32.lock().is_ok());
assert!(pool_f64.lock().is_ok());
}
}
pub struct ComprehensiveMemoryManager<T: Float + Clone + Send + Sync + 'static> {
pool: EnhancedMemoryPool<T>,
monitor: AdaptivePressureMonitor,
analytics: MemoryAnalytics,
optimizer: MemoryOptimizer<T>,
}
impl<T: Float + Clone + Send + Sync + 'static> ComprehensiveMemoryManager<T> {
pub fn new() -> Self {
let pool_config = PoolConfig::default();
let monitor_config = MonitorConfig::default();
let analytics_config = AnalyticsConfig::default();
let optimizer_config = OptimizerConfig::default();
Self {
pool: EnhancedMemoryPool::new(pool_config),
monitor: AdaptivePressureMonitor::new(monitor_config),
analytics: MemoryAnalytics::new(analytics_config),
optimizer: MemoryOptimizer::new(optimizer_config),
}
}
pub fn with_configs(
pool_config: PoolConfig,
monitor_config: MonitorConfig,
analytics_config: AnalyticsConfig,
optimizer_config: OptimizerConfig,
) -> Self {
Self {
pool: EnhancedMemoryPool::new(pool_config),
monitor: AdaptivePressureMonitor::new(monitor_config),
analytics: MemoryAnalytics::new(analytics_config),
optimizer: MemoryOptimizer::new(optimizer_config),
}
}
pub fn allocate(&self, shape: &[usize]) -> RusTorchResult<ArrayD<T>> {
let source_location = format!("tensor::allocate::{:?}", shape);
let alloc_id = self.analytics.record_allocation(
shape.iter().product::<usize>() * std::mem::size_of::<T>(),
source_location,
)?;
let array = self.optimizer.optimize_allocation(shape)?;
Ok(array)
}
pub fn deallocate(&self, array: ArrayD<T>, alloc_id: u64) -> RusTorchResult<()> {
self.analytics.record_deallocation(alloc_id)?;
self.optimizer.optimize_deallocation(array)?;
Ok(())
}
pub fn start_all_systems(&self) -> RusTorchResult<()> {
self.monitor.start_monitoring()?;
self.analytics.start_analysis()?;
Ok(())
}
pub fn stop_all_systems(&self) -> RusTorchResult<()> {
self.monitor.stop_monitoring()?;
self.analytics.stop_analysis()?;
Ok(())
}
pub fn generate_comprehensive_report(&self) -> RusTorchResult<ComprehensiveReport> {
let pool_stats = self.pool.get_stats()?;
let monitor_stats = self.monitor.get_stats()?;
let analytics_report = self.analytics.generate_report()?;
let optimizer_stats = self.optimizer.get_stats()?;
Ok(ComprehensiveReport {
pool_stats,
monitor_stats,
analytics_report,
optimizer_stats,
})
}
pub fn optimize_all(&self) -> RusTorchResult<OptimizationSummary> {
let start_time = std::time::Instant::now();
let gc_stats = self.pool.garbage_collect()?;
let defrag_reclaimed = self.optimizer.defragment_memory()?;
let compression_saved = self.optimizer.compress_memory()?;
let total_time = start_time.elapsed();
Ok(OptimizationSummary {
gc_memory_reclaimed: gc_stats.memory_reclaimed,
defrag_memory_reclaimed: defrag_reclaimed,
compression_memory_saved: compression_saved,
total_optimization_time: total_time,
})
}
pub fn get_health_status(&self) -> RusTorchResult<SystemHealthStatus> {
let snapshot = self.monitor.get_current_snapshot()?;
let trend = self.monitor.analyze_trend()?;
let prediction = self.optimizer.predict_memory_usage()?;
let health_score = self.calculate_health_score(&snapshot, &trend, &prediction)?;
Ok(SystemHealthStatus {
current_snapshot: snapshot,
trend_analysis: trend,
memory_prediction: prediction,
health_score,
recommendations: self.generate_recommendations(health_score)?,
})
}
fn calculate_health_score(
&self,
snapshot: &Option<pressure_monitor::MemorySnapshot>,
trend: &Option<pressure_monitor::PressureTrend>,
prediction: &Option<MemoryPrediction>,
) -> RusTorchResult<f64> {
let mut score = 1.0;
if let Some(snap) = snapshot {
score -= snap.pressure_ratio * 0.4; }
if let Some(t) = trend {
if t.direction > 0.0 {
score -= t.strength * 0.3;
}
}
if let Some(pred) = prediction {
if pred.confidence < 0.5 {
score -= 0.1; }
}
Ok(score.max(0.0).min(1.0))
}
fn generate_recommendations(&self, health_score: f64) -> RusTorchResult<Vec<String>> {
let mut recommendations = Vec::new();
if health_score < 0.3 {
recommendations.push("Critical: Immediate memory optimization required".to_string());
recommendations.push("Consider reducing batch sizes or model complexity".to_string());
recommendations.push("Run comprehensive memory cleanup".to_string());
} else if health_score < 0.6 {
recommendations.push("Warning: High memory pressure detected".to_string());
recommendations.push("Consider running memory defragmentation".to_string());
recommendations.push("Monitor memory usage closely".to_string());
} else if health_score < 0.8 {
recommendations.push("Moderate memory pressure".to_string());
recommendations.push("Consider periodic cleanup".to_string());
} else {
recommendations.push("Memory system operating normally".to_string());
}
Ok(recommendations)
}
}
impl<T: Float + Clone + Send + Sync + 'static> Default for ComprehensiveMemoryManager<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ComprehensiveReport {
pub pool_stats: EnhancedPoolStats,
pub monitor_stats: pressure_monitor::MonitorStats,
pub analytics_report: MemoryReport,
pub optimizer_stats: optimizer::OptimizationStats,
}
#[derive(Debug, Clone)]
pub struct OptimizationSummary {
pub gc_memory_reclaimed: usize,
pub defrag_memory_reclaimed: usize,
pub compression_memory_saved: usize,
pub total_optimization_time: std::time::Duration,
}
#[derive(Debug, Clone)]
pub struct SystemHealthStatus {
pub current_snapshot: Option<pressure_monitor::MemorySnapshot>,
pub trend_analysis: Option<pressure_monitor::PressureTrend>,
pub memory_prediction: Option<MemoryPrediction>,
pub health_score: f64,
pub recommendations: Vec<String>,
}
impl std::fmt::Display for ComprehensiveReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Comprehensive Memory Management Report")?;
writeln!(f, "=====================================")?;
writeln!(f, "")?;
writeln!(f, "Enhanced Pool Statistics:")?;
writeln!(f, "{}", self.pool_stats)?;
writeln!(f, "")?;
writeln!(f, "Monitor Statistics:")?;
writeln!(
f,
" Total Snapshots: {}",
self.monitor_stats.total_snapshots
)?;
writeln!(
f,
" Average Pressure: {:.2}%",
self.monitor_stats.avg_pressure * 100.0
)?;
writeln!(
f,
" Peak Pressure: {:.2}%",
self.monitor_stats.peak_pressure * 100.0
)?;
writeln!(f, "")?;
writeln!(f, "{}", self.analytics_report)?;
writeln!(f, "")?;
writeln!(f, "Optimizer Statistics:")?;
writeln!(
f,
" Total Optimizations: {}",
self.optimizer_stats.total_optimizations
)?;
writeln!(
f,
" Memory Saved: {} bytes",
self.optimizer_stats.memory_saved
)?;
writeln!(
f,
" Cache Hit Ratio: {:.2}%",
self.optimizer_stats.cache_hit_ratio * 100.0
)?;
writeln!(
f,
" Zero-Copy Operations: {}",
self.optimizer_stats.zero_copy_operations
)?;
Ok(())
}
}
lazy_static! {
static ref GLOBAL_MEMORY_MANAGER_F32: Arc<Mutex<ComprehensiveMemoryManager<f32>>> =
Arc::new(Mutex::new(ComprehensiveMemoryManager::new()));
static ref GLOBAL_MEMORY_MANAGER_F64: Arc<Mutex<ComprehensiveMemoryManager<f64>>> =
Arc::new(Mutex::new(ComprehensiveMemoryManager::new()));
}
pub fn get_global_memory_manager_f32() -> Arc<Mutex<ComprehensiveMemoryManager<f32>>> {
GLOBAL_MEMORY_MANAGER_F32.clone()
}
pub fn get_global_memory_manager_f64() -> Arc<Mutex<ComprehensiveMemoryManager<f64>>> {
GLOBAL_MEMORY_MANAGER_F64.clone()
}