use crate::types::Position3D;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub struct MemoryManager {
buffer_pools: Arc<RwLock<HashMap<usize, BufferPool<f32>>>>,
array2d_pools: Arc<RwLock<HashMap<(usize, usize), Array2Pool>>>,
cache_manager: Arc<RwLock<CacheManager>>,
memory_stats: Arc<RwLock<MemoryStatistics>>,
config: MemoryConfig,
}
pub struct BufferPool<T> {
available: VecDeque<Array1<T>>,
max_size: usize,
total_allocations: u64,
pool_hits: u64,
}
pub struct Array2Pool {
available: VecDeque<Array2<f32>>,
dimensions: (usize, usize),
max_size: usize,
total_allocations: u64,
pool_hits: u64,
}
pub struct CacheManager {
hrtf_cache: HashMap<HrtfCacheKey, HrtfCacheEntry>,
distance_cache: HashMap<DistanceCacheKey, f32>,
room_cache: HashMap<RoomCacheKey, Array1<f32>>,
cache_stats: CacheStatistics,
max_cache_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub max_buffer_pool_size: usize,
pub max_cache_size: usize,
pub enable_monitoring: bool,
pub memory_pressure_threshold: f32,
pub cache_policy: CachePolicy,
pub buffer_alignment: usize,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum CachePolicy {
LRU,
LFU,
TTL,
SizeBased,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct HrtfCacheKey {
azimuth: i32,
elevation: i32,
distance: u32,
}
#[derive(Debug, Clone)]
struct HrtfCacheEntry {
left_hrir: Array1<f32>,
right_hrir: Array1<f32>,
last_accessed: Instant,
access_count: u64,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct DistanceCacheKey {
distance_mm: u32, model_type: u8, }
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct RoomCacheKey {
room_hash: u64,
source_position_hash: u64,
listener_position_hash: u64,
}
#[derive(Debug, Clone)]
pub struct MemoryStatistics {
pub total_allocated: u64,
pub memory_in_use: u64,
pub peak_memory_usage: u64,
pub buffer_pool_stats: HashMap<usize, BufferPoolStats>,
pub cache_hit_rates: HashMap<String, f64>,
pub memory_pressure: f32,
pub last_updated: Instant,
}
impl Default for MemoryStatistics {
fn default() -> Self {
Self {
total_allocated: 0,
memory_in_use: 0,
peak_memory_usage: 0,
buffer_pool_stats: HashMap::new(),
cache_hit_rates: HashMap::new(),
memory_pressure: 0.0,
last_updated: Instant::now(),
}
}
}
#[derive(Debug, Default, Clone)]
pub struct BufferPoolStats {
pub total_allocations: u64,
pub pool_hits: u64,
pub current_pool_size: usize,
pub hit_rate: f64,
}
#[derive(Debug, Default)]
struct CacheStatistics {
total_requests: u64,
cache_hits: u64,
cache_misses: u64,
cache_evictions: u64,
memory_usage: u64,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
max_buffer_pool_size: 128,
max_cache_size: 1024,
enable_monitoring: true,
memory_pressure_threshold: 0.8,
cache_policy: CachePolicy::LRU,
buffer_alignment: 32, }
}
}
impl Default for MemoryManager {
fn default() -> Self {
Self::new(MemoryConfig::default())
}
}
impl MemoryManager {
pub fn new(config: MemoryConfig) -> Self {
Self {
buffer_pools: Arc::new(RwLock::new(HashMap::new())),
array2d_pools: Arc::new(RwLock::new(HashMap::new())),
cache_manager: Arc::new(RwLock::new(CacheManager::new(&config))),
memory_stats: Arc::new(RwLock::new(MemoryStatistics::default())),
config,
}
}
pub async fn get_buffer(&self, size: usize) -> Array1<f32> {
let mut pools = self.buffer_pools.write().await;
let pool = pools
.entry(size)
.or_insert_with(|| BufferPool::new(size, self.config.max_buffer_pool_size));
if let Some(mut buffer) = pool.available.pop_front() {
buffer.fill(0.0);
pool.pool_hits += 1;
self.update_buffer_stats(size, false).await;
buffer
} else {
pool.total_allocations += 1;
self.update_buffer_stats(size, true).await;
Array1::zeros(size)
}
}
pub async fn return_buffer(&self, buffer: Array1<f32>) {
let size = buffer.len();
let mut pools = self.buffer_pools.write().await;
if let Some(pool) = pools.get_mut(&size) {
if pool.available.len() < pool.max_size {
pool.available.push_back(buffer);
}
}
}
pub async fn get_array2d(&self, rows: usize, cols: usize) -> Array2<f32> {
let dims = (rows, cols);
let mut pools = self.array2d_pools.write().await;
let pool = pools
.entry(dims)
.or_insert_with(|| Array2Pool::new(dims, self.config.max_buffer_pool_size));
if let Some(mut array) = pool.available.pop_front() {
array.fill(0.0);
pool.pool_hits += 1;
array
} else {
pool.total_allocations += 1;
Array2::zeros(dims)
}
}
pub async fn return_array2d(&self, array: Array2<f32>) {
let dims = array.dim();
let mut pools = self.array2d_pools.write().await;
if let Some(pool) = pools.get_mut(&dims) {
if pool.available.len() < pool.max_size {
pool.available.push_back(array);
}
}
}
pub async fn cache_hrtf(
&self,
key: (i32, i32, f32),
left_hrir: Array1<f32>,
right_hrir: Array1<f32>,
) {
let cache_key = HrtfCacheKey {
azimuth: key.0,
elevation: key.1,
distance: (key.2 * 1000.0) as u32, };
let entry = HrtfCacheEntry {
left_hrir,
right_hrir,
last_accessed: Instant::now(),
access_count: 1,
};
let mut cache_manager = self.cache_manager.write().await;
cache_manager.cache_hrtf(cache_key, entry).await;
}
pub async fn get_cached_hrtf(
&self,
key: (i32, i32, f32),
) -> Option<(Array1<f32>, Array1<f32>)> {
let cache_key = HrtfCacheKey {
azimuth: key.0,
elevation: key.1,
distance: (key.2 * 1000.0) as u32,
};
let mut cache_manager = self.cache_manager.write().await;
cache_manager.get_hrtf(&cache_key).await
}
pub async fn cache_distance_attenuation(
&self,
distance: f32,
model_type: u8,
attenuation: f32,
) {
let key = DistanceCacheKey {
distance_mm: (distance * 1000.0) as u32,
model_type,
};
let mut cache_manager = self.cache_manager.write().await;
cache_manager.cache_distance(key, attenuation).await;
}
pub async fn get_cached_distance_attenuation(
&self,
distance: f32,
model_type: u8,
) -> Option<f32> {
let key = DistanceCacheKey {
distance_mm: (distance * 1000.0) as u32,
model_type,
};
let cache_manager = self.cache_manager.read().await;
cache_manager.get_distance(&key)
}
pub async fn get_memory_stats(&self) -> MemoryStatistics {
let stats = self.memory_stats.read().await;
stats.clone()
}
pub async fn check_memory_pressure(&self) -> bool {
let stats = self.memory_stats.read().await;
if stats.memory_pressure > self.config.memory_pressure_threshold {
drop(stats); self.cleanup_memory().await;
true
} else {
false
}
}
async fn cleanup_memory(&self) {
let mut cache_manager = self.cache_manager.write().await;
cache_manager
.evict_lru_entries(self.config.max_cache_size / 2)
.await;
self.trim_buffer_pools().await;
self.update_memory_stats().await;
}
async fn trim_buffer_pools(&self) {
let mut pools = self.buffer_pools.write().await;
for pool in pools.values_mut() {
pool.available.truncate(pool.max_size / 2);
}
let mut array_pools = self.array2d_pools.write().await;
for pool in array_pools.values_mut() {
pool.available.truncate(pool.max_size / 2);
}
}
async fn update_buffer_stats(&self, size: usize, is_new_allocation: bool) {
let mut stats = self.memory_stats.write().await;
if is_new_allocation {
stats.total_allocated += (size * std::mem::size_of::<f32>()) as u64;
}
{
let pool_stats = stats.buffer_pool_stats.entry(size).or_default();
if is_new_allocation {
pool_stats.total_allocations += 1;
} else {
pool_stats.pool_hits += 1;
}
pool_stats.hit_rate =
pool_stats.pool_hits as f64 / pool_stats.total_allocations.max(1) as f64;
}
stats.last_updated = Instant::now();
}
async fn update_memory_stats(&self) {
let mut stats = self.memory_stats.write().await;
let pools = self.buffer_pools.read().await;
let mut memory_in_use = 0u64;
for (size, pool) in pools.iter() {
let pool_memory = (pool.available.len() * size * std::mem::size_of::<f32>()) as u64;
memory_in_use += pool_memory;
let pool_stats = stats.buffer_pool_stats.entry(*size).or_default();
pool_stats.current_pool_size = pool.available.len();
}
stats.memory_in_use = memory_in_use;
if memory_in_use > stats.peak_memory_usage {
stats.peak_memory_usage = memory_in_use;
}
stats.memory_pressure = (memory_in_use as f32 / (1024.0 * 1024.0 * 1024.0)).min(1.0); stats.last_updated = Instant::now();
}
}
impl<T> BufferPool<T> {
fn new(size: usize, max_size: usize) -> Self {
Self {
available: VecDeque::with_capacity(max_size),
max_size,
total_allocations: 0,
pool_hits: 0,
}
}
}
impl Array2Pool {
fn new(dimensions: (usize, usize), max_size: usize) -> Self {
Self {
available: VecDeque::with_capacity(max_size),
dimensions,
max_size,
total_allocations: 0,
pool_hits: 0,
}
}
}
impl CacheManager {
fn new(config: &MemoryConfig) -> Self {
Self {
hrtf_cache: HashMap::new(),
distance_cache: HashMap::new(),
room_cache: HashMap::new(),
cache_stats: CacheStatistics::default(),
max_cache_size: config.max_cache_size,
}
}
async fn cache_hrtf(&mut self, key: HrtfCacheKey, entry: HrtfCacheEntry) {
if self.hrtf_cache.len() >= self.max_cache_size {
self.evict_lru_hrtf().await;
}
self.hrtf_cache.insert(key, entry);
}
async fn get_hrtf(&mut self, key: &HrtfCacheKey) -> Option<(Array1<f32>, Array1<f32>)> {
if let Some(entry) = self.hrtf_cache.get_mut(key) {
entry.last_accessed = Instant::now();
entry.access_count += 1;
self.cache_stats.cache_hits += 1;
Some((entry.left_hrir.clone(), entry.right_hrir.clone()))
} else {
self.cache_stats.cache_misses += 1;
None
}
}
async fn cache_distance(&mut self, key: DistanceCacheKey, value: f32) {
if self.distance_cache.len() >= self.max_cache_size {
if self.distance_cache.len() > self.max_cache_size * 3 / 4 {
let keys: Vec<_> = self.distance_cache.keys().cloned().collect();
for key in keys.iter().take(self.max_cache_size / 4) {
self.distance_cache.remove(key);
}
}
}
self.distance_cache.insert(key, value);
}
fn get_distance(&self, key: &DistanceCacheKey) -> Option<f32> {
self.distance_cache.get(key).copied()
}
async fn evict_lru_entries(&mut self, count: usize) {
let mut entries: Vec<_> = self.hrtf_cache.iter().collect();
entries.sort_by_key(|a| a.1.last_accessed);
let to_remove: Vec<_> = entries
.iter()
.take(count.min(entries.len()))
.map(|(k, _)| (*k).clone())
.collect();
for key in to_remove {
self.hrtf_cache.remove(&key);
self.cache_stats.cache_evictions += 1;
}
}
async fn evict_lru_hrtf(&mut self) {
if let Some((oldest_key, _)) = self
.hrtf_cache
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
{
let key_to_remove = oldest_key.clone();
self.hrtf_cache.remove(&key_to_remove);
self.cache_stats.cache_evictions += 1;
}
}
}
pub mod cache_optimization {
use super::*;
#[derive(Debug)]
pub struct SoAPositions {
pub x: Vec<f32>,
pub y: Vec<f32>,
pub z: Vec<f32>,
pub capacity: usize,
}
impl SoAPositions {
pub fn with_capacity(capacity: usize) -> Self {
Self {
x: Vec::with_capacity(capacity),
y: Vec::with_capacity(capacity),
z: Vec::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, pos: Position3D) {
self.x.push(pos.x);
self.y.push(pos.y);
self.z.push(pos.z);
}
pub fn get(&self, index: usize) -> Option<Position3D> {
if index < self.len() {
Some(Position3D::new(self.x[index], self.y[index], self.z[index]))
} else {
None
}
}
pub fn len(&self) -> usize {
self.x.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.x.clear();
self.y.clear();
self.z.clear();
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code)]
pub fn prefetch_data<T>(data: *const T) {
#[cfg(target_feature = "sse")]
unsafe {
std::arch::x86_64::_mm_prefetch(data as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn prefetch_data<T>(_data: *const T) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_manager_creation() {
let config = MemoryConfig::default();
let manager = MemoryManager::new(config);
let stats = manager.get_memory_stats().await;
assert_eq!(stats.total_allocated, 0);
}
#[tokio::test]
async fn test_buffer_pool_reuse() {
let config = MemoryConfig::default();
let manager = MemoryManager::new(config);
let buffer = manager.get_buffer(1024).await;
assert_eq!(buffer.len(), 1024);
manager.return_buffer(buffer).await;
let buffer2 = manager.get_buffer(1024).await;
assert_eq!(buffer2.len(), 1024);
let stats = manager.get_memory_stats().await;
assert!(stats.buffer_pool_stats.contains_key(&1024));
}
#[tokio::test]
async fn test_hrtf_cache() {
let config = MemoryConfig::default();
let manager = MemoryManager::new(config);
let left = Array1::zeros(256);
let right = Array1::zeros(256);
manager
.cache_hrtf((45, 0, 2.0), left.clone(), right.clone())
.await;
let cached = manager.get_cached_hrtf((45, 0, 2.0)).await;
assert!(cached.is_some());
let (cached_left, cached_right) = cached.expect("Cached HRTF should be available");
assert_eq!(cached_left.len(), 256);
assert_eq!(cached_right.len(), 256);
}
#[tokio::test]
async fn test_distance_cache() {
let config = MemoryConfig::default();
let manager = MemoryManager::new(config);
manager.cache_distance_attenuation(5.0, 1, 0.2).await;
let cached = manager.get_cached_distance_attenuation(5.0, 1).await;
assert_eq!(cached, Some(0.2));
let not_cached = manager.get_cached_distance_attenuation(10.0, 1).await;
assert_eq!(not_cached, None);
}
#[tokio::test]
async fn test_memory_pressure() {
let mut config = MemoryConfig::default();
config.memory_pressure_threshold = 0.1; let manager = MemoryManager::new(config);
let mut buffers = Vec::new();
for _ in 0..100 {
buffers.push(manager.get_buffer(1024).await);
}
manager.update_memory_stats().await;
let pressure_detected = manager.check_memory_pressure().await;
}
#[tokio::test]
async fn test_soa_positions() {
let mut positions = cache_optimization::SoAPositions::with_capacity(10);
positions.push(Position3D::new(1.0, 2.0, 3.0));
positions.push(Position3D::new(4.0, 5.0, 6.0));
assert_eq!(positions.len(), 2);
let pos = positions.get(0).expect("First position should exist");
assert_eq!(pos.x, 1.0);
assert_eq!(pos.y, 2.0);
assert_eq!(pos.z, 3.0);
}
}