Skip to main content

mofa_plugins/tts/
cache.rs

1//! Model Cache Management
2//!
3//! Handles caching, validation, and retrieval of TTS models
4//! stored in ~/.mofa/models/tts/
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::time::SystemTime;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15
16/// Model metadata stored alongside cached models
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ModelMetadata {
19    /// Model identifier (e.g., "hexgrad/Kokoro-82M")
20    pub model_id: String,
21    /// Model version/tag
22    pub version: String,
23    /// File size in bytes
24    pub file_size: u64,
25    /// MD5 checksum
26    pub checksum: String,
27    /// Download timestamp
28    pub downloaded_at: SystemTime,
29    /// Last access timestamp
30    pub last_accessed: SystemTime,
31    /// Number of times accessed
32    pub access_count: u64,
33}
34
35/// Model cache manager
36pub struct ModelCache {
37    /// Cache root directory (e.g., ~/.mofa/models/tts/)
38    cache_dir: PathBuf,
39    /// In-memory metadata cache
40    metadata: Arc<RwLock<HashMap<String, ModelMetadata>>>,
41}
42
43impl ModelCache {
44    /// Create a new model cache manager
45    pub fn new(cache_dir: Option<PathBuf>) -> Result<Self> {
46        let cache_dir = cache_dir.unwrap_or_else(|| {
47            dirs::home_dir()
48                .expect("Failed to determine home directory")
49                .join(".mofa")
50                .join("models")
51                .join("tts")
52        });
53
54        // Ensure cache directory exists
55        fs::create_dir_all(&cache_dir)
56            .context(format!("Failed to create cache directory: {:?}", cache_dir))?;
57
58        info!("Model cache initialized at: {:?}", cache_dir);
59
60        Ok(Self {
61            cache_dir,
62            metadata: Arc::new(RwLock::new(HashMap::new())),
63        })
64    }
65
66    /// Get the cache directory
67    pub fn cache_dir(&self) -> &Path {
68        &self.cache_dir
69    }
70
71    /// Get path for a specific model
72    pub fn model_path(&self, model_id: &str) -> PathBuf {
73        // Sanitize model_id (replace '/' with '-')
74        let safe_id = model_id.replace('/', "-");
75        self.cache_dir.join(safe_id)
76    }
77
78    /// Get metadata file path for a model
79    pub fn metadata_path(&self, model_id: &str) -> PathBuf {
80        let mut path = self.model_path(model_id);
81        path.set_extension("json");
82        path
83    }
84
85    /// Check if model exists in cache
86    pub async fn exists(&self, model_id: &str) -> bool {
87        let model_path = self.model_path(model_id);
88        model_path.exists()
89    }
90
91    /// Load metadata for a cached model
92    pub async fn load_metadata(&self, model_id: &str) -> Result<Option<ModelMetadata>> {
93        let metadata_path = self.metadata_path(model_id);
94
95        if !metadata_path.exists() {
96            return Ok(None);
97        }
98
99        let content = fs::read_to_string(&metadata_path)
100            .context(format!("Failed to read metadata: {:?}", metadata_path))?;
101
102        let metadata: ModelMetadata = serde_json::from_str(&content)
103            .context(format!("Failed to parse metadata: {:?}", metadata_path))?;
104
105        // Update in-memory cache
106        let mut cache = self.metadata.write().await;
107        cache.insert(model_id.to_string(), metadata.clone());
108
109        debug!("Loaded metadata for model: {}", model_id);
110        Ok(Some(metadata))
111    }
112
113    /// Save metadata for a cached model
114    pub async fn save_metadata(&self, metadata: &ModelMetadata) -> Result<()> {
115        let metadata_path = self.metadata_path(&metadata.model_id);
116
117        let content =
118            serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
119
120        fs::write(&metadata_path, content)
121            .context(format!("Failed to write metadata: {:?}", metadata_path))?;
122
123        // Update in-memory cache
124        let mut cache = self.metadata.write().await;
125        cache.insert(metadata.model_id.clone(), metadata.clone());
126
127        debug!("Saved metadata for model: {}", metadata.model_id);
128        Ok(())
129    }
130
131    /// Validate model file integrity using checksum
132    pub async fn validate(&self, model_id: &str, expected_checksum: Option<&str>) -> Result<bool> {
133        let model_path = self.model_path(model_id);
134
135        if !model_path.exists() {
136            return Ok(false);
137        }
138
139        // Load metadata
140        let metadata = match self.load_metadata(model_id).await? {
141            Some(m) => m,
142            None => return Ok(false),
143        };
144
145        // Verify file exists and has correct size
146        let file_size = fs::metadata(&model_path)?.len();
147        if file_size != metadata.file_size {
148            warn!(
149                "Model file size mismatch for {}: expected {}, got {}",
150                model_id, metadata.file_size, file_size
151            );
152            return Ok(false);
153        }
154
155        // Verify checksum if provided
156        if let Some(expected) = expected_checksum
157            && metadata.checksum != expected
158        {
159            warn!("Model checksum mismatch for {}", model_id);
160            return Ok(false);
161        }
162
163        debug!("Model validation passed: {}", model_id);
164        Ok(true)
165    }
166
167    /// Get model size in bytes
168    pub async fn get_size(&self, model_id: &str) -> Result<u64> {
169        let model_path = self.model_path(model_id);
170        let metadata =
171            fs::metadata(&model_path).context(format!("Model not found: {:?}", model_path))?;
172        Ok(metadata.len())
173    }
174
175    /// Update last access time for a model
176    pub async fn update_access(&self, model_id: &str) -> Result<()> {
177        if let Some(mut metadata) = self.load_metadata(model_id).await? {
178            metadata.last_accessed = SystemTime::now();
179            metadata.access_count += 1;
180            self.save_metadata(&metadata).await?;
181        }
182        Ok(())
183    }
184
185    /// List all cached models
186    pub async fn list_models(&self) -> Result<Vec<String>> {
187        let mut models = Vec::new();
188
189        let entries = fs::read_dir(&self.cache_dir).context(format!(
190            "Failed to read cache directory: {:?}",
191            self.cache_dir
192        ))?;
193
194        for entry in entries {
195            let entry = entry?;
196            let path = entry.path();
197
198            // Skip metadata files and directories
199            if path.is_dir() || path.extension().is_some_and(|e| e == "json") {
200                continue;
201            }
202
203            // Convert filename back to model_id
204            if let Some(name) = path.file_stem().and_then(|n| n.to_str()) {
205                let model_id = name.replace('-', "/");
206                models.push(model_id);
207            }
208        }
209
210        Ok(models)
211    }
212
213    /// Delete a cached model
214    pub async fn delete_model(&self, model_id: &str) -> Result<()> {
215        let model_path = self.model_path(model_id);
216        let metadata_path = self.metadata_path(model_id);
217
218        // Delete model file
219        if model_path.exists() {
220            fs::remove_file(&model_path)
221                .context(format!("Failed to delete model: {:?}", model_path))?;
222        }
223
224        // Delete metadata
225        if metadata_path.exists() {
226            fs::remove_file(&metadata_path)
227                .context(format!("Failed to delete metadata: {:?}", metadata_path))?;
228        }
229
230        // Remove from in-memory cache
231        let mut cache = self.metadata.write().await;
232        cache.remove(model_id);
233
234        info!("Deleted cached model: {}", model_id);
235        Ok(())
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use tempfile::TempDir;
243
244    #[tokio::test]
245    async fn test_cache_creation() {
246        let temp_dir = TempDir::new().unwrap();
247        let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
248
249        assert_eq!(cache.cache_dir(), temp_dir.path());
250    }
251
252    #[tokio::test]
253    async fn test_model_path() {
254        let temp_dir = TempDir::new().unwrap();
255        let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
256
257        let path = cache.model_path("hexgrad/Kokoro-82M");
258        assert!(path.ends_with("hexgrad-Kokoro-82M"));
259    }
260
261    #[tokio::test]
262    async fn test_metadata_path() {
263        let temp_dir = TempDir::new().unwrap();
264        let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
265
266        let path = cache.metadata_path("hexgrad/Kokoro-82M");
267        assert!(path.ends_with("hexgrad-Kokoro-82M.json"));
268    }
269
270    #[tokio::test]
271    async fn test_save_and_load_metadata() {
272        let temp_dir = TempDir::new().unwrap();
273        let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
274
275        let metadata = ModelMetadata {
276            model_id: "test/model".to_string(),
277            version: "v1.0".to_string(),
278            file_size: 12345,
279            checksum: "abc123".to_string(),
280            downloaded_at: SystemTime::now(),
281            last_accessed: SystemTime::now(),
282            access_count: 0,
283        };
284
285        cache.save_metadata(&metadata).await.unwrap();
286        let loaded = cache.load_metadata("test/model").await.unwrap();
287
288        assert!(loaded.is_some());
289        let loaded = loaded.unwrap();
290        assert_eq!(loaded.model_id, "test/model");
291        assert_eq!(loaded.file_size, 12345);
292        assert_eq!(loaded.checksum, "abc123");
293    }
294
295    #[tokio::test]
296    async fn test_exists() {
297        let temp_dir = TempDir::new().unwrap();
298        let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
299
300        assert!(!cache.exists("test/model").await);
301
302        // Create a test file
303        let model_path = cache.model_path("test/model");
304        fs::write(&model_path, b"test data").unwrap();
305
306        assert!(cache.exists("test/model").await);
307    }
308}