modelexpress_common/
cache.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{Utils, constants};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12/// Configuration for model cache management
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15    /// Local path where models are cached
16    pub local_path: PathBuf,
17    /// Server endpoint for model downloads
18    pub server_endpoint: String,
19    /// Timeout for cache operations
20    pub timeout_secs: Option<u64>,
21}
22
23impl Default for CacheConfig {
24    fn default() -> Self {
25        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
26        Self {
27            local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
28            server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
29            timeout_secs: None,
30        }
31    }
32}
33
34impl CacheConfig {
35    /// Discover cache configuration
36    pub fn discover() -> Result<Self> {
37        // Priority order:
38        // 1. Command line argument (--cache-path)
39        // 2. Environment variable (MODEL_EXPRESS_CACHE_DIRECTORY)
40        // 3. Config file (~/.model-express/config.yaml)
41        // 4. Auto-detection (common paths)
42        // 5. Default fallback
43
44        // Try command line args first
45        if let Some(path) = Self::get_cache_path_from_args() {
46            return Self::from_path(path);
47        }
48
49        // Try environment variable
50        if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
51            return Self::from_path(path);
52        }
53
54        // Try config file
55        if let Ok(config) = Self::from_config_file() {
56            return Ok(config);
57        }
58
59        // Try auto-detection
60        if let Ok(config) = Self::auto_detect() {
61            return Ok(config);
62        }
63
64        // Use default configuration as fallback
65        debug!("Using default cache configuration");
66        Ok(Self::default())
67    }
68
69    /// Create a cache configuration with explicit parameters
70    pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
71        // Ensure the directory exists
72        fs::create_dir_all(&local_path)
73            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
74
75        Ok(Self {
76            local_path,
77            server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
78            timeout_secs: None,
79        })
80    }
81
82    /// Create config from a specific path
83    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
84        let local_path = path.as_ref().to_path_buf();
85
86        // Ensure the directory exists
87        fs::create_dir_all(&local_path)
88            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
89
90        Ok(Self {
91            local_path,
92            server_endpoint: Self::get_default_server_endpoint(),
93            timeout_secs: None,
94        })
95    }
96
97    /// Load configuration from file
98    pub fn from_config_file() -> Result<Self> {
99        let config_path = Self::get_config_path()?;
100
101        if !config_path.exists() {
102            return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
103        }
104
105        let content = fs::read_to_string(&config_path)
106            .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
107
108        let config: Self = serde_yaml::from_str(&content)
109            .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
110
111        Ok(config)
112    }
113
114    /// Save configuration to file
115    pub fn save_to_config_file(&self) -> Result<()> {
116        let config_path = Self::get_config_path()?;
117
118        // Ensure config directory exists
119        if let Some(parent) = config_path.parent() {
120            fs::create_dir_all(parent)
121                .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
122        }
123
124        let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
125
126        fs::write(&config_path, content)
127            .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
128
129        Ok(())
130    }
131
132    /// Auto-detect cache configuration
133    pub fn auto_detect() -> Result<Self> {
134        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
135        let common_paths = vec![
136            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
137            PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
138            PathBuf::from("/cache"),
139            PathBuf::from("/app/models"),
140            PathBuf::from("./cache"),
141            PathBuf::from("./models"),
142        ];
143
144        for path in common_paths {
145            if path.exists() && path.is_dir() {
146                return Ok(Self {
147                    local_path: path,
148                    server_endpoint: Self::get_default_server_endpoint(),
149                    timeout_secs: None,
150                });
151            }
152        }
153
154        Err(anyhow::anyhow!(
155            "No cache directory found in common locations"
156        ))
157    }
158
159    /// Query server for cache information
160    pub fn from_server() -> Result<Self> {
161        // This would typically make an HTTP request to the server
162        // For now, we'll return an error to indicate server is not available
163        Err(anyhow::anyhow!("Server not available for cache discovery"))
164    }
165
166    /// Get cache path from command line arguments
167    fn get_cache_path_from_args() -> Option<String> {
168        let args: Vec<String> = env::args().collect();
169
170        for (i, arg) in args.iter().enumerate() {
171            if arg == "--cache-path"
172                && let Some(next_arg) = args.get(i.saturating_add(1))
173            {
174                return Some(next_arg.clone());
175            }
176        }
177
178        None
179    }
180
181    /// Get default server endpoint
182    fn get_default_server_endpoint() -> String {
183        env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
184            .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
185    }
186
187    /// Get configuration file path
188    fn get_config_path() -> Result<PathBuf> {
189        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
190
191        Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
192    }
193
194    /// Convert a Hugging Face folder name back to the original model ID
195    /// Examples:
196    /// - "models--google-t5--t5-small" -> "google-t5/t5-small"
197    pub fn folder_name_to_model_id(folder_name: &str) -> String {
198        // Handle models
199        if let Some(stripped) = folder_name.strip_prefix("models--") {
200            // Convert models--owner--repo to owner/repo
201            stripped.replace("--", "/")
202        } else if folder_name.starts_with("datasets--") {
203            // TODO: Handle datasets names conversion
204            folder_name.to_string()
205        } else if folder_name.starts_with("spaces--") {
206            // TODO: Handle spaces names conversion
207            folder_name.to_string()
208        } else {
209            // If it doesn't match the expected pattern, return as-is
210            folder_name.to_string()
211        }
212    }
213
214    /// Get cache statistics
215    pub fn get_cache_stats(&self) -> Result<CacheStats> {
216        let mut stats = CacheStats {
217            total_models: 0,
218            total_size: 0,
219            models: Vec::new(),
220        };
221
222        if !self.local_path.exists() {
223            return Ok(stats);
224        }
225
226        for entry in fs::read_dir(&self.local_path)? {
227            let entry = entry?;
228            let path = entry.path();
229
230            if path.is_dir() {
231                let size = Self::get_directory_size(&path)?;
232                let folder_name = path
233                    .file_name()
234                    .and_then(|n| n.to_str())
235                    .unwrap_or("unknown")
236                    .to_string();
237                info!("Folder name: {}", folder_name);
238                // Convert folder name back to human-readable model ID
239                let model_name = Self::folder_name_to_model_id(&folder_name);
240
241                stats.total_models = stats.total_models.saturating_add(1);
242                stats.total_size = stats.total_size.saturating_add(size);
243                stats.models.push(ModelInfo {
244                    name: model_name,
245                    size,
246                    path: path.to_path_buf(),
247                });
248            }
249        }
250
251        Ok(stats)
252    }
253
254    /// Get directory size recursively
255    fn get_directory_size(path: &Path) -> Result<u64> {
256        let mut size: u64 = 0;
257
258        for entry in fs::read_dir(path)? {
259            let entry = entry?;
260            let path = entry.path();
261
262            if path.is_file() {
263                size = size.saturating_add(fs::metadata(&path)?.len());
264            } else if path.is_dir() {
265                size = size.saturating_add(Self::get_directory_size(&path)?);
266            }
267        }
268
269        Ok(size)
270    }
271
272    /// Clear specific model from cache
273    pub fn clear_model(&self, model_name: &str) -> Result<()> {
274        let model_path = self.local_path.join(model_name);
275
276        if model_path.exists() {
277            fs::remove_dir_all(&model_path)
278                .with_context(|| format!("Failed to remove model: {model_path:?}"))?;
279            info!("Cleared model: {}", model_name);
280        } else {
281            warn!("Model not found in cache: {}", model_name);
282        }
283
284        Ok(())
285    }
286
287    /// Clear entire cache
288    pub fn clear_all(&self) -> Result<()> {
289        if self.local_path.exists() {
290            fs::remove_dir_all(&self.local_path)
291                .with_context(|| format!("Failed to clear cache: {:?}", self.local_path))?;
292            info!("Cleared entire cache");
293        } else {
294            warn!("Cache directory does not exist");
295        }
296
297        Ok(())
298    }
299}
300
301/// Cache statistics
302#[derive(Debug, Clone)]
303pub struct CacheStats {
304    pub total_models: usize,
305    pub total_size: u64,
306    pub models: Vec<ModelInfo>,
307}
308
309/// Model information
310#[derive(Debug, Clone)]
311pub struct ModelInfo {
312    pub name: String,
313    pub size: u64,
314    pub path: PathBuf,
315}
316
317impl CacheStats {
318    /// Format bytes as human readable string
319    fn format_bytes(bytes: u64) -> String {
320        const KB: u64 = 1024;
321        const MB: u64 = KB * 1024;
322        const GB: u64 = MB * 1024;
323
324        match bytes {
325            size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
326            size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
327            size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
328            size => format!("{size} bytes"),
329        }
330    }
331
332    /// Format total size as human readable string
333    pub fn format_total_size(&self) -> String {
334        Self::format_bytes(self.total_size)
335    }
336
337    /// Format individual model size as human readable string
338    pub fn format_model_size(&self, model: &ModelInfo) -> String {
339        Self::format_bytes(model.size)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::Utils;
347    use tempfile::TempDir;
348
349    #[test]
350    #[allow(clippy::expect_used)]
351    fn test_cache_config_from_path() {
352        let temp_dir = TempDir::new().expect("Failed to create temp directory");
353        let config =
354            CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
355
356        assert_eq!(config.local_path, temp_dir.path());
357    }
358
359    #[test]
360    #[allow(clippy::expect_used)]
361    fn test_cache_config_save_and_load() {
362        let temp_dir = TempDir::new().expect("Failed to create temp directory");
363        let original_config = CacheConfig {
364            local_path: temp_dir.path().join("cache"),
365            server_endpoint: "http://localhost:8001".to_string(),
366            timeout_secs: Some(30),
367        };
368
369        // Save config
370        original_config
371            .save_to_config_file()
372            .expect("Failed to save config");
373
374        // Load config
375        let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
376
377        assert_eq!(loaded_config.local_path, original_config.local_path);
378        assert_eq!(
379            loaded_config.server_endpoint,
380            original_config.server_endpoint
381        );
382        assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
383    }
384
385    #[test]
386    fn test_cache_stats_formatting() {
387        let stats = CacheStats {
388            total_models: 2,
389            total_size: 1024 * 1024 * 5, // 5 MB
390            models: vec![
391                ModelInfo {
392                    name: "model1".to_string(),
393                    size: 1024 * 1024 * 2, // 2 MB
394                    path: PathBuf::from("/test/model1"),
395                },
396                ModelInfo {
397                    name: "model2".to_string(),
398                    size: 1024 * 1024 * 3, // 3 MB
399                    path: PathBuf::from("/test/model2"),
400                },
401            ],
402        };
403
404        assert_eq!(stats.format_total_size(), "5.00 MB");
405        assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
406        assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
407    }
408
409    #[test]
410    fn test_cache_config_default() {
411        let config = CacheConfig::default();
412
413        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
414        assert_eq!(
415            config.local_path,
416            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
417        );
418        assert_eq!(
419            config.server_endpoint,
420            String::from("http://localhost:8001")
421        );
422        assert_eq!(config.timeout_secs, None);
423    }
424
425    #[test]
426    #[allow(clippy::expect_used)]
427    fn test_get_config_path() {
428        let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
429
430        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
431        assert_eq!(
432            config_path,
433            PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
434        );
435    }
436
437    #[test]
438    fn test_folder_name_to_model_id() {
439        // Test models conversion
440        assert_eq!(
441            CacheConfig::folder_name_to_model_id("models--google-t5--t5-small"),
442            "google-t5/t5-small"
443        );
444        assert_eq!(
445            CacheConfig::folder_name_to_model_id("models--microsoft--DialoGPT-medium"),
446            "microsoft/DialoGPT-medium"
447        );
448        assert_eq!(
449            CacheConfig::folder_name_to_model_id("models--huggingface--CodeBERTa-small-v1"),
450            "huggingface/CodeBERTa-small-v1"
451        );
452
453        // Test single name models (no organization)
454        assert_eq!(
455            CacheConfig::folder_name_to_model_id("models--bert-base-uncased"),
456            "bert-base-uncased"
457        );
458
459        // Test datasets (TODO - should return as-is for now)
460        assert_eq!(
461            CacheConfig::folder_name_to_model_id("datasets--squad"),
462            "datasets--squad"
463        );
464        assert_eq!(
465            CacheConfig::folder_name_to_model_id("datasets--huggingface--squad"),
466            "datasets--huggingface--squad"
467        );
468
469        // Test spaces (TODO - should return as-is for now)
470        assert_eq!(
471            CacheConfig::folder_name_to_model_id("spaces--gradio--hello-world"),
472            "spaces--gradio--hello-world"
473        );
474
475        // Test unrecognized patterns (should return as-is)
476        assert_eq!(
477            CacheConfig::folder_name_to_model_id("random-folder-name"),
478            "random-folder-name"
479        );
480        assert_eq!(
481            CacheConfig::folder_name_to_model_id("some--other--format"),
482            "some--other--format"
483        );
484
485        // Test edge cases
486        assert_eq!(CacheConfig::folder_name_to_model_id("models--"), "");
487        assert_eq!(
488            CacheConfig::folder_name_to_model_id("models--single"),
489            "single"
490        );
491    }
492}