use crate::error::{RusTorchError, RusTorchResult};
use crate::memory::analytics::AllocationPattern;
use crate::memory::pressure_monitor::{GcStrategy, PressureLevel};
use ndarray::{ArrayD, IxDyn};
use num_traits::Float;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OptimizationStrategy {
MemoryFirst,
SpeedFirst,
Balanced,
Adaptive,
}
#[derive(Clone, Debug)]
pub struct OptimizerConfig {
pub strategy: OptimizationStrategy,
pub enable_prediction: bool,
pub prediction_window: usize,
pub compression_threshold: f64,
pub defrag_threshold: f64,
pub cache_size_limit: usize,
pub enable_zero_copy: bool,
pub preallocation_factor: f64,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
strategy: OptimizationStrategy::Adaptive,
enable_prediction: true,
prediction_window: 20,
compression_threshold: 0.8,
defrag_threshold: 0.6,
cache_size_limit: 512 * 1024 * 1024, enable_zero_copy: true,
preallocation_factor: 1.2,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryPrediction {
pub predicted_memory: usize,
pub confidence: f64,
pub horizon: Duration,
pub recommended_prealloc: usize,
}
#[derive(Debug, Clone)]
pub struct UsagePattern {
pub pattern_type: PatternType,
pub peak_times: Vec<SystemTime>,
pub avg_allocation_size: usize,
pub allocation_frequency: f64,
pub lifecycle: LifecycleCharacteristics,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PatternType {
Steady,
Bursty,
Periodic,
Growing,
Random,
}
#[derive(Debug, Clone)]
pub struct LifecycleCharacteristics {
pub avg_lifetime: Duration,
pub lifetime_variance: f64,
pub reuse_probability: f64,
}
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub data: ArrayD<T>,
pub last_accessed: Instant,
pub access_count: usize,
pub priority: u8,
pub compression_ratio: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct OptimizationStats {
pub total_optimizations: usize,
pub memory_saved: usize,
pub cache_hit_ratio: f64,
pub prediction_accuracy: f64,
pub defragmentations: usize,
pub zero_copy_operations: usize,
pub compression_operations: usize,
pub optimization_time: Duration,
}
impl Default for OptimizationStats {
fn default() -> Self {
Self {
total_optimizations: 0,
memory_saved: 0,
cache_hit_ratio: 0.0,
prediction_accuracy: 0.0,
defragmentations: 0,
zero_copy_operations: 0,
compression_operations: 0,
optimization_time: Duration::from_millis(0),
}
}
}
pub struct MemoryOptimizer<T: Float + Clone + Send + Sync + 'static> {
config: OptimizerConfig,
usage_history: RwLock<VecDeque<MemoryUsageSnapshot>>,
patterns: RwLock<Vec<UsagePattern>>,
smart_cache: RwLock<HashMap<String, CacheEntry<T>>>,
prealloc_pools: RwLock<BTreeMap<usize, VecDeque<ArrayD<T>>>>,
current_strategy: RwLock<OptimizationStrategy>,
stats: RwLock<OptimizationStats>,
last_optimization: Mutex<Option<Instant>>,
}
#[derive(Debug, Clone)]
pub struct MemoryUsageSnapshot {
pub timestamp: SystemTime,
pub total_usage: usize,
pub active_allocations: usize,
pub pressure_level: PressureLevel,
pub gc_strategy: GcStrategy,
}
impl<T: Float + Clone + Send + Sync + 'static> MemoryOptimizer<T> {
pub fn new(config: OptimizerConfig) -> Self {
let strategy = config.strategy;
Self {
config,
usage_history: RwLock::new(VecDeque::new()),
patterns: RwLock::new(Vec::new()),
smart_cache: RwLock::new(HashMap::new()),
prealloc_pools: RwLock::new(BTreeMap::new()),
current_strategy: RwLock::new(strategy),
stats: RwLock::new(OptimizationStats::default()),
last_optimization: Mutex::new(None),
}
}
pub fn optimize_allocation(&self, shape: &[usize]) -> RusTorchResult<ArrayD<T>> {
let start_time = Instant::now();
let total_elements: usize = shape.iter().product();
let size_class = self.get_size_class(total_elements);
if let Some(cached) = self.try_cache_retrieval(shape)? {
self.update_stats_cache_hit();
return Ok(cached);
}
if let Some(prealloc) = self.try_prealloc_retrieval(size_class, shape)? {
self.update_stats_zero_copy();
return Ok(prealloc);
}
if self.config.enable_prediction {
if let Some(prediction) = self.predict_memory_usage()? {
self.consider_preallocation(&prediction)?;
}
}
let array = self.allocate_with_strategy(shape)?;
self.update_optimization_stats(start_time.elapsed());
Ok(array)
}
pub fn optimize_deallocation(&self, array: ArrayD<T>) -> RusTorchResult<()> {
let shape = array.shape().to_vec();
let total_elements: usize = shape.iter().product();
let size_class = self.get_size_class(total_elements);
if self.should_cache(&array) {
self.add_to_cache(array, &shape)?;
return Ok(());
}
if self.should_preallocate(size_class) {
self.add_to_prealloc_pool(array, size_class)?;
return Ok(());
}
Ok(())
}
pub fn defragment_memory(&self) -> RusTorchResult<usize> {
let start_time = Instant::now();
let mut memory_reclaimed = 0;
{
let mut cache = self.smart_cache.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire cache write lock".to_string())
})?;
let mut to_remove = Vec::new();
let now = Instant::now();
for (key, entry) in cache.iter() {
let age = now.duration_since(entry.last_accessed);
let access_rate = entry.access_count as f64 / age.as_secs_f64().max(1.0);
if access_rate < 0.1 && age > Duration::from_secs(300) {
to_remove.push(key.clone());
}
}
for key in to_remove {
if let Some(entry) = cache.remove(&key) {
memory_reclaimed += entry.data.len() * std::mem::size_of::<T>();
}
}
}
{
let mut pools = self.prealloc_pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pools write lock".to_string())
})?;
for (_, pool) in pools.iter_mut() {
let keep_count = (pool.len() / 2).max(1);
let remove_count = pool.len() - keep_count;
for _ in 0..remove_count {
if let Some(array) = pool.pop_front() {
memory_reclaimed += array.len() * std::mem::size_of::<T>();
}
}
}
}
{
let mut stats = self.stats.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats write lock".to_string())
})?;
stats.defragmentations += 1;
stats.memory_saved += memory_reclaimed;
stats.optimization_time += start_time.elapsed();
}
Ok(memory_reclaimed)
}
pub fn predict_memory_usage(&self) -> RusTorchResult<Option<MemoryPrediction>> {
if !self.config.enable_prediction {
return Ok(None);
}
let history = self.usage_history.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire history read lock".to_string())
})?;
if history.len() < self.config.prediction_window {
return Ok(None);
}
let recent_snapshots: Vec<_> = history
.iter()
.rev()
.take(self.config.prediction_window)
.collect();
let mut total_usage_trend = 0.0;
for i in 1..recent_snapshots.len() {
let curr = recent_snapshots[i - 1];
let prev = recent_snapshots[i];
total_usage_trend += curr.total_usage as f64 - prev.total_usage as f64;
}
total_usage_trend /= (recent_snapshots.len() - 1) as f64;
let current_usage = recent_snapshots[0].total_usage;
let predicted_memory =
((current_usage as f64 + total_usage_trend * 5.0) as usize).max(current_usage);
let confidence = if total_usage_trend.abs() < current_usage as f64 * 0.1 {
0.8 } else {
0.5 };
let recommended_prealloc = ((predicted_memory as f64 * self.config.preallocation_factor)
as usize)
.saturating_sub(current_usage);
Ok(Some(MemoryPrediction {
predicted_memory,
confidence,
horizon: Duration::from_secs(300), recommended_prealloc,
}))
}
pub fn update_usage_snapshot(&self, snapshot: MemoryUsageSnapshot) -> RusTorchResult<()> {
let mut history = self.usage_history.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire history write lock".to_string())
})?;
history.push_back(snapshot);
while history.len() > self.config.prediction_window * 2 {
history.pop_front();
}
self.analyze_usage_patterns(&history)?;
self.update_strategy_if_adaptive(&history)?;
Ok(())
}
pub fn get_stats(&self) -> RusTorchResult<OptimizationStats> {
let stats = self.stats.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats read lock".to_string())
})?;
Ok(stats.clone())
}
pub fn compress_memory(&self) -> RusTorchResult<usize> {
Ok(0)
}
fn get_size_class(&self, total_elements: usize) -> usize {
match total_elements {
0..=64 => 0,
65..=256 => 1,
257..=1024 => 2,
1025..=4096 => 3,
4097..=16384 => 4,
_ => 5,
}
}
fn try_cache_retrieval(&self, shape: &[usize]) -> RusTorchResult<Option<ArrayD<T>>> {
let cache_key = format!("{:?}", shape);
let mut cache = self.smart_cache.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire cache write lock".to_string())
})?;
if let Some(entry) = cache.get_mut(&cache_key) {
entry.last_accessed = Instant::now();
entry.access_count += 1;
let mut result = entry.data.clone();
result.fill(T::zero());
return Ok(Some(result));
}
Ok(None)
}
fn try_prealloc_retrieval(
&self,
size_class: usize,
shape: &[usize],
) -> RusTorchResult<Option<ArrayD<T>>> {
let mut pools = self.prealloc_pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pools write lock".to_string())
})?;
if let Some(pool) = pools.get_mut(&size_class) {
if let Some(array) = pool.pop_back() {
let total_elements: usize = shape.iter().product();
if array.len() == total_elements {
let array_clone = array.clone();
if let Ok(reshaped) = array_clone.into_shape_with_order(IxDyn(shape)) {
return Ok(Some(reshaped));
}
}
pool.push_back(array);
}
}
Ok(None)
}
fn allocate_with_strategy(&self, shape: &[usize]) -> RusTorchResult<ArrayD<T>> {
let strategy = *self.current_strategy.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire strategy read lock".to_string())
})?;
match strategy {
OptimizationStrategy::MemoryFirst => {
Ok(ArrayD::zeros(IxDyn(shape)))
}
OptimizationStrategy::SpeedFirst => {
Ok(ArrayD::zeros(IxDyn(shape)))
}
OptimizationStrategy::Balanced | OptimizationStrategy::Adaptive => {
Ok(ArrayD::zeros(IxDyn(shape)))
}
}
}
fn should_cache(&self, array: &ArrayD<T>) -> bool {
let current_cache_size = self.estimate_cache_size();
let array_size = array.len() * std::mem::size_of::<T>();
current_cache_size + array_size <= self.config.cache_size_limit
}
fn should_preallocate(&self, size_class: usize) -> bool {
size_class <= 3 }
fn add_to_cache(&self, array: ArrayD<T>, shape: &[usize]) -> RusTorchResult<()> {
let cache_key = format!("{:?}", shape);
let entry = CacheEntry {
data: array,
last_accessed: Instant::now(),
access_count: 1,
priority: 128, compression_ratio: None,
};
let mut cache = self.smart_cache.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire cache write lock".to_string())
})?;
cache.insert(cache_key, entry);
Ok(())
}
fn add_to_prealloc_pool(&self, array: ArrayD<T>, size_class: usize) -> RusTorchResult<()> {
let mut pools = self.prealloc_pools.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire pools write lock".to_string())
})?;
let pool = pools.entry(size_class).or_insert_with(VecDeque::new);
pool.push_back(array);
if pool.len() > 10 {
pool.pop_front();
}
Ok(())
}
fn consider_preallocation(&self, prediction: &MemoryPrediction) -> RusTorchResult<()> {
if prediction.confidence < 0.7 || prediction.recommended_prealloc < 1024 {
return Ok(()); }
Ok(())
}
fn estimate_cache_size(&self) -> usize {
if let Ok(cache) = self.smart_cache.read() {
cache.len() * 1024 } else {
0
}
}
fn analyze_usage_patterns(
&self,
history: &VecDeque<MemoryUsageSnapshot>,
) -> RusTorchResult<()> {
Ok(())
}
fn update_strategy_if_adaptive(
&self,
history: &VecDeque<MemoryUsageSnapshot>,
) -> RusTorchResult<()> {
let current_strategy = *self.current_strategy.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire strategy read lock".to_string())
})?;
if current_strategy != OptimizationStrategy::Adaptive {
return Ok(());
}
if let Some(latest) = history.back() {
let new_strategy = match latest.pressure_level {
PressureLevel::Low => OptimizationStrategy::SpeedFirst,
PressureLevel::Medium => OptimizationStrategy::Balanced,
PressureLevel::High | PressureLevel::Critical => OptimizationStrategy::MemoryFirst,
};
let mut strategy = self.current_strategy.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire strategy write lock".to_string())
})?;
*strategy = new_strategy;
}
Ok(())
}
fn update_stats_cache_hit(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.cache_hit_ratio = (stats.cache_hit_ratio * 0.9) + 0.1; }
}
fn update_stats_zero_copy(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.zero_copy_operations += 1;
}
}
fn update_optimization_stats(&self, duration: Duration) {
if let Ok(mut stats) = self.stats.write() {
stats.total_optimizations += 1;
stats.optimization_time += duration;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimizer_creation() {
let config = OptimizerConfig::default();
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
let stats = optimizer.get_stats().unwrap();
assert_eq!(stats.total_optimizations, 0);
}
#[test]
fn test_size_class_calculation() {
let config = OptimizerConfig::default();
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
assert_eq!(optimizer.get_size_class(32), 0);
assert_eq!(optimizer.get_size_class(128), 1);
assert_eq!(optimizer.get_size_class(512), 2);
assert_eq!(optimizer.get_size_class(2048), 3);
}
#[test]
fn test_optimization_allocation() {
let config = OptimizerConfig::default();
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
let array = optimizer.optimize_allocation(&[3, 4]).unwrap();
assert_eq!(array.shape(), &[3, 4]);
let stats = optimizer.get_stats().unwrap();
assert_eq!(stats.total_optimizations, 1);
}
#[test]
fn test_cache_optimization() {
let config = OptimizerConfig::default();
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
let array1 = optimizer.optimize_allocation(&[2, 2]).unwrap();
optimizer.optimize_deallocation(array1).unwrap();
let array2 = optimizer.optimize_allocation(&[2, 2]).unwrap();
assert_eq!(array2.shape(), &[2, 2]);
}
#[test]
fn test_memory_prediction() {
let config = OptimizerConfig {
prediction_window: 5, ..OptimizerConfig::default()
};
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
for i in 0..10 {
let snapshot = MemoryUsageSnapshot {
timestamp: SystemTime::now(),
total_usage: 1000 + i * 100,
active_allocations: 10 + i,
pressure_level: PressureLevel::Low,
gc_strategy: GcStrategy::Conservative,
};
optimizer.update_usage_snapshot(snapshot).unwrap();
}
let prediction = optimizer.predict_memory_usage().unwrap();
assert!(prediction.is_some());
let pred = prediction.unwrap();
assert!(pred.predicted_memory > 0);
assert!(pred.confidence > 0.0 && pred.confidence <= 1.0);
}
#[test]
fn test_defragmentation() {
let config = OptimizerConfig::default();
let optimizer: MemoryOptimizer<f32> = MemoryOptimizer::new(config);
for _ in 0..5 {
let array = optimizer.optimize_allocation(&[10, 10]).unwrap();
optimizer.optimize_deallocation(array).unwrap();
}
let reclaimed = optimizer.defragment_memory().unwrap();
let stats = optimizer.get_stats().unwrap();
assert_eq!(stats.defragmentations, 1);
}
}