1use crate::{
5 Utils, constants, models::ModelProvider, providers::huggingface::HuggingFaceProviderCache,
6};
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::env;
10use std::fs;
11use std::path::{Path, PathBuf};
12use tracing::{debug, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CacheConfig {
17 pub local_path: PathBuf,
19 pub server_endpoint: String,
21 pub timeout_secs: Option<u64>,
23 #[serde(default = "default_shared_storage")]
26 pub shared_storage: bool,
27 #[serde(default = "default_transfer_chunk_size")]
29 pub transfer_chunk_size: usize,
30}
31
32fn default_shared_storage() -> bool {
33 constants::DEFAULT_SHARED_STORAGE
34}
35
36fn default_transfer_chunk_size() -> usize {
37 constants::DEFAULT_TRANSFER_CHUNK_SIZE
38}
39
40impl Default for CacheConfig {
41 fn default() -> Self {
42 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
43 Self {
44 local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
45 server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
46 timeout_secs: None,
47 shared_storage: constants::DEFAULT_SHARED_STORAGE,
48 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
49 }
50 }
51}
52
53impl CacheConfig {
54 pub fn discover() -> Result<Self> {
56 if let Some(path) = Self::get_cache_path_from_args() {
65 return Self::from_path(path);
66 }
67
68 if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
70 return Self::from_path(path);
71 }
72
73 if let Ok(config) = Self::from_config_file() {
75 return Ok(config);
76 }
77
78 if let Ok(config) = Self::auto_detect() {
80 return Ok(config);
81 }
82
83 debug!("Using default cache configuration");
85 Ok(Self::default())
86 }
87
88 pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
90 fs::create_dir_all(&local_path)
92 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
93
94 Ok(Self {
95 local_path,
96 server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
97 timeout_secs: None,
98 shared_storage: constants::DEFAULT_SHARED_STORAGE,
99 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
100 })
101 }
102
103 pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
105 let local_path = path.as_ref().to_path_buf();
106
107 fs::create_dir_all(&local_path)
109 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
110
111 Ok(Self {
112 local_path,
113 server_endpoint: Self::get_default_server_endpoint(),
114 timeout_secs: None,
115 shared_storage: constants::DEFAULT_SHARED_STORAGE,
116 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
117 })
118 }
119
120 pub fn from_config_file() -> Result<Self> {
122 let config_path = Self::get_config_path()?;
123
124 if !config_path.exists() {
125 return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
126 }
127
128 let content = fs::read_to_string(&config_path)
129 .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
130
131 let config: Self = serde_yaml::from_str(&content)
132 .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
133
134 Ok(config)
135 }
136
137 pub fn save_to_config_file(&self) -> Result<()> {
139 let config_path = Self::get_config_path()?;
140
141 if let Some(parent) = config_path.parent() {
143 fs::create_dir_all(parent)
144 .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
145 }
146
147 let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
148
149 fs::write(&config_path, content)
150 .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
151
152 Ok(())
153 }
154
155 pub fn auto_detect() -> Result<Self> {
157 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
158 let common_paths = vec![
159 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
160 PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
161 PathBuf::from("/cache"),
162 PathBuf::from("/app/models"),
163 PathBuf::from("./cache"),
164 PathBuf::from("./models"),
165 ];
166
167 for path in common_paths {
168 if path.exists() && path.is_dir() {
169 return Ok(Self {
170 local_path: path,
171 server_endpoint: Self::get_default_server_endpoint(),
172 timeout_secs: None,
173 shared_storage: constants::DEFAULT_SHARED_STORAGE,
174 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
175 });
176 }
177 }
178
179 Err(anyhow::anyhow!(
180 "No cache directory found in common locations"
181 ))
182 }
183
184 pub fn from_server() -> Result<Self> {
186 Err(anyhow::anyhow!("Server not available for cache discovery"))
189 }
190
191 fn get_cache_path_from_args() -> Option<String> {
193 let args: Vec<String> = env::args().collect();
194
195 for (i, arg) in args.iter().enumerate() {
196 if arg == "--cache-path"
197 && let Some(next_arg) = args.get(i.saturating_add(1))
198 {
199 return Some(next_arg.clone());
200 }
201 }
202
203 None
204 }
205
206 fn get_default_server_endpoint() -> String {
208 env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
209 .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
210 }
211
212 fn get_config_path() -> Result<PathBuf> {
214 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
215
216 Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
217 }
218
219 pub fn get_cache_stats(&self) -> Result<CacheStats> {
221 let mut models = Vec::new();
222
223 if !self.local_path.exists() {
224 return Ok(CacheStats {
225 total_models: 0,
226 total_size: 0,
227 models,
228 });
229 }
230
231 let provider = ModelProvider::HuggingFace;
232 models.extend(cache_for_provider(provider).list_models(&self.local_path)?);
233
234 models.sort_by(|left, right| {
235 provider_sort_key(left.provider)
236 .cmp(&provider_sort_key(right.provider))
237 .then_with(|| left.name.cmp(&right.name))
238 });
239
240 let total_size = models.iter().map(|model| model.size).sum();
241
242 Ok(CacheStats {
243 total_models: models.len(),
244 total_size,
245 models,
246 })
247 }
248
249 pub fn clear_model(&self, model_name: &str, provider: ModelProvider) -> Result<()> {
251 cache_for_provider(provider).clear_model(&self.local_path, model_name)
252 }
253
254 pub fn clear_all(&self) -> Result<()> {
256 if self.local_path.exists() {
257 for entry in fs::read_dir(&self.local_path)
258 .with_context(|| format!("Failed to read cache directory: {:?}", self.local_path))?
259 {
260 let entry = entry
261 .with_context(|| format!("Failed to read entry in: {:?}", self.local_path))?;
262 let path = entry.path();
263 if path.is_dir() {
264 fs::remove_dir_all(&path)
265 .with_context(|| format!("Failed to remove directory: {:?}", path))?;
266 } else {
267 fs::remove_file(&path)
268 .with_context(|| format!("Failed to remove file: {:?}", path))?;
269 }
270 }
271 info!("Cleared entire cache");
272 } else {
273 warn!("Cache directory does not exist");
274 }
275
276 Ok(())
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct CacheStats {
283 pub total_models: usize,
284 pub total_size: u64,
285 pub models: Vec<ModelInfo>,
286}
287
288#[derive(Debug, Clone)]
290pub struct ModelInfo {
291 pub provider: ModelProvider,
292 pub name: String,
293 pub size: u64,
294 pub path: PathBuf,
295}
296
297impl CacheStats {
298 fn format_bytes(bytes: u64) -> String {
300 const KB: u64 = 1024;
301 const MB: u64 = KB * 1024;
302 const GB: u64 = MB * 1024;
303
304 match bytes {
305 size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
306 size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
307 size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
308 size => format!("{size} bytes"),
309 }
310 }
311
312 pub fn format_total_size(&self) -> String {
314 Self::format_bytes(self.total_size)
315 }
316
317 pub fn format_model_size(&self, model: &ModelInfo) -> String {
319 Self::format_bytes(model.size)
320 }
321}
322
323pub(crate) trait ProviderCache: Send + Sync {
324 fn clear_model(&self, cache_root: &Path, model_name: &str) -> Result<()>;
325 fn resolve_model_path(
326 &self,
327 cache_root: &Path,
328 model_name: &str,
329 revision: Option<&str>,
330 ) -> Result<PathBuf>;
331 fn list_models(&self, cache_root: &Path) -> Result<Vec<ModelInfo>>;
332}
333
334pub(crate) fn cache_for_provider(provider: ModelProvider) -> &'static dyn ProviderCache {
335 match provider {
336 ModelProvider::HuggingFace => &HuggingFaceProviderCache,
337 }
338}
339
340pub fn resolve_model_path(
341 cache_root: &Path,
342 provider: ModelProvider,
343 model_name: &str,
344 revision: Option<&str>,
345) -> Result<PathBuf> {
346 cache_for_provider(provider).resolve_model_path(cache_root, model_name, revision)
347}
348
349pub(crate) fn directory_size(path: &Path) -> Result<u64> {
350 let mut size: u64 = 0;
351
352 for entry in fs::read_dir(path)? {
353 let entry = entry?;
354 let path = entry.path();
355
356 if path.is_file() {
357 size = size.saturating_add(fs::metadata(&path)?.len());
358 } else if path.is_dir() {
359 size = size.saturating_add(directory_size(&path)?);
360 }
361 }
362
363 Ok(size)
364}
365
366fn provider_sort_key(provider: ModelProvider) -> u8 {
367 match provider {
368 ModelProvider::HuggingFace => 0,
369 }
370}
371
372#[cfg(test)]
373#[allow(clippy::expect_used)]
374mod tests {
375 use super::*;
376 use crate::Utils;
377 use tempfile::TempDir;
378
379 #[test]
380 #[allow(clippy::expect_used)]
381 fn test_cache_config_from_path() {
382 let temp_dir = TempDir::new().expect("Failed to create temp directory");
383 let config =
384 CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
385
386 assert_eq!(config.local_path, temp_dir.path());
387 }
388
389 #[test]
390 #[allow(clippy::expect_used)]
391 fn test_cache_config_save_and_load() {
392 let temp_dir = TempDir::new().expect("Failed to create temp directory");
393 let original_config = CacheConfig {
394 local_path: temp_dir.path().join("cache"),
395 server_endpoint: "http://localhost:8001".to_string(),
396 timeout_secs: Some(30),
397 shared_storage: false,
398 transfer_chunk_size: 64 * 1024,
399 };
400
401 original_config
403 .save_to_config_file()
404 .expect("Failed to save config");
405
406 let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
408
409 assert_eq!(loaded_config.local_path, original_config.local_path);
410 assert_eq!(
411 loaded_config.server_endpoint,
412 original_config.server_endpoint
413 );
414 assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
415 assert_eq!(loaded_config.shared_storage, original_config.shared_storage);
416 assert_eq!(
417 loaded_config.transfer_chunk_size,
418 original_config.transfer_chunk_size
419 );
420 }
421
422 #[test]
423 fn test_cache_stats_formatting() {
424 let stats = CacheStats {
425 total_models: 2,
426 total_size: 1024 * 1024 * 5, models: vec![
428 ModelInfo {
429 provider: ModelProvider::HuggingFace,
430 name: "model1".to_string(),
431 size: 1024 * 1024 * 2, path: PathBuf::from("/test/model1"),
433 },
434 ModelInfo {
435 provider: ModelProvider::HuggingFace,
436 name: "model2".to_string(),
437 size: 1024 * 1024 * 3, path: PathBuf::from("/test/model2"),
439 },
440 ],
441 };
442
443 assert_eq!(stats.format_total_size(), "5.00 MB");
444 assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
445 assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
446 }
447
448 #[test]
449 fn test_cache_config_default() {
450 let config = CacheConfig::default();
451
452 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
453 assert_eq!(
454 config.local_path,
455 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
456 );
457 assert_eq!(
458 config.server_endpoint,
459 String::from("http://localhost:8001")
460 );
461 assert_eq!(config.timeout_secs, None);
462 assert!(config.shared_storage);
463 assert_eq!(
464 config.transfer_chunk_size,
465 constants::DEFAULT_TRANSFER_CHUNK_SIZE
466 );
467 }
468
469 #[test]
470 #[allow(clippy::expect_used)]
471 fn test_get_config_path() {
472 let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
473
474 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
475 assert_eq!(
476 config_path,
477 PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
478 );
479 }
480
481 #[test]
482 fn test_resolve_model_path_huggingface_uses_snapshot_layout() {
483 let cache_root = Path::new("/tmp/cache");
484
485 assert_eq!(
486 resolve_model_path(
487 cache_root,
488 ModelProvider::HuggingFace,
489 "google/t5-small",
490 Some("abc123"),
491 )
492 .expect("Expected HF model path"),
493 PathBuf::from("/tmp/cache/models--google--t5-small/snapshots/abc123")
494 );
495 }
496
497 fn create_test_cache_config(local_path: PathBuf) -> CacheConfig {
498 CacheConfig {
499 local_path,
500 server_endpoint: "http://localhost:8001".to_string(),
501 timeout_secs: None,
502 shared_storage: false,
503 transfer_chunk_size: 64 * 1024,
504 }
505 }
506
507 #[test]
508 fn test_get_cache_stats_supports_hf_layout() {
509 let temp_dir = TempDir::new().expect("Failed to create temp directory");
510 let cache_path = temp_dir.path().join("cache");
511 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
512
513 let hf_model_dir = cache_path.join("models--google--t5-small");
514 fs::create_dir_all(&hf_model_dir).expect("Failed to create HF model directory");
515 fs::write(hf_model_dir.join("config.json"), b"{}").expect("Failed to write HF file");
516
517 let ignored_dir = cache_path.join("tmp");
518 fs::create_dir_all(&ignored_dir).expect("Failed to create ignored directory");
519 fs::write(ignored_dir.join("scratch.txt"), b"ignore")
520 .expect("Failed to write ignored file");
521
522 let stats = create_test_cache_config(cache_path)
523 .get_cache_stats()
524 .expect("Failed to get cache stats");
525
526 assert_eq!(stats.total_models, 1);
527 assert_eq!(stats.total_size, 2);
528 assert_eq!(stats.models.len(), 1);
529
530 assert_eq!(stats.models[0].provider, ModelProvider::HuggingFace);
531 assert_eq!(stats.models[0].name, "google/t5-small");
532 assert_eq!(stats.models[0].size, 2);
533 assert_eq!(stats.models[0].path, hf_model_dir);
534 assert!(stats.models.iter().all(|model| model.name != "tmp"));
535 }
536
537 #[test]
538 fn test_clear_model_removes_only_requested_layout() {
539 let temp_dir = TempDir::new().expect("Failed to create temp directory");
540 let cache_path = temp_dir.path().join("cache");
541 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
542
543 let hf_model_dir = cache_path.join("models--google--t5-small");
544 fs::create_dir_all(&hf_model_dir).expect("Failed to create HF model directory");
545 fs::write(hf_model_dir.join("config.json"), b"{}").expect("Failed to write HF file");
546
547 let config = create_test_cache_config(cache_path);
548
549 config
550 .clear_model("google/t5-small", ModelProvider::HuggingFace)
551 .expect("Failed to clear HF model");
552 assert!(!hf_model_dir.exists(), "HF model should be removed");
553 }
554
555 #[test]
556 fn test_clear_all_removes_contents_but_keeps_directory() {
557 let temp_dir = TempDir::new().expect("Failed to create temp directory");
558 let cache_path = temp_dir.path().join("cache");
559 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
560
561 let model_dir = cache_path.join("models--test--model");
563 fs::create_dir_all(&model_dir).expect("Failed to create model directory");
564 fs::write(model_dir.join("config.json"), "{}").expect("Failed to write file");
565 fs::write(cache_path.join("test_file.txt"), "test").expect("Failed to write file");
566
567 let config = create_test_cache_config(cache_path.clone());
568
569 config.clear_all().expect("Failed to clear cache");
571
572 assert!(cache_path.exists(), "Cache directory should still exist");
574 assert!(
575 fs::read_dir(&cache_path)
576 .expect("Failed to read dir")
577 .next()
578 .is_none(),
579 "Cache directory should be empty"
580 );
581 }
582
583 #[test]
584 fn test_clear_all_handles_nonexistent_directory() {
585 let temp_dir = TempDir::new().expect("Failed to create temp directory");
586 let cache_path = temp_dir.path().join("nonexistent_cache");
587
588 let config = create_test_cache_config(cache_path.clone());
589
590 config
592 .clear_all()
593 .with_context(|| format!("Failed to clear cache: {cache_path:?}"))
594 .expect("Failed to clear cache");
595 assert!(!cache_path.exists());
596 }
597
598 #[test]
599 fn test_clear_all_removes_nested_directories() {
600 let temp_dir = TempDir::new().expect("Failed to create temp directory");
601 let cache_path = temp_dir.path().join("cache");
602 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
603
604 let deep_path = cache_path.join("a").join("b").join("c");
606 fs::create_dir_all(&deep_path).expect("Failed to create nested directories");
607 fs::write(deep_path.join("deep_file.txt"), "deep").expect("Failed to write file");
608
609 let config = create_test_cache_config(cache_path.clone());
610
611 config.clear_all().expect("Failed to clear cache");
612
613 assert!(cache_path.exists(), "Cache directory should still exist");
614 assert!(
615 fs::read_dir(&cache_path)
616 .expect("Failed to read dir")
617 .next()
618 .is_none(),
619 "Cache directory should be empty after clearing nested content"
620 );
621 }
622}