1use crate::{Utils, constants};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15 pub local_path: PathBuf,
17 pub server_endpoint: String,
19 pub timeout_secs: Option<u64>,
21 #[serde(default = "default_shared_storage")]
24 pub shared_storage: bool,
25 #[serde(default = "default_transfer_chunk_size")]
27 pub transfer_chunk_size: usize,
28}
29
30fn default_shared_storage() -> bool {
31 constants::DEFAULT_SHARED_STORAGE
32}
33
34fn default_transfer_chunk_size() -> usize {
35 constants::DEFAULT_TRANSFER_CHUNK_SIZE
36}
37
38impl Default for CacheConfig {
39 fn default() -> Self {
40 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
41 Self {
42 local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
43 server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
44 timeout_secs: None,
45 shared_storage: constants::DEFAULT_SHARED_STORAGE,
46 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
47 }
48 }
49}
50
51impl CacheConfig {
52 pub fn discover() -> Result<Self> {
54 if let Some(path) = Self::get_cache_path_from_args() {
63 return Self::from_path(path);
64 }
65
66 if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
68 return Self::from_path(path);
69 }
70
71 if let Ok(config) = Self::from_config_file() {
73 return Ok(config);
74 }
75
76 if let Ok(config) = Self::auto_detect() {
78 return Ok(config);
79 }
80
81 debug!("Using default cache configuration");
83 Ok(Self::default())
84 }
85
86 pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
88 fs::create_dir_all(&local_path)
90 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
91
92 Ok(Self {
93 local_path,
94 server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
95 timeout_secs: None,
96 shared_storage: constants::DEFAULT_SHARED_STORAGE,
97 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
98 })
99 }
100
101 pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
103 let local_path = path.as_ref().to_path_buf();
104
105 fs::create_dir_all(&local_path)
107 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
108
109 Ok(Self {
110 local_path,
111 server_endpoint: Self::get_default_server_endpoint(),
112 timeout_secs: None,
113 shared_storage: constants::DEFAULT_SHARED_STORAGE,
114 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
115 })
116 }
117
118 pub fn from_config_file() -> Result<Self> {
120 let config_path = Self::get_config_path()?;
121
122 if !config_path.exists() {
123 return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
124 }
125
126 let content = fs::read_to_string(&config_path)
127 .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
128
129 let config: Self = serde_yaml::from_str(&content)
130 .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
131
132 Ok(config)
133 }
134
135 pub fn save_to_config_file(&self) -> Result<()> {
137 let config_path = Self::get_config_path()?;
138
139 if let Some(parent) = config_path.parent() {
141 fs::create_dir_all(parent)
142 .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
143 }
144
145 let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
146
147 fs::write(&config_path, content)
148 .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
149
150 Ok(())
151 }
152
153 pub fn auto_detect() -> Result<Self> {
155 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
156 let common_paths = vec![
157 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
158 PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
159 PathBuf::from("/cache"),
160 PathBuf::from("/app/models"),
161 PathBuf::from("./cache"),
162 PathBuf::from("./models"),
163 ];
164
165 for path in common_paths {
166 if path.exists() && path.is_dir() {
167 return Ok(Self {
168 local_path: path,
169 server_endpoint: Self::get_default_server_endpoint(),
170 timeout_secs: None,
171 shared_storage: constants::DEFAULT_SHARED_STORAGE,
172 transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
173 });
174 }
175 }
176
177 Err(anyhow::anyhow!(
178 "No cache directory found in common locations"
179 ))
180 }
181
182 pub fn from_server() -> Result<Self> {
184 Err(anyhow::anyhow!("Server not available for cache discovery"))
187 }
188
189 fn get_cache_path_from_args() -> Option<String> {
191 let args: Vec<String> = env::args().collect();
192
193 for (i, arg) in args.iter().enumerate() {
194 if arg == "--cache-path"
195 && let Some(next_arg) = args.get(i.saturating_add(1))
196 {
197 return Some(next_arg.clone());
198 }
199 }
200
201 None
202 }
203
204 fn get_default_server_endpoint() -> String {
206 env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
207 .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
208 }
209
210 fn get_config_path() -> Result<PathBuf> {
212 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
213
214 Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
215 }
216
217 pub fn folder_name_to_model_id(folder_name: &str) -> String {
221 if let Some(stripped) = folder_name.strip_prefix("models--") {
223 stripped.replace("--", "/")
225 } else if folder_name.starts_with("datasets--") {
226 folder_name.to_string()
228 } else if folder_name.starts_with("spaces--") {
229 folder_name.to_string()
231 } else {
232 folder_name.to_string()
234 }
235 }
236
237 pub fn get_cache_stats(&self) -> Result<CacheStats> {
239 let mut stats = CacheStats {
240 total_models: 0,
241 total_size: 0,
242 models: Vec::new(),
243 };
244
245 if !self.local_path.exists() {
246 return Ok(stats);
247 }
248
249 for entry in fs::read_dir(&self.local_path)? {
250 let entry = entry?;
251 let path = entry.path();
252
253 if path.is_dir() {
254 let size = Self::get_directory_size(&path)?;
255 let folder_name = path
256 .file_name()
257 .and_then(|n| n.to_str())
258 .unwrap_or("unknown")
259 .to_string();
260 info!("Folder name: {}", folder_name);
261 let model_name = Self::folder_name_to_model_id(&folder_name);
263
264 stats.total_models = stats.total_models.saturating_add(1);
265 stats.total_size = stats.total_size.saturating_add(size);
266 stats.models.push(ModelInfo {
267 name: model_name,
268 size,
269 path: path.to_path_buf(),
270 });
271 }
272 }
273
274 Ok(stats)
275 }
276
277 fn get_directory_size(path: &Path) -> Result<u64> {
279 let mut size: u64 = 0;
280
281 for entry in fs::read_dir(path)? {
282 let entry = entry?;
283 let path = entry.path();
284
285 if path.is_file() {
286 size = size.saturating_add(fs::metadata(&path)?.len());
287 } else if path.is_dir() {
288 size = size.saturating_add(Self::get_directory_size(&path)?);
289 }
290 }
291
292 Ok(size)
293 }
294
295 pub fn clear_model(&self, model_name: &str) -> Result<()> {
297 let model_path = self.local_path.join(model_name);
298
299 if model_path.exists() {
300 fs::remove_dir_all(&model_path)
301 .with_context(|| format!("Failed to remove model: {model_path:?}"))?;
302 info!("Cleared model: {}", model_name);
303 } else {
304 warn!("Model not found in cache: {}", model_name);
305 }
306
307 Ok(())
308 }
309
310 pub fn clear_all(&self) -> Result<()> {
312 if self.local_path.exists() {
313 for entry in fs::read_dir(&self.local_path)
314 .with_context(|| format!("Failed to read cache directory: {:?}", self.local_path))?
315 {
316 let entry = entry
317 .with_context(|| format!("Failed to read entry in: {:?}", self.local_path))?;
318 let path = entry.path();
319 if path.is_dir() {
320 fs::remove_dir_all(&path)
321 .with_context(|| format!("Failed to remove directory: {:?}", path))?;
322 } else {
323 fs::remove_file(&path)
324 .with_context(|| format!("Failed to remove file: {:?}", path))?;
325 }
326 }
327 info!("Cleared entire cache");
328 } else {
329 warn!("Cache directory does not exist");
330 }
331
332 Ok(())
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct CacheStats {
339 pub total_models: usize,
340 pub total_size: u64,
341 pub models: Vec<ModelInfo>,
342}
343
344#[derive(Debug, Clone)]
346pub struct ModelInfo {
347 pub name: String,
348 pub size: u64,
349 pub path: PathBuf,
350}
351
352impl CacheStats {
353 fn format_bytes(bytes: u64) -> String {
355 const KB: u64 = 1024;
356 const MB: u64 = KB * 1024;
357 const GB: u64 = MB * 1024;
358
359 match bytes {
360 size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
361 size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
362 size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
363 size => format!("{size} bytes"),
364 }
365 }
366
367 pub fn format_total_size(&self) -> String {
369 Self::format_bytes(self.total_size)
370 }
371
372 pub fn format_model_size(&self, model: &ModelInfo) -> String {
374 Self::format_bytes(model.size)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::Utils;
382 use tempfile::TempDir;
383
384 #[test]
385 #[allow(clippy::expect_used)]
386 fn test_cache_config_from_path() {
387 let temp_dir = TempDir::new().expect("Failed to create temp directory");
388 let config =
389 CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
390
391 assert_eq!(config.local_path, temp_dir.path());
392 }
393
394 #[test]
395 #[allow(clippy::expect_used)]
396 fn test_cache_config_save_and_load() {
397 let temp_dir = TempDir::new().expect("Failed to create temp directory");
398 let original_config = CacheConfig {
399 local_path: temp_dir.path().join("cache"),
400 server_endpoint: "http://localhost:8001".to_string(),
401 timeout_secs: Some(30),
402 shared_storage: false,
403 transfer_chunk_size: 64 * 1024,
404 };
405
406 original_config
408 .save_to_config_file()
409 .expect("Failed to save config");
410
411 let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
413
414 assert_eq!(loaded_config.local_path, original_config.local_path);
415 assert_eq!(
416 loaded_config.server_endpoint,
417 original_config.server_endpoint
418 );
419 assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
420 assert_eq!(loaded_config.shared_storage, original_config.shared_storage);
421 assert_eq!(
422 loaded_config.transfer_chunk_size,
423 original_config.transfer_chunk_size
424 );
425 }
426
427 #[test]
428 fn test_cache_stats_formatting() {
429 let stats = CacheStats {
430 total_models: 2,
431 total_size: 1024 * 1024 * 5, models: vec![
433 ModelInfo {
434 name: "model1".to_string(),
435 size: 1024 * 1024 * 2, path: PathBuf::from("/test/model1"),
437 },
438 ModelInfo {
439 name: "model2".to_string(),
440 size: 1024 * 1024 * 3, path: PathBuf::from("/test/model2"),
442 },
443 ],
444 };
445
446 assert_eq!(stats.format_total_size(), "5.00 MB");
447 assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
448 assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
449 }
450
451 #[test]
452 fn test_cache_config_default() {
453 let config = CacheConfig::default();
454
455 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
456 assert_eq!(
457 config.local_path,
458 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
459 );
460 assert_eq!(
461 config.server_endpoint,
462 String::from("http://localhost:8001")
463 );
464 assert_eq!(config.timeout_secs, None);
465 assert!(config.shared_storage);
466 assert_eq!(
467 config.transfer_chunk_size,
468 constants::DEFAULT_TRANSFER_CHUNK_SIZE
469 );
470 }
471
472 #[test]
473 #[allow(clippy::expect_used)]
474 fn test_get_config_path() {
475 let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
476
477 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
478 assert_eq!(
479 config_path,
480 PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
481 );
482 }
483
484 #[test]
485 fn test_folder_name_to_model_id() {
486 assert_eq!(
488 CacheConfig::folder_name_to_model_id("models--google-t5--t5-small"),
489 "google-t5/t5-small"
490 );
491 assert_eq!(
492 CacheConfig::folder_name_to_model_id("models--microsoft--DialoGPT-medium"),
493 "microsoft/DialoGPT-medium"
494 );
495 assert_eq!(
496 CacheConfig::folder_name_to_model_id("models--huggingface--CodeBERTa-small-v1"),
497 "huggingface/CodeBERTa-small-v1"
498 );
499
500 assert_eq!(
502 CacheConfig::folder_name_to_model_id("models--bert-base-uncased"),
503 "bert-base-uncased"
504 );
505
506 assert_eq!(
508 CacheConfig::folder_name_to_model_id("datasets--squad"),
509 "datasets--squad"
510 );
511 assert_eq!(
512 CacheConfig::folder_name_to_model_id("datasets--huggingface--squad"),
513 "datasets--huggingface--squad"
514 );
515
516 assert_eq!(
518 CacheConfig::folder_name_to_model_id("spaces--gradio--hello-world"),
519 "spaces--gradio--hello-world"
520 );
521
522 assert_eq!(
524 CacheConfig::folder_name_to_model_id("random-folder-name"),
525 "random-folder-name"
526 );
527 assert_eq!(
528 CacheConfig::folder_name_to_model_id("some--other--format"),
529 "some--other--format"
530 );
531
532 assert_eq!(CacheConfig::folder_name_to_model_id("models--"), "");
534 assert_eq!(
535 CacheConfig::folder_name_to_model_id("models--single"),
536 "single"
537 );
538 }
539
540 #[test]
541 #[allow(clippy::expect_used)]
542 fn test_clear_all_removes_contents_but_keeps_directory() {
543 let temp_dir = TempDir::new().expect("Failed to create temp directory");
544 let cache_path = temp_dir.path().join("cache");
545 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
546
547 let model_dir = cache_path.join("models--test--model");
549 fs::create_dir_all(&model_dir).expect("Failed to create model directory");
550 fs::write(model_dir.join("config.json"), "{}").expect("Failed to write file");
551 fs::write(cache_path.join("test_file.txt"), "test").expect("Failed to write file");
552
553 let config = CacheConfig {
554 local_path: cache_path.clone(),
555 server_endpoint: "http://localhost:8001".to_string(),
556 timeout_secs: None,
557 shared_storage: false,
558 transfer_chunk_size: 64 * 1024,
559 };
560
561 config.clear_all().expect("Failed to clear cache");
563
564 assert!(cache_path.exists(), "Cache directory should still exist");
566 assert!(
567 fs::read_dir(&cache_path)
568 .expect("Failed to read dir")
569 .next()
570 .is_none(),
571 "Cache directory should be empty"
572 );
573 }
574
575 #[test]
576 #[allow(clippy::expect_used)]
577 fn test_clear_all_handles_nonexistent_directory() {
578 let temp_dir = TempDir::new().expect("Failed to create temp directory");
579 let cache_path = temp_dir.path().join("nonexistent_cache");
580
581 let config = CacheConfig {
582 local_path: cache_path.clone(),
583 server_endpoint: "http://localhost:8001".to_string(),
584 timeout_secs: None,
585 shared_storage: false,
586 transfer_chunk_size: 64 * 1024,
587 };
588
589 config
591 .clear_all()
592 .with_context(|| format!("Failed to clear cache: {cache_path:?}"))
593 .expect("Failed to clear cache");
594 assert!(!cache_path.exists());
595 }
596
597 #[test]
598 #[allow(clippy::expect_used)]
599 fn test_clear_all_removes_nested_directories() {
600 let temp_dir = TempDir::new().expect("Failed to create temp directory");
601 let cache_path = temp_dir.path().join("cache");
602 fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
603
604 let deep_path = cache_path.join("a").join("b").join("c");
606 fs::create_dir_all(&deep_path).expect("Failed to create nested directories");
607 fs::write(deep_path.join("deep_file.txt"), "deep").expect("Failed to write file");
608
609 let config = CacheConfig {
610 local_path: cache_path.clone(),
611 server_endpoint: "http://localhost:8001".to_string(),
612 timeout_secs: None,
613 shared_storage: false,
614 transfer_chunk_size: 64 * 1024,
615 };
616
617 config.clear_all().expect("Failed to clear cache");
618
619 assert!(cache_path.exists(), "Cache directory should still exist");
620 assert!(
621 fs::read_dir(&cache_path)
622 .expect("Failed to read dir")
623 .next()
624 .is_none(),
625 "Cache directory should be empty after clearing nested content"
626 );
627 }
628}