1use crate::types::Position3D;
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15pub struct MemoryManager {
17 buffer_pools: Arc<RwLock<HashMap<usize, BufferPool<f32>>>>,
19 array2d_pools: Arc<RwLock<HashMap<(usize, usize), Array2Pool>>>,
21 cache_manager: Arc<RwLock<CacheManager>>,
23 memory_stats: Arc<RwLock<MemoryStatistics>>,
25 config: MemoryConfig,
27}
28
29pub struct BufferPool<T> {
31 available: VecDeque<Array1<T>>,
33 max_size: usize,
35 total_allocations: u64,
37 pool_hits: u64,
39}
40
41pub struct Array2Pool {
43 available: VecDeque<Array2<f32>>,
45 dimensions: (usize, usize),
47 max_size: usize,
49 total_allocations: u64,
51 pool_hits: u64,
53}
54
55pub struct CacheManager {
57 hrtf_cache: HashMap<HrtfCacheKey, HrtfCacheEntry>,
59 distance_cache: HashMap<DistanceCacheKey, f32>,
61 room_cache: HashMap<RoomCacheKey, Array1<f32>>,
63 cache_stats: CacheStatistics,
65 max_cache_size: usize,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryConfig {
72 pub max_buffer_pool_size: usize,
74 pub max_cache_size: usize,
76 pub enable_monitoring: bool,
78 pub memory_pressure_threshold: f32,
80 pub cache_policy: CachePolicy,
82 pub buffer_alignment: usize,
84}
85
86#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
88pub enum CachePolicy {
89 LRU,
91 LFU,
93 TTL,
95 SizeBased,
97}
98
99#[derive(Debug, Clone, Hash, PartialEq, Eq)]
101struct HrtfCacheKey {
102 azimuth: i32,
103 elevation: i32,
104 distance: u32,
105}
106
107#[derive(Debug, Clone)]
109struct HrtfCacheEntry {
110 left_hrir: Array1<f32>,
111 right_hrir: Array1<f32>,
112 last_accessed: Instant,
113 access_count: u64,
114}
115
116#[derive(Debug, Clone, Hash, PartialEq, Eq)]
118struct DistanceCacheKey {
119 distance_mm: u32, model_type: u8, }
122
123#[derive(Debug, Clone, Hash, PartialEq, Eq)]
125struct RoomCacheKey {
126 room_hash: u64,
127 source_position_hash: u64,
128 listener_position_hash: u64,
129}
130
131#[derive(Debug, Clone)]
133pub struct MemoryStatistics {
134 pub total_allocated: u64,
136 pub memory_in_use: u64,
138 pub peak_memory_usage: u64,
140 pub buffer_pool_stats: HashMap<usize, BufferPoolStats>,
142 pub cache_hit_rates: HashMap<String, f64>,
144 pub memory_pressure: f32,
146 pub last_updated: Instant,
148}
149
150impl Default for MemoryStatistics {
151 fn default() -> Self {
152 Self {
153 total_allocated: 0,
154 memory_in_use: 0,
155 peak_memory_usage: 0,
156 buffer_pool_stats: HashMap::new(),
157 cache_hit_rates: HashMap::new(),
158 memory_pressure: 0.0,
159 last_updated: Instant::now(),
160 }
161 }
162}
163
164#[derive(Debug, Default, Clone)]
166pub struct BufferPoolStats {
167 pub total_allocations: u64,
169 pub pool_hits: u64,
171 pub current_pool_size: usize,
173 pub hit_rate: f64,
175}
176
177#[derive(Debug, Default)]
179struct CacheStatistics {
180 total_requests: u64,
182 cache_hits: u64,
184 cache_misses: u64,
186 cache_evictions: u64,
188 memory_usage: u64,
190}
191
192impl Default for MemoryConfig {
193 fn default() -> Self {
194 Self {
195 max_buffer_pool_size: 128,
196 max_cache_size: 1024,
197 enable_monitoring: true,
198 memory_pressure_threshold: 0.8,
199 cache_policy: CachePolicy::LRU,
200 buffer_alignment: 32, }
202 }
203}
204
205impl Default for MemoryManager {
206 fn default() -> Self {
207 Self::new(MemoryConfig::default())
208 }
209}
210
211impl MemoryManager {
212 pub fn new(config: MemoryConfig) -> Self {
214 Self {
215 buffer_pools: Arc::new(RwLock::new(HashMap::new())),
216 array2d_pools: Arc::new(RwLock::new(HashMap::new())),
217 cache_manager: Arc::new(RwLock::new(CacheManager::new(&config))),
218 memory_stats: Arc::new(RwLock::new(MemoryStatistics::default())),
219 config,
220 }
221 }
222
223 pub async fn get_buffer(&self, size: usize) -> Array1<f32> {
225 let mut pools = self.buffer_pools.write().await;
226 let pool = pools
227 .entry(size)
228 .or_insert_with(|| BufferPool::new(size, self.config.max_buffer_pool_size));
229
230 if let Some(mut buffer) = pool.available.pop_front() {
231 buffer.fill(0.0);
233 pool.pool_hits += 1;
234 self.update_buffer_stats(size, false).await;
235 buffer
236 } else {
237 pool.total_allocations += 1;
239 self.update_buffer_stats(size, true).await;
240 Array1::zeros(size)
241 }
242 }
243
244 pub async fn return_buffer(&self, buffer: Array1<f32>) {
246 let size = buffer.len();
247 let mut pools = self.buffer_pools.write().await;
248
249 if let Some(pool) = pools.get_mut(&size) {
250 if pool.available.len() < pool.max_size {
251 pool.available.push_back(buffer);
252 }
253 }
255 }
256
257 pub async fn get_array2d(&self, rows: usize, cols: usize) -> Array2<f32> {
259 let dims = (rows, cols);
260 let mut pools = self.array2d_pools.write().await;
261 let pool = pools
262 .entry(dims)
263 .or_insert_with(|| Array2Pool::new(dims, self.config.max_buffer_pool_size));
264
265 if let Some(mut array) = pool.available.pop_front() {
266 array.fill(0.0);
268 pool.pool_hits += 1;
269 array
270 } else {
271 pool.total_allocations += 1;
273 Array2::zeros(dims)
274 }
275 }
276
277 pub async fn return_array2d(&self, array: Array2<f32>) {
279 let dims = array.dim();
280 let mut pools = self.array2d_pools.write().await;
281
282 if let Some(pool) = pools.get_mut(&dims) {
283 if pool.available.len() < pool.max_size {
284 pool.available.push_back(array);
285 }
286 }
287 }
288
289 pub async fn cache_hrtf(
291 &self,
292 key: (i32, i32, f32),
293 left_hrir: Array1<f32>,
294 right_hrir: Array1<f32>,
295 ) {
296 let cache_key = HrtfCacheKey {
297 azimuth: key.0,
298 elevation: key.1,
299 distance: (key.2 * 1000.0) as u32, };
301
302 let entry = HrtfCacheEntry {
303 left_hrir,
304 right_hrir,
305 last_accessed: Instant::now(),
306 access_count: 1,
307 };
308
309 let mut cache_manager = self.cache_manager.write().await;
310 cache_manager.cache_hrtf(cache_key, entry).await;
311 }
312
313 pub async fn get_cached_hrtf(
315 &self,
316 key: (i32, i32, f32),
317 ) -> Option<(Array1<f32>, Array1<f32>)> {
318 let cache_key = HrtfCacheKey {
319 azimuth: key.0,
320 elevation: key.1,
321 distance: (key.2 * 1000.0) as u32,
322 };
323
324 let mut cache_manager = self.cache_manager.write().await;
325 cache_manager.get_hrtf(&cache_key).await
326 }
327
328 pub async fn cache_distance_attenuation(
330 &self,
331 distance: f32,
332 model_type: u8,
333 attenuation: f32,
334 ) {
335 let key = DistanceCacheKey {
336 distance_mm: (distance * 1000.0) as u32,
337 model_type,
338 };
339
340 let mut cache_manager = self.cache_manager.write().await;
341 cache_manager.cache_distance(key, attenuation).await;
342 }
343
344 pub async fn get_cached_distance_attenuation(
346 &self,
347 distance: f32,
348 model_type: u8,
349 ) -> Option<f32> {
350 let key = DistanceCacheKey {
351 distance_mm: (distance * 1000.0) as u32,
352 model_type,
353 };
354
355 let cache_manager = self.cache_manager.read().await;
356 cache_manager.get_distance(&key)
357 }
358
359 pub async fn get_memory_stats(&self) -> MemoryStatistics {
361 let stats = self.memory_stats.read().await;
362 stats.clone()
363 }
364
365 pub async fn check_memory_pressure(&self) -> bool {
367 let stats = self.memory_stats.read().await;
368 if stats.memory_pressure > self.config.memory_pressure_threshold {
369 drop(stats); self.cleanup_memory().await;
371 true
372 } else {
373 false
374 }
375 }
376
377 async fn cleanup_memory(&self) {
379 let mut cache_manager = self.cache_manager.write().await;
381 cache_manager
382 .evict_lru_entries(self.config.max_cache_size / 2)
383 .await;
384
385 self.trim_buffer_pools().await;
387
388 self.update_memory_stats().await;
390 }
391
392 async fn trim_buffer_pools(&self) {
394 let mut pools = self.buffer_pools.write().await;
395 for pool in pools.values_mut() {
396 pool.available.truncate(pool.max_size / 2);
397 }
398
399 let mut array_pools = self.array2d_pools.write().await;
400 for pool in array_pools.values_mut() {
401 pool.available.truncate(pool.max_size / 2);
402 }
403 }
404
405 async fn update_buffer_stats(&self, size: usize, is_new_allocation: bool) {
407 let mut stats = self.memory_stats.write().await;
408
409 if is_new_allocation {
411 stats.total_allocated += (size * std::mem::size_of::<f32>()) as u64;
412 }
413
414 {
416 let pool_stats = stats.buffer_pool_stats.entry(size).or_default();
417 if is_new_allocation {
418 pool_stats.total_allocations += 1;
419 } else {
420 pool_stats.pool_hits += 1;
421 }
422 pool_stats.hit_rate =
423 pool_stats.pool_hits as f64 / pool_stats.total_allocations.max(1) as f64;
424 }
425
426 stats.last_updated = Instant::now();
427 }
428
429 async fn update_memory_stats(&self) {
431 let mut stats = self.memory_stats.write().await;
432
433 let pools = self.buffer_pools.read().await;
435 let mut memory_in_use = 0u64;
436
437 for (size, pool) in pools.iter() {
438 let pool_memory = (pool.available.len() * size * std::mem::size_of::<f32>()) as u64;
439 memory_in_use += pool_memory;
440
441 let pool_stats = stats.buffer_pool_stats.entry(*size).or_default();
442 pool_stats.current_pool_size = pool.available.len();
443 }
444
445 stats.memory_in_use = memory_in_use;
446 if memory_in_use > stats.peak_memory_usage {
447 stats.peak_memory_usage = memory_in_use;
448 }
449
450 stats.memory_pressure = (memory_in_use as f32 / (1024.0 * 1024.0 * 1024.0)).min(1.0); stats.last_updated = Instant::now();
453 }
454}
455
456impl<T> BufferPool<T> {
457 fn new(size: usize, max_size: usize) -> Self {
458 Self {
459 available: VecDeque::with_capacity(max_size),
460 max_size,
461 total_allocations: 0,
462 pool_hits: 0,
463 }
464 }
465}
466
467impl Array2Pool {
468 fn new(dimensions: (usize, usize), max_size: usize) -> Self {
469 Self {
470 available: VecDeque::with_capacity(max_size),
471 dimensions,
472 max_size,
473 total_allocations: 0,
474 pool_hits: 0,
475 }
476 }
477}
478
479impl CacheManager {
480 fn new(config: &MemoryConfig) -> Self {
481 Self {
482 hrtf_cache: HashMap::new(),
483 distance_cache: HashMap::new(),
484 room_cache: HashMap::new(),
485 cache_stats: CacheStatistics::default(),
486 max_cache_size: config.max_cache_size,
487 }
488 }
489
490 async fn cache_hrtf(&mut self, key: HrtfCacheKey, entry: HrtfCacheEntry) {
491 if self.hrtf_cache.len() >= self.max_cache_size {
492 self.evict_lru_hrtf().await;
493 }
494 self.hrtf_cache.insert(key, entry);
495 }
496
497 async fn get_hrtf(&mut self, key: &HrtfCacheKey) -> Option<(Array1<f32>, Array1<f32>)> {
498 if let Some(entry) = self.hrtf_cache.get_mut(key) {
499 entry.last_accessed = Instant::now();
500 entry.access_count += 1;
501 self.cache_stats.cache_hits += 1;
502 Some((entry.left_hrir.clone(), entry.right_hrir.clone()))
503 } else {
504 self.cache_stats.cache_misses += 1;
505 None
506 }
507 }
508
509 async fn cache_distance(&mut self, key: DistanceCacheKey, value: f32) {
510 if self.distance_cache.len() >= self.max_cache_size {
511 if self.distance_cache.len() > self.max_cache_size * 3 / 4 {
513 let keys: Vec<_> = self.distance_cache.keys().cloned().collect();
514 for key in keys.iter().take(self.max_cache_size / 4) {
515 self.distance_cache.remove(key);
516 }
517 }
518 }
519 self.distance_cache.insert(key, value);
520 }
521
522 fn get_distance(&self, key: &DistanceCacheKey) -> Option<f32> {
523 self.distance_cache.get(key).copied()
524 }
525
526 async fn evict_lru_entries(&mut self, count: usize) {
527 let mut entries: Vec<_> = self.hrtf_cache.iter().collect();
529 entries.sort_by_key(|a| a.1.last_accessed);
530
531 let to_remove: Vec<_> = entries
532 .iter()
533 .take(count.min(entries.len()))
534 .map(|(k, _)| (*k).clone())
535 .collect();
536 for key in to_remove {
537 self.hrtf_cache.remove(&key);
538 self.cache_stats.cache_evictions += 1;
539 }
540 }
541
542 async fn evict_lru_hrtf(&mut self) {
543 if let Some((oldest_key, _)) = self
544 .hrtf_cache
545 .iter()
546 .min_by_key(|(_, entry)| entry.last_accessed)
547 {
548 let key_to_remove = oldest_key.clone();
549 self.hrtf_cache.remove(&key_to_remove);
550 self.cache_stats.cache_evictions += 1;
551 }
552 }
553}
554
555pub mod cache_optimization {
557 use super::*;
558
559 #[derive(Debug)]
561 pub struct SoAPositions {
562 pub x: Vec<f32>,
564 pub y: Vec<f32>,
566 pub z: Vec<f32>,
568 pub capacity: usize,
570 }
571
572 impl SoAPositions {
573 pub fn with_capacity(capacity: usize) -> Self {
575 Self {
576 x: Vec::with_capacity(capacity),
577 y: Vec::with_capacity(capacity),
578 z: Vec::with_capacity(capacity),
579 capacity,
580 }
581 }
582
583 pub fn push(&mut self, pos: Position3D) {
585 self.x.push(pos.x);
586 self.y.push(pos.y);
587 self.z.push(pos.z);
588 }
589
590 pub fn get(&self, index: usize) -> Option<Position3D> {
592 if index < self.len() {
593 Some(Position3D::new(self.x[index], self.y[index], self.z[index]))
594 } else {
595 None
596 }
597 }
598
599 pub fn len(&self) -> usize {
601 self.x.len()
602 }
603
604 pub fn is_empty(&self) -> bool {
606 self.len() == 0
607 }
608
609 pub fn clear(&mut self) {
611 self.x.clear();
612 self.y.clear();
613 self.z.clear();
614 }
615 }
616
617 #[cfg(target_arch = "x86_64")]
619 #[allow(unsafe_code)]
620 pub fn prefetch_data<T>(data: *const T) {
621 #[cfg(target_feature = "sse")]
622 unsafe {
623 std::arch::x86_64::_mm_prefetch(data as *const i8, std::arch::x86_64::_MM_HINT_T0);
624 }
625 }
626
627 #[cfg(not(target_arch = "x86_64"))]
628 pub fn prefetch_data<T>(_data: *const T) {
630 }
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637
638 #[tokio::test]
639 async fn test_memory_manager_creation() {
640 let config = MemoryConfig::default();
641 let manager = MemoryManager::new(config);
642
643 let stats = manager.get_memory_stats().await;
644 assert_eq!(stats.total_allocated, 0);
645 }
646
647 #[tokio::test]
648 async fn test_buffer_pool_reuse() {
649 let config = MemoryConfig::default();
650 let manager = MemoryManager::new(config);
651
652 let buffer = manager.get_buffer(1024).await;
654 assert_eq!(buffer.len(), 1024);
655 manager.return_buffer(buffer).await;
656
657 let buffer2 = manager.get_buffer(1024).await;
659 assert_eq!(buffer2.len(), 1024);
660
661 let stats = manager.get_memory_stats().await;
662 assert!(stats.buffer_pool_stats.contains_key(&1024));
663 }
664
665 #[tokio::test]
666 async fn test_hrtf_cache() {
667 let config = MemoryConfig::default();
668 let manager = MemoryManager::new(config);
669
670 let left = Array1::zeros(256);
671 let right = Array1::zeros(256);
672
673 manager
675 .cache_hrtf((45, 0, 2.0), left.clone(), right.clone())
676 .await;
677
678 let cached = manager.get_cached_hrtf((45, 0, 2.0)).await;
680 assert!(cached.is_some());
681
682 let (cached_left, cached_right) = cached.expect("Cached HRTF should be available");
683 assert_eq!(cached_left.len(), 256);
684 assert_eq!(cached_right.len(), 256);
685 }
686
687 #[tokio::test]
688 async fn test_distance_cache() {
689 let config = MemoryConfig::default();
690 let manager = MemoryManager::new(config);
691
692 manager.cache_distance_attenuation(5.0, 1, 0.2).await;
694
695 let cached = manager.get_cached_distance_attenuation(5.0, 1).await;
697 assert_eq!(cached, Some(0.2));
698
699 let not_cached = manager.get_cached_distance_attenuation(10.0, 1).await;
701 assert_eq!(not_cached, None);
702 }
703
704 #[tokio::test]
705 async fn test_memory_pressure() {
706 let mut config = MemoryConfig::default();
707 config.memory_pressure_threshold = 0.1; let manager = MemoryManager::new(config);
709
710 let mut buffers = Vec::new();
712 for _ in 0..100 {
713 buffers.push(manager.get_buffer(1024).await);
714 }
715
716 manager.update_memory_stats().await;
718
719 let pressure_detected = manager.check_memory_pressure().await;
721 }
723
724 #[tokio::test]
725 async fn test_soa_positions() {
726 let mut positions = cache_optimization::SoAPositions::with_capacity(10);
727
728 positions.push(Position3D::new(1.0, 2.0, 3.0));
729 positions.push(Position3D::new(4.0, 5.0, 6.0));
730
731 assert_eq!(positions.len(), 2);
732
733 let pos = positions.get(0).expect("First position should exist");
734 assert_eq!(pos.x, 1.0);
735 assert_eq!(pos.y, 2.0);
736 assert_eq!(pos.z, 3.0);
737 }
738}