1use crate::{Utils, constants};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15 pub local_path: PathBuf,
17 pub server_endpoint: String,
19 pub timeout_secs: Option<u64>,
21}
22
23impl Default for CacheConfig {
24 fn default() -> Self {
25 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
26 Self {
27 local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
28 server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
29 timeout_secs: None,
30 }
31 }
32}
33
34impl CacheConfig {
35 pub fn discover() -> Result<Self> {
37 if let Some(path) = Self::get_cache_path_from_args() {
46 return Self::from_path(path);
47 }
48
49 if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
51 return Self::from_path(path);
52 }
53
54 if let Ok(config) = Self::from_config_file() {
56 return Ok(config);
57 }
58
59 if let Ok(config) = Self::auto_detect() {
61 return Ok(config);
62 }
63
64 debug!("Using default cache configuration");
66 Ok(Self::default())
67 }
68
69 pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
71 fs::create_dir_all(&local_path)
73 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
74
75 Ok(Self {
76 local_path,
77 server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
78 timeout_secs: None,
79 })
80 }
81
82 pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
84 let local_path = path.as_ref().to_path_buf();
85
86 fs::create_dir_all(&local_path)
88 .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
89
90 Ok(Self {
91 local_path,
92 server_endpoint: Self::get_default_server_endpoint(),
93 timeout_secs: None,
94 })
95 }
96
97 pub fn from_config_file() -> Result<Self> {
99 let config_path = Self::get_config_path()?;
100
101 if !config_path.exists() {
102 return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
103 }
104
105 let content = fs::read_to_string(&config_path)
106 .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
107
108 let config: Self = serde_yaml::from_str(&content)
109 .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
110
111 Ok(config)
112 }
113
114 pub fn save_to_config_file(&self) -> Result<()> {
116 let config_path = Self::get_config_path()?;
117
118 if let Some(parent) = config_path.parent() {
120 fs::create_dir_all(parent)
121 .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
122 }
123
124 let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
125
126 fs::write(&config_path, content)
127 .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
128
129 Ok(())
130 }
131
132 pub fn auto_detect() -> Result<Self> {
134 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
135 let common_paths = vec![
136 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
137 PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
138 PathBuf::from("/cache"),
139 PathBuf::from("/app/models"),
140 PathBuf::from("./cache"),
141 PathBuf::from("./models"),
142 ];
143
144 for path in common_paths {
145 if path.exists() && path.is_dir() {
146 return Ok(Self {
147 local_path: path,
148 server_endpoint: Self::get_default_server_endpoint(),
149 timeout_secs: None,
150 });
151 }
152 }
153
154 Err(anyhow::anyhow!(
155 "No cache directory found in common locations"
156 ))
157 }
158
159 pub fn from_server() -> Result<Self> {
161 Err(anyhow::anyhow!("Server not available for cache discovery"))
164 }
165
166 fn get_cache_path_from_args() -> Option<String> {
168 let args: Vec<String> = env::args().collect();
169
170 for (i, arg) in args.iter().enumerate() {
171 if arg == "--cache-path"
172 && let Some(next_arg) = args.get(i.saturating_add(1))
173 {
174 return Some(next_arg.clone());
175 }
176 }
177
178 None
179 }
180
181 fn get_default_server_endpoint() -> String {
183 env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
184 .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
185 }
186
187 fn get_config_path() -> Result<PathBuf> {
189 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
190
191 Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
192 }
193
194 pub fn folder_name_to_model_id(folder_name: &str) -> String {
198 if let Some(stripped) = folder_name.strip_prefix("models--") {
200 stripped.replace("--", "/")
202 } else if folder_name.starts_with("datasets--") {
203 folder_name.to_string()
205 } else if folder_name.starts_with("spaces--") {
206 folder_name.to_string()
208 } else {
209 folder_name.to_string()
211 }
212 }
213
214 pub fn get_cache_stats(&self) -> Result<CacheStats> {
216 let mut stats = CacheStats {
217 total_models: 0,
218 total_size: 0,
219 models: Vec::new(),
220 };
221
222 if !self.local_path.exists() {
223 return Ok(stats);
224 }
225
226 for entry in fs::read_dir(&self.local_path)? {
227 let entry = entry?;
228 let path = entry.path();
229
230 if path.is_dir() {
231 let size = Self::get_directory_size(&path)?;
232 let folder_name = path
233 .file_name()
234 .and_then(|n| n.to_str())
235 .unwrap_or("unknown")
236 .to_string();
237 info!("Folder name: {}", folder_name);
238 let model_name = Self::folder_name_to_model_id(&folder_name);
240
241 stats.total_models = stats.total_models.saturating_add(1);
242 stats.total_size = stats.total_size.saturating_add(size);
243 stats.models.push(ModelInfo {
244 name: model_name,
245 size,
246 path: path.to_path_buf(),
247 });
248 }
249 }
250
251 Ok(stats)
252 }
253
254 fn get_directory_size(path: &Path) -> Result<u64> {
256 let mut size: u64 = 0;
257
258 for entry in fs::read_dir(path)? {
259 let entry = entry?;
260 let path = entry.path();
261
262 if path.is_file() {
263 size = size.saturating_add(fs::metadata(&path)?.len());
264 } else if path.is_dir() {
265 size = size.saturating_add(Self::get_directory_size(&path)?);
266 }
267 }
268
269 Ok(size)
270 }
271
272 pub fn clear_model(&self, model_name: &str) -> Result<()> {
274 let model_path = self.local_path.join(model_name);
275
276 if model_path.exists() {
277 fs::remove_dir_all(&model_path)
278 .with_context(|| format!("Failed to remove model: {model_path:?}"))?;
279 info!("Cleared model: {}", model_name);
280 } else {
281 warn!("Model not found in cache: {}", model_name);
282 }
283
284 Ok(())
285 }
286
287 pub fn clear_all(&self) -> Result<()> {
289 if self.local_path.exists() {
290 fs::remove_dir_all(&self.local_path)
291 .with_context(|| format!("Failed to clear cache: {:?}", self.local_path))?;
292 info!("Cleared entire cache");
293 } else {
294 warn!("Cache directory does not exist");
295 }
296
297 Ok(())
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct CacheStats {
304 pub total_models: usize,
305 pub total_size: u64,
306 pub models: Vec<ModelInfo>,
307}
308
309#[derive(Debug, Clone)]
311pub struct ModelInfo {
312 pub name: String,
313 pub size: u64,
314 pub path: PathBuf,
315}
316
317impl CacheStats {
318 fn format_bytes(bytes: u64) -> String {
320 const KB: u64 = 1024;
321 const MB: u64 = KB * 1024;
322 const GB: u64 = MB * 1024;
323
324 match bytes {
325 size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
326 size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
327 size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
328 size => format!("{size} bytes"),
329 }
330 }
331
332 pub fn format_total_size(&self) -> String {
334 Self::format_bytes(self.total_size)
335 }
336
337 pub fn format_model_size(&self, model: &ModelInfo) -> String {
339 Self::format_bytes(model.size)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::Utils;
347 use tempfile::TempDir;
348
349 #[test]
350 #[allow(clippy::expect_used)]
351 fn test_cache_config_from_path() {
352 let temp_dir = TempDir::new().expect("Failed to create temp directory");
353 let config =
354 CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
355
356 assert_eq!(config.local_path, temp_dir.path());
357 }
358
359 #[test]
360 #[allow(clippy::expect_used)]
361 fn test_cache_config_save_and_load() {
362 let temp_dir = TempDir::new().expect("Failed to create temp directory");
363 let original_config = CacheConfig {
364 local_path: temp_dir.path().join("cache"),
365 server_endpoint: "http://localhost:8001".to_string(),
366 timeout_secs: Some(30),
367 };
368
369 original_config
371 .save_to_config_file()
372 .expect("Failed to save config");
373
374 let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
376
377 assert_eq!(loaded_config.local_path, original_config.local_path);
378 assert_eq!(
379 loaded_config.server_endpoint,
380 original_config.server_endpoint
381 );
382 assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
383 }
384
385 #[test]
386 fn test_cache_stats_formatting() {
387 let stats = CacheStats {
388 total_models: 2,
389 total_size: 1024 * 1024 * 5, models: vec![
391 ModelInfo {
392 name: "model1".to_string(),
393 size: 1024 * 1024 * 2, path: PathBuf::from("/test/model1"),
395 },
396 ModelInfo {
397 name: "model2".to_string(),
398 size: 1024 * 1024 * 3, path: PathBuf::from("/test/model2"),
400 },
401 ],
402 };
403
404 assert_eq!(stats.format_total_size(), "5.00 MB");
405 assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
406 assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
407 }
408
409 #[test]
410 fn test_cache_config_default() {
411 let config = CacheConfig::default();
412
413 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
414 assert_eq!(
415 config.local_path,
416 PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
417 );
418 assert_eq!(
419 config.server_endpoint,
420 String::from("http://localhost:8001")
421 );
422 assert_eq!(config.timeout_secs, None);
423 }
424
425 #[test]
426 #[allow(clippy::expect_used)]
427 fn test_get_config_path() {
428 let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
429
430 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
431 assert_eq!(
432 config_path,
433 PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
434 );
435 }
436
437 #[test]
438 fn test_folder_name_to_model_id() {
439 assert_eq!(
441 CacheConfig::folder_name_to_model_id("models--google-t5--t5-small"),
442 "google-t5/t5-small"
443 );
444 assert_eq!(
445 CacheConfig::folder_name_to_model_id("models--microsoft--DialoGPT-medium"),
446 "microsoft/DialoGPT-medium"
447 );
448 assert_eq!(
449 CacheConfig::folder_name_to_model_id("models--huggingface--CodeBERTa-small-v1"),
450 "huggingface/CodeBERTa-small-v1"
451 );
452
453 assert_eq!(
455 CacheConfig::folder_name_to_model_id("models--bert-base-uncased"),
456 "bert-base-uncased"
457 );
458
459 assert_eq!(
461 CacheConfig::folder_name_to_model_id("datasets--squad"),
462 "datasets--squad"
463 );
464 assert_eq!(
465 CacheConfig::folder_name_to_model_id("datasets--huggingface--squad"),
466 "datasets--huggingface--squad"
467 );
468
469 assert_eq!(
471 CacheConfig::folder_name_to_model_id("spaces--gradio--hello-world"),
472 "spaces--gradio--hello-world"
473 );
474
475 assert_eq!(
477 CacheConfig::folder_name_to_model_id("random-folder-name"),
478 "random-folder-name"
479 );
480 assert_eq!(
481 CacheConfig::folder_name_to_model_id("some--other--format"),
482 "some--other--format"
483 );
484
485 assert_eq!(CacheConfig::folder_name_to_model_id("models--"), "");
487 assert_eq!(
488 CacheConfig::folder_name_to_model_id("models--single"),
489 "single"
490 );
491 }
492}