use crate::{
error::{Result, VoirsError},
traits::{CacheStats, ModelCache},
};
use async_trait::async_trait;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
any::Any,
collections::{HashMap, HashSet},
hash::Hash,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tokio::fs;
use tracing::{debug, info, warn};
#[cfg(feature = "cloud")]
use sha2::{Digest, Sha256};
pub struct AdvancedModelCache {
memory_cache: Arc<RwLock<HashMap<String, CachedModel>>>,
disk_cache_dir: Option<PathBuf>,
max_memory_bytes: usize,
#[allow(dead_code)]
max_disk_bytes: usize,
current_memory_usage: Arc<RwLock<usize>>,
access_order: Arc<RwLock<Vec<String>>>,
loading_queue: Arc<RwLock<HashSet<String>>>,
config: ModelCacheConfig,
stats: Arc<RwLock<ModelCacheStats>>,
#[allow(dead_code)]
metadata_cache: Arc<RwLock<HashMap<String, ModelMetadata>>>,
#[allow(dead_code)]
dependency_graph: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCacheConfig {
pub memory_cache_enabled: bool,
pub disk_cache_enabled: bool,
pub memory_cache_size_mb: usize,
pub disk_cache_size_mb: usize,
pub model_ttl_seconds: u64,
pub enable_compression: bool,
pub enable_cache_warming: bool,
pub enable_preloading: bool,
pub cleanup_interval_seconds: u64,
pub max_concurrent_loads: usize,
pub verify_integrity: bool,
pub priority_levels: HashMap<String, ModelPriority>,
}
impl Default for ModelCacheConfig {
fn default() -> Self {
Self {
memory_cache_enabled: true,
disk_cache_enabled: true,
memory_cache_size_mb: 1024, disk_cache_size_mb: 8192, model_ttl_seconds: 86400, enable_compression: true,
enable_cache_warming: true,
enable_preloading: false,
cleanup_interval_seconds: 3600, max_concurrent_loads: 4,
verify_integrity: true,
priority_levels: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelPriority {
Critical,
High,
Normal,
Low,
}
impl Default for ModelPriority {
fn default() -> Self {
Self::Normal
}
}
#[derive(Debug)]
pub struct CachedModel {
pub data: Box<dyn Any + Send + Sync>,
pub metadata: ModelMetadata,
pub size_bytes: usize,
pub cached_at: SystemTime,
pub expires_at: SystemTime,
pub last_accessed: SystemTime,
pub access_count: u64,
pub priority: ModelPriority,
pub pinned: bool,
pub checksum: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: String,
pub version: String,
pub model_type: ModelType,
pub architecture: String,
pub languages: Vec<String>,
pub parameter_count: Option<u64>,
pub file_path: Option<PathBuf>,
pub dependencies: Vec<String>,
pub created_at: SystemTime,
pub source: ModelSource,
pub config_hash: u64,
pub precision: ModelPrecision,
pub hardware_requirements: HardwareRequirements,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
G2P,
Acoustic,
Vocoder,
LanguageModel,
Encoder,
Decoder,
Enhancement,
Other(u8),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelSource {
Local(PathBuf),
Remote(String),
HuggingFace { repo: String, revision: String },
Builtin(String),
Custom(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelPrecision {
FP32,
FP16,
INT8,
INT4,
Mixed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareRequirements {
pub min_ram_mb: u64,
pub gpu_memory_mb: Option<u64>,
pub cpu_features: Vec<String>,
pub gpu_requirements: Option<GpuRequirements>,
pub min_compute_capability: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuRequirements {
pub min_memory_mb: u64,
pub required_apis: Vec<String>,
pub preferred_vendor: Option<String>,
pub min_compute_capability: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelCacheStats {
pub basic_stats: CacheStats,
pub models_loaded: u64,
pub models_evicted: u64,
pub load_failures: u64,
pub warming_time_ms: u64,
pub avg_load_time_ms: f64,
pub memory_fragmentation: f64,
pub disk_usage_bytes: u64,
pub queue_size: usize,
pub hot_models: Vec<String>,
pub cold_models: Vec<String>,
pub priority_distribution: HashMap<ModelPriority, usize>,
}
impl AdvancedModelCache {
pub fn new(config: ModelCacheConfig, disk_cache_dir: Option<PathBuf>) -> Result<Self> {
if let Some(ref dir) = disk_cache_dir {
if config.disk_cache_enabled {
std::fs::create_dir_all(dir).map_err(|e| {
VoirsError::cache_error(format!("Failed to create cache directory: {e}"))
})?;
}
}
Ok(Self {
memory_cache: Arc::new(RwLock::new(HashMap::new())),
disk_cache_dir,
max_memory_bytes: config.memory_cache_size_mb * 1024 * 1024,
max_disk_bytes: config.disk_cache_size_mb * 1024 * 1024,
current_memory_usage: Arc::new(RwLock::new(0)),
access_order: Arc::new(RwLock::new(Vec::new())),
loading_queue: Arc::new(RwLock::new(HashSet::new())),
config,
stats: Arc::new(RwLock::new(ModelCacheStats::default())),
metadata_cache: Arc::new(RwLock::new(HashMap::new())),
dependency_graph: Arc::new(RwLock::new(HashMap::new())),
})
}
pub async fn warm_cache(&self, model_names: Vec<String>) -> Result<()> {
let start_time = Instant::now();
info!("Starting cache warming for {} models", model_names.len());
let mut successful_loads = 0;
let mut failed_loads = 0;
for model_name in model_names {
match self.preload_model(&model_name).await {
Ok(_) => {
successful_loads += 1;
debug!("Preloaded model: {}", model_name);
}
Err(e) => {
failed_loads += 1;
warn!("Failed to preload model '{}': {}", model_name, e);
}
}
}
let warming_time = start_time.elapsed().as_millis() as u64;
{
let mut stats = self.stats.write();
stats.warming_time_ms = warming_time;
stats.models_loaded += successful_loads;
stats.load_failures += failed_loads;
}
info!(
"Cache warming completed in {}ms: {} successful, {} failed",
warming_time, successful_loads, failed_loads
);
Ok(())
}
pub async fn preload_model(&self, model_name: &str) -> Result<()> {
if self.contains_key(model_name).await {
return Ok(());
}
{
let loading_queue = self.loading_queue.read();
if loading_queue.contains(model_name) {
return Err(VoirsError::cache_error(format!(
"Model '{model_name}' is already being loaded"
)));
}
}
{
let mut loading_queue = self.loading_queue.write();
loading_queue.insert(model_name.to_string());
}
let loaded = if self.config.disk_cache_enabled {
self.load_from_disk_cache(model_name).await.unwrap_or(false)
} else {
false
};
{
let mut loading_queue = self.loading_queue.write();
loading_queue.remove(model_name);
}
if loaded {
info!("Loaded model '{}' from disk cache", model_name);
Ok(())
} else {
Err(VoirsError::cache_error(format!(
"Model '{model_name}' not found in cache"
)))
}
}
pub async fn contains_key(&self, key: &str) -> bool {
let cache = self.memory_cache.read();
cache.contains_key(key)
}
async fn load_from_disk_cache(&self, model_name: &str) -> Result<bool> {
if let Some(ref cache_dir) = self.disk_cache_dir {
let model_path = cache_dir.join(format!("{model_name}.cache"));
let metadata_path = cache_dir.join(format!("{model_name}.meta"));
if model_path.exists() && metadata_path.exists() {
let metadata_content = fs::read_to_string(&metadata_path).await.map_err(|e| {
VoirsError::cache_error(format!("Failed to read metadata: {e}"))
})?;
let metadata: ModelMetadata =
serde_json::from_str(&metadata_content).map_err(|e| {
VoirsError::cache_error(format!("Failed to parse metadata: {e}"))
})?;
let model_data = format!("Model data for {model_name}");
let checksum = self.calculate_file_checksum(&model_path).await.ok();
let size_bytes = model_path
.metadata()
.map_err(|e| {
VoirsError::cache_error(format!("Failed to get file metadata: {e}"))
})?
.len() as usize;
let cached_model = CachedModel {
data: Box::new(model_data),
metadata: metadata.clone(),
size_bytes,
cached_at: SystemTime::now(),
expires_at: SystemTime::now()
+ Duration::from_secs(self.config.model_ttl_seconds),
last_accessed: SystemTime::now(),
access_count: 0,
priority: self
.config
.priority_levels
.get(model_name)
.copied()
.unwrap_or_default(),
pinned: false,
checksum,
};
self.put_cached_model(model_name, cached_model).await?;
return Ok(true);
}
}
Ok(false)
}
async fn put_cached_model(&self, key: &str, model: CachedModel) -> Result<()> {
self.ensure_memory_capacity(model.size_bytes).await?;
{
let mut cache = self.memory_cache.write();
let mut current_usage = self.current_memory_usage.write();
let model_size = model.size_bytes;
cache.insert(key.to_string(), model);
*current_usage += model_size;
}
self.update_access_order(key).await;
{
let mut stats = self.stats.write();
stats.basic_stats.total_entries += 1;
stats.basic_stats.memory_usage_bytes = *self.current_memory_usage.read();
stats.models_loaded += 1;
}
Ok(())
}
async fn ensure_memory_capacity(&self, required_bytes: usize) -> Result<()> {
let current_usage = *self.current_memory_usage.read();
if current_usage + required_bytes > self.max_memory_bytes {
self.evict_lru_models(required_bytes).await?;
}
Ok(())
}
async fn evict_lru_models(&self, required_bytes: usize) -> Result<()> {
let mut freed_bytes = 0;
let mut evicted_count = 0;
let models_to_evict = self.get_eviction_candidates(required_bytes).await;
for model_key in models_to_evict {
if let Some(model) = self.remove_model(&model_key).await? {
freed_bytes += model.size_bytes;
evicted_count += 1;
debug!("Evicted model '{}' ({} bytes)", model_key, model.size_bytes);
if freed_bytes >= required_bytes {
break;
}
}
}
{
let mut stats = self.stats.write();
stats.models_evicted += evicted_count;
}
info!(
"Evicted {} models, freed {} bytes",
evicted_count, freed_bytes
);
Ok(())
}
async fn get_eviction_candidates(&self, _required_bytes: usize) -> Vec<String> {
let cache = self.memory_cache.read();
let access_order = self.access_order.read();
let mut candidates = Vec::new();
let mut sortable_models: Vec<_> = cache
.iter()
.filter(|(_, model)| !model.pinned) .map(|(key, model)| {
let access_index = access_order
.iter()
.position(|k| k == key)
.unwrap_or(usize::MAX);
(key.clone(), model.priority, access_index, model.size_bytes)
})
.collect();
sortable_models.sort_by(|a, b| {
use std::cmp::Ordering;
match (a.1 as u8).cmp(&(b.1 as u8)) {
Ordering::Equal => b.2.cmp(&a.2), other => other,
}
});
for (key, _, _, _) in sortable_models {
candidates.push(key);
}
candidates
}
async fn remove_model(&self, key: &str) -> Result<Option<CachedModel>> {
let removed_model = {
let mut cache = self.memory_cache.write();
cache.remove(key)
};
if let Some(ref model) = removed_model {
{
let mut current_usage = self.current_memory_usage.write();
*current_usage = current_usage.saturating_sub(model.size_bytes);
}
{
let mut access_order = self.access_order.write();
access_order.retain(|k| k != key);
}
{
let mut stats = self.stats.write();
stats.basic_stats.total_entries = stats.basic_stats.total_entries.saturating_sub(1);
stats.basic_stats.memory_usage_bytes = *self.current_memory_usage.read();
}
}
Ok(removed_model)
}
async fn update_access_order(&self, key: &str) {
let mut access_order = self.access_order.write();
access_order.retain(|k| k != key);
access_order.insert(0, key.to_string());
}
pub async fn pin_model(&self, key: &str) -> Result<()> {
let mut cache = self.memory_cache.write();
if let Some(model) = cache.get_mut(key) {
model.pinned = true;
info!("Pinned model '{}'", key);
Ok(())
} else {
Err(VoirsError::cache_error(format!(
"Model '{key}' not found in cache"
)))
}
}
pub async fn unpin_model(&self, key: &str) -> Result<()> {
let mut cache = self.memory_cache.write();
if let Some(model) = cache.get_mut(key) {
model.pinned = false;
info!("Unpinned model '{}'", key);
Ok(())
} else {
Err(VoirsError::cache_error(format!(
"Model '{key}' not found in cache"
)))
}
}
pub async fn get_model_metadata(&self, key: &str) -> Option<ModelMetadata> {
let cache = self.memory_cache.read();
cache.get(key).map(|model| model.metadata.clone())
}
pub async fn list_cached_models(&self) -> Vec<String> {
let cache = self.memory_cache.read();
cache.keys().cloned().collect()
}
pub async fn get_usage_summary(&self) -> CacheUsageSummary {
let cache = self.memory_cache.read();
let current_usage = *self.current_memory_usage.read();
let model_count = cache.len();
let total_accesses: u64 = cache.values().map(|m| m.access_count).sum();
let pinned_count = cache.values().filter(|m| m.pinned).count();
CacheUsageSummary {
total_models: model_count,
memory_usage_bytes: current_usage,
memory_usage_mb: current_usage / (1024 * 1024),
memory_utilization: (current_usage as f64 / self.max_memory_bytes as f64) * 100.0,
total_accesses,
pinned_models: pinned_count,
avg_model_size: current_usage.checked_div(model_count).unwrap_or(0),
}
}
pub async fn perform_maintenance(&self) -> Result<()> {
info!("Starting cache maintenance");
let expired_count = self.cleanup_expired_models().await?;
self.update_cache_statistics().await;
info!(
"Cache maintenance completed: {} expired models removed",
expired_count
);
Ok(())
}
async fn cleanup_expired_models(&self) -> Result<usize> {
let now = SystemTime::now();
let mut expired_keys = Vec::new();
{
let cache = self.memory_cache.read();
for (key, model) in cache.iter() {
if model.expires_at <= now && !model.pinned {
expired_keys.push(key.clone());
}
}
}
let mut removed_count = 0;
for key in expired_keys {
if self.remove_model(&key).await?.is_some() {
removed_count += 1;
debug!("Removed expired model: {}", key);
}
}
Ok(removed_count)
}
async fn update_cache_statistics(&self) {
let cache = self.memory_cache.read();
let mut stats = self.stats.write();
stats.basic_stats.total_entries = cache.len();
stats.basic_stats.memory_usage_bytes = *self.current_memory_usage.read();
let used_memory = stats.basic_stats.memory_usage_bytes;
let allocated_memory = self.max_memory_bytes;
stats.memory_fragmentation = if allocated_memory > 0 {
1.0 - (used_memory as f64 / allocated_memory as f64)
} else {
0.0
};
stats.priority_distribution.clear();
for model in cache.values() {
*stats
.priority_distribution
.entry(model.priority)
.or_insert(0) += 1;
}
stats.queue_size = self.loading_queue.read().len();
let mut models_by_access: Vec<_> = cache
.iter()
.map(|(key, model)| (key.clone(), model.access_count))
.collect();
models_by_access.sort_by_key(|b| std::cmp::Reverse(b.1));
let hot_threshold = models_by_access.len() / 4; let cold_threshold = models_by_access.len() * 3 / 4;
stats.hot_models = models_by_access
.iter()
.take(hot_threshold)
.map(|(key, _)| key.clone())
.collect();
stats.cold_models = models_by_access
.iter()
.skip(cold_threshold)
.map(|(key, _)| key.clone())
.collect();
}
async fn calculate_file_checksum(&self, file_path: &Path) -> Result<String> {
#[cfg(feature = "cloud")]
{
let data = fs::read(file_path)
.await
.map_err(|e| VoirsError::FileCorrupted {
path: file_path.to_path_buf(),
reason: format!("Failed to read file for checksum calculation: {}", e),
})?;
let mut hasher = Sha256::new();
hasher.update(&data);
let result = hasher.finalize();
Ok(format!("{:x}", result))
}
#[cfg(not(feature = "cloud"))]
{
warn!("Checksum calculation requires 'cloud' feature to be enabled");
Ok("no-checksum-available".to_string())
}
}
#[allow(dead_code)]
async fn verify_checksum(&self, file_path: &Path, expected_checksum: &str) -> Result<bool> {
let calculated = self.calculate_file_checksum(file_path).await?;
Ok(calculated == expected_checksum)
}
}
#[async_trait]
impl ModelCache for AdvancedModelCache {
async fn get_any(&self, key: &str) -> Result<Option<Box<dyn Any + Send + Sync>>> {
let result = {
let cache = self.memory_cache.read();
cache.get(key).is_some()
};
if result {
self.update_access_order(key).await;
{
let mut cache = self.memory_cache.write();
if let Some(model) = cache.get_mut(key) {
model.access_count += 1;
model.last_accessed = SystemTime::now();
}
}
{
let mut stats = self.stats.write();
let total_requests = stats.basic_stats.hit_rate + stats.basic_stats.miss_rate;
let hits = (stats.basic_stats.hit_rate / 100.0) * total_requests;
let new_total = total_requests + 1.0;
stats.basic_stats.hit_rate = ((hits + 1.0) / new_total) * 100.0;
stats.basic_stats.miss_rate = 100.0 - stats.basic_stats.hit_rate;
}
Ok(None)
} else {
{
let mut stats = self.stats.write();
let total_requests = stats.basic_stats.hit_rate + stats.basic_stats.miss_rate;
let misses = (stats.basic_stats.miss_rate / 100.0) * total_requests;
let new_total = total_requests + 1.0;
stats.basic_stats.miss_rate = ((misses + 1.0) / new_total) * 100.0;
stats.basic_stats.hit_rate = 100.0 - stats.basic_stats.miss_rate;
}
Ok(None)
}
}
async fn put_any(&self, key: &str, value: Box<dyn Any + Send + Sync>) -> Result<()> {
let metadata = ModelMetadata {
name: key.to_string(),
version: "1.0.0".to_string(),
model_type: ModelType::Other(0),
architecture: "unknown".to_string(),
languages: vec!["en".to_string()],
parameter_count: None,
file_path: None,
dependencies: vec![],
created_at: SystemTime::now(),
source: ModelSource::Custom("memory".to_string()),
config_hash: 0,
precision: ModelPrecision::FP32,
hardware_requirements: HardwareRequirements {
min_ram_mb: 0,
gpu_memory_mb: None,
cpu_features: vec![],
gpu_requirements: None,
min_compute_capability: None,
},
};
let estimated_size = std::mem::size_of_val(&*value) + key.len();
let cached_model = CachedModel {
data: value,
metadata,
size_bytes: estimated_size,
cached_at: SystemTime::now(),
expires_at: SystemTime::now() + Duration::from_secs(self.config.model_ttl_seconds),
last_accessed: SystemTime::now(),
access_count: 0,
priority: self
.config
.priority_levels
.get(key)
.copied()
.unwrap_or_default(),
pinned: false,
checksum: None,
};
self.put_cached_model(key, cached_model).await
}
async fn remove(&self, key: &str) -> Result<()> {
self.remove_model(key).await?;
Ok(())
}
async fn clear(&self) -> Result<()> {
{
let mut cache = self.memory_cache.write();
let mut current_usage = self.current_memory_usage.write();
let mut access_order = self.access_order.write();
cache.clear();
*current_usage = 0;
access_order.clear();
}
{
let mut stats = self.stats.write();
stats.basic_stats.total_entries = 0;
stats.basic_stats.memory_usage_bytes = 0;
}
info!("Cleared all models from cache");
Ok(())
}
fn stats(&self) -> CacheStats {
let stats = self.stats.read();
stats.basic_stats
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheUsageSummary {
pub total_models: usize,
pub memory_usage_bytes: usize,
pub memory_usage_mb: usize,
pub memory_utilization: f64,
pub total_accesses: u64,
pub pinned_models: usize,
pub avg_model_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_advanced_model_cache_creation() {
let config = ModelCacheConfig::default();
let cache = AdvancedModelCache::new(config, None).unwrap();
assert_eq!(cache.list_cached_models().await.len(), 0);
}
#[tokio::test]
async fn test_model_caching_and_retrieval() {
let config = ModelCacheConfig::default();
let cache = AdvancedModelCache::new(config, None).unwrap();
let test_data = "test model data".to_string();
cache
.put_any("test_model", Box::new(test_data))
.await
.unwrap();
assert!(cache.contains_key("test_model").await);
assert_eq!(cache.list_cached_models().await.len(), 1);
}
#[tokio::test]
async fn test_model_pinning() {
let config = ModelCacheConfig::default();
let cache = AdvancedModelCache::new(config, None).unwrap();
let test_data = "pinned model".to_string();
cache
.put_any("pinned_model", Box::new(test_data))
.await
.unwrap();
cache.pin_model("pinned_model").await.unwrap();
let metadata = cache.get_model_metadata("pinned_model").await;
assert!(metadata.is_some());
}
#[tokio::test]
async fn test_cache_maintenance() {
let config = ModelCacheConfig::default();
let cache = AdvancedModelCache::new(config, None).unwrap();
let test_data = "maintenance test".to_string();
cache
.put_any("test_model", Box::new(test_data))
.await
.unwrap();
cache.perform_maintenance().await.unwrap();
let summary = cache.get_usage_summary().await;
assert!(summary.memory_usage_bytes <= summary.memory_usage_mb * 1024 * 1024 + 1024 * 1024);
}
#[tokio::test]
async fn test_cache_warming() {
let config = ModelCacheConfig::default();
let cache = AdvancedModelCache::new(config, None).unwrap();
let result = cache.warm_cache(vec![]).await;
assert!(result.is_ok());
}
}