1use super::config::{CacheEvictionStrategy, MemoryConfig, TransformerBasedOptimizerConfig};
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{BTreeMap, HashMap, VecDeque};
8use std::fmt::Debug;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum MemoryManagementStrategy {
15 FIFO,
17 LRU,
19 LFU,
21 ARC,
23 Compressed,
25 Hierarchical,
27}
28
29pub struct TransformerMemoryManager<T: Float + Debug + Send + Sync + 'static> {
31 strategy: MemoryManagementStrategy,
33
34 config: MemoryConfig,
36
37 primary_cache: MemoryCache<T>,
39
40 secondary_cache: Option<MemoryCache<T>>,
42
43 compression_manager: Option<CompressionManager<T>>,
45
46 statistics: MemoryStatistics,
48
49 access_tracker: AccessTracker,
51
52 pressure_monitor: MemoryPressureMonitor,
54
55 model_dimension: usize,
57}
58
59impl<T: Float + Debug + Send + Sync + 'static> TransformerMemoryManager<T> {
60 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
62 let memory_config = config.memory_config.clone();
63 let strategy = match memory_config.eviction_strategy {
64 CacheEvictionStrategy::LRU => MemoryManagementStrategy::LRU,
65 CacheEvictionStrategy::LFU => MemoryManagementStrategy::LFU,
66 CacheEvictionStrategy::FIFO => MemoryManagementStrategy::FIFO,
67 CacheEvictionStrategy::Random => MemoryManagementStrategy::LRU, };
69
70 let primary_cache = MemoryCache::new(
71 memory_config.max_cache_size / 2,
72 memory_config.eviction_strategy,
73 )?;
74
75 let secondary_cache = if memory_config.max_cache_size > 1024 * 1024 * 100 {
76 Some(MemoryCache::new(
78 memory_config.max_cache_size / 2,
79 CacheEvictionStrategy::FIFO,
80 )?)
81 } else {
82 None
83 };
84
85 let compression_manager = if memory_config.enable_compression {
86 Some(CompressionManager::new(0.5)?) } else {
88 None
89 };
90
91 let statistics = MemoryStatistics::new();
92 let access_tracker = AccessTracker::new(1000);
93 let pressure_monitor = MemoryPressureMonitor::new();
94
95 Ok(Self {
96 strategy,
97 config: memory_config,
98 primary_cache,
99 secondary_cache,
100 compression_manager,
101 statistics,
102 access_tracker,
103 pressure_monitor,
104 model_dimension: config.model_dimension,
105 })
106 }
107
108 pub fn store(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
110 let start_time = Instant::now();
111
112 self.pressure_monitor.update(self.get_memory_usage());
114 if self.pressure_monitor.is_high_pressure() {
115 self.evict_memory()?;
116 }
117
118 let storage_result = match self.strategy {
120 MemoryManagementStrategy::LRU => self.store_lru(key.clone(), tensor.clone()),
121 MemoryManagementStrategy::LFU => self.store_lfu(key.clone(), tensor.clone()),
122 MemoryManagementStrategy::FIFO => self.store_fifo(key.clone(), tensor.clone()),
123 MemoryManagementStrategy::ARC => self.store_arc(key.clone(), tensor.clone()),
124 MemoryManagementStrategy::Compressed => {
125 self.store_compressed(key.clone(), tensor.clone())
126 }
127 MemoryManagementStrategy::Hierarchical => {
128 self.store_hierarchical(key.clone(), tensor.clone())
129 }
130 };
131
132 let tensor_len = tensor.len();
134
135 if storage_result.is_err() && self.secondary_cache.is_some() {
137 if let Some(ref mut secondary) = self.secondary_cache {
138 secondary.store(key.clone(), tensor)?;
139 }
140 }
141
142 let storage_time = start_time.elapsed();
144 self.statistics.record_storage(tensor_len, storage_time);
145 self.access_tracker.record_write(key);
146
147 storage_result
148 }
149
150 pub fn retrieve(&mut self, key: &str) -> Result<Option<Array2<T>>> {
152 let start_time = Instant::now();
153
154 let result = self.primary_cache.retrieve(key)?;
156
157 if result.is_some() {
158 self.access_tracker.record_read(key.to_string());
159 let retrieval_time = start_time.elapsed();
160 self.statistics.record_retrieval(retrieval_time, true);
161 return Ok(result);
162 }
163
164 if let Some(ref mut secondary) = self.secondary_cache {
166 let result = secondary.retrieve(key)?;
167 if result.is_some() {
168 self.access_tracker.record_read(key.to_string());
169 let retrieval_time = start_time.elapsed();
170 self.statistics.record_retrieval(retrieval_time, true);
171 return Ok(result);
172 }
173 }
174
175 if let Some(ref mut compression) = self.compression_manager {
177 if let Some(compressed_data) = compression.retrieve(key)? {
178 let decompressed = compression.decompress(&compressed_data)?;
179 self.access_tracker.record_read(key.to_string());
180 let retrieval_time = start_time.elapsed();
181 self.statistics.record_retrieval(retrieval_time, true);
182 return Ok(Some(decompressed));
183 }
184 }
185
186 let retrieval_time = start_time.elapsed();
187 self.statistics.record_retrieval(retrieval_time, false);
188 Ok(None)
189 }
190
191 pub fn remove(&mut self, key: &str) -> Result<bool> {
193 let mut removed = false;
194
195 if self.primary_cache.remove(key)? {
196 removed = true;
197 }
198
199 if let Some(ref mut secondary) = self.secondary_cache {
200 if secondary.remove(key)? {
201 removed = true;
202 }
203 }
204
205 if let Some(ref mut compression) = self.compression_manager {
206 if compression.remove(key)? {
207 removed = true;
208 }
209 }
210
211 self.access_tracker.record_removal(key.to_string());
212 Ok(removed)
213 }
214
215 pub fn clear(&mut self) -> Result<()> {
217 self.primary_cache.clear()?;
218
219 if let Some(ref mut secondary) = self.secondary_cache {
220 secondary.clear()?;
221 }
222
223 if let Some(ref mut compression) = self.compression_manager {
224 compression.clear()?;
225 }
226
227 self.statistics.reset();
228 self.access_tracker.clear();
229 self.pressure_monitor.reset();
230
231 Ok(())
232 }
233
234 pub fn get_memory_usage(&self) -> usize {
236 let primary_usage = self.primary_cache.get_memory_usage();
237 let secondary_usage = self
238 .secondary_cache
239 .as_ref()
240 .map(|cache| cache.get_memory_usage())
241 .unwrap_or(0);
242 let compression_usage = self
243 .compression_manager
244 .as_ref()
245 .map(|comp| comp.get_memory_usage())
246 .unwrap_or(0);
247
248 primary_usage + secondary_usage + compression_usage
249 }
250
251 pub fn optimize_memory(&mut self) -> Result<OptimizationReport> {
253 let start_time = Instant::now();
254 let initial_usage = self.get_memory_usage();
255
256 let access_patterns = self.access_tracker.analyze_patterns();
258
259 self.reorganize_by_frequency(&access_patterns)?;
261
262 if let Some(ref mut compression) = self.compression_manager {
264 compression.optimize_compression_ratios(&access_patterns)?;
265 }
266
267 self.defragment_memory()?;
269
270 let final_usage = self.get_memory_usage();
271 let optimization_time = start_time.elapsed();
272
273 Ok(OptimizationReport {
274 initial_memory_usage: initial_usage,
275 final_memory_usage: final_usage,
276 memory_saved: initial_usage.saturating_sub(final_usage),
277 optimization_time,
278 operations_performed: access_patterns.total_accesses,
279 })
280 }
281
282 pub fn prefetch(&mut self, keys: Vec<String>) -> Result<usize> {
284 let mut prefetched_count = 0;
285
286 for key in keys {
287 if !self.primary_cache.contains(&key) {
288 if let Some(ref mut secondary) = self.secondary_cache {
290 if let Some(tensor) = secondary.retrieve(&key)? {
291 if self.primary_cache.store(key.clone(), tensor).is_ok() {
292 secondary.remove(&key)?;
293 prefetched_count += 1;
294 }
295 }
296 }
297
298 if let Some(ref mut compression) = self.compression_manager {
300 if let Some(compressed_data) = compression.retrieve(&key)? {
301 let decompressed = compression.decompress(&compressed_data)?;
302 if self.primary_cache.store(key.clone(), decompressed).is_ok() {
303 compression.remove(&key)?;
304 prefetched_count += 1;
305 }
306 }
307 }
308 }
309 }
310
311 Ok(prefetched_count)
312 }
313
314 fn store_lru(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
316 self.primary_cache.store(key, tensor)
317 }
318
319 fn store_lfu(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
320 self.primary_cache.store(key, tensor)
322 }
323
324 fn store_fifo(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
325 self.primary_cache.store(key, tensor)
326 }
327
328 fn store_arc(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
329 self.primary_cache.store(key, tensor)
331 }
332
333 fn store_compressed(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
334 if let Some(ref mut compression) = self.compression_manager {
335 let compressed_data = compression.compress(&tensor)?;
336 compression.store(key, compressed_data)?;
337 Ok(())
338 } else {
339 self.primary_cache.store(key, tensor)
340 }
341 }
342
343 fn store_hierarchical(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
344 let tensor_size = tensor.len() * std::mem::size_of::<T>();
345
346 if tensor_size < self.config.allocation_block_size {
347 self.primary_cache.store(key, tensor)
349 } else if let Some(ref mut secondary) = self.secondary_cache {
350 secondary.store(key, tensor)
352 } else {
353 self.primary_cache.store(key, tensor)
355 }
356 }
357
358 fn evict_memory(&mut self) -> Result<()> {
359 self.primary_cache.evict_lru()?;
361
362 if self.pressure_monitor.is_high_pressure() {
364 if let Some(ref mut secondary) = self.secondary_cache {
365 secondary.evict_lru()?;
366 }
367 }
368
369 Ok(())
370 }
371
372 fn reorganize_by_frequency(&mut self, patterns: &AccessPatterns) -> Result<()> {
373 let frequent_keys: Vec<String> = patterns
375 .frequency_map
376 .iter()
377 .filter(|(_, &count)| count as f64 > patterns.average_frequency)
378 .map(|(key, _)| key.clone())
379 .collect();
380
381 self.prefetch(frequent_keys)?;
382 Ok(())
383 }
384
385 fn defragment_memory(&mut self) -> Result<()> {
386 let primary_items = self.primary_cache.get_all_items()?;
388 self.primary_cache.clear()?;
389
390 for (key, tensor) in primary_items {
391 self.primary_cache.store(key, tensor)?;
392 }
393
394 Ok(())
395 }
396
397 pub fn get_statistics(&self) -> &MemoryStatistics {
399 &self.statistics
400 }
401
402 pub fn get_access_patterns(&self) -> AccessPatterns {
404 self.access_tracker.analyze_patterns()
405 }
406
407 pub fn set_strategy(&mut self, strategy: MemoryManagementStrategy) {
409 self.strategy = strategy;
410 }
411
412 pub fn get_memory_pressure(&self) -> f64 {
414 self.pressure_monitor.get_pressure_ratio()
415 }
416}
417
418pub struct MemoryCache<T: Float + Debug + Send + Sync + 'static> {
420 storage: HashMap<String, CacheEntry<T>>,
422
423 access_order: VecDeque<String>,
425
426 access_frequency: HashMap<String, usize>,
428
429 max_size: usize,
431
432 current_size: usize,
434
435 eviction_strategy: CacheEvictionStrategy,
437}
438
439impl<T: Float + Debug + Send + Sync + 'static> MemoryCache<T> {
440 pub fn new(max_size: usize, eviction_strategy: CacheEvictionStrategy) -> Result<Self> {
441 Ok(Self {
442 storage: HashMap::new(),
443 access_order: VecDeque::new(),
444 access_frequency: HashMap::new(),
445 max_size,
446 current_size: 0,
447 eviction_strategy,
448 })
449 }
450
451 pub fn store(&mut self, key: String, tensor: Array2<T>) -> Result<()> {
452 let tensor_size = tensor.len() * std::mem::size_of::<T>();
453
454 while self.current_size + tensor_size > self.max_size && !self.storage.is_empty() {
456 self.evict_one()?;
457 }
458
459 if tensor_size > self.max_size {
460 return Err(crate::error::OptimError::Other(
461 "Tensor too large for cache".to_string(),
462 ));
463 }
464
465 if let Some(old_entry) = self.storage.remove(&key) {
467 self.current_size -= old_entry.size;
468 self.remove_from_access_order(&key);
469 }
470
471 let entry = CacheEntry {
473 tensor,
474 size: tensor_size,
475 access_time: Instant::now(),
476 access_count: 1,
477 };
478
479 self.storage.insert(key.clone(), entry);
480 self.current_size += tensor_size;
481 self.update_access_tracking(&key);
482
483 Ok(())
484 }
485
486 pub fn retrieve(&mut self, key: &str) -> Result<Option<Array2<T>>> {
487 let tensor_result = if let Some(entry) = self.storage.get_mut(key) {
488 entry.access_time = Instant::now();
489 entry.access_count += 1;
490 Some(entry.tensor.clone())
491 } else {
492 None
493 };
494
495 if tensor_result.is_some() {
497 self.update_access_tracking(key);
498 }
499
500 Ok(tensor_result)
501 }
502
503 pub fn remove(&mut self, key: &str) -> Result<bool> {
504 if let Some(entry) = self.storage.remove(key) {
505 self.current_size -= entry.size;
506 self.remove_from_access_order(key);
507 self.access_frequency.remove(key);
508 Ok(true)
509 } else {
510 Ok(false)
511 }
512 }
513
514 pub fn contains(&self, key: &str) -> bool {
515 self.storage.contains_key(key)
516 }
517
518 pub fn clear(&mut self) -> Result<()> {
519 self.storage.clear();
520 self.access_order.clear();
521 self.access_frequency.clear();
522 self.current_size = 0;
523 Ok(())
524 }
525
526 pub fn get_memory_usage(&self) -> usize {
527 self.current_size
528 }
529
530 pub fn evict_lru(&mut self) -> Result<()> {
531 if let Some(oldest_key) = self.access_order.front().cloned() {
532 self.remove(&oldest_key)?;
533 }
534 Ok(())
535 }
536
537 fn evict_one(&mut self) -> Result<()> {
538 match self.eviction_strategy {
539 CacheEvictionStrategy::LRU => self.evict_lru(),
540 CacheEvictionStrategy::LFU => self.evict_lfu(),
541 CacheEvictionStrategy::FIFO => self.evict_fifo(),
542 CacheEvictionStrategy::Random => self.evict_random(),
543 }
544 }
545
546 fn evict_lfu(&mut self) -> Result<()> {
547 if let Some((min_freq, lfu_key)) = self
548 .access_frequency
549 .iter()
550 .min_by_key(|(_, &freq)| freq)
551 .map(|(key, &freq)| (freq, key.clone()))
552 {
553 self.remove(&lfu_key)?;
554 }
555 Ok(())
556 }
557
558 fn evict_fifo(&mut self) -> Result<()> {
559 if let Some(first_key) = self.access_order.front().cloned() {
560 self.remove(&first_key)?;
561 }
562 Ok(())
563 }
564
565 fn evict_random(&mut self) -> Result<()> {
566 if let Some(random_key) = self.storage.keys().next().cloned() {
567 self.remove(&random_key)?;
568 }
569 Ok(())
570 }
571
572 fn update_access_tracking(&mut self, key: &str) {
573 self.remove_from_access_order(key);
575 self.access_order.push_back(key.to_string());
576
577 *self.access_frequency.entry(key.to_string()).or_insert(0) += 1;
579 }
580
581 fn remove_from_access_order(&mut self, key: &str) {
582 self.access_order.retain(|k| k != key);
583 }
584
585 pub fn get_all_items(&self) -> Result<Vec<(String, Array2<T>)>> {
586 let items = self
587 .storage
588 .iter()
589 .map(|(key, entry)| (key.clone(), entry.tensor.clone()))
590 .collect();
591 Ok(items)
592 }
593}
594
595#[derive(Debug, Clone)]
597pub struct CacheEntry<T: Float + Debug + Send + Sync + 'static> {
598 pub tensor: Array2<T>,
599 pub size: usize,
600 pub access_time: Instant,
601 pub access_count: usize,
602}
603
604pub struct CompressionManager<T: Float + Debug + Send + Sync + 'static> {
606 compressed_storage: HashMap<String, CompressedData<T>>,
608
609 compression_ratio: f64,
611
612 memory_usage: usize,
614
615 _phantom: std::marker::PhantomData<T>,
617}
618
619impl<T: Float + Debug + Send + Sync + 'static> CompressionManager<T> {
620 pub fn new(compression_ratio: f64) -> Result<Self> {
621 Ok(Self {
622 compressed_storage: HashMap::new(),
623 compression_ratio,
624 memory_usage: 0,
625 _phantom: std::marker::PhantomData,
626 })
627 }
628
629 pub fn compress(&self, tensor: &Array2<T>) -> Result<CompressedData<T>> {
630 let shape = tensor.shape().to_vec();
632 let data: Vec<T> = tensor.iter().cloned().collect();
633 let data_len = data.len(); Ok(CompressedData::<T> {
636 shape,
637 data,
638 original_size: tensor.len() * std::mem::size_of::<T>(),
639 compressed_size: data_len * std::mem::size_of::<T>() / 2, })
641 }
642
643 pub fn decompress(&self, compressed: &CompressedData<T>) -> Result<Array2<T>> {
644 let array = Array2::from_shape_vec(
645 (compressed.shape[0], compressed.shape[1]),
646 compressed.data.clone(),
647 )
648 .map_err(|_| crate::error::OptimError::Other("Decompression failed".to_string()))?;
649 Ok(array)
650 }
651
652 pub fn store(&mut self, key: String, compressed: CompressedData<T>) -> Result<()> {
653 self.memory_usage += compressed.compressed_size;
654 self.compressed_storage.insert(key, compressed);
655 Ok(())
656 }
657
658 pub fn retrieve(&self, key: &str) -> Result<Option<CompressedData<T>>> {
659 Ok(self.compressed_storage.get(key).cloned())
660 }
661
662 pub fn remove(&mut self, key: &str) -> Result<bool> {
663 if let Some(compressed) = self.compressed_storage.remove(key) {
664 self.memory_usage -= compressed.compressed_size;
665 Ok(true)
666 } else {
667 Ok(false)
668 }
669 }
670
671 pub fn clear(&mut self) -> Result<()> {
672 self.compressed_storage.clear();
673 self.memory_usage = 0;
674 Ok(())
675 }
676
677 pub fn get_memory_usage(&self) -> usize {
678 self.memory_usage
679 }
680
681 pub fn optimize_compression_ratios(&mut self, _patterns: &AccessPatterns) -> Result<()> {
682 Ok(())
685 }
686}
687
688#[derive(Debug, Clone)]
690pub struct CompressedData<T: Float + Debug + Send + Sync + 'static> {
691 pub shape: Vec<usize>,
692 pub data: Vec<T>, pub original_size: usize,
694 pub compressed_size: usize,
695}
696
697#[derive(Debug, Clone)]
699pub struct MemoryStatistics {
700 pub total_stores: usize,
702
703 pub total_retrievals: usize,
705
706 pub cache_hits: usize,
708
709 pub cache_misses: usize,
711
712 pub total_bytes_stored: usize,
714
715 pub average_storage_time: Duration,
717
718 pub average_retrieval_time: Duration,
720
721 pub pressure_events: usize,
723}
724
725impl Default for MemoryStatistics {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731impl MemoryStatistics {
732 pub fn new() -> Self {
733 Self {
734 total_stores: 0,
735 total_retrievals: 0,
736 cache_hits: 0,
737 cache_misses: 0,
738 total_bytes_stored: 0,
739 average_storage_time: Duration::new(0, 0),
740 average_retrieval_time: Duration::new(0, 0),
741 pressure_events: 0,
742 }
743 }
744
745 pub fn record_storage(&mut self, bytes: usize, time: Duration) {
746 self.total_stores += 1;
747 self.total_bytes_stored += bytes;
748 self.average_storage_time = (self.average_storage_time * (self.total_stores - 1) as u32
749 + time)
750 / self.total_stores as u32;
751 }
752
753 pub fn record_retrieval(&mut self, time: Duration, hit: bool) {
754 self.total_retrievals += 1;
755 if hit {
756 self.cache_hits += 1;
757 } else {
758 self.cache_misses += 1;
759 }
760 self.average_retrieval_time =
761 (self.average_retrieval_time * (self.total_retrievals - 1) as u32 + time)
762 / self.total_retrievals as u32;
763 }
764
765 pub fn record_pressure_event(&mut self) {
766 self.pressure_events += 1;
767 }
768
769 pub fn get_hit_ratio(&self) -> f64 {
770 if self.total_retrievals > 0 {
771 self.cache_hits as f64 / self.total_retrievals as f64
772 } else {
773 0.0
774 }
775 }
776
777 pub fn reset(&mut self) {
778 *self = Self::new();
779 }
780}
781
782pub struct AccessTracker {
784 read_log: VecDeque<AccessEvent>,
786
787 write_log: VecDeque<AccessEvent>,
789
790 max_log_size: usize,
792}
793
794impl AccessTracker {
795 pub fn new(max_log_size: usize) -> Self {
796 Self {
797 read_log: VecDeque::new(),
798 write_log: VecDeque::new(),
799 max_log_size,
800 }
801 }
802
803 pub fn record_read(&mut self, key: String) {
804 self.read_log.push_back(AccessEvent {
805 key,
806 timestamp: Instant::now(),
807 });
808
809 if self.read_log.len() > self.max_log_size {
810 self.read_log.pop_front();
811 }
812 }
813
814 pub fn record_write(&mut self, key: String) {
815 self.write_log.push_back(AccessEvent {
816 key,
817 timestamp: Instant::now(),
818 });
819
820 if self.write_log.len() > self.max_log_size {
821 self.write_log.pop_front();
822 }
823 }
824
825 pub fn record_removal(&mut self, _key: String) {
826 }
828
829 pub fn analyze_patterns(&self) -> AccessPatterns {
830 let mut frequency_map = HashMap::new();
831
832 for event in self.read_log.iter().chain(self.write_log.iter()) {
834 *frequency_map.entry(event.key.clone()).or_insert(0) += 1;
835 }
836
837 let total_accesses: usize = frequency_map.values().sum();
838 let average_frequency = if frequency_map.is_empty() {
839 0.0
840 } else {
841 total_accesses as f64 / frequency_map.len() as f64
842 };
843
844 AccessPatterns {
845 frequency_map,
846 average_frequency,
847 total_accesses,
848 }
849 }
850
851 pub fn clear(&mut self) {
852 self.read_log.clear();
853 self.write_log.clear();
854 }
855}
856
857#[derive(Debug, Clone)]
859pub struct AccessEvent {
860 pub key: String,
861 pub timestamp: Instant,
862}
863
864#[derive(Debug, Clone)]
866pub struct AccessPatterns {
867 pub frequency_map: HashMap<String, usize>,
868 pub average_frequency: f64,
869 pub total_accesses: usize,
870}
871
872pub struct MemoryPressureMonitor {
874 current_usage: usize,
876
877 max_memory: usize,
879
880 warning_threshold: f64,
882 critical_threshold: f64,
883
884 pressure_history: VecDeque<f64>,
886}
887
888impl Default for MemoryPressureMonitor {
889 fn default() -> Self {
890 Self::new()
891 }
892}
893
894impl MemoryPressureMonitor {
895 pub fn new() -> Self {
896 Self {
897 current_usage: 0,
898 max_memory: 1024 * 1024 * 1024, warning_threshold: 0.7,
900 critical_threshold: 0.9,
901 pressure_history: VecDeque::new(),
902 }
903 }
904
905 pub fn update(&mut self, current_usage: usize) {
906 self.current_usage = current_usage;
907 let pressure_ratio = self.get_pressure_ratio();
908
909 self.pressure_history.push_back(pressure_ratio);
910 if self.pressure_history.len() > 100 {
911 self.pressure_history.pop_front();
912 }
913 }
914
915 pub fn get_pressure_ratio(&self) -> f64 {
916 if self.max_memory > 0 {
917 self.current_usage as f64 / self.max_memory as f64
918 } else {
919 0.0
920 }
921 }
922
923 pub fn is_high_pressure(&self) -> bool {
924 self.get_pressure_ratio() > self.critical_threshold
925 }
926
927 pub fn is_warning_pressure(&self) -> bool {
928 self.get_pressure_ratio() > self.warning_threshold
929 }
930
931 pub fn reset(&mut self) {
932 self.current_usage = 0;
933 self.pressure_history.clear();
934 }
935}
936
937#[derive(Debug, Clone)]
939pub struct OptimizationReport {
940 pub initial_memory_usage: usize,
941 pub final_memory_usage: usize,
942 pub memory_saved: usize,
943 pub optimization_time: Duration,
944 pub operations_performed: usize,
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950
951 #[test]
952 fn test_memory_manager_creation() {
953 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
954 let manager = TransformerMemoryManager::new(&config);
955 assert!(manager.is_ok());
956 }
957
958 #[test]
959 fn test_memory_cache() {
960 let cache = MemoryCache::<f32>::new(1024 * 1024, CacheEvictionStrategy::LRU);
961 assert!(cache.is_ok());
962
963 let mut c = cache.unwrap();
964 let tensor = Array2::<f32>::ones((10, 10));
965 assert!(c.store("test".to_string(), tensor).is_ok());
966 assert!(c.contains("test"));
967 }
968
969 #[test]
970 fn test_compression_manager() {
971 let compression = CompressionManager::<f32>::new(0.5);
972 assert!(compression.is_ok());
973
974 let comp = compression.unwrap();
975 let tensor = Array2::<f32>::ones((5, 5));
976 let compressed = comp.compress(&tensor);
977 assert!(compressed.is_ok());
978
979 let decompressed = comp.decompress(&compressed.unwrap());
980 assert!(decompressed.is_ok());
981 }
982
983 #[test]
984 fn test_access_tracker() {
985 let mut tracker = AccessTracker::new(100);
986
987 tracker.record_read("key1".to_string());
988 tracker.record_write("key2".to_string());
989
990 let patterns = tracker.analyze_patterns();
991 assert!(patterns.total_accesses > 0);
992 }
993
994 #[test]
995 fn test_memory_pressure_monitor() {
996 let mut monitor = MemoryPressureMonitor::new();
997
998 monitor.update(500 * 1024 * 1024); assert!(!monitor.is_high_pressure());
1000
1001 monitor.update(950 * 1024 * 1024); assert!(monitor.is_high_pressure());
1003 }
1004}