1use std::{
4 collections::HashMap,
5 path::{Path, PathBuf},
6 sync::RwLock,
7};
8
9use super::{Bundle, BundleError, parse_oci_uri};
10
11#[derive(Debug, thiserror::Error)]
13pub enum CacheError {
14 #[error("IO error: {0}")]
15 IoError(#[from] std::io::Error),
16
17 #[error("Bundle not found in cache: {0}")]
18 NotFound(String),
19
20 #[error("Cache corrupted: {0}")]
21 Corrupted(String),
22
23 #[error("Lock poisoned")]
24 LockPoisoned,
25}
26
27pub struct BundleCache {
29 cache_dir: PathBuf,
30 index: RwLock<HashMap<String, PathBuf>>,
31}
32
33impl BundleCache {
34 pub fn new<P: AsRef<Path>>(cache_dir: P) -> Result<Self, CacheError> {
36 let cache_dir = cache_dir.as_ref().to_path_buf();
37 std::fs::create_dir_all(&cache_dir)?;
38
39 let mut cache = Self {
40 cache_dir,
41 index: RwLock::new(HashMap::new()),
42 };
43
44 cache.rebuild_index()?;
45 Ok(cache)
46 }
47
48 pub fn default_dir() -> PathBuf {
50 dirs::home_dir()
51 .unwrap_or_else(|| PathBuf::from("."))
52 .join(".mcpkit")
53 .join("bundles")
54 }
55
56 fn rebuild_index(&mut self) -> Result<(), CacheError> {
58 let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
59 index.clear();
60
61 self.scan_directory(&self.cache_dir.clone(), &mut index)?;
63
64 Ok(())
65 }
66
67 fn scan_directory(
69 &self,
70 dir: &Path,
71 index: &mut HashMap<String, PathBuf>,
72 ) -> Result<(), CacheError> {
73 if !dir.exists() {
74 return Ok(());
75 }
76
77 for entry in std::fs::read_dir(dir)? {
78 let entry = entry?;
79 let path = entry.path();
80
81 if path.is_dir() {
82 let wasm_path = path.join("module.wasm");
84 let config_path = path.join("config.yaml");
85
86 if wasm_path.exists() && config_path.exists() {
87 if let Some(uri) = self.path_to_uri(&path) {
89 index.insert(uri, path.clone());
90 }
91 } else {
92 self.scan_directory(&path, index)?;
94 }
95 }
96 }
97
98 Ok(())
99 }
100
101 fn path_to_uri(&self, path: &Path) -> Option<String> {
103 let relative = path.strip_prefix(&self.cache_dir).ok()?;
104 let components: Vec<&str> = relative
105 .components()
106 .filter_map(|c| c.as_os_str().to_str())
107 .collect();
108
109 if components.len() >= 3 {
110 let registry = components[0];
113 let repo = components[1..components.len() - 1].join("/");
114 let version = components[components.len() - 1];
115
116 Some(format!("oci://{}/{}:{}", registry, repo, version))
117 } else {
118 None
119 }
120 }
121
122 pub fn uri_to_path(&self, uri: &str) -> Result<PathBuf, BundleError> {
124 let (registry, repository, tag) = parse_oci_uri(uri)?;
125 let tag = tag.unwrap_or_else(|| "latest".to_string());
126
127 let path = self.cache_dir.join(®istry).join(&repository).join(&tag);
129
130 Ok(path)
131 }
132
133 pub fn put(&self, uri: &str, bundle: &Bundle) -> Result<(), CacheError> {
135 let path = self
136 .uri_to_path(uri)
137 .map_err(|e| CacheError::Corrupted(e.to_string()))?;
138
139 std::fs::create_dir_all(&path)?;
141 bundle
142 .save_to_directory(&path)
143 .map_err(|e| CacheError::IoError(std::io::Error::other(e.to_string())))?;
144
145 let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
147 index.insert(uri.to_string(), path);
148
149 Ok(())
150 }
151
152 pub fn get(&self, uri: &str) -> Result<Bundle, CacheError> {
154 let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
155
156 let path = index
157 .get(uri)
158 .ok_or_else(|| CacheError::NotFound(uri.to_string()))?;
159
160 Bundle::from_directory(path)
161 .map_err(|e| CacheError::IoError(std::io::Error::other(e.to_string())))
162 }
163
164 pub fn exists(&self, uri: &str) -> bool {
166 let index = self.index.read().ok();
167 index.is_some_and(|idx| idx.contains_key(uri))
168 }
169
170 pub fn remove(&self, uri: &str) -> Result<(), CacheError> {
172 let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
173
174 if let Some(path) = index.remove(uri) {
175 if path.exists() {
176 std::fs::remove_dir_all(&path)?;
177 }
178 }
179
180 Ok(())
181 }
182
183 pub fn clear(&self) -> Result<(), CacheError> {
185 let mut index = self.index.write().map_err(|_| CacheError::LockPoisoned)?;
186
187 for path in index.values() {
189 if path.exists() {
190 std::fs::remove_dir_all(path)?;
191 }
192 }
193
194 index.clear();
195 Ok(())
196 }
197
198 pub fn list(&self) -> Result<Vec<String>, CacheError> {
200 let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
201
202 Ok(index.keys().cloned().collect())
203 }
204
205 pub fn stats(&self) -> Result<CacheStats, CacheError> {
207 let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
208
209 let mut total_size = 0u64;
210 let mut bundle_count = 0usize;
211
212 for path in index.values() {
213 if path.exists() {
214 bundle_count += 1;
215 total_size += Self::dir_size(path)?;
216 }
217 }
218
219 Ok(CacheStats {
220 bundle_count,
221 total_size,
222 cache_dir: self.cache_dir.clone(),
223 })
224 }
225
226 fn dir_size(path: &Path) -> Result<u64, CacheError> {
228 let mut size = 0u64;
229
230 if path.is_dir() {
231 for entry in std::fs::read_dir(path)? {
232 let entry = entry?;
233 let path = entry.path();
234
235 if path.is_dir() {
236 size += Self::dir_size(&path)?;
237 } else {
238 size += entry.metadata()?.len();
239 }
240 }
241 } else {
242 size = std::fs::metadata(path)?.len();
243 }
244
245 Ok(size)
246 }
247
248 pub fn verify(&self) -> Result<Vec<String>, CacheError> {
250 let index = self.index.read().map_err(|_| CacheError::LockPoisoned)?;
251
252 let mut corrupted = Vec::new();
253
254 for (uri, path) in index.iter() {
255 let wasm_path = path.join("module.wasm");
257 let config_path = path.join("config.yaml");
258
259 if !wasm_path.exists() || !config_path.exists() {
260 corrupted.push(uri.clone());
261 continue;
262 }
263
264 if let Ok(bundle) = Bundle::from_directory(path) {
266 if bundle.verify().is_err() {
267 corrupted.push(uri.clone());
268 }
269 } else {
270 corrupted.push(uri.clone());
271 }
272 }
273
274 Ok(corrupted)
275 }
276}
277
278#[derive(Debug, Clone)]
280pub struct CacheStats {
281 pub bundle_count: usize,
282 pub total_size: u64,
283 pub cache_dir: PathBuf,
284}
285
286impl CacheStats {
287 pub fn format_size(&self) -> String {
289 const UNITS: &[&str] = &["B", "KB", "MB", "GB"];
290 let mut size = self.total_size as f64;
291 let mut unit_idx = 0;
292
293 while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
294 size /= 1024.0;
295 unit_idx += 1;
296 }
297
298 format!("{:.2} {}", size, UNITS[unit_idx])
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use tempfile::TempDir;
305
306 use super::*;
307
308 #[test]
309 fn test_cache_operations() {
310 let temp_dir = TempDir::new().unwrap();
311 let cache = BundleCache::new(temp_dir.path()).unwrap();
312
313 let uri = "oci://ghcr.io/org/tool:v1.0.0";
315 let path = cache.uri_to_path(uri).unwrap();
316 assert!(path.ends_with("ghcr.io/org/tool/v1.0.0"));
317
318 assert!(!cache.exists(uri));
320
321 let bundle = Bundle::new(
323 vec![0x00, 0x61, 0x73, 0x6d],
324 b"version: 1.0".to_vec(),
325 "ghcr.io/org/tool".to_string(),
326 "v1.0.0".to_string(),
327 );
328
329 cache.put(uri, &bundle).unwrap();
330 assert!(cache.exists(uri));
331
332 let retrieved = cache.get(uri).unwrap();
334 assert_eq!(retrieved.wasm, bundle.wasm);
335 assert_eq!(retrieved.config, bundle.config);
336
337 let list = cache.list().unwrap();
339 assert_eq!(list.len(), 1);
340 assert!(list.contains(&uri.to_string()));
341
342 let stats = cache.stats().unwrap();
344 assert_eq!(stats.bundle_count, 1);
345 assert!(stats.total_size > 0);
346
347 let corrupted = cache.verify().unwrap();
349 assert!(corrupted.is_empty());
350
351 cache.remove(uri).unwrap();
353 assert!(!cache.exists(uri));
354
355 cache.put(uri, &bundle).unwrap();
357 cache.clear().unwrap();
358 assert!(cache.list().unwrap().is_empty());
359 }
360
361 #[test]
362 fn test_cache_path_conversion() {
363 let temp_dir = TempDir::new().unwrap();
364 let cache = BundleCache::new(temp_dir.path()).unwrap();
365
366 let test_cases = vec![
368 ("oci://ghcr.io/org/tool:latest", "ghcr.io/org/tool/latest"),
369 ("oci://docker.io/user/app:v2.0", "docker.io/user/app/v2.0"),
370 (
371 "oci://localhost:5000/test/bundle:tag",
372 "localhost:5000/test/bundle/tag",
373 ),
374 ];
375
376 for (uri, expected_suffix) in test_cases {
377 let path = cache.uri_to_path(uri).unwrap();
378 assert!(path.to_string_lossy().ends_with(expected_suffix));
379 }
380 }
381
382 #[test]
383 fn test_stats_format_size() {
384 let stats = CacheStats {
385 bundle_count: 5,
386 total_size: 1536, cache_dir: PathBuf::from("/tmp/cache"),
388 };
389
390 assert_eq!(stats.format_size(), "1.50 KB");
391
392 let stats_mb = CacheStats {
393 bundle_count: 10,
394 total_size: 5_242_880, cache_dir: PathBuf::from("/tmp/cache"),
396 };
397
398 assert_eq!(stats_mb.format_size(), "5.00 MB");
399 }
400}