Skip to main content

modelexpress_common/
cache.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    Utils, constants, models::ModelProvider, providers::huggingface::HuggingFaceProviderCache,
6};
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::env;
10use std::fs;
11use std::path::{Path, PathBuf};
12use tracing::{debug, info, warn};
13
14/// Configuration for model cache management
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CacheConfig {
17    /// Local path where models are cached
18    pub local_path: PathBuf,
19    /// Server endpoint for model downloads
20    pub server_endpoint: String,
21    /// Timeout for cache operations
22    pub timeout_secs: Option<u64>,
23    /// Whether to use shared storage mode (client and server share a network drive)
24    /// When false, files will be streamed from server to client
25    #[serde(default = "default_shared_storage")]
26    pub shared_storage: bool,
27    /// Chunk size in bytes for file transfer streaming when shared_storage is false
28    #[serde(default = "default_transfer_chunk_size")]
29    pub transfer_chunk_size: usize,
30}
31
32fn default_shared_storage() -> bool {
33    constants::DEFAULT_SHARED_STORAGE
34}
35
36fn default_transfer_chunk_size() -> usize {
37    constants::DEFAULT_TRANSFER_CHUNK_SIZE
38}
39
40impl Default for CacheConfig {
41    fn default() -> Self {
42        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
43        Self {
44            local_path: PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH),
45            server_endpoint: format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT),
46            timeout_secs: None,
47            shared_storage: constants::DEFAULT_SHARED_STORAGE,
48            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
49        }
50    }
51}
52
53impl CacheConfig {
54    /// Discover cache configuration
55    pub fn discover() -> Result<Self> {
56        // Priority order:
57        // 1. Command line argument (--cache-path)
58        // 2. Environment variable (MODEL_EXPRESS_CACHE_DIRECTORY)
59        // 3. Config file (~/.model-express/config.yaml)
60        // 4. Auto-detection (common paths)
61        // 5. Default fallback
62
63        // Try command line args first
64        if let Some(path) = Self::get_cache_path_from_args() {
65            return Self::from_path(path);
66        }
67
68        // Try environment variable
69        if let Ok(path) = env::var("MODEL_EXPRESS_CACHE_DIRECTORY") {
70            return Self::from_path(path);
71        }
72
73        // Try config file
74        if let Ok(config) = Self::from_config_file() {
75            return Ok(config);
76        }
77
78        // Try auto-detection
79        if let Ok(config) = Self::auto_detect() {
80            return Ok(config);
81        }
82
83        // Use default configuration as fallback
84        debug!("Using default cache configuration");
85        Ok(Self::default())
86    }
87
88    /// Create a cache configuration with explicit parameters
89    pub fn new(local_path: PathBuf, server_endpoint: Option<String>) -> Result<Self> {
90        // Ensure the directory exists
91        fs::create_dir_all(&local_path)
92            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
93
94        Ok(Self {
95            local_path,
96            server_endpoint: server_endpoint.unwrap_or_else(Self::get_default_server_endpoint),
97            timeout_secs: None,
98            shared_storage: constants::DEFAULT_SHARED_STORAGE,
99            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
100        })
101    }
102
103    /// Create config from a specific path
104    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
105        let local_path = path.as_ref().to_path_buf();
106
107        // Ensure the directory exists
108        fs::create_dir_all(&local_path)
109            .with_context(|| format!("Failed to create cache directory: {local_path:?}"))?;
110
111        Ok(Self {
112            local_path,
113            server_endpoint: Self::get_default_server_endpoint(),
114            timeout_secs: None,
115            shared_storage: constants::DEFAULT_SHARED_STORAGE,
116            transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
117        })
118    }
119
120    /// Load configuration from file
121    pub fn from_config_file() -> Result<Self> {
122        let config_path = Self::get_config_path()?;
123
124        if !config_path.exists() {
125            return Err(anyhow::anyhow!("Config file not found: {:?}", config_path));
126        }
127
128        let content = fs::read_to_string(&config_path)
129            .with_context(|| format!("Failed to read config file: {config_path:?}"))?;
130
131        let config: Self = serde_yaml::from_str(&content)
132            .with_context(|| format!("Failed to parse config file: {config_path:?}"))?;
133
134        Ok(config)
135    }
136
137    /// Save configuration to file
138    pub fn save_to_config_file(&self) -> Result<()> {
139        let config_path = Self::get_config_path()?;
140
141        // Ensure config directory exists
142        if let Some(parent) = config_path.parent() {
143            fs::create_dir_all(parent)
144                .with_context(|| format!("Failed to create config directory: {parent:?}"))?;
145        }
146
147        let content = serde_yaml::to_string(self).context("Failed to serialize config")?;
148
149        fs::write(&config_path, content)
150            .with_context(|| format!("Failed to write config file: {config_path:?}"))?;
151
152        Ok(())
153    }
154
155    /// Auto-detect cache configuration
156    pub fn auto_detect() -> Result<Self> {
157        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
158        let common_paths = vec![
159            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH),
160            PathBuf::from(&home).join(constants::DEFAULT_HF_CACHE_PATH),
161            PathBuf::from("/cache"),
162            PathBuf::from("/app/models"),
163            PathBuf::from("./cache"),
164            PathBuf::from("./models"),
165        ];
166
167        for path in common_paths {
168            if path.exists() && path.is_dir() {
169                return Ok(Self {
170                    local_path: path,
171                    server_endpoint: Self::get_default_server_endpoint(),
172                    timeout_secs: None,
173                    shared_storage: constants::DEFAULT_SHARED_STORAGE,
174                    transfer_chunk_size: constants::DEFAULT_TRANSFER_CHUNK_SIZE,
175                });
176            }
177        }
178
179        Err(anyhow::anyhow!(
180            "No cache directory found in common locations"
181        ))
182    }
183
184    /// Query server for cache information
185    pub fn from_server() -> Result<Self> {
186        // This would typically make an HTTP request to the server
187        // For now, we'll return an error to indicate server is not available
188        Err(anyhow::anyhow!("Server not available for cache discovery"))
189    }
190
191    /// Get cache path from command line arguments
192    fn get_cache_path_from_args() -> Option<String> {
193        let args: Vec<String> = env::args().collect();
194
195        for (i, arg) in args.iter().enumerate() {
196            if arg == "--cache-path"
197                && let Some(next_arg) = args.get(i.saturating_add(1))
198            {
199                return Some(next_arg.clone());
200            }
201        }
202
203        None
204    }
205
206    /// Get default server endpoint
207    fn get_default_server_endpoint() -> String {
208        env::var("MODEL_EXPRESS_SERVER_ENDPOINT")
209            .unwrap_or_else(|_| format!("http://localhost:{}", constants::DEFAULT_GRPC_PORT))
210    }
211
212    /// Get configuration file path
213    fn get_config_path() -> Result<PathBuf> {
214        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
215
216        Ok(PathBuf::from(home).join(constants::DEFAULT_CONFIG_PATH))
217    }
218
219    /// Get cache statistics
220    pub fn get_cache_stats(&self) -> Result<CacheStats> {
221        let mut models = Vec::new();
222
223        if !self.local_path.exists() {
224            return Ok(CacheStats {
225                total_models: 0,
226                total_size: 0,
227                models,
228            });
229        }
230
231        let provider = ModelProvider::HuggingFace;
232        models.extend(cache_for_provider(provider).list_models(&self.local_path)?);
233
234        models.sort_by(|left, right| {
235            provider_sort_key(left.provider)
236                .cmp(&provider_sort_key(right.provider))
237                .then_with(|| left.name.cmp(&right.name))
238        });
239
240        let total_size = models.iter().map(|model| model.size).sum();
241
242        Ok(CacheStats {
243            total_models: models.len(),
244            total_size,
245            models,
246        })
247    }
248
249    /// Clear specific model from cache for a given provider.
250    pub fn clear_model(&self, model_name: &str, provider: ModelProvider) -> Result<()> {
251        cache_for_provider(provider).clear_model(&self.local_path, model_name)
252    }
253
254    /// Clear entire cache
255    pub fn clear_all(&self) -> Result<()> {
256        if self.local_path.exists() {
257            for entry in fs::read_dir(&self.local_path)
258                .with_context(|| format!("Failed to read cache directory: {:?}", self.local_path))?
259            {
260                let entry = entry
261                    .with_context(|| format!("Failed to read entry in: {:?}", self.local_path))?;
262                let path = entry.path();
263                if path.is_dir() {
264                    fs::remove_dir_all(&path)
265                        .with_context(|| format!("Failed to remove directory: {:?}", path))?;
266                } else {
267                    fs::remove_file(&path)
268                        .with_context(|| format!("Failed to remove file: {:?}", path))?;
269                }
270            }
271            info!("Cleared entire cache");
272        } else {
273            warn!("Cache directory does not exist");
274        }
275
276        Ok(())
277    }
278}
279
280/// Cache statistics
281#[derive(Debug, Clone)]
282pub struct CacheStats {
283    pub total_models: usize,
284    pub total_size: u64,
285    pub models: Vec<ModelInfo>,
286}
287
288/// Model information
289#[derive(Debug, Clone)]
290pub struct ModelInfo {
291    pub provider: ModelProvider,
292    pub name: String,
293    pub size: u64,
294    pub path: PathBuf,
295}
296
297impl CacheStats {
298    /// Format bytes as human readable string
299    fn format_bytes(bytes: u64) -> String {
300        const KB: u64 = 1024;
301        const MB: u64 = KB * 1024;
302        const GB: u64 = MB * 1024;
303
304        match bytes {
305            size if size >= GB => format!("{:.2} GB", size as f64 / GB as f64),
306            size if size >= MB => format!("{:.2} MB", size as f64 / MB as f64),
307            size if size >= KB => format!("{:.2} KB", size as f64 / KB as f64),
308            size => format!("{size} bytes"),
309        }
310    }
311
312    /// Format total size as human readable string
313    pub fn format_total_size(&self) -> String {
314        Self::format_bytes(self.total_size)
315    }
316
317    /// Format individual model size as human readable string
318    pub fn format_model_size(&self, model: &ModelInfo) -> String {
319        Self::format_bytes(model.size)
320    }
321}
322
323pub(crate) trait ProviderCache: Send + Sync {
324    fn clear_model(&self, cache_root: &Path, model_name: &str) -> Result<()>;
325    fn resolve_model_path(
326        &self,
327        cache_root: &Path,
328        model_name: &str,
329        revision: Option<&str>,
330    ) -> Result<PathBuf>;
331    fn list_models(&self, cache_root: &Path) -> Result<Vec<ModelInfo>>;
332}
333
334pub(crate) fn cache_for_provider(provider: ModelProvider) -> &'static dyn ProviderCache {
335    match provider {
336        ModelProvider::HuggingFace => &HuggingFaceProviderCache,
337    }
338}
339
340pub fn resolve_model_path(
341    cache_root: &Path,
342    provider: ModelProvider,
343    model_name: &str,
344    revision: Option<&str>,
345) -> Result<PathBuf> {
346    cache_for_provider(provider).resolve_model_path(cache_root, model_name, revision)
347}
348
349pub(crate) fn directory_size(path: &Path) -> Result<u64> {
350    let mut size: u64 = 0;
351
352    for entry in fs::read_dir(path)? {
353        let entry = entry?;
354        let path = entry.path();
355
356        if path.is_file() {
357            size = size.saturating_add(fs::metadata(&path)?.len());
358        } else if path.is_dir() {
359            size = size.saturating_add(directory_size(&path)?);
360        }
361    }
362
363    Ok(size)
364}
365
366fn provider_sort_key(provider: ModelProvider) -> u8 {
367    match provider {
368        ModelProvider::HuggingFace => 0,
369    }
370}
371
372#[cfg(test)]
373#[allow(clippy::expect_used)]
374mod tests {
375    use super::*;
376    use crate::Utils;
377    use tempfile::TempDir;
378
379    #[test]
380    #[allow(clippy::expect_used)]
381    fn test_cache_config_from_path() {
382        let temp_dir = TempDir::new().expect("Failed to create temp directory");
383        let config =
384            CacheConfig::from_path(temp_dir.path()).expect("Failed to create config from path");
385
386        assert_eq!(config.local_path, temp_dir.path());
387    }
388
389    #[test]
390    #[allow(clippy::expect_used)]
391    fn test_cache_config_save_and_load() {
392        let temp_dir = TempDir::new().expect("Failed to create temp directory");
393        let original_config = CacheConfig {
394            local_path: temp_dir.path().join("cache"),
395            server_endpoint: "http://localhost:8001".to_string(),
396            timeout_secs: Some(30),
397            shared_storage: false,
398            transfer_chunk_size: 64 * 1024,
399        };
400
401        // Save config
402        original_config
403            .save_to_config_file()
404            .expect("Failed to save config");
405
406        // Load config
407        let loaded_config = CacheConfig::from_config_file().expect("Failed to load config");
408
409        assert_eq!(loaded_config.local_path, original_config.local_path);
410        assert_eq!(
411            loaded_config.server_endpoint,
412            original_config.server_endpoint
413        );
414        assert_eq!(loaded_config.timeout_secs, original_config.timeout_secs);
415        assert_eq!(loaded_config.shared_storage, original_config.shared_storage);
416        assert_eq!(
417            loaded_config.transfer_chunk_size,
418            original_config.transfer_chunk_size
419        );
420    }
421
422    #[test]
423    fn test_cache_stats_formatting() {
424        let stats = CacheStats {
425            total_models: 2,
426            total_size: 1024 * 1024 * 5, // 5 MB
427            models: vec![
428                ModelInfo {
429                    provider: ModelProvider::HuggingFace,
430                    name: "model1".to_string(),
431                    size: 1024 * 1024 * 2, // 2 MB
432                    path: PathBuf::from("/test/model1"),
433                },
434                ModelInfo {
435                    provider: ModelProvider::HuggingFace,
436                    name: "model2".to_string(),
437                    size: 1024 * 1024 * 3, // 3 MB
438                    path: PathBuf::from("/test/model2"),
439                },
440            ],
441        };
442
443        assert_eq!(stats.format_total_size(), "5.00 MB");
444        assert_eq!(stats.format_model_size(&stats.models[0]), "2.00 MB");
445        assert_eq!(stats.format_model_size(&stats.models[1]), "3.00 MB");
446    }
447
448    #[test]
449    fn test_cache_config_default() {
450        let config = CacheConfig::default();
451
452        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
453        assert_eq!(
454            config.local_path,
455            PathBuf::from(&home).join(constants::DEFAULT_CACHE_PATH)
456        );
457        assert_eq!(
458            config.server_endpoint,
459            String::from("http://localhost:8001")
460        );
461        assert_eq!(config.timeout_secs, None);
462        assert!(config.shared_storage);
463        assert_eq!(
464            config.transfer_chunk_size,
465            constants::DEFAULT_TRANSFER_CHUNK_SIZE
466        );
467    }
468
469    #[test]
470    #[allow(clippy::expect_used)]
471    fn test_get_config_path() {
472        let config_path = CacheConfig::get_config_path().expect("Failed to get config path");
473
474        let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
475        assert_eq!(
476            config_path,
477            PathBuf::from(&home).join(constants::DEFAULT_CONFIG_PATH)
478        );
479    }
480
481    #[test]
482    fn test_resolve_model_path_huggingface_uses_snapshot_layout() {
483        let cache_root = Path::new("/tmp/cache");
484
485        assert_eq!(
486            resolve_model_path(
487                cache_root,
488                ModelProvider::HuggingFace,
489                "google/t5-small",
490                Some("abc123"),
491            )
492            .expect("Expected HF model path"),
493            PathBuf::from("/tmp/cache/models--google--t5-small/snapshots/abc123")
494        );
495    }
496
497    fn create_test_cache_config(local_path: PathBuf) -> CacheConfig {
498        CacheConfig {
499            local_path,
500            server_endpoint: "http://localhost:8001".to_string(),
501            timeout_secs: None,
502            shared_storage: false,
503            transfer_chunk_size: 64 * 1024,
504        }
505    }
506
507    #[test]
508    fn test_get_cache_stats_supports_hf_layout() {
509        let temp_dir = TempDir::new().expect("Failed to create temp directory");
510        let cache_path = temp_dir.path().join("cache");
511        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
512
513        let hf_model_dir = cache_path.join("models--google--t5-small");
514        fs::create_dir_all(&hf_model_dir).expect("Failed to create HF model directory");
515        fs::write(hf_model_dir.join("config.json"), b"{}").expect("Failed to write HF file");
516
517        let ignored_dir = cache_path.join("tmp");
518        fs::create_dir_all(&ignored_dir).expect("Failed to create ignored directory");
519        fs::write(ignored_dir.join("scratch.txt"), b"ignore")
520            .expect("Failed to write ignored file");
521
522        let stats = create_test_cache_config(cache_path)
523            .get_cache_stats()
524            .expect("Failed to get cache stats");
525
526        assert_eq!(stats.total_models, 1);
527        assert_eq!(stats.total_size, 2);
528        assert_eq!(stats.models.len(), 1);
529
530        assert_eq!(stats.models[0].provider, ModelProvider::HuggingFace);
531        assert_eq!(stats.models[0].name, "google/t5-small");
532        assert_eq!(stats.models[0].size, 2);
533        assert_eq!(stats.models[0].path, hf_model_dir);
534        assert!(stats.models.iter().all(|model| model.name != "tmp"));
535    }
536
537    #[test]
538    fn test_clear_model_removes_only_requested_layout() {
539        let temp_dir = TempDir::new().expect("Failed to create temp directory");
540        let cache_path = temp_dir.path().join("cache");
541        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
542
543        let hf_model_dir = cache_path.join("models--google--t5-small");
544        fs::create_dir_all(&hf_model_dir).expect("Failed to create HF model directory");
545        fs::write(hf_model_dir.join("config.json"), b"{}").expect("Failed to write HF file");
546
547        let config = create_test_cache_config(cache_path);
548
549        config
550            .clear_model("google/t5-small", ModelProvider::HuggingFace)
551            .expect("Failed to clear HF model");
552        assert!(!hf_model_dir.exists(), "HF model should be removed");
553    }
554
555    #[test]
556    fn test_clear_all_removes_contents_but_keeps_directory() {
557        let temp_dir = TempDir::new().expect("Failed to create temp directory");
558        let cache_path = temp_dir.path().join("cache");
559        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
560
561        // Create some test content
562        let model_dir = cache_path.join("models--test--model");
563        fs::create_dir_all(&model_dir).expect("Failed to create model directory");
564        fs::write(model_dir.join("config.json"), "{}").expect("Failed to write file");
565        fs::write(cache_path.join("test_file.txt"), "test").expect("Failed to write file");
566
567        let config = create_test_cache_config(cache_path.clone());
568
569        // Clear cache
570        config.clear_all().expect("Failed to clear cache");
571
572        // Directory should still exist but be empty
573        assert!(cache_path.exists(), "Cache directory should still exist");
574        assert!(
575            fs::read_dir(&cache_path)
576                .expect("Failed to read dir")
577                .next()
578                .is_none(),
579            "Cache directory should be empty"
580        );
581    }
582
583    #[test]
584    fn test_clear_all_handles_nonexistent_directory() {
585        let temp_dir = TempDir::new().expect("Failed to create temp directory");
586        let cache_path = temp_dir.path().join("nonexistent_cache");
587
588        let config = create_test_cache_config(cache_path.clone());
589
590        // Should succeed without error even if directory doesn't exist
591        config
592            .clear_all()
593            .with_context(|| format!("Failed to clear cache: {cache_path:?}"))
594            .expect("Failed to clear cache");
595        assert!(!cache_path.exists());
596    }
597
598    #[test]
599    fn test_clear_all_removes_nested_directories() {
600        let temp_dir = TempDir::new().expect("Failed to create temp directory");
601        let cache_path = temp_dir.path().join("cache");
602        fs::create_dir_all(&cache_path).expect("Failed to create cache directory");
603
604        // Create nested structure
605        let deep_path = cache_path.join("a").join("b").join("c");
606        fs::create_dir_all(&deep_path).expect("Failed to create nested directories");
607        fs::write(deep_path.join("deep_file.txt"), "deep").expect("Failed to write file");
608
609        let config = create_test_cache_config(cache_path.clone());
610
611        config.clear_all().expect("Failed to clear cache");
612
613        assert!(cache_path.exists(), "Cache directory should still exist");
614        assert!(
615            fs::read_dir(&cache_path)
616                .expect("Failed to read dir")
617                .next()
618                .is_none(),
619            "Cache directory should be empty after clearing nested content"
620        );
621    }
622}