Skip to main content

mcpkit_rs/bundle/
cache.rs

1//! Local cache for bundles
2
3use std::{
4    collections::HashMap,
5    path::{Path, PathBuf},
6    sync::RwLock,
7};
8
9use super::{Bundle, BundleError, parse_oci_uri};
10
11/// Cache errors
12#[derive(Debug, thiserror::Error)]
13pub enum CacheError {
14    #[error("IO error: {0}")]
15    IoError(#[from] std::io::Error),
16
17    #[error("Bundle not found in cache: {0}")]
18    NotFound(String),
19
20    #[error("Cache corrupted: {0}")]
21    Corrupted(String),
22
23    #[error("Lock poisoned")]
24    LockPoisoned,
25}
26
27/// Bundle cache for local storage
28pub struct BundleCache {
29    cache_dir: PathBuf,
30    index: RwLock<HashMap<String, PathBuf>>,
31}
32
33impl BundleCache {
34    /// Create a new bundle cache
35    pub fn new<P: AsRef<Path>>(cache_dir: P) -> Result<Self, CacheError> {
36        let cache_dir = cache_dir.as_ref().to_path_buf();
37        std::fs::create_dir_all(&cache_dir)?;
38
39        let mut cache = Self {
40            cache_dir,
41            index: RwLock::new(HashMap::new()),
42        };
43
44        cache.rebuild_index()?;
45        Ok(cache)
46    }
47
48    /// Get default cache directory
49    pub fn default_dir() -> PathBuf {
50        dirs::home_dir()
51            .unwrap_or_else(|| PathBuf::from("."))
52            .join(".mcpkit")
53            .join("bundles")
54    }
55
56    /// Rebuild the cache index by scanning the filesystem
57    fn rebuild_index(&mut self) -> Result<(), CacheError> {
58        let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
59        index.clear();
60
61        // Scan cache directory for bundles
62        self.scan_directory(&self.cache_dir.clone(), &mut index)?;
63
64        Ok(())
65    }
66
67    /// Recursively scan directory for bundles
68    fn scan_directory(
69        &self,
70        dir: &Path,
71        index: &mut HashMap<String, PathBuf>,
72    ) -> Result<(), CacheError> {
73        if !dir.exists() {
74            return Ok(());
75        }
76
77        for entry in std::fs::read_dir(dir)? {
78            let entry = entry?;
79            let path = entry.path();
80
81            if path.is_dir() {
82                // Check if this is a bundle directory (contains module.wasm and config.yaml)
83                let wasm_path = path.join("module.wasm");
84                let config_path = path.join("config.yaml");
85
86                if wasm_path.exists() && config_path.exists() {
87                    // Try to reconstruct the URI from the path
88                    if let Some(uri) = self.path_to_uri(&path) {
89                        index.insert(uri, path.clone());
90                    }
91                } else {
92                    // Recurse into subdirectory
93                    self.scan_directory(&path, index)?;
94                }
95            }
96        }
97
98        Ok(())
99    }
100
101    /// Convert cache path back to URI
102    fn path_to_uri(&self, path: &Path) -> Option<String> {
103        let relative = path.strip_prefix(&self.cache_dir).ok()?;
104        let components: Vec<&str> = relative
105            .components()
106            .filter_map(|c| c.as_os_str().to_str())
107            .collect();
108
109        if components.len() >= 3 {
110            // Format: registry/org/repo/version
111            // Convert back to: oci://registry/org/repo:version
112            let registry = components[0];
113            let repo = components[1..components.len() - 1].join("/");
114            let version = components[components.len() - 1];
115
116            Some(format!("oci://{}/{}:{}", registry, repo, version))
117        } else {
118            None
119        }
120    }
121
122    /// Get the cache path for a URI
123    pub fn uri_to_path(&self, uri: &str) -> Result<PathBuf, BundleError> {
124        let (registry, repository, tag) = parse_oci_uri(uri)?;
125        let tag = tag.unwrap_or_else(|| "latest".to_string());
126
127        // Create path: cache_dir/registry/repository/tag/
128        let path = self.cache_dir.join(&registry).join(&repository).join(&tag);
129
130        Ok(path)
131    }
132
133    /// Store a bundle in cache
134    pub fn put(&self, uri: &str, bundle: &Bundle) -> Result<(), CacheError> {
135        let path = self
136            .uri_to_path(uri)
137            .map_err(|e| CacheError::Corrupted(e.to_string()))?;
138
139        // Create directory and save bundle
140        std::fs::create_dir_all(&path)?;
141        bundle
142            .save_to_directory(&path)
143            .map_err(|e| CacheError::IoError(std::io::Error::other(e.to_string())))?;
144
145        // Update index
146        let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
147        index.insert(uri.to_string(), path);
148
149        Ok(())
150    }
151
152    /// Get a bundle from cache
153    pub fn get(&self, uri: &str) -> Result<Bundle, CacheError> {
154        let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
155
156        let path = index
157            .get(uri)
158            .ok_or_else(|| CacheError::NotFound(uri.to_string()))?;
159
160        Bundle::from_directory(path)
161            .map_err(|e| CacheError::IoError(std::io::Error::other(e.to_string())))
162    }
163
164    /// Check if a bundle exists in cache
165    pub fn exists(&self, uri: &str) -> bool {
166        let index = self.index.read().ok();
167        index.is_some_and(|idx| idx.contains_key(uri))
168    }
169
170    /// Remove a bundle from cache
171    pub fn remove(&self, uri: &str) -> Result<(), CacheError> {
172        let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
173
174        if let Some(path) = index.remove(uri) {
175            if path.exists() {
176                std::fs::remove_dir_all(&path)?;
177            }
178        }
179
180        Ok(())
181    }
182
183    /// Clear entire cache
184    pub fn clear(&self) -> Result<(), CacheError> {
185        let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
186
187        // Remove all cached bundles
188        for path in index.values() {
189            if path.exists() {
190                std::fs::remove_dir_all(path)?;
191            }
192        }
193
194        index.clear();
195        Ok(())
196    }
197
198    /// List all cached bundles
199    pub fn list(&self) -> Result<Vec<String>, CacheError> {
200        let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
201
202        Ok(index.keys().cloned().collect())
203    }
204
205    /// Get cache statistics
206    pub fn stats(&self) -> Result<CacheStats, CacheError> {
207        let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
208
209        let mut total_size = 0u64;
210        let mut bundle_count = 0usize;
211
212        for path in index.values() {
213            if path.exists() {
214                bundle_count += 1;
215                total_size += Self::dir_size(path)?;
216            }
217        }
218
219        Ok(CacheStats {
220            bundle_count,
221            total_size,
222            cache_dir: self.cache_dir.clone(),
223        })
224    }
225
226    /// Calculate directory size recursively
227    fn dir_size(path: &Path) -> Result<u64, CacheError> {
228        let mut size = 0u64;
229
230        if path.is_dir() {
231            for entry in std::fs::read_dir(path)? {
232                let entry = entry?;
233                let path = entry.path();
234
235                if path.is_dir() {
236                    size += Self::dir_size(&path)?;
237                } else {
238                    size += entry.metadata()?.len();
239                }
240            }
241        } else {
242            size = std::fs::metadata(path)?.len();
243        }
244
245        Ok(size)
246    }
247
248    /// Verify cache integrity
249    pub fn verify(&self) -> Result<Vec<String>, CacheError> {
250        let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
251
252        let mut corrupted = Vec::new();
253
254        for (uri, path) in index.iter() {
255            // Check if files exist
256            let wasm_path = path.join("module.wasm");
257            let config_path = path.join("config.yaml");
258
259            if !wasm_path.exists() || !config_path.exists() {
260                corrupted.push(uri.clone());
261                continue;
262            }
263
264            // Try to load and verify bundle
265            if let Ok(bundle) = Bundle::from_directory(path) {
266                if bundle.verify().is_err() {
267                    corrupted.push(uri.clone());
268                }
269            } else {
270                corrupted.push(uri.clone());
271            }
272        }
273
274        Ok(corrupted)
275    }
276}
277
278/// Cache statistics
279#[derive(Debug, Clone)]
280pub struct CacheStats {
281    pub bundle_count: usize,
282    pub total_size: u64,
283    pub cache_dir: PathBuf,
284}
285
286impl CacheStats {
287    /// Format size in human-readable format
288    pub fn format_size(&self) -> String {
289        const UNITS: &[&str] = &["B", "KB", "MB", "GB"];
290        let mut size = self.total_size as f64;
291        let mut unit_idx = 0;
292
293        while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
294            size /= 1024.0;
295            unit_idx += 1;
296        }
297
298        format!("{:.2} {}", size, UNITS[unit_idx])
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use tempfile::TempDir;
305
306    use super::*;
307
308    #[test]
309    fn test_cache_operations() {
310        let temp_dir = TempDir::new().unwrap();
311        let cache = BundleCache::new(temp_dir.path()).unwrap();
312
313        // Test URI to path conversion
314        let uri = "oci://ghcr.io/org/tool:v1.0.0";
315        let path = cache.uri_to_path(uri).unwrap();
316        assert!(path.ends_with("ghcr.io/org/tool/v1.0.0"));
317
318        // Test cache existence check
319        assert!(!cache.exists(uri));
320
321        // Create and store a bundle
322        let bundle = Bundle::new(
323            vec![0x00, 0x61, 0x73, 0x6d],
324            b"version: 1.0".to_vec(),
325            "ghcr.io/org/tool".to_string(),
326            "v1.0.0".to_string(),
327        );
328
329        cache.put(uri, &bundle).unwrap();
330        assert!(cache.exists(uri));
331
332        // Retrieve bundle
333        let retrieved = cache.get(uri).unwrap();
334        assert_eq!(retrieved.wasm, bundle.wasm);
335        assert_eq!(retrieved.config, bundle.config);
336
337        // List cached bundles
338        let list = cache.list().unwrap();
339        assert_eq!(list.len(), 1);
340        assert!(list.contains(&uri.to_string()));
341
342        // Get stats
343        let stats = cache.stats().unwrap();
344        assert_eq!(stats.bundle_count, 1);
345        assert!(stats.total_size > 0);
346
347        // Verify cache
348        let corrupted = cache.verify().unwrap();
349        assert!(corrupted.is_empty());
350
351        // Remove bundle
352        cache.remove(uri).unwrap();
353        assert!(!cache.exists(uri));
354
355        // Clear cache
356        cache.put(uri, &bundle).unwrap();
357        cache.clear().unwrap();
358        assert!(cache.list().unwrap().is_empty());
359    }
360
361    #[test]
362    fn test_cache_path_conversion() {
363        let temp_dir = TempDir::new().unwrap();
364        let cache = BundleCache::new(temp_dir.path()).unwrap();
365
366        // Test various URI formats
367        let test_cases = vec![
368            ("oci://ghcr.io/org/tool:latest", "ghcr.io/org/tool/latest"),
369            ("oci://docker.io/user/app:v2.0", "docker.io/user/app/v2.0"),
370            (
371                "oci://localhost:5000/test/bundle:tag",
372                "localhost:5000/test/bundle/tag",
373            ),
374        ];
375
376        for (uri, expected_suffix) in test_cases {
377            let path = cache.uri_to_path(uri).unwrap();
378            assert!(path.to_string_lossy().ends_with(expected_suffix));
379        }
380    }
381
382    #[test]
383    fn test_stats_format_size() {
384        let stats = CacheStats {
385            bundle_count: 5,
386            total_size: 1536, // 1.5 KB
387            cache_dir: PathBuf::from("/tmp/cache"),
388        };
389
390        assert_eq!(stats.format_size(), "1.50 KB");
391
392        let stats_mb = CacheStats {
393            bundle_count: 10,
394            total_size: 5_242_880, // 5 MB
395            cache_dir: PathBuf::from("/tmp/cache"),
396        };
397
398        assert_eq!(stats_mb.format_size(), "5.00 MB");
399    }
400}