use crate::search::cache::SearchCache;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, warn};
const DEFAULT_MEMORY_LIMIT_MB: usize = 500;
const ESTIMATED_BYTES_PER_SEARCH_ENTRY: usize = 8 * 1024;
#[allow(dead_code)]
const ESTIMATED_BYTES_PER_EMBEDDING_ENTRY: usize = 8 * 1024;
pub struct MemoryMonitor {
caches: Arc<RwLock<HashMap<String, Arc<SearchCache>>>>,
memory_limit_mb: usize,
}
impl MemoryMonitor {
pub fn new() -> Self {
Self::with_limit(DEFAULT_MEMORY_LIMIT_MB)
}
pub fn with_limit(memory_limit_mb: usize) -> Self {
Self {
caches: Arc::new(RwLock::new(HashMap::new())),
memory_limit_mb,
}
}
pub fn register_cache(&self, name: &str, cache: Arc<SearchCache>) {
let mut caches = self.caches.write().unwrap();
caches.insert(name.to_string(), cache);
debug!("Registered cache '{}' for memory monitoring", name);
}
pub fn unregister_cache(&self, name: &str) -> bool {
let mut caches = self.caches.write().unwrap();
caches.remove(name).is_some()
}
pub fn stats(&self) -> MemoryStats {
let caches = self.caches.read().unwrap();
let mut cache_stats = HashMap::new();
let mut total_bytes = 0;
for (name, cache) in caches.iter() {
let stats = cache.stats();
let estimated_bytes = stats.size * ESTIMATED_BYTES_PER_SEARCH_ENTRY;
cache_stats.insert(
name.clone(),
CacheMemoryStats {
size: stats.size,
capacity: stats.capacity,
estimated_mb: estimated_bytes / (1024 * 1024),
},
);
total_bytes += estimated_bytes;
}
let total_mb = total_bytes / (1024 * 1024);
let utilization = if self.memory_limit_mb > 0 {
(total_mb as f64 / self.memory_limit_mb as f64) * 100.0
} else {
0.0
};
MemoryStats {
total_mb,
limit_mb: self.memory_limit_mb,
utilization_percent: utilization,
cache_stats,
}
}
pub fn is_within_limit(&self) -> bool {
let stats = self.stats();
stats.total_mb <= self.memory_limit_mb
}
pub fn is_approaching_limit(&self) -> bool {
let stats = self.stats();
stats.utilization_percent > 80.0
}
pub fn check_and_log(&self) {
let stats = self.stats();
if stats.total_mb > self.memory_limit_mb {
warn!(
"Cache memory usage EXCEEDS limit: {}MB / {}MB ({:.1}%)",
stats.total_mb, stats.limit_mb, stats.utilization_percent
);
} else if stats.utilization_percent > 80.0 {
warn!(
"Cache memory usage approaching limit: {}MB / {}MB ({:.1}%)",
stats.total_mb, stats.limit_mb, stats.utilization_percent
);
} else {
debug!(
"Cache memory usage: {}MB / {}MB ({:.1}%)",
stats.total_mb, stats.limit_mb, stats.utilization_percent
);
}
}
pub fn emergency_clear_if_needed(&self) -> usize {
if !self.is_within_limit() {
warn!("Emergency cache clear triggered due to memory limit");
let caches = self.caches.read().unwrap();
let count = caches.len();
for (name, cache) in caches.iter() {
cache.clear();
debug!("Cleared cache '{}'", name);
}
count
} else {
0
}
}
}
impl Default for MemoryMonitor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_mb: usize,
pub limit_mb: usize,
pub utilization_percent: f64,
pub cache_stats: HashMap<String, CacheMemoryStats>,
}
impl MemoryStats {
pub fn is_safe(&self) -> bool {
self.utilization_percent < 80.0
}
pub fn is_critical(&self) -> bool {
self.total_mb > self.limit_mb
}
}
#[derive(Debug, Clone)]
pub struct CacheMemoryStats {
pub size: usize,
pub capacity: usize,
pub estimated_mb: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_monitor_creation() {
let monitor = MemoryMonitor::new();
let stats = monitor.stats();
assert_eq!(stats.total_mb, 0);
assert_eq!(stats.limit_mb, DEFAULT_MEMORY_LIMIT_MB);
assert_eq!(stats.utilization_percent, 0.0);
}
#[test]
fn test_memory_monitor_custom_limit() {
let monitor = MemoryMonitor::with_limit(1000);
let stats = monitor.stats();
assert_eq!(stats.limit_mb, 1000);
}
#[test]
fn test_memory_stats_safety_checks() {
let stats = MemoryStats {
total_mb: 300,
limit_mb: 500,
utilization_percent: 60.0,
cache_stats: HashMap::new(),
};
assert!(stats.is_safe());
assert!(!stats.is_critical());
}
#[test]
fn test_memory_stats_critical() {
let stats = MemoryStats {
total_mb: 600,
limit_mb: 500,
utilization_percent: 120.0,
cache_stats: HashMap::new(),
};
assert!(!stats.is_safe());
assert!(stats.is_critical());
}
#[test]
fn test_cache_registration() {
let monitor = MemoryMonitor::new();
let cache = Arc::new(SearchCache::new(100));
monitor.register_cache("test_cache", cache);
let stats = monitor.stats();
assert!(stats.cache_stats.contains_key("test_cache"));
}
#[test]
fn test_cache_unregistration() {
let monitor = MemoryMonitor::new();
let cache = Arc::new(SearchCache::new(100));
monitor.register_cache("test_cache", cache);
assert!(monitor.unregister_cache("test_cache"));
assert!(!monitor.unregister_cache("test_cache"));
let stats = monitor.stats();
assert!(!stats.cache_stats.contains_key("test_cache"));
}
}