use crate::error::{RusTorchError, RusTorchResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_size_bytes: u64,
pub max_models: usize,
pub expiration_days: u64,
pub auto_cleanup: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size_bytes: 10 * 1024 * 1024 * 1024, max_models: 50,
expiration_days: 30,
auto_cleanup: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub model_name: String,
pub file_path: PathBuf,
pub file_size: u64,
pub downloaded_at: chrono::DateTime<chrono::Utc>,
pub last_accessed: chrono::DateTime<chrono::Utc>,
pub checksum: Option<String>,
}
pub struct ModelCache {
cache_dir: PathBuf,
config: CacheConfig,
entries: HashMap<String, CacheEntry>,
metadata_file: PathBuf,
}
impl ModelCache {
pub fn new<P: Into<PathBuf>>(cache_dir: P) -> RusTorchResult<Self> {
let cache_dir = cache_dir.into();
let metadata_file = cache_dir.join("cache_metadata.json");
std::fs::create_dir_all(&cache_dir)?;
let config = CacheConfig::default();
let mut cache = Self {
cache_dir,
config,
entries: HashMap::new(),
metadata_file,
};
cache.load_metadata()?;
if cache.config.auto_cleanup {
cache.cleanup_expired()?;
}
Ok(cache)
}
pub fn with_config<P: Into<PathBuf>>(
cache_dir: P,
config: CacheConfig,
) -> RusTorchResult<Self> {
let cache_dir = cache_dir.into();
let metadata_file = cache_dir.join("cache_metadata.json");
std::fs::create_dir_all(&cache_dir)?;
let mut cache = Self {
cache_dir,
config,
entries: HashMap::new(),
metadata_file,
};
cache.load_metadata()?;
if cache.config.auto_cleanup {
cache.cleanup_expired()?;
}
Ok(cache)
}
pub fn get_model_path(&mut self, model_name: &str) -> Option<PathBuf> {
if let Some(entry) = self.entries.get_mut(model_name) {
if entry.file_path.exists() {
entry.last_accessed = chrono::Utc::now();
let path = entry.file_path.clone();
self.save_metadata().ok()?;
return Some(path);
} else {
self.entries.remove(model_name);
self.save_metadata().ok()?;
}
}
None
}
pub fn get_download_path(&self, model_name: &str) -> PathBuf {
self.cache_dir.join(format!("{}.pth", model_name))
}
pub fn cache_model<P: AsRef<Path>>(
&mut self,
model_name: &str,
source_path: P,
) -> RusTorchResult<PathBuf> {
let source_path = source_path.as_ref();
let target_path = self.get_download_path(model_name);
if source_path != target_path {
std::fs::copy(source_path, &target_path)?;
}
let metadata = std::fs::metadata(&target_path)?;
let file_size = metadata.len();
let checksum = self.calculate_checksum(&target_path)?;
let entry = CacheEntry {
model_name: model_name.to_string(),
file_path: target_path.clone(),
file_size,
downloaded_at: chrono::Utc::now(),
last_accessed: chrono::Utc::now(),
checksum: Some(checksum),
};
self.entries.insert(model_name.to_string(), entry);
self.enforce_cache_limits()?;
self.save_metadata()?;
Ok(target_path)
}
pub fn remove_model(&mut self, model_name: &str) -> RusTorchResult<bool> {
if let Some(entry) = self.entries.remove(model_name) {
if entry.file_path.exists() {
std::fs::remove_file(&entry.file_path)?;
}
self.save_metadata()?;
Ok(true)
} else {
Ok(false)
}
}
pub fn clear(&mut self) -> RusTorchResult<()> {
for entry in self.entries.values() {
if entry.file_path.exists() {
std::fs::remove_file(&entry.file_path).ok();
}
}
self.entries.clear();
self.save_metadata()?;
Ok(())
}
pub fn stats(&self) -> (usize, u64) {
let model_count = self.entries.len();
let total_size = self.entries.values().map(|e| e.file_size).sum();
(model_count, total_size)
}
pub fn list_cached_models(&self) -> Vec<&str> {
self.entries.keys().map(|s| s.as_str()).collect()
}
fn load_metadata(&mut self) -> RusTorchResult<()> {
if !self.metadata_file.exists() {
return Ok(());
}
let content = std::fs::read_to_string(&self.metadata_file)?;
let entries: HashMap<String, CacheEntry> = serde_json::from_str(&content)
.map_err(|e| RusTorchError::DeserializationError(e.to_string()))?;
for (name, entry) in entries {
if entry.file_path.exists() {
self.entries.insert(name, entry);
}
}
Ok(())
}
fn save_metadata(&self) -> RusTorchResult<()> {
let content = serde_json::to_string_pretty(&self.entries)
.map_err(|e| RusTorchError::SerializationError(e.to_string()))?;
std::fs::write(&self.metadata_file, content)?;
Ok(())
}
fn calculate_checksum<P: AsRef<Path>>(&self, path: P) -> RusTorchResult<String> {
use std::io::Read;
let mut file = std::fs::File::open(path)?;
let mut hasher = sha2::Sha256::new();
let mut buffer = [0; 8192];
loop {
let bytes_read = file.read(&mut buffer)?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
use sha2::Digest;
let hash = hasher.finalize();
Ok(format!("{:x}", hash))
}
fn cleanup_expired(&mut self) -> RusTorchResult<()> {
let expiration_threshold =
chrono::Utc::now() - chrono::Duration::days(self.config.expiration_days as i64);
let expired_models: Vec<String> = self
.entries
.iter()
.filter(|(_, entry)| entry.last_accessed < expiration_threshold)
.map(|(name, _)| name.clone())
.collect();
for model_name in expired_models {
println!("Removing expired cached model: {}", model_name);
self.remove_model(&model_name)?;
}
Ok(())
}
fn enforce_cache_limits(&mut self) -> RusTorchResult<()> {
let total_size: u64 = self.entries.values().map(|e| e.file_size).sum();
if total_size > self.config.max_size_bytes || self.entries.len() > self.config.max_models {
let mut entries_by_access: Vec<_> = self.entries.iter().collect();
entries_by_access.sort_by_key(|(_, entry)| entry.last_accessed);
let mut current_size = total_size;
let mut current_count = self.entries.len();
for (model_name, entry) in entries_by_access {
if current_size <= self.config.max_size_bytes
&& current_count <= self.config.max_models
{
break;
}
println!("Removing LRU cached model: {}", model_name);
if entry.file_path.exists() {
std::fs::remove_file(&entry.file_path).ok();
}
current_size -= entry.file_size;
current_count -= 1;
}
self.entries.retain(|_, entry| entry.file_path.exists());
self.save_metadata()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.max_size_bytes, 10 * 1024 * 1024 * 1024);
assert_eq!(config.max_models, 50);
assert_eq!(config.expiration_days, 30);
assert!(config.auto_cleanup);
}
#[test]
fn test_cache_creation() {
let temp_dir = TempDir::new().unwrap();
let cache = ModelCache::new(temp_dir.path());
assert!(cache.is_ok());
let cache = cache.unwrap();
assert_eq!(cache.stats(), (0, 0));
}
#[test]
fn test_cache_with_custom_config() {
let temp_dir = TempDir::new().unwrap();
let config = CacheConfig {
max_size_bytes: 1024 * 1024, max_models: 5,
expiration_days: 7,
auto_cleanup: false,
};
let cache = ModelCache::with_config(temp_dir.path(), config.clone());
assert!(cache.is_ok());
let cache = cache.unwrap();
assert_eq!(cache.config.max_size_bytes, 1024 * 1024);
assert_eq!(cache.config.max_models, 5);
}
#[test]
fn test_cache_model() {
let temp_dir = TempDir::new().unwrap();
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
let test_file = temp_dir.path().join("test_model.pth");
let test_data = b"test model data";
std::fs::write(&test_file, test_data).unwrap();
let result = cache.cache_model("test_model", &test_file);
assert!(result.is_ok());
let cached_path = result.unwrap();
assert!(cached_path.exists());
assert_eq!(cache.stats().0, 1); }
#[test]
fn test_get_model_path() {
let temp_dir = TempDir::new().unwrap();
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
let test_file = temp_dir.path().join("test_model.pth");
std::fs::write(&test_file, b"test data").unwrap();
cache.cache_model("test_model", &test_file).unwrap();
let path = cache.get_model_path("test_model");
assert!(path.is_some());
assert!(path.unwrap().exists());
let no_path = cache.get_model_path("nonexistent");
assert!(no_path.is_none());
}
#[test]
fn test_remove_model() {
let temp_dir = TempDir::new().unwrap();
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
let test_file = temp_dir.path().join("test_model.pth");
std::fs::write(&test_file, b"test data").unwrap();
cache.cache_model("test_model", &test_file).unwrap();
assert_eq!(cache.stats().0, 1);
let removed = cache.remove_model("test_model").unwrap();
assert!(removed);
assert_eq!(cache.stats().0, 0);
let not_removed = cache.remove_model("nonexistent").unwrap();
assert!(!not_removed);
}
#[test]
fn test_list_cached_models() {
let temp_dir = TempDir::new().unwrap();
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
assert!(cache.list_cached_models().is_empty());
for i in 0..3 {
let test_file = temp_dir.path().join(format!("model_{}.pth", i));
std::fs::write(&test_file, b"test data").unwrap();
cache
.cache_model(&format!("model_{}", i), &test_file)
.unwrap();
}
let models = cache.list_cached_models();
assert_eq!(models.len(), 3);
assert!(models.contains(&"model_0"));
assert!(models.contains(&"model_1"));
assert!(models.contains(&"model_2"));
}
#[test]
fn test_clear_cache() {
let temp_dir = TempDir::new().unwrap();
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
for i in 0..3 {
let test_file = temp_dir.path().join(format!("model_{}.pth", i));
std::fs::write(&test_file, b"test data").unwrap();
cache
.cache_model(&format!("model_{}", i), &test_file)
.unwrap();
}
assert_eq!(cache.stats().0, 3);
cache.clear().unwrap();
assert_eq!(cache.stats().0, 0);
}
#[test]
fn test_cache_persistence() {
let temp_dir = TempDir::new().unwrap();
{
let mut cache = ModelCache::new(temp_dir.path()).unwrap();
let test_file = temp_dir.path().join("test_model.pth");
std::fs::write(&test_file, b"test data").unwrap();
cache.cache_model("test_model", &test_file).unwrap();
}
{
let cache = ModelCache::new(temp_dir.path()).unwrap();
assert_eq!(cache.stats().0, 1);
assert!(cache.list_cached_models().contains(&"test_model"));
}
}
}