Skip to main content

optirs_learned/transformer_based_optimizer/
memory_manager.rs

1// Memory management for transformer-based optimizer
2
3use super::config::{CacheEvictionStrategy, MemoryConfig, TransformerBasedOptimizerConfig};
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{BTreeMap, HashMap, VecDeque};
8use std::fmt::Debug;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12/// Memory management strategy types
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum MemoryManagementStrategy {
15    /// Simple FIFO eviction
16    FIFO,
17    /// Least Recently Used
18    LRU,
19    /// Least Frequently Used
20    LFU,
21    /// Adaptive replacement cache
22    ARC,
23    /// Compressed memory storage
24    Compressed,
25    /// Hierarchical memory organization
26    Hierarchical,
27}
28
29/// Transformer memory manager
30pub struct TransformerMemoryManager<T: Float + Debug + Send + Sync + 'static> {
31    /// Memory management strategy
32    strategy: MemoryManagementStrategy,
33
34    /// Configuration
35    config: MemoryConfig,
36
37    /// Primary memory cache
38    primary_cache: MemoryCache<T>,
39
40    /// Secondary cache for overflow
41    secondary_cache: Option<MemoryCache<T>>,
42
43    /// Memory compression manager
44    compression_manager: Option<CompressionManager<T>>,
45
46    /// Memory statistics
47    statistics: MemoryStatistics,
48
49    /// Access patterns tracker
50    access_tracker: AccessTracker,
51
52    /// Memory pressure monitor
53    pressure_monitor: MemoryPressureMonitor,
54
55    /// Model dimension
56    model_dimension: usize,
57}
58
59impl<T: Float + Debug + Send + Sync + 'static> TransformerMemoryManager<T> {
60    /// Create new memory manager
61    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
62        let memory_config = config.memory_config.clone();
63        let strategy = match memory_config.eviction_strategy {
64            CacheEvictionStrategy::LRU => MemoryManagementStrategy::LRU,
65            CacheEvictionStrategy::LFU => MemoryManagementStrategy::LFU,
66            CacheEvictionStrategy::FIFO => MemoryManagementStrategy::FIFO,
67            CacheEvictionStrategy::Random => MemoryManagementStrategy::LRU, // Fallback
68        };
69
70        let primary_cache = MemoryCache::new(
71            memory_config.max_cache_size / 2,
72            memory_config.eviction_strategy,
73        )?;
74
75        let secondary_cache = if memory_config.max_cache_size > 1024 * 1024 * 100 {
76            // 100MB
77            Some(MemoryCache::new(
78                memory_config.max_cache_size / 2,
79                CacheEvictionStrategy::FIFO,
80            )?)
81        } else {
82            None
83        };
84
85        let compression_manager = if memory_config.enable_compression {
86            Some(CompressionManager::new(0.5)?) // 50% compression ratio target
87        } else {
88            None
89        };
90
91        let statistics = MemoryStatistics::new();
92        let access_tracker = AccessTracker::new(1000);
93        let pressure_monitor = MemoryPressureMonitor::new();
94
95        Ok(Self {
96            strategy,
97            config: memory_config,
98            primary_cache,
99            secondary_cache,
100            compression_manager,
101            statistics,
102            access_tracker,
103            pressure_monitor,
104            model_dimension: config.model_dimension,
105        })
106    }
107
108    /// Store tensor in memory with key
109    pub fn store(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
110        let start_time = Instant::now();
111
112        // Check memory pressure and evict if necessary
113        self.pressure_monitor.update(self.get_memory_usage());
114        if self.pressure_monitor.is_high_pressure() {
115            self.evict_memory()?;
116        }
117
118        // Try to store in primary cache first
119        let storage_result = match self.strategy {
120            MemoryManagementStrategy::LRU => self.store_lru(key.clone(), tensor.clone()),
121            MemoryManagementStrategy::LFU => self.store_lfu(key.clone(), tensor.clone()),
122            MemoryManagementStrategy::FIFO => self.store_fifo(key.clone(), tensor.clone()),
123            MemoryManagementStrategy::ARC => self.store_arc(key.clone(), tensor.clone()),
124            MemoryManagementStrategy::Compressed => {
125                self.store_compressed(key.clone(), tensor.clone())
126            }
127            MemoryManagementStrategy::Hierarchical => {
128                self.store_hierarchical(key.clone(), tensor.clone())
129            }
130        };
131
132        // Get tensor length before potential move
133        let tensor_len = tensor.len();
134
135        // If primary cache is full, try secondary cache
136        if storage_result.is_err() && self.secondary_cache.is_some() {
137            if let Some(ref mut secondary) = self.secondary_cache {
138                secondary.store(key.clone(), tensor)?;
139            }
140        }
141
142        // Update statistics
143        let storage_time = start_time.elapsed();
144        self.statistics.record_storage(tensor_len, storage_time);
145        self.access_tracker.record_write(key);
146
147        storage_result
148    }
149
150    /// Retrieve tensor from memory
151    pub fn retrieve(&mut self, key: &str) -> Result<Option<Array2<T>>> {
152        let start_time = Instant::now();
153
154        // Try primary cache first
155        let result = self.primary_cache.retrieve(key)?;
156
157        if result.is_some() {
158            self.access_tracker.record_read(key.to_string());
159            let retrieval_time = start_time.elapsed();
160            self.statistics.record_retrieval(retrieval_time, true);
161            return Ok(result);
162        }
163
164        // Try secondary cache
165        if let Some(ref mut secondary) = self.secondary_cache {
166            let result = secondary.retrieve(key)?;
167            if result.is_some() {
168                self.access_tracker.record_read(key.to_string());
169                let retrieval_time = start_time.elapsed();
170                self.statistics.record_retrieval(retrieval_time, true);
171                return Ok(result);
172            }
173        }
174
175        // Check compressed storage
176        if let Some(ref mut compression) = self.compression_manager {
177            if let Some(compressed_data) = compression.retrieve(key)? {
178                let decompressed = compression.decompress(&compressed_data)?;
179                self.access_tracker.record_read(key.to_string());
180                let retrieval_time = start_time.elapsed();
181                self.statistics.record_retrieval(retrieval_time, true);
182                return Ok(Some(decompressed));
183            }
184        }
185
186        let retrieval_time = start_time.elapsed();
187        self.statistics.record_retrieval(retrieval_time, false);
188        Ok(None)
189    }
190
191    /// Remove tensor from memory
192    pub fn remove(&mut self, key: &str) -> Result<bool> {
193        let mut removed = false;
194
195        if self.primary_cache.remove(key)? {
196            removed = true;
197        }
198
199        if let Some(ref mut secondary) = self.secondary_cache {
200            if secondary.remove(key)? {
201                removed = true;
202            }
203        }
204
205        if let Some(ref mut compression) = self.compression_manager {
206            if compression.remove(key)? {
207                removed = true;
208            }
209        }
210
211        self.access_tracker.record_removal(key.to_string());
212        Ok(removed)
213    }
214
215    /// Clear all memory
216    pub fn clear(&mut self) -> Result<()> {
217        self.primary_cache.clear()?;
218
219        if let Some(ref mut secondary) = self.secondary_cache {
220            secondary.clear()?;
221        }
222
223        if let Some(ref mut compression) = self.compression_manager {
224            compression.clear()?;
225        }
226
227        self.statistics.reset();
228        self.access_tracker.clear();
229        self.pressure_monitor.reset();
230
231        Ok(())
232    }
233
234    /// Get memory usage statistics
235    pub fn get_memory_usage(&self) -> usize {
236        let primary_usage = self.primary_cache.get_memory_usage();
237        let secondary_usage = self
238            .secondary_cache
239            .as_ref()
240            .map(|cache| cache.get_memory_usage())
241            .unwrap_or(0);
242        let compression_usage = self
243            .compression_manager
244            .as_ref()
245            .map(|comp| comp.get_memory_usage())
246            .unwrap_or(0);
247
248        primary_usage + secondary_usage + compression_usage
249    }
250
251    /// Optimize memory layout
252    pub fn optimize_memory(&mut self) -> Result<OptimizationReport> {
253        let start_time = Instant::now();
254        let initial_usage = self.get_memory_usage();
255
256        // Analyze access patterns
257        let access_patterns = self.access_tracker.analyze_patterns();
258
259        // Reorganize based on access frequency
260        self.reorganize_by_frequency(&access_patterns)?;
261
262        // Compress frequently accessed but large items
263        if let Some(ref mut compression) = self.compression_manager {
264            compression.optimize_compression_ratios(&access_patterns)?;
265        }
266
267        // Defragment memory
268        self.defragment_memory()?;
269
270        let final_usage = self.get_memory_usage();
271        let optimization_time = start_time.elapsed();
272
273        Ok(OptimizationReport {
274            initial_memory_usage: initial_usage,
275            final_memory_usage: final_usage,
276            memory_saved: initial_usage.saturating_sub(final_usage),
277            optimization_time,
278            operations_performed: access_patterns.total_accesses,
279        })
280    }
281
282    /// Prefetch data based on access patterns
283    pub fn prefetch(&mut self, keys: Vec<String>) -> Result<usize> {
284        let mut prefetched_count = 0;
285
286        for key in keys {
287            if !self.primary_cache.contains(&key) {
288                // Try to move from secondary to primary cache
289                if let Some(ref mut secondary) = self.secondary_cache {
290                    if let Some(tensor) = secondary.retrieve(&key)? {
291                        if self.primary_cache.store(key.clone(), tensor).is_ok() {
292                            secondary.remove(&key)?;
293                            prefetched_count += 1;
294                        }
295                    }
296                }
297
298                // Try to decompress and move to primary cache
299                if let Some(ref mut compression) = self.compression_manager {
300                    if let Some(compressed_data) = compression.retrieve(&key)? {
301                        let decompressed = compression.decompress(&compressed_data)?;
302                        if self.primary_cache.store(key.clone(), decompressed).is_ok() {
303                            compression.remove(&key)?;
304                            prefetched_count += 1;
305                        }
306                    }
307                }
308            }
309        }
310
311        Ok(prefetched_count)
312    }
313
314    /// Storage strategy implementations
315    fn store_lru(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
316        self.primary_cache.store(key, tensor)
317    }
318
319    fn store_lfu(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
320        // For LFU, we need to track access frequency
321        self.primary_cache.store(key, tensor)
322    }
323
324    fn store_fifo(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
325        self.primary_cache.store(key, tensor)
326    }
327
328    fn store_arc(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
329        // Adaptive Replacement Cache - simplified implementation
330        self.primary_cache.store(key, tensor)
331    }
332
333    fn store_compressed(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
334        if let Some(ref mut compression) = self.compression_manager {
335            let compressed_data = compression.compress(&tensor)?;
336            compression.store(key, compressed_data)?;
337            Ok(())
338        } else {
339            self.primary_cache.store(key, tensor)
340        }
341    }
342
343    fn store_hierarchical(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
344        let tensor_size = tensor.len() * std::mem::size_of::<T>();
345
346        if tensor_size < self.config.allocation_block_size {
347            // Small tensors go to primary cache
348            self.primary_cache.store(key, tensor)
349        } else if let Some(ref mut secondary) = self.secondary_cache {
350            // Large tensors go to secondary cache
351            secondary.store(key, tensor)
352        } else {
353            // Fallback to primary cache
354            self.primary_cache.store(key, tensor)
355        }
356    }
357
358    fn evict_memory(&mut self) -> Result<()> {
359        // Evict from primary cache first
360        self.primary_cache.evict_lru()?;
361
362        // If still under pressure, evict from secondary cache
363        if self.pressure_monitor.is_high_pressure() {
364            if let Some(ref mut secondary) = self.secondary_cache {
365                secondary.evict_lru()?;
366            }
367        }
368
369        Ok(())
370    }
371
372    fn reorganize_by_frequency(&mut self, patterns: &AccessPatterns) -> Result<()> {
373        // Move frequently accessed items to primary cache
374        let frequent_keys: Vec<String> = patterns
375            .frequency_map
376            .iter()
377            .filter(|(_, &count)| count as f64 > patterns.average_frequency)
378            .map(|(key, _)| key.clone())
379            .collect();
380
381        self.prefetch(frequent_keys)?;
382        Ok(())
383    }
384
385    fn defragment_memory(&mut self) -> Result<()> {
386        // Simplified defragmentation - rebuild caches
387        let primary_items = self.primary_cache.get_all_items()?;
388        self.primary_cache.clear()?;
389
390        for (key, tensor) in primary_items {
391            self.primary_cache.store(key, tensor)?;
392        }
393
394        Ok(())
395    }
396
397    /// Get memory statistics
398    pub fn get_statistics(&self) -> &MemoryStatistics {
399        &self.statistics
400    }
401
402    /// Get access patterns
403    pub fn get_access_patterns(&self) -> AccessPatterns {
404        self.access_tracker.analyze_patterns()
405    }
406
407    /// Set memory management strategy
408    pub fn set_strategy(&mut self, strategy: MemoryManagementStrategy) {
409        self.strategy = strategy;
410    }
411
412    /// Get current memory pressure
413    pub fn get_memory_pressure(&self) -> f64 {
414        self.pressure_monitor.get_pressure_ratio()
415    }
416}
417
418/// Memory cache implementation
419pub struct MemoryCache<T: Float + Debug + Send + Sync + 'static> {
420    /// Stored tensors
421    storage: HashMap<String, CacheEntry<T>>,
422
423    /// Access order for LRU
424    access_order: VecDeque<String>,
425
426    /// Access frequency for LFU
427    access_frequency: HashMap<String, usize>,
428
429    /// Maximum cache size in bytes
430    max_size: usize,
431
432    /// Current cache size in bytes
433    current_size: usize,
434
435    /// Eviction strategy
436    eviction_strategy: CacheEvictionStrategy,
437}
438
439impl<T: Float + Debug + Send + Sync + 'static> MemoryCache<T> {
440    pub fn new(max_size: usize, eviction_strategy: CacheEvictionStrategy) -> Result<Self> {
441        Ok(Self {
442            storage: HashMap::new(),
443            access_order: VecDeque::new(),
444            access_frequency: HashMap::new(),
445            max_size,
446            current_size: 0,
447            eviction_strategy,
448        })
449    }
450
451    pub fn store(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
452        let tensor_size = tensor.len() * std::mem::size_of::<T>();
453
454        // Check if we need to evict
455        while self.current_size + tensor_size > self.max_size && !self.storage.is_empty() {
456            self.evict_one()?;
457        }
458
459        if tensor_size > self.max_size {
460            return Err(crate::error::OptimError::Other(
461                "Tensor too large for cache".to_string(),
462            ));
463        }
464
465        // Remove existing entry if present
466        if let Some(old_entry) = self.storage.remove(&key) {
467            self.current_size -= old_entry.size;
468            self.remove_from_access_order(&key);
469        }
470
471        // Add new entry
472        let entry = CacheEntry {
473            tensor,
474            size: tensor_size,
475            access_time: Instant::now(),
476            access_count: 1,
477        };
478
479        self.storage.insert(key.clone(), entry);
480        self.current_size += tensor_size;
481        self.update_access_tracking(&key);
482
483        Ok(())
484    }
485
486    pub fn retrieve(&mut self, key: &str) -> Result<Option<Array2<T>>> {
487        let tensor_result = if let Some(entry) = self.storage.get_mut(key) {
488            entry.access_time = Instant::now();
489            entry.access_count += 1;
490            Some(entry.tensor.clone())
491        } else {
492            None
493        };
494
495        // Update access tracking after releasing the mutable borrow
496        if tensor_result.is_some() {
497            self.update_access_tracking(key);
498        }
499
500        Ok(tensor_result)
501    }
502
503    pub fn remove(&mut self, key: &str) -> Result<bool> {
504        if let Some(entry) = self.storage.remove(key) {
505            self.current_size -= entry.size;
506            self.remove_from_access_order(key);
507            self.access_frequency.remove(key);
508            Ok(true)
509        } else {
510            Ok(false)
511        }
512    }
513
514    pub fn contains(&self, key: &str) -> bool {
515        self.storage.contains_key(key)
516    }
517
518    pub fn clear(&mut self) -> Result<()> {
519        self.storage.clear();
520        self.access_order.clear();
521        self.access_frequency.clear();
522        self.current_size = 0;
523        Ok(())
524    }
525
526    pub fn get_memory_usage(&self) -> usize {
527        self.current_size
528    }
529
530    pub fn evict_lru(&mut self) -> Result<()> {
531        if let Some(oldest_key) = self.access_order.front().cloned() {
532            self.remove(&oldest_key)?;
533        }
534        Ok(())
535    }
536
537    fn evict_one(&mut self) -> Result<()> {
538        match self.eviction_strategy {
539            CacheEvictionStrategy::LRU => self.evict_lru(),
540            CacheEvictionStrategy::LFU => self.evict_lfu(),
541            CacheEvictionStrategy::FIFO => self.evict_fifo(),
542            CacheEvictionStrategy::Random => self.evict_random(),
543        }
544    }
545
546    fn evict_lfu(&mut self) -> Result<()> {
547        if let Some((min_freq, lfu_key)) = self
548            .access_frequency
549            .iter()
550            .min_by_key(|(_, &freq)| freq)
551            .map(|(key, &freq)| (freq, key.clone()))
552        {
553            self.remove(&lfu_key)?;
554        }
555        Ok(())
556    }
557
558    fn evict_fifo(&mut self) -> Result<()> {
559        if let Some(first_key) = self.access_order.front().cloned() {
560            self.remove(&first_key)?;
561        }
562        Ok(())
563    }
564
565    fn evict_random(&mut self) -> Result<()> {
566        if let Some(random_key) = self.storage.keys().next().cloned() {
567            self.remove(&random_key)?;
568        }
569        Ok(())
570    }
571
572    fn update_access_tracking(&mut self, key: &str) {
573        // Update LRU order
574        self.remove_from_access_order(key);
575        self.access_order.push_back(key.to_string());
576
577        // Update LFU frequency
578        *self.access_frequency.entry(key.to_string()).or_insert(0) += 1;
579    }
580
581    fn remove_from_access_order(&mut self, key: &str) {
582        self.access_order.retain(|k| k != key);
583    }
584
585    pub fn get_all_items(&self) -> Result<Vec<(String, Array2<T>)>> {
586        let items = self
587            .storage
588            .iter()
589            .map(|(key, entry)| (key.clone(), entry.tensor.clone()))
590            .collect();
591        Ok(items)
592    }
593}
594
595/// Cache entry
596#[derive(Debug, Clone)]
597pub struct CacheEntry<T: Float + Debug + Send + Sync + 'static> {
598    pub tensor: Array2<T>,
599    pub size: usize,
600    pub access_time: Instant,
601    pub access_count: usize,
602}
603
604/// Compression manager
605pub struct CompressionManager<T: Float + Debug + Send + Sync + 'static> {
606    /// Compressed storage
607    compressed_storage: HashMap<String, CompressedData<T>>,
608
609    /// Compression ratio target
610    compression_ratio: f64,
611
612    /// Memory usage
613    memory_usage: usize,
614
615    /// Phantom data for type parameter
616    _phantom: std::marker::PhantomData<T>,
617}
618
619impl<T: Float + Debug + Send + Sync + 'static> CompressionManager<T> {
620    pub fn new(compression_ratio: f64) -> Result<Self> {
621        Ok(Self {
622            compressed_storage: HashMap::new(),
623            compression_ratio,
624            memory_usage: 0,
625            _phantom: std::marker::PhantomData,
626        })
627    }
628
629    pub fn compress(&self, tensor: &Array2<T>) -> Result<CompressedData<T>> {
630        // Simplified compression - just store dimensions and flattened data
631        let shape = tensor.shape().to_vec();
632        let data: Vec<T> = tensor.iter().cloned().collect();
633        let data_len = data.len(); // Get length before move
634
635        Ok(CompressedData::<T> {
636            shape,
637            data,
638            original_size: tensor.len() * std::mem::size_of::<T>(),
639            compressed_size: data_len * std::mem::size_of::<T>() / 2, // Simulated compression
640        })
641    }
642
643    pub fn decompress(&self, compressed: &CompressedData<T>) -> Result<Array2<T>> {
644        let array = Array2::from_shape_vec(
645            (compressed.shape[0], compressed.shape[1]),
646            compressed.data.clone(),
647        )
648        .map_err(|_| crate::error::OptimError::Other("Decompression failed".to_string()))?;
649        Ok(array)
650    }
651
652    pub fn store(&mut self, key: String, compressed: CompressedData<T>) -> Result<()> {
653        self.memory_usage += compressed.compressed_size;
654        self.compressed_storage.insert(key, compressed);
655        Ok(())
656    }
657
658    pub fn retrieve(&self, key: &str) -> Result<Option<CompressedData<T>>> {
659        Ok(self.compressed_storage.get(key).cloned())
660    }
661
662    pub fn remove(&mut self, key: &str) -> Result<bool> {
663        if let Some(compressed) = self.compressed_storage.remove(key) {
664            self.memory_usage -= compressed.compressed_size;
665            Ok(true)
666        } else {
667            Ok(false)
668        }
669    }
670
671    pub fn clear(&mut self) -> Result<()> {
672        self.compressed_storage.clear();
673        self.memory_usage = 0;
674        Ok(())
675    }
676
677    pub fn get_memory_usage(&self) -> usize {
678        self.memory_usage
679    }
680
681    pub fn optimize_compression_ratios(&mut self, _patterns: &AccessPatterns) -> Result<()> {
682        // Optimize compression based on access patterns
683        // This is a placeholder for more sophisticated compression optimization
684        Ok(())
685    }
686}
687
688/// Compressed data structure
689#[derive(Debug, Clone)]
690pub struct CompressedData<T: Float + Debug + Send + Sync + 'static> {
691    pub shape: Vec<usize>,
692    pub data: Vec<T>, // Generic data type
693    pub original_size: usize,
694    pub compressed_size: usize,
695}
696
697/// Memory statistics
698#[derive(Debug, Clone)]
699pub struct MemoryStatistics {
700    /// Total storage operations
701    pub total_stores: usize,
702
703    /// Total retrieval operations
704    pub total_retrievals: usize,
705
706    /// Cache hits
707    pub cache_hits: usize,
708
709    /// Cache misses
710    pub cache_misses: usize,
711
712    /// Total bytes stored
713    pub total_bytes_stored: usize,
714
715    /// Average storage time
716    pub average_storage_time: Duration,
717
718    /// Average retrieval time
719    pub average_retrieval_time: Duration,
720
721    /// Memory pressure events
722    pub pressure_events: usize,
723}
724
725impl Default for MemoryStatistics {
726    fn default() -> Self {
727        Self::new()
728    }
729}
730
731impl MemoryStatistics {
732    pub fn new() -> Self {
733        Self {
734            total_stores: 0,
735            total_retrievals: 0,
736            cache_hits: 0,
737            cache_misses: 0,
738            total_bytes_stored: 0,
739            average_storage_time: Duration::new(0, 0),
740            average_retrieval_time: Duration::new(0, 0),
741            pressure_events: 0,
742        }
743    }
744
745    pub fn record_storage(&mut self, bytes: usize, time: Duration) {
746        self.total_stores += 1;
747        self.total_bytes_stored += bytes;
748        self.average_storage_time = (self.average_storage_time * (self.total_stores - 1) as u32
749            + time)
750            / self.total_stores as u32;
751    }
752
753    pub fn record_retrieval(&mut self, time: Duration, hit: bool) {
754        self.total_retrievals += 1;
755        if hit {
756            self.cache_hits += 1;
757        } else {
758            self.cache_misses += 1;
759        }
760        self.average_retrieval_time =
761            (self.average_retrieval_time * (self.total_retrievals - 1) as u32 + time)
762                / self.total_retrievals as u32;
763    }
764
765    pub fn record_pressure_event(&mut self) {
766        self.pressure_events += 1;
767    }
768
769    pub fn get_hit_ratio(&self) -> f64 {
770        if self.total_retrievals > 0 {
771            self.cache_hits as f64 / self.total_retrievals as f64
772        } else {
773            0.0
774        }
775    }
776
777    pub fn reset(&mut self) {
778        *self = Self::new();
779    }
780}
781
782/// Access tracker
783pub struct AccessTracker {
784    /// Read access log
785    read_log: VecDeque<AccessEvent>,
786
787    /// Write access log
788    write_log: VecDeque<AccessEvent>,
789
790    /// Maximum log size
791    max_log_size: usize,
792}
793
794impl AccessTracker {
795    pub fn new(max_log_size: usize) -> Self {
796        Self {
797            read_log: VecDeque::new(),
798            write_log: VecDeque::new(),
799            max_log_size,
800        }
801    }
802
803    pub fn record_read(&mut self, key: String) {
804        self.read_log.push_back(AccessEvent {
805            key,
806            timestamp: Instant::now(),
807        });
808
809        if self.read_log.len() > self.max_log_size {
810            self.read_log.pop_front();
811        }
812    }
813
814    pub fn record_write(&mut self, key: String) {
815        self.write_log.push_back(AccessEvent {
816            key,
817            timestamp: Instant::now(),
818        });
819
820        if self.write_log.len() > self.max_log_size {
821            self.write_log.pop_front();
822        }
823    }
824
825    pub fn record_removal(&mut self, _key: String) {
826        // Record removal operation
827    }
828
829    pub fn analyze_patterns(&self) -> AccessPatterns {
830        let mut frequency_map = HashMap::new();
831
832        // Count access frequencies
833        for event in self.read_log.iter().chain(self.write_log.iter()) {
834            *frequency_map.entry(event.key.clone()).or_insert(0) += 1;
835        }
836
837        let total_accesses: usize = frequency_map.values().sum();
838        let average_frequency = if frequency_map.is_empty() {
839            0.0
840        } else {
841            total_accesses as f64 / frequency_map.len() as f64
842        };
843
844        AccessPatterns {
845            frequency_map,
846            average_frequency,
847            total_accesses,
848        }
849    }
850
851    pub fn clear(&mut self) {
852        self.read_log.clear();
853        self.write_log.clear();
854    }
855}
856
857/// Access event
858#[derive(Debug, Clone)]
859pub struct AccessEvent {
860    pub key: String,
861    pub timestamp: Instant,
862}
863
864/// Access patterns analysis
865#[derive(Debug, Clone)]
866pub struct AccessPatterns {
867    pub frequency_map: HashMap<String, usize>,
868    pub average_frequency: f64,
869    pub total_accesses: usize,
870}
871
872/// Memory pressure monitor
873pub struct MemoryPressureMonitor {
874    /// Current memory usage
875    current_usage: usize,
876
877    /// Maximum allowed memory
878    max_memory: usize,
879
880    /// Pressure thresholds
881    warning_threshold: f64,
882    critical_threshold: f64,
883
884    /// Pressure history
885    pressure_history: VecDeque<f64>,
886}
887
888impl Default for MemoryPressureMonitor {
889    fn default() -> Self {
890        Self::new()
891    }
892}
893
894impl MemoryPressureMonitor {
895    pub fn new() -> Self {
896        Self {
897            current_usage: 0,
898            max_memory: 1024 * 1024 * 1024, // 1GB default
899            warning_threshold: 0.7,
900            critical_threshold: 0.9,
901            pressure_history: VecDeque::new(),
902        }
903    }
904
905    pub fn update(&mut self, current_usage: usize) {
906        self.current_usage = current_usage;
907        let pressure_ratio = self.get_pressure_ratio();
908
909        self.pressure_history.push_back(pressure_ratio);
910        if self.pressure_history.len() > 100 {
911            self.pressure_history.pop_front();
912        }
913    }
914
915    pub fn get_pressure_ratio(&self) -> f64 {
916        if self.max_memory > 0 {
917            self.current_usage as f64 / self.max_memory as f64
918        } else {
919            0.0
920        }
921    }
922
923    pub fn is_high_pressure(&self) -> bool {
924        self.get_pressure_ratio() > self.critical_threshold
925    }
926
927    pub fn is_warning_pressure(&self) -> bool {
928        self.get_pressure_ratio() > self.warning_threshold
929    }
930
931    pub fn reset(&mut self) {
932        self.current_usage = 0;
933        self.pressure_history.clear();
934    }
935}
936
937/// Optimization report
938#[derive(Debug, Clone)]
939pub struct OptimizationReport {
940    pub initial_memory_usage: usize,
941    pub final_memory_usage: usize,
942    pub memory_saved: usize,
943    pub optimization_time: Duration,
944    pub operations_performed: usize,
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950
951    #[test]
952    fn test_memory_manager_creation() {
953        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
954        let manager = TransformerMemoryManager::new(&config);
955        assert!(manager.is_ok());
956    }
957
958    #[test]
959    fn test_memory_cache() {
960        let cache = MemoryCache::<f32>::new(1024 * 1024, CacheEvictionStrategy::LRU);
961        assert!(cache.is_ok());
962
963        let mut c = cache.expect("unwrap failed");
964        let tensor = Array2::<f32>::ones((10, 10));
965        assert!(c.store("test".to_string(), tensor).is_ok());
966        assert!(c.contains("test"));
967    }
968
969    #[test]
970    fn test_compression_manager() {
971        let compression = CompressionManager::<f32>::new(0.5);
972        assert!(compression.is_ok());
973
974        let comp = compression.expect("unwrap failed");
975        let tensor = Array2::<f32>::ones((5, 5));
976        let compressed = comp.compress(&tensor);
977        assert!(compressed.is_ok());
978
979        let decompressed = comp.decompress(&compressed.expect("unwrap failed"));
980        assert!(decompressed.is_ok());
981    }
982
983    #[test]
984    fn test_access_tracker() {
985        let mut tracker = AccessTracker::new(100);
986
987        tracker.record_read("key1".to_string());
988        tracker.record_write("key2".to_string());
989
990        let patterns = tracker.analyze_patterns();
991        assert!(patterns.total_accesses > 0);
992    }
993
994    #[test]
995    fn test_memory_pressure_monitor() {
996        let mut monitor = MemoryPressureMonitor::new();
997
998        monitor.update(500 * 1024 * 1024); // 500MB
999        assert!(!monitor.is_high_pressure());
1000
1001        monitor.update(950 * 1024 * 1024); // 950MB
1002        assert!(monitor.is_high_pressure());
1003    }
1004}