use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use tokio::time::{Duration as TokioDuration, interval};
use tracing::{debug, error, info, warn};
use std::path::PathBuf;
use std::sync::Arc;
use crate::registry::backend::ModelRecord;
use crate::registry::state::RegistryManager;
use modelexpress_common::config::DurationConfig;
use modelexpress_common::download::get_provider;
use modelexpress_common::models::ModelStatus;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEvictionConfig {
pub enabled: bool,
pub policy: EvictionPolicyType,
pub check_interval: DurationConfig,
}
impl Default for CacheEvictionConfig {
fn default() -> Self {
Self {
enabled: true,
policy: EvictionPolicyType::Lru(LruConfig::default()),
check_interval: DurationConfig::hours(1), }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum EvictionPolicyType {
Lru(LruConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LruConfig {
pub unused_threshold: DurationConfig,
pub max_models: Option<u32>,
pub min_free_space_bytes: Option<u64>,
}
impl Default for LruConfig {
fn default() -> Self {
Self {
unused_threshold: DurationConfig::new(Duration::days(7)), max_models: None,
min_free_space_bytes: None,
}
}
}
#[derive(Debug, Clone)]
pub struct EvictionResult {
pub evicted_count: u32,
pub evicted_models: Vec<String>,
pub bytes_freed: Option<u64>,
pub reason: EvictionReason,
}
#[derive(Debug, Clone)]
pub enum EvictionReason {
TimeThreshold,
CountLimit,
DiskSpace,
Manual,
}
#[async_trait::async_trait]
pub trait EvictionPolicyTrait {
async fn select_for_eviction(
&self,
models: &[ModelRecord],
config: &CacheEvictionConfig,
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct LruEvictionPolicy;
impl LruEvictionPolicy {
fn is_time_expired(model: &ModelRecord, threshold: &DurationConfig) -> bool {
let threshold_duration = threshold.as_chrono_duration();
let cutoff_time = match Utc::now().checked_sub_signed(threshold_duration) {
Some(time) => time,
None => Utc::now(),
};
model.last_used_at < cutoff_time
}
async fn get_disk_space_info() -> Option<(u64, u64)> {
None
}
}
#[async_trait::async_trait]
impl EvictionPolicyTrait for LruEvictionPolicy {
async fn select_for_eviction(
&self,
models: &[ModelRecord],
config: &CacheEvictionConfig,
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
let EvictionPolicyType::Lru(lru_config) = &config.policy;
let mut candidates_for_eviction = Vec::new();
let downloaded_models: Vec<&ModelRecord> = models
.iter()
.filter(|model| model.status == ModelStatus::DOWNLOADED)
.collect();
debug!(
"Evaluating {downloaded_count} downloaded models for eviction",
downloaded_count = downloaded_models.len()
);
for model in &downloaded_models {
if Self::is_time_expired(model, &lru_config.unused_threshold) {
debug!(
"Model '{model_name}' is expired (last used: {last_used_at})",
model_name = model.model_name,
last_used_at = model.last_used_at
);
candidates_for_eviction.push(model.model_name.clone());
}
}
if let Some(max_models) = lru_config.max_models {
let models_to_remove_by_count =
downloaded_models.len().saturating_sub(max_models as usize);
if models_to_remove_by_count > 0 {
debug!(
"Need to remove {models_to_remove_by_count} models due to count limit (have: {downloaded_count}, max: {max_models})",
models_to_remove_by_count = models_to_remove_by_count,
downloaded_count = downloaded_models.len(),
max_models = max_models
);
let mut sorted_models = downloaded_models.clone();
sorted_models.sort_by_key(|model| model.last_used_at);
for model in sorted_models.iter().take(models_to_remove_by_count) {
if !candidates_for_eviction.contains(&model.model_name) {
candidates_for_eviction.push(model.model_name.clone());
}
}
}
}
if let Some(_min_free_space) = lru_config.min_free_space_bytes
&& let Some((_total_space, _free_space)) = Self::get_disk_space_info().await
{
debug!("Disk space checking is not yet implemented");
}
debug!(
"Selected {evicted_count} models for eviction: {candidates:?}",
evicted_count = candidates_for_eviction.len(),
candidates = candidates_for_eviction
);
Ok(candidates_for_eviction)
}
}
pub struct CacheEvictionService {
registry: Arc<RegistryManager>,
config: CacheEvictionConfig,
cache_directory: PathBuf,
}
impl CacheEvictionService {
pub fn new(
registry: Arc<RegistryManager>,
config: CacheEvictionConfig,
cache_directory: PathBuf,
) -> Self {
Self {
registry,
config,
cache_directory,
}
}
pub async fn start(
self,
mut shutdown_receiver: tokio::sync::oneshot::Receiver<()>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if !self.config.enabled {
info!("Cache eviction service is disabled");
return Ok(());
}
info!(
"Starting cache eviction service with policy: {policy:?}, check interval: {interval}s",
policy = self.config.policy,
interval = self.config.check_interval.num_seconds()
);
let mut interval_timer = interval(TokioDuration::from_secs(
self.config.check_interval.num_seconds() as u64,
));
loop {
tokio::select! {
_ = interval_timer.tick() => {
if let Err(e) = self.run_eviction_cycle().await {
error!("Error during cache eviction cycle: {e}", e = e);
}
}
_ = &mut shutdown_receiver => {
info!("Cache eviction service received shutdown signal");
break;
}
}
}
info!("Cache eviction service stopped");
Ok(())
}
async fn run_eviction_cycle(
&self,
) -> Result<EvictionResult, Box<dyn std::error::Error + Send + Sync>> {
debug!("Starting cache eviction cycle");
let models = self.registry.get_models_by_last_used(None).await?;
debug!(
"Found {total_models} total models in database",
total_models = models.len()
);
let models_to_evict = match &self.config.policy {
EvictionPolicyType::Lru(_) => {
let lru_policy = LruEvictionPolicy;
lru_policy
.select_for_eviction(&models, &self.config)
.await?
}
};
let evicted_count = models_to_evict.len() as u32;
if evicted_count == 0 {
debug!("No models selected for eviction");
return Ok(EvictionResult {
evicted_count: 0,
evicted_models: Vec::new(),
bytes_freed: None,
reason: EvictionReason::TimeThreshold,
});
}
info!(
"Evicting {evicted_count} models: {models:?}",
evicted_count = evicted_count,
models = models_to_evict
);
let mut successfully_evicted = Vec::new();
for model_name in &models_to_evict {
match self.evict_model(model_name).await {
Ok(()) => {
successfully_evicted.push(model_name.clone());
info!(
"Successfully evicted model: {model_name}",
model_name = model_name
);
}
Err(e) => {
warn!(
"Failed to evict model '{model_name}': {e}",
model_name = model_name,
e = e
);
}
}
}
let result = EvictionResult {
evicted_count: successfully_evicted.len() as u32,
evicted_models: successfully_evicted,
bytes_freed: None, reason: EvictionReason::TimeThreshold,
};
if result.evicted_count > 0 {
info!(
"Cache eviction cycle completed: {evicted_count} models evicted",
evicted_count = result.evicted_count
);
} else {
debug!("Cache eviction cycle completed: no models evicted");
}
Ok(result)
}
async fn evict_model(
&self,
model_name: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let record = self
.registry
.get_model_record(model_name)
.await?
.ok_or_else(|| format!("model '{model_name}' not found in registry"))?;
let provider = get_provider(record.provider);
provider
.delete_model(model_name, self.cache_directory.clone())
.await
.map_err(|e| format!("failed to delete model files for '{model_name}': {e}"))?;
self.registry.delete_model(model_name).await?;
Ok(())
}
pub async fn manual_evict(
&self,
model_names: &[String],
) -> Result<EvictionResult, Box<dyn std::error::Error + Send + Sync>> {
info!(
"Manual eviction requested for models: {models:?}",
models = model_names
);
let mut successfully_evicted = Vec::new();
for model_name in model_names {
match self.evict_model(model_name).await {
Ok(()) => {
successfully_evicted.push(model_name.clone());
info!(
"Successfully evicted model: {model_name}",
model_name = model_name
);
}
Err(e) => {
warn!(
"Failed to evict model '{model_name}': {e}",
model_name = model_name,
e = e
);
}
}
}
Ok(EvictionResult {
evicted_count: successfully_evicted.len() as u32,
evicted_models: successfully_evicted,
bytes_freed: None,
reason: EvictionReason::Manual,
})
}
pub async fn get_cache_stats(
&self,
) -> Result<CacheStats, Box<dyn std::error::Error + Send + Sync>> {
let models = self.registry.get_models_by_last_used(None).await?;
let (downloading, downloaded, error) = self.registry.get_status_counts().await?;
let _now = Utc::now();
let mut oldest_model: Option<DateTime<Utc>> = None;
let mut newest_model: Option<DateTime<Utc>> = None;
for model in &models {
if model.status == ModelStatus::DOWNLOADED {
if oldest_model.is_none_or(|oldest| model.last_used_at < oldest) {
oldest_model = Some(model.last_used_at);
}
if newest_model.is_none_or(|newest| model.last_used_at > newest) {
newest_model = Some(model.last_used_at);
}
}
}
Ok(CacheStats {
total_models: models.len() as u32,
downloading_models: downloading,
downloaded_models: downloaded,
error_models: error,
oldest_model_last_used: oldest_model,
newest_model_last_used: newest_model,
})
}
}
#[derive(Debug, Clone, Serialize)]
pub struct CacheStats {
pub total_models: u32,
pub downloading_models: u32,
pub downloaded_models: u32,
pub error_models: u32,
pub oldest_model_last_used: Option<DateTime<Utc>>,
pub newest_model_last_used: Option<DateTime<Utc>>,
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use crate::registry::backend::MockRegistryBackend;
use modelexpress_common::models::ModelProvider;
use tempfile::TempDir;
fn service_with_mock(
mock: MockRegistryBackend,
config: CacheEvictionConfig,
) -> (CacheEvictionService, TempDir) {
let registry = Arc::new(RegistryManager::with_backend(Arc::new(mock)));
let cache_dir = TempDir::new().expect("Failed to create cache directory");
let service = CacheEvictionService::new(registry, config, cache_dir.path().to_path_buf());
(service, cache_dir)
}
#[test]
fn test_default_config() {
let config = CacheEvictionConfig::default();
assert!(config.enabled);
assert_eq!(config.check_interval.num_seconds(), 3600);
assert!(matches!(config.policy, EvictionPolicyType::Lru(_)));
}
#[test]
fn test_lru_config_defaults() {
let lru_config = LruConfig::default();
assert_eq!(lru_config.unused_threshold.num_seconds(), 7 * 24 * 3600);
assert!(lru_config.max_models.is_none());
assert!(lru_config.min_free_space_bytes.is_none());
}
#[test]
fn test_duration_config_parsing() {
use modelexpress_common::config::parse_duration_string;
let json = r#"{"enabled": true, "policy": {"type": "lru", "unused_threshold": "7d"}, "check_interval": "2h"}"#;
let config: CacheEvictionConfig =
serde_json::from_str(json).expect("Failed to parse config");
assert_eq!(config.check_interval.num_seconds(), 2 * 3600);
let json = r#"{"enabled": true, "policy": {"type": "lru", "unused_threshold": 604800}, "check_interval": 1800}"#;
let config: CacheEvictionConfig =
serde_json::from_str(json).expect("Failed to parse config");
assert_eq!(config.check_interval.num_seconds(), 1800);
assert_eq!(
parse_duration_string("30m")
.expect("Failed to parse 30m")
.num_seconds(),
30 * 60
);
assert_eq!(
parse_duration_string("45s")
.expect("Failed to parse 45s")
.num_seconds(),
45
);
assert_eq!(
parse_duration_string("1d")
.expect("Failed to parse 1d")
.num_seconds(),
24 * 3600
);
assert_eq!(
parse_duration_string("2h30m")
.expect("Failed to parse 2h30m")
.num_seconds(),
2 * 3600 + 30 * 60
);
}
#[test]
fn test_is_time_expired() {
let now = Utc::now();
let old_model = ModelRecord {
model_name: "old-model".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(10),
last_used_at: now - Duration::days(8),
message: None,
};
let recent_model = ModelRecord {
model_name: "recent-model".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(6),
last_used_at: now - Duration::days(5),
message: None,
};
let threshold = DurationConfig::new(Duration::days(7));
assert!(LruEvictionPolicy::is_time_expired(&old_model, &threshold));
assert!(!LruEvictionPolicy::is_time_expired(
&recent_model,
&threshold
));
}
#[tokio::test]
async fn test_lru_eviction_policy_time_based() {
let now = Utc::now();
let models = vec![
ModelRecord {
model_name: "old-model".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(10),
last_used_at: now - Duration::days(8),
message: None,
},
ModelRecord {
model_name: "recent-model".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(6),
last_used_at: now - Duration::days(5),
message: None,
},
ModelRecord {
model_name: "downloading-model".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADING,
created_at: now - Duration::days(10),
last_used_at: now - Duration::days(8),
message: None,
},
];
let config = CacheEvictionConfig {
enabled: true,
policy: EvictionPolicyType::Lru(LruConfig {
unused_threshold: DurationConfig::new(Duration::days(7)), max_models: None,
min_free_space_bytes: None,
}),
check_interval: DurationConfig::hours(1),
};
let policy = LruEvictionPolicy;
let evicted = policy
.select_for_eviction(&models, &config)
.await
.expect("Failed to select models for eviction");
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0], "old-model");
}
#[tokio::test]
async fn test_lru_eviction_policy_count_based() {
let now = Utc::now();
let models = vec![
ModelRecord {
model_name: "model1".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(3),
last_used_at: now - Duration::days(3),
message: None,
},
ModelRecord {
model_name: "model2".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(2),
last_used_at: now - Duration::days(2),
message: None,
},
ModelRecord {
model_name: "model3".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now - Duration::days(1),
last_used_at: now - Duration::days(1),
message: None,
},
];
let config = CacheEvictionConfig {
enabled: true,
policy: EvictionPolicyType::Lru(LruConfig {
unused_threshold: DurationConfig::new(Duration::days(30)), max_models: Some(2), min_free_space_bytes: None,
}),
check_interval: DurationConfig::hours(1),
};
let policy = LruEvictionPolicy;
let evicted = policy
.select_for_eviction(&models, &config)
.await
.expect("Failed to select models for eviction");
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0], "model1");
}
#[tokio::test]
async fn test_cache_eviction_service_creation() {
let mock = MockRegistryBackend::new();
let (service, _cache_dir) = service_with_mock(mock, CacheEvictionConfig::default());
assert!(service.config.enabled);
}
#[tokio::test]
async fn test_get_cache_stats_uses_registry() {
let now = Utc::now();
let mut mock = MockRegistryBackend::new();
mock.expect_get_models_by_last_used()
.once()
.returning(move |_| {
Ok(vec![ModelRecord {
model_name: "model1".to_string(),
provider: ModelProvider::HuggingFace,
status: ModelStatus::DOWNLOADED,
created_at: now,
last_used_at: now,
message: None,
}])
});
mock.expect_get_status_counts()
.once()
.returning(|| Ok((1, 1, 1)));
let (service, _cache_dir) = service_with_mock(mock, CacheEvictionConfig::default());
let stats = service.get_cache_stats().await.expect("stats");
assert_eq!(stats.total_models, 1);
assert_eq!(stats.downloaded_models, 1);
assert_eq!(stats.downloading_models, 1);
assert_eq!(stats.error_models, 1);
}
}