mofa_plugins/tts/
cache.rs1use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::time::SystemTime;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ModelMetadata {
19 pub model_id: String,
21 pub version: String,
23 pub file_size: u64,
25 pub checksum: String,
27 pub downloaded_at: SystemTime,
29 pub last_accessed: SystemTime,
31 pub access_count: u64,
33}
34
35pub struct ModelCache {
37 cache_dir: PathBuf,
39 metadata: Arc<RwLock<HashMap<String, ModelMetadata>>>,
41}
42
43impl ModelCache {
44 pub fn new(cache_dir: Option<PathBuf>) -> Result<Self> {
46 let cache_dir = cache_dir.unwrap_or_else(|| {
47 dirs::home_dir()
48 .expect("Failed to determine home directory")
49 .join(".mofa")
50 .join("models")
51 .join("tts")
52 });
53
54 fs::create_dir_all(&cache_dir)
56 .context(format!("Failed to create cache directory: {:?}", cache_dir))?;
57
58 info!("Model cache initialized at: {:?}", cache_dir);
59
60 Ok(Self {
61 cache_dir,
62 metadata: Arc::new(RwLock::new(HashMap::new())),
63 })
64 }
65
66 pub fn cache_dir(&self) -> &Path {
68 &self.cache_dir
69 }
70
71 pub fn model_path(&self, model_id: &str) -> PathBuf {
73 let safe_id = model_id.replace('/', "-");
75 self.cache_dir.join(safe_id)
76 }
77
78 pub fn metadata_path(&self, model_id: &str) -> PathBuf {
80 let mut path = self.model_path(model_id);
81 path.set_extension("json");
82 path
83 }
84
85 pub async fn exists(&self, model_id: &str) -> bool {
87 let model_path = self.model_path(model_id);
88 model_path.exists()
89 }
90
91 pub async fn load_metadata(&self, model_id: &str) -> Result<Option<ModelMetadata>> {
93 let metadata_path = self.metadata_path(model_id);
94
95 if !metadata_path.exists() {
96 return Ok(None);
97 }
98
99 let content = fs::read_to_string(&metadata_path)
100 .context(format!("Failed to read metadata: {:?}", metadata_path))?;
101
102 let metadata: ModelMetadata = serde_json::from_str(&content)
103 .context(format!("Failed to parse metadata: {:?}", metadata_path))?;
104
105 let mut cache = self.metadata.write().await;
107 cache.insert(model_id.to_string(), metadata.clone());
108
109 debug!("Loaded metadata for model: {}", model_id);
110 Ok(Some(metadata))
111 }
112
113 pub async fn save_metadata(&self, metadata: &ModelMetadata) -> Result<()> {
115 let metadata_path = self.metadata_path(&metadata.model_id);
116
117 let content =
118 serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
119
120 fs::write(&metadata_path, content)
121 .context(format!("Failed to write metadata: {:?}", metadata_path))?;
122
123 let mut cache = self.metadata.write().await;
125 cache.insert(metadata.model_id.clone(), metadata.clone());
126
127 debug!("Saved metadata for model: {}", metadata.model_id);
128 Ok(())
129 }
130
131 pub async fn validate(&self, model_id: &str, expected_checksum: Option<&str>) -> Result<bool> {
133 let model_path = self.model_path(model_id);
134
135 if !model_path.exists() {
136 return Ok(false);
137 }
138
139 let metadata = match self.load_metadata(model_id).await? {
141 Some(m) => m,
142 None => return Ok(false),
143 };
144
145 let file_size = fs::metadata(&model_path)?.len();
147 if file_size != metadata.file_size {
148 warn!(
149 "Model file size mismatch for {}: expected {}, got {}",
150 model_id, metadata.file_size, file_size
151 );
152 return Ok(false);
153 }
154
155 if let Some(expected) = expected_checksum
157 && metadata.checksum != expected
158 {
159 warn!("Model checksum mismatch for {}", model_id);
160 return Ok(false);
161 }
162
163 debug!("Model validation passed: {}", model_id);
164 Ok(true)
165 }
166
167 pub async fn get_size(&self, model_id: &str) -> Result<u64> {
169 let model_path = self.model_path(model_id);
170 let metadata =
171 fs::metadata(&model_path).context(format!("Model not found: {:?}", model_path))?;
172 Ok(metadata.len())
173 }
174
175 pub async fn update_access(&self, model_id: &str) -> Result<()> {
177 if let Some(mut metadata) = self.load_metadata(model_id).await? {
178 metadata.last_accessed = SystemTime::now();
179 metadata.access_count += 1;
180 self.save_metadata(&metadata).await?;
181 }
182 Ok(())
183 }
184
185 pub async fn list_models(&self) -> Result<Vec<String>> {
187 let mut models = Vec::new();
188
189 let entries = fs::read_dir(&self.cache_dir).context(format!(
190 "Failed to read cache directory: {:?}",
191 self.cache_dir
192 ))?;
193
194 for entry in entries {
195 let entry = entry?;
196 let path = entry.path();
197
198 if path.is_dir() || path.extension().is_some_and(|e| e == "json") {
200 continue;
201 }
202
203 if let Some(name) = path.file_stem().and_then(|n| n.to_str()) {
205 let model_id = name.replace('-', "/");
206 models.push(model_id);
207 }
208 }
209
210 Ok(models)
211 }
212
213 pub async fn delete_model(&self, model_id: &str) -> Result<()> {
215 let model_path = self.model_path(model_id);
216 let metadata_path = self.metadata_path(model_id);
217
218 if model_path.exists() {
220 fs::remove_file(&model_path)
221 .context(format!("Failed to delete model: {:?}", model_path))?;
222 }
223
224 if metadata_path.exists() {
226 fs::remove_file(&metadata_path)
227 .context(format!("Failed to delete metadata: {:?}", metadata_path))?;
228 }
229
230 let mut cache = self.metadata.write().await;
232 cache.remove(model_id);
233
234 info!("Deleted cached model: {}", model_id);
235 Ok(())
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use tempfile::TempDir;
243
244 #[tokio::test]
245 async fn test_cache_creation() {
246 let temp_dir = TempDir::new().unwrap();
247 let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
248
249 assert_eq!(cache.cache_dir(), temp_dir.path());
250 }
251
252 #[tokio::test]
253 async fn test_model_path() {
254 let temp_dir = TempDir::new().unwrap();
255 let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
256
257 let path = cache.model_path("hexgrad/Kokoro-82M");
258 assert!(path.ends_with("hexgrad-Kokoro-82M"));
259 }
260
261 #[tokio::test]
262 async fn test_metadata_path() {
263 let temp_dir = TempDir::new().unwrap();
264 let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
265
266 let path = cache.metadata_path("hexgrad/Kokoro-82M");
267 assert!(path.ends_with("hexgrad-Kokoro-82M.json"));
268 }
269
270 #[tokio::test]
271 async fn test_save_and_load_metadata() {
272 let temp_dir = TempDir::new().unwrap();
273 let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
274
275 let metadata = ModelMetadata {
276 model_id: "test/model".to_string(),
277 version: "v1.0".to_string(),
278 file_size: 12345,
279 checksum: "abc123".to_string(),
280 downloaded_at: SystemTime::now(),
281 last_accessed: SystemTime::now(),
282 access_count: 0,
283 };
284
285 cache.save_metadata(&metadata).await.unwrap();
286 let loaded = cache.load_metadata("test/model").await.unwrap();
287
288 assert!(loaded.is_some());
289 let loaded = loaded.unwrap();
290 assert_eq!(loaded.model_id, "test/model");
291 assert_eq!(loaded.file_size, 12345);
292 assert_eq!(loaded.checksum, "abc123");
293 }
294
295 #[tokio::test]
296 async fn test_exists() {
297 let temp_dir = TempDir::new().unwrap();
298 let cache = ModelCache::new(Some(temp_dir.path().to_path_buf())).unwrap();
299
300 assert!(!cache.exists("test/model").await);
301
302 let model_path = cache.model_path("test/model");
304 fs::write(&model_path, b"test data").unwrap();
305
306 assert!(cache.exists("test/model").await);
307 }
308}