Skip to main content

modelexpress_common/
cache.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{Utils, constants};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12/// Configuration for model cache management
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15    /// Local path where models are cached
16    pub local_path: PathBuf,
17    /// Server endpoint for model downloads
18    pub server_endpoint: String,
19    /// Timeout for cache operations
20    pub timeout_secs: Option<u64>,
21    /// Whether to use shared storage mode (client and server share a network drive)
22    /// When false, files will be streamed from server to client
23    #[serde(default = "default_shared_storage")]
24    pub shared_storage: bool,
25    /// Chunk size in bytes for file transfer streaming when shared_storage is false
26    #[serde(default = "default_transfer_chunk_size")]
27    pub transfer_chunk_size: usize,
28}
29
30fn default_shared_storage() -> bool {
31    constants::DEFAULT_SHARED_STORAGE
32}
33
34fn default_transfer_chunk_size() -> usize {
35    constants::DEFAULT_TRANSFER_CHUNK_SIZE
36}
37
38impl Default for CacheConfig {
39    fn default() -> Self {
40        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
41        Self {
42            local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
43            server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
44            timeout_secs: None,
45            shared_storage: constants::DEFAULT_SHARED_STORAGE,
46            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
47        }
48    }
49}
50
51impl CacheConfig {
52    /// Discover cache configuration
53    pub fn discover() -> Result<Self> {
54        // Priority order:
55        // 1. Command line argument (--cache-path)
56        // 2. Environment variable (MODEL_EXPRESS_CACHE_DIRECTORY)
57        // 3. Config file (~/.model-express/config.yaml)
58        // 4. Auto-detection (common paths)
59        // 5. Default fallback
60
61        // Try command line args first
62        if let Some(path) = Self::get_cache_path_from_args() {
63            return Self::from_path(path);
64        }
65
66        // Try environment variable
67        if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
68            return Self::from_path(path);
69        }
70
71        // Try config file
72        if let Ok(config) = Self::from_config_file() {
73            return Ok(config);
74        }
75
76        // Try auto-detection
77        if let Ok(config) = Self::auto_detect() {
78            return Ok(config);
79        }
80
81        // Use default configuration as fallback
82        debug!("Using default cache configuration");
83        Ok(Self::default())
84    }
85
86    /// Create a cache configuration with explicit parameters
87    pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
88        // Ensure the directory exists
89        fs::create_dir_all(&local_path)
90            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
91
92        Ok(Self {
93            local_path,
94            server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
95            timeout_secs: None,
96            shared_storage: constants::DEFAULT_SHARED_STORAGE,
97            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
98        })
99    }
100
101    /// Create config from a specific path
102    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
103        let local_path = path.as_ref().to_path_buf();
104
105        // Ensure the directory exists
106        fs::create_dir_all(&local_path)
107            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
108
109        Ok(Self {
110            local_path,
111            server_endpoint: Self::get_default_server_endpoint(),
112            timeout_secs: None,
113            shared_storage: constants::DEFAULT_SHARED_STORAGE,
114            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
115        })
116    }
117
118    /// Load configuration from file
119    pub fn from_config_file() -> Result<Self> {
120        let config_path = Self::get_config_path()?;
121
122        if !config_path.exists() {
123            return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
124        }
125
126        let content = fs::read_to_string(&config_path)
127            .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
128
129        let config: Self = serde_yaml::from_str(&content)
130            .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
131
132        Ok(config)
133    }
134
135    /// Save configuration to file
136    pub fn save_to_config_file(&self) -> Result<()> {
137        let config_path = Self::get_config_path()?;
138
139        // Ensure config directory exists
140        if let Some(parent) = config_path.parent() {
141            fs::create_dir_all(parent)
142                .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
143        }
144
145        let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
146
147        fs::write(&config_path, content)
148            .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
149
150        Ok(())
151    }
152
153    /// Auto-detect cache configuration
154    pub fn auto_detect() -> Result<Self> {
155        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
156        let common_paths = vec![
157            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
158            PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
159            PathBuf::from("/cache"),
160            PathBuf::from("/app/models"),
161            PathBuf::from("./cache"),
162            PathBuf::from("./models"),
163        ];
164
165        for path in common_paths {
166            if path.exists() && path.is_dir() {
167                return Ok(Self {
168                    local_path: path,
169                    server_endpoint: Self::get_default_server_endpoint(),
170                    timeout_secs: None,
171                    shared_storage: constants::DEFAULT_SHARED_STORAGE,
172                    transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
173                });
174            }
175        }
176
177        Err(anyhow::anyhow!(
178            "No cache directory found in common locations"
179        ))
180    }
181
182    /// Query server for cache information
183    pub fn from_server() -> Result<Self> {
184        // This would typically make an HTTP request to the server
185        // For now, we'll return an error to indicate server is not available
186        Err(anyhow::anyhow!("Server not available for cache discovery"))
187    }
188
189    /// Get cache path from command line arguments
190    fn get_cache_path_from_args() -> Option<String> {
191        let args: Vec<String> = env::args().collect();
192
193        for (i, arg) in args.iter().enumerate() {
194            if arg == "--cache-path"
195                && let Some(next_arg) = args.get(i.saturating_add(1))
196            {
197                return Some(next_arg.clone());
198            }
199        }
200
201        None
202    }
203
204    /// Get default server endpoint
205    fn get_default_server_endpoint() -> String {
206        env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
207            .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
208    }
209
210    /// Get configuration file path
211    fn get_config_path() -> Result<PathBuf> {
212        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
213
214        Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
215    }
216
217    /// Convert a Hugging Face folder name back to the original model ID
218    /// Examples:
219    /// - "models--google-t5--t5-small" -> "google-t5/t5-small"
220    pub fn folder_name_to_model_id(folder_name: &str) -> String {
221        // Handle models
222        if let Some(stripped) = folder_name.strip_prefix("models--") {
223            // Convert models--owner--repo to owner/repo
224            stripped.replace("--", "/")
225        } else if folder_name.starts_with("datasets--") {
226            // TODO: Handle datasets names conversion
227            folder_name.to_string()
228        } else if folder_name.starts_with("spaces--") {
229            // TODO: Handle spaces names conversion
230            folder_name.to_string()
231        } else {
232            // If it doesn't match the expected pattern, return as-is
233            folder_name.to_string()
234        }
235    }
236
237    /// Get cache statistics
238    pub fn get_cache_stats(&self) -> Result<CacheStats> {
239        let mut stats = CacheStats {
240            total_models: 0,
241            total_size: 0,
242            models: Vec::new(),
243        };
244
245        if !self.local_path.exists() {
246            return Ok(stats);
247        }
248
249        for entry in fs::read_dir(&self.local_path)? {
250            let entry = entry?;
251            let path = entry.path();
252
253            if path.is_dir() {
254                let size = Self::get_directory_size(&path)?;
255                let folder_name = path
256                    .file_name()
257                    .and_then(|n| n.to_str())
258                    .unwrap_or("unknown")
259                    .to_string();
260                info!("Folder name: {}", folder_name);
261                // Convert folder name back to human-readable model ID
262                let model_name = Self::folder_name_to_model_id(&folder_name);
263
264                stats.total_models = stats.total_models.saturating_add(1);
265                stats.total_size = stats.total_size.saturating_add(size);
266                stats.models.push(ModelInfo {
267                    name: model_name,
268                    size,
269                    path: path.to_path_buf(),
270                });
271            }
272        }
273
274        Ok(stats)
275    }
276
277    /// Get directory size recursively
278    fn get_directory_size(path: &Path) -> Result<u64> {
279        let mut size: u64 = 0;
280
281        for entry in fs::read_dir(path)? {
282            let entry = entry?;
283            let path = entry.path();
284
285            if path.is_file() {
286                size = size.saturating_add(fs::metadata(&path)?.len());
287            } else if path.is_dir() {
288                size = size.saturating_add(Self::get_directory_size(&path)?);
289            }
290        }
291
292        Ok(size)
293    }
294
295    /// Clear specific model from cache
296    pub fn clear_model(&self, model_name: &str) -> Result<()> {
297        let model_path = self.local_path.join(model_name);
298
299        if model_path.exists() {
300            fs::remove_dir_all(&model_path)
301                .with_context(|| format!("Failed to remove model: {model_path:?}"))?;
302            info!("Cleared model: {}", model_name);
303        } else {
304            warn!("Model not found in cache: {}", model_name);
305        }
306
307        Ok(())
308    }
309
310    /// Clear entire cache
311    pub fn clear_all(&self) -> Result<()> {
312        if self.local_path.exists() {
313            for entry in fs::read_dir(&self.local_path)
314                .with_context(|| format!("Failed to read cache directory: {:?}", self.local_path))?
315            {
316                let entry = entry
317                    .with_context(|| format!("Failed to read entry in: {:?}", self.local_path))?;
318                let path = entry.path();
319                if path.is_dir() {
320                    fs::remove_dir_all(&path)
321                        .with_context(|| format!("Failed to remove directory: {:?}", path))?;
322                } else {
323                    fs::remove_file(&path)
324                        .with_context(|| format!("Failed to remove file: {:?}", path))?;
325                }
326            }
327            info!("Cleared entire cache");
328        } else {
329            warn!("Cache directory does not exist");
330        }
331
332        Ok(())
333    }
334}
335
336/// Cache statistics
337#[derive(Debug, Clone)]
338pub struct CacheStats {
339    pub total_models: usize,
340    pub total_size: u64,
341    pub models: Vec<ModelInfo>,
342}
343
344/// Model information
345#[derive(Debug, Clone)]
346pub struct ModelInfo {
347    pub name: String,
348    pub size: u64,
349    pub path: PathBuf,
350}
351
352impl CacheStats {
353    /// Format bytes as human readable string
354    fn format_bytes(bytes: u64) -> String {
355        const KB: u64 = 1024;
356        const MB: u64 = KB * 1024;
357        const GB: u64 = MB * 1024;
358
359        match bytes {
360            size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
361            size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
362            size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
363            size => format!("{size} bytes"),
364        }
365    }
366
367    /// Format total size as human readable string
368    pub fn format_total_size(&self) -> String {
369        Self::format_bytes(self.total_size)
370    }
371
372    /// Format individual model size as human readable string
373    pub fn format_model_size(&self, model: &ModelInfo) -> String {
374        Self::format_bytes(model.size)
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::Utils;
382    use tempfile::TempDir;
383
384    #[test]
385    #[allow(clippy::expect_used)]
386    fn test_cache_config_from_path() {
387        let temp_dir = TempDir::new().expect("Failed to create temp directory");
388        let config =
389            CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
390
391        assert_eq!(config.local_path, temp_dir.path());
392    }
393
394    #[test]
395    #[allow(clippy::expect_used)]
396    fn test_cache_config_save_and_load() {
397        let temp_dir = TempDir::new().expect("Failed to create temp directory");
398        let original_config = CacheConfig {
399            local_path: temp_dir.path().join("cache"),
400            server_endpoint: "http://localhost:8001".to_string(),
401            timeout_secs: Some(30),
402            shared_storage: false,
403            transfer_chunk_size: 64 * 1024,
404        };
405
406        // Save config
407        original_config
408            .save_to_config_file()
409            .expect("Failed to save config");
410
411        // Load config
412        let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
413
414        assert_eq!(loaded_config.local_path, original_config.local_path);
415        assert_eq!(
416            loaded_config.server_endpoint,
417            original_config.server_endpoint
418        );
419        assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
420        assert_eq!(loaded_config.shared_storage, original_config.shared_storage);
421        assert_eq!(
422            loaded_config.transfer_chunk_size,
423            original_config.transfer_chunk_size
424        );
425    }
426
427    #[test]
428    fn test_cache_stats_formatting() {
429        let stats = CacheStats {
430            total_models: 2,
431            total_size: 1024 * 1024 * 5, // 5 MB
432            models: vec![
433                ModelInfo {
434                    name: "model1".to_string(),
435                    size: 1024 * 1024 * 2, // 2 MB
436                    path: PathBuf::from("/test/model1"),
437                },
438                ModelInfo {
439                    name: "model2".to_string(),
440                    size: 1024 * 1024 * 3, // 3 MB
441                    path: PathBuf::from("/test/model2"),
442                },
443            ],
444        };
445
446        assert_eq!(stats.format_total_size(), "5.00 MB");
447        assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
448        assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
449    }
450
451    #[test]
452    fn test_cache_config_default() {
453        let config = CacheConfig::default();
454
455        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
456        assert_eq!(
457            config.local_path,
458            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
459        );
460        assert_eq!(
461            config.server_endpoint,
462            String::from("http://localhost:8001")
463        );
464        assert_eq!(config.timeout_secs, None);
465        assert!(config.shared_storage);
466        assert_eq!(
467            config.transfer_chunk_size,
468            constants::DEFAULT_TRANSFER_CHUNK_SIZE
469        );
470    }
471
472    #[test]
473    #[allow(clippy::expect_used)]
474    fn test_get_config_path() {
475        let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
476
477        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
478        assert_eq!(
479            config_path,
480            PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
481        );
482    }
483
484    #[test]
485    fn test_folder_name_to_model_id() {
486        // Test models conversion
487        assert_eq!(
488            CacheConfig::folder_name_to_model_id("models--google-t5--t5-small"),
489            "google-t5/t5-small"
490        );
491        assert_eq!(
492            CacheConfig::folder_name_to_model_id("models--microsoft--DialoGPT-medium"),
493            "microsoft/DialoGPT-medium"
494        );
495        assert_eq!(
496            CacheConfig::folder_name_to_model_id("models--huggingface--CodeBERTa-small-v1"),
497            "huggingface/CodeBERTa-small-v1"
498        );
499
500        // Test single name models (no organization)
501        assert_eq!(
502            CacheConfig::folder_name_to_model_id("models--bert-base-uncased"),
503            "bert-base-uncased"
504        );
505
506        // Test datasets (TODO - should return as-is for now)
507        assert_eq!(
508            CacheConfig::folder_name_to_model_id("datasets--squad"),
509            "datasets--squad"
510        );
511        assert_eq!(
512            CacheConfig::folder_name_to_model_id("datasets--huggingface--squad"),
513            "datasets--huggingface--squad"
514        );
515
516        // Test spaces (TODO - should return as-is for now)
517        assert_eq!(
518            CacheConfig::folder_name_to_model_id("spaces--gradio--hello-world"),
519            "spaces--gradio--hello-world"
520        );
521
522        // Test unrecognized patterns (should return as-is)
523        assert_eq!(
524            CacheConfig::folder_name_to_model_id("random-folder-name"),
525            "random-folder-name"
526        );
527        assert_eq!(
528            CacheConfig::folder_name_to_model_id("some--other--format"),
529            "some--other--format"
530        );
531
532        // Test edge cases
533        assert_eq!(CacheConfig::folder_name_to_model_id("models--"), "");
534        assert_eq!(
535            CacheConfig::folder_name_to_model_id("models--single"),
536            "single"
537        );
538    }
539
540    #[test]
541    #[allow(clippy::expect_used)]
542    fn test_clear_all_removes_contents_but_keeps_directory() {
543        let temp_dir = TempDir::new().expect("Failed to create temp directory");
544        let cache_path = temp_dir.path().join("cache");
545        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
546
547        // Create some test content
548        let model_dir = cache_path.join("models--test--model");
549        fs::create_dir_all(&model_dir).expect("Failed to create model directory");
550        fs::write(model_dir.join("config.json"), "{}").expect("Failed to write file");
551        fs::write(cache_path.join("test_file.txt"), "test").expect("Failed to write file");
552
553        let config = CacheConfig {
554            local_path: cache_path.clone(),
555            server_endpoint: "http://localhost:8001".to_string(),
556            timeout_secs: None,
557            shared_storage: false,
558            transfer_chunk_size: 64 * 1024,
559        };
560
561        // Clear cache
562        config.clear_all().expect("Failed to clear cache");
563
564        // Directory should still exist but be empty
565        assert!(cache_path.exists(), "Cache directory should still exist");
566        assert!(
567            fs::read_dir(&cache_path)
568                .expect("Failed to read dir")
569                .next()
570                .is_none(),
571            "Cache directory should be empty"
572        );
573    }
574
575    #[test]
576    #[allow(clippy::expect_used)]
577    fn test_clear_all_handles_nonexistent_directory() {
578        let temp_dir = TempDir::new().expect("Failed to create temp directory");
579        let cache_path = temp_dir.path().join("nonexistent_cache");
580
581        let config = CacheConfig {
582            local_path: cache_path.clone(),
583            server_endpoint: "http://localhost:8001".to_string(),
584            timeout_secs: None,
585            shared_storage: false,
586            transfer_chunk_size: 64 * 1024,
587        };
588
589        // Should succeed without error even if directory doesn't exist
590        config
591            .clear_all()
592            .with_context(|| format!("Failed to clear cache: {cache_path:?}"))
593            .expect("Failed to clear cache");
594        assert!(!cache_path.exists());
595    }
596
597    #[test]
598    #[allow(clippy::expect_used)]
599    fn test_clear_all_removes_nested_directories() {
600        let temp_dir = TempDir::new().expect("Failed to create temp directory");
601        let cache_path = temp_dir.path().join("cache");
602        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
603
604        // Create nested structure
605        let deep_path = cache_path.join("a").join("b").join("c");
606        fs::create_dir_all(&deep_path).expect("Failed to create nested directories");
607        fs::write(deep_path.join("deep_file.txt"), "deep").expect("Failed to write file");
608
609        let config = CacheConfig {
610            local_path: cache_path.clone(),
611            server_endpoint: "http://localhost:8001".to_string(),
612            timeout_secs: None,
613            shared_storage: false,
614            transfer_chunk_size: 64 * 1024,
615        };
616
617        config.clear_all().expect("Failed to clear cache");
618
619        assert!(cache_path.exists(), "Cache directory should still exist");
620        assert!(
621            fs::read_dir(&cache_path)
622                .expect("Failed to read dir")
623                .next()
624                .is_none(),
625            "Cache directory should be empty after clearing nested content"
626        );
627    }
628}