use crate::error::{ModelError, Result};
use crate::local::{LocalModel, LocalModelConfig};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use tokio::sync::RwLock;
static MODEL_CACHE: OnceLock<ModelCache> = OnceLock::new();
struct CachedModel {
model: Arc<RwLock<LocalModel>>,
last_accessed: Instant,
loaded_at: Instant,
access_count: usize,
}
pub struct ModelCache {
cache: Mutex<HashMap<PathBuf, CachedModel>>,
max_cached_models: usize,
max_idle_duration: Duration,
enabled: Mutex<bool>,
}
impl ModelCache {
pub fn new() -> Self {
Self {
cache: Mutex::new(HashMap::new()),
max_cached_models: 3, max_idle_duration: Duration::from_secs(3600), enabled: Mutex::new(true),
}
}
pub fn with_config(max_cached_models: usize, max_idle_duration: Duration) -> Self {
Self {
cache: Mutex::new(HashMap::new()),
max_cached_models,
max_idle_duration,
enabled: Mutex::new(true),
}
}
pub fn set_enabled(&self, enabled: bool) {
*self.enabled.lock().unwrap() = enabled;
info!("Model caching {}", if enabled { "enabled" } else { "disabled" });
}
pub fn is_enabled(&self) -> bool {
*self.enabled.lock().unwrap()
}
pub async fn get_or_load(&self, config: LocalModelConfig) -> Result<Arc<RwLock<LocalModel>>> {
if !self.is_enabled() {
debug!("Caching disabled, loading model directly");
let model = LocalModel::load(config).await?;
return Ok(Arc::new(RwLock::new(model)));
}
let model_path = config.model_path.clone();
{
let mut cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get_mut(&model_path) {
cached.last_accessed = Instant::now();
cached.access_count += 1;
debug!(
"Cache hit for model '{}' (access #{})",
model_path.display(),
cached.access_count
);
return Ok(cached.model.clone());
}
}
debug!("Cache miss for model '{}', loading...", model_path.display());
let model = LocalModel::load(config).await?;
let model_arc = Arc::new(RwLock::new(model));
{
let mut cache = self.cache.lock().unwrap();
let cached = CachedModel {
model: model_arc.clone(),
last_accessed: Instant::now(),
loaded_at: Instant::now(),
access_count: 1,
};
self.evict_if_needed(&mut cache);
cache.insert(model_path, cached);
info!(
"Loaded and cached model ({} models in cache)",
cache.len()
);
}
Ok(model_arc)
}
pub fn get_cached(&self, model_path: &PathBuf) -> Option<Arc<RwLock<LocalModel>>> {
if !self.is_enabled() {
return None;
}
let mut cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get_mut(model_path) {
cached.last_accessed = Instant::now();
cached.access_count += 1;
debug!(
"Cache hit for model '{}' (access #{})",
model_path.display(),
cached.access_count
);
return Some(cached.model.clone());
}
None
}
pub async fn preload(&self, config: LocalModelConfig) -> Result<Arc<RwLock<LocalModel>>> {
info!("Preloading model '{}'", config.model_path.display());
self.get_or_load(config).await
}
pub fn evict(&self, model_path: &PathBuf) {
let mut cache = self.cache.lock().unwrap();
if cache.remove(model_path).is_some() {
info!("Evicted model '{}' from cache", model_path.display());
}
}
pub fn clear(&self) {
let mut cache = self.cache.lock().unwrap();
let count = cache.len();
cache.clear();
info!("Cleared all {} cached model(s)", count);
}
pub fn stats(&self) -> CacheStats {
let cache = self.cache.lock().unwrap();
let now = Instant::now();
let models: Vec<_> = cache.iter().map(|(path, cached)| {
CacheModelInfo {
path: path.clone(),
access_count: cached.access_count,
last_accessed: now.duration_since(cached.last_accessed),
loaded_at: now.duration_since(cached.loaded_at),
}
}).collect();
CacheStats {
cached_models: cache.len(),
max_cached_models: self.max_cached_models,
enabled: self.is_enabled(),
models,
}
}
pub fn cleanup_idle(&self) {
let mut cache = self.cache.lock().unwrap();
let now = Instant::now();
let idle_models: Vec<PathBuf> = cache
.iter()
.filter(|(_, cached)| {
now.duration_since(cached.last_accessed) > self.max_idle_duration
})
.map(|(path, _)| path.clone())
.collect();
for path in idle_models {
cache.remove(&path);
info!("Removed idle model '{}' from cache", path.display());
}
}
fn evict_if_needed(&self, cache: &mut HashMap<PathBuf, CachedModel>) {
if cache.len() >= self.max_cached_models {
if let Some((lru_path, _)) = cache
.iter()
.min_by_key(|(_, cached)| cached.last_accessed)
{
let path = lru_path.clone();
cache.remove(&path);
warn!(
"Evicted LRU model '{}' from cache (capacity: {})",
path.display(),
self.max_cached_models
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub cached_models: usize,
pub max_cached_models: usize,
pub enabled: bool,
pub models: Vec<CacheModelInfo>,
}
#[derive(Debug, Clone)]
pub struct CacheModelInfo {
pub path: PathBuf,
pub access_count: usize,
pub last_accessed: Duration,
pub loaded_at: Duration,
}
pub fn global_model_cache() -> &'static ModelCache {
MODEL_CACHE.get_or_init(|| ModelCache::new())
}
pub async fn get_or_load_model(config: LocalModelConfig) -> Result<Arc<RwLock<LocalModel>>> {
global_model_cache().get_or_load(config).await
}
pub fn get_cached_model(model_path: &PathBuf) -> Option<Arc<RwLock<LocalModel>>> {
global_model_cache().get_cached(model_path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let cache = ModelCache::new();
assert!(cache.is_enabled());
assert_eq!(cache.max_cached_models, 3);
}
#[test]
fn test_cache_configuration() {
let cache = ModelCache::with_config(5, Duration::from_secs(7200));
assert_eq!(cache.max_cached_models, 5);
assert_eq!(cache.max_idle_duration.as_secs(), 7200);
}
#[test]
fn test_enable_disable() {
let cache = ModelCache::new();
assert!(cache.is_enabled());
cache.set_enabled(false);
assert!(!cache.is_enabled());
cache.set_enabled(true);
assert!(cache.is_enabled());
}
#[test]
fn test_evict() {
let cache = ModelCache::new();
let path = PathBuf::from("/test/model");
cache.evict(&path);
let stats = cache.stats();
assert_eq!(stats.cached_models, 0);
}
#[test]
fn test_clear() {
let cache = ModelCache::new();
cache.clear();
let stats = cache.stats();
assert_eq!(stats.cached_models, 0);
assert!(stats.enabled);
}
#[test]
fn test_stats() {
let cache = ModelCache::new();
let stats = cache.stats();
assert_eq!(stats.cached_models, 0);
assert_eq!(stats.max_cached_models, 3);
assert!(stats.enabled);
assert_eq!(stats.models.len(), 0);
}
#[test]
fn test_cleanup_idle() {
let cache = ModelCache::with_config(10, Duration::from_secs(1));
cache.cleanup_idle();
let stats = cache.stats();
assert_eq!(stats.cached_models, 0);
}
}