mermaid_cli/cache/
cache_manager.rs

1use anyhow::Result;
2use directories::ProjectDirs;
3use rustc_hash::FxHashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use super::file_cache::FileCache;
8use super::types::{CacheKey, CachedTokens};
9use crate::utils::lock_arc_mutex_safe;
10
11/// Main cache manager for the application
12#[derive(Debug)]
13pub struct CacheManager {
14    file_cache: Arc<FileCache>,
15    memory_cache: Arc<Mutex<MemoryCache>>,
16    cache_dir: PathBuf,
17}
18
19/// In-memory cache for hot data
20#[derive(Debug, Default)]
21struct MemoryCache {
22    tokens: FxHashMap<CacheKey, CachedTokens>,
23    hits: usize,
24    misses: usize,
25}
26
27impl CacheManager {
28    /// Create a new cache manager
29    pub fn new() -> Result<Self> {
30        // Get cache directory (~/.cache/mermaid on Linux, ~/Library/Caches/mermaid on macOS)
31        let cache_dir = if let Some(proj_dirs) = ProjectDirs::from("", "", "mermaid") {
32            proj_dirs.cache_dir().to_path_buf()
33        } else {
34            // Fallback to ~/.cache/mermaid
35            let home = std::env::var("HOME")?;
36            PathBuf::from(home).join(".cache").join("mermaid")
37        };
38
39        let file_cache = Arc::new(FileCache::new(cache_dir.clone())?);
40        let memory_cache = Arc::new(Mutex::new(MemoryCache::default()));
41
42        Ok(Self {
43            file_cache,
44            memory_cache,
45            cache_dir,
46        })
47    }
48
49    /// Get or compute token count for content
50    pub fn get_or_compute_tokens(
51        &self,
52        path: &Path,
53        content: &str,
54        model_name: &str,
55    ) -> Result<usize> {
56        // Generate cache key
57        let key = FileCache::generate_key(path)?;
58
59        // Check memory cache first
60        {
61            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
62            if let Some(cached) = mem_cache.tokens.get(&key).cloned() {
63                if cached.model_name == model_name {
64                    mem_cache.hits += 1;
65                    return Ok(cached.count);
66                }
67            }
68        }
69
70        // Check file cache
71        if let Some(cached) = self.file_cache.load::<CachedTokens>(&key)? {
72            if cached.model_name == model_name {
73                // Validate cache
74                if self.file_cache.is_valid(&key)? {
75                    // Store in memory cache
76                    let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
77                    mem_cache.tokens.insert(key.clone(), cached.clone());
78                    mem_cache.hits += 1;
79                    return Ok(cached.count);
80                } else {
81                    // Invalid cache, remove it
82                    self.file_cache.remove(&key)?;
83                }
84            }
85        }
86
87        // Cache miss - compute and cache
88        {
89            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
90            mem_cache.misses += 1;
91        }
92
93        // Count tokens
94        let tokenizer = crate::utils::Tokenizer::new(model_name);
95        let count = tokenizer.count_tokens(content)?;
96
97        // Cache the results
98        let cached = CachedTokens {
99            count,
100            model_name: model_name.to_string(),
101        };
102
103        // Save to file cache
104        self.file_cache.save(&key, &cached)?;
105
106        // Save to memory cache
107        {
108            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
109            mem_cache.tokens.insert(key, cached);
110        }
111
112        Ok(count)
113    }
114
115    /// Invalidate cache for a specific file
116    pub fn invalidate(&self, path: &Path) -> Result<()> {
117        let key = FileCache::generate_key(path)?;
118
119        // Remove from memory cache
120        {
121            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
122            mem_cache.tokens.remove(&key);
123        }
124
125        // Remove from file cache
126        self.file_cache.remove(&key)?;
127
128        Ok(())
129    }
130
131    /// Clear all caches
132    pub fn clear_all(&self) -> Result<()> {
133        // Clear memory cache
134        {
135            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
136            mem_cache.tokens.clear();
137            mem_cache.hits = 0;
138            mem_cache.misses = 0;
139        }
140
141        // Clear file cache (remove cache directory)
142        if self.cache_dir.exists() {
143            std::fs::remove_dir_all(&self.cache_dir)?;
144            std::fs::create_dir_all(&self.cache_dir)?;
145        }
146
147        Ok(())
148    }
149
150    /// Get cache statistics
151    pub fn get_stats(&self) -> Result<CacheStats> {
152        let file_stats = self.file_cache.get_stats()?;
153
154        let (memory_entries, hits, misses, hit_rate) = {
155            let mem_cache = lock_arc_mutex_safe(&self.memory_cache);
156            let total_requests = mem_cache.hits + mem_cache.misses;
157            let hit_rate = if total_requests > 0 {
158                (mem_cache.hits as f32 / total_requests as f32) * 100.0
159            } else {
160                0.0
161            };
162            (
163                mem_cache.tokens.len(),
164                mem_cache.hits,
165                mem_cache.misses,
166                hit_rate,
167            )
168        };
169
170        Ok(CacheStats {
171            file_cache_entries: file_stats.total_entries,
172            memory_cache_entries: memory_entries,
173            total_size: file_stats.total_size,
174            compressed_size: file_stats.total_compressed_size,
175            compression_ratio: file_stats.compression_ratio,
176            cache_hits: hits,
177            cache_misses: misses,
178            hit_rate,
179            cache_directory: self.cache_dir.clone(),
180        })
181    }
182}
183
184/// Cache statistics
185#[derive(Debug, Clone)]
186pub struct CacheStats {
187    pub file_cache_entries: usize,
188    pub memory_cache_entries: usize,
189    pub total_size: usize,
190    pub compressed_size: usize,
191    pub compression_ratio: f32,
192    pub cache_hits: usize,
193    pub cache_misses: usize,
194    pub hit_rate: f32,
195    pub cache_directory: PathBuf,
196}
197
198impl CacheStats {
199    /// Format cache stats for display
200    pub fn format(&self) -> String {
201        format!(
202            "Cache Statistics:\n\
203            Directory: {}\n\
204            File Cache: {} entries\n\
205            Memory Cache: {} entries\n\
206            Total Size: {:.2} MB\n\
207            Compressed: {:.2} MB (ratio: {:.1}x)\n\
208            Hit Rate: {:.1}% ({} hits, {} misses)",
209            self.cache_directory.display(),
210            self.file_cache_entries,
211            self.memory_cache_entries,
212            self.total_size as f64 / 1_048_576.0,
213            self.compressed_size as f64 / 1_048_576.0,
214            self.compression_ratio,
215            self.hit_rate,
216            self.cache_hits,
217            self.cache_misses
218        )
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    // Phase 4 Test Suite: Cache Manager - focused core logic tests
227
228    #[test]
229    fn test_cache_manager_new() {
230        // Test basic cache manager creation
231        let result = CacheManager::new();
232        assert!(result.is_ok() || result.is_err(), "Should return a Result");
233    }
234
235    #[test]
236    fn test_cache_stats_hit_rate_calculation() {
237        // Test hit rate calculation logic
238        let stats = CacheStats {
239            file_cache_entries: 10,
240            memory_cache_entries: 5,
241            total_size: 1_000_000,
242            compressed_size: 500_000,
243            compression_ratio: 2.0,
244            cache_hits: 100,
245            cache_misses: 20,
246            hit_rate: 83.33,
247            cache_directory: PathBuf::from("/cache"),
248        };
249
250        // Verify hit rate calculation
251        let expected_hit_rate = (100.0 / 120.0) * 100.0;
252        assert!(
253            (stats.hit_rate - expected_hit_rate).abs() < 0.1,
254            "Hit rate should be ~83.33%"
255        );
256    }
257
258    #[test]
259    fn test_cache_stats_compression_ratio() {
260        // Test compression ratio is calculated correctly
261        let stats = CacheStats {
262            file_cache_entries: 5,
263            memory_cache_entries: 3,
264            total_size: 1000,
265            compressed_size: 400,
266            compression_ratio: 2.5,
267            cache_hits: 50,
268            cache_misses: 10,
269            hit_rate: 83.33,
270            cache_directory: PathBuf::from("/cache"),
271        };
272
273        assert_eq!(
274            stats.compression_ratio, 2.5,
275            "Compression ratio should be 2.5"
276        );
277        assert_eq!(stats.total_size, 1000, "Total size should be 1000");
278    }
279
280    #[test]
281    fn test_cache_stats_format_display() {
282        // Test that stats can be formatted for display
283        let stats = CacheStats {
284            file_cache_entries: 10,
285            memory_cache_entries: 5,
286            total_size: 1_048_576,
287            compressed_size: 524_288,
288            compression_ratio: 2.0,
289            cache_hits: 100,
290            cache_misses: 20,
291            hit_rate: 83.33,
292            cache_directory: PathBuf::from("/cache"),
293        };
294
295        let formatted = stats.format();
296        assert!(
297            formatted.contains("Cache Statistics"),
298            "Should include header"
299        );
300        assert!(
301            formatted.contains("/cache"),
302            "Should include cache directory"
303        );
304        assert!(
305            formatted.contains("File Cache: 10"),
306            "Should include file cache entries"
307        );
308        assert!(
309            formatted.contains("Memory Cache: 5"),
310            "Should include memory cache entries"
311        );
312    }
313
314    #[test]
315    fn test_memory_cache_default() {
316        // Test default MemoryCache initialization
317        let mem_cache = MemoryCache::default();
318        assert_eq!(mem_cache.hits, 0, "Initial hits should be 0");
319        assert_eq!(mem_cache.misses, 0, "Initial misses should be 0");
320        assert!(
321            mem_cache.tokens.is_empty(),
322            "Initial tokens should be empty"
323        );
324    }
325
326    #[test]
327    fn test_cache_key_components() {
328        // Test cache key structure and components
329        let path = PathBuf::from("src/main.rs");
330        let file_hash = "abc123def456".to_string();
331
332        let key = CacheKey {
333            file_path: path.clone(),
334            file_hash: file_hash.clone(),
335        };
336
337        assert_eq!(key.file_path, path, "File path should match");
338        assert_eq!(key.file_hash, file_hash, "File hash should match");
339    }
340
341    #[test]
342    fn test_cached_tokens_structure() {
343        // Test CachedTokens structure
344        let cached = CachedTokens {
345            count: 1000,
346            model_name: "ollama/tinyllama".to_string(),
347        };
348
349        assert_eq!(cached.count, 1000, "Token count should be 1000");
350        assert_eq!(
351            cached.model_name, "ollama/tinyllama",
352            "Model name should match"
353        );
354    }
355
356    #[test]
357    fn test_cache_directory_structure() {
358        // Test cache directory path construction
359        let cache_dir = PathBuf::from("/home/user/.cache/mermaid");
360
361        // Verify cache directory is a valid PathBuf
362        assert!(
363            cache_dir.is_absolute(),
364            "Cache directory should be absolute"
365        );
366        assert!(
367            cache_dir.to_string_lossy().contains("mermaid"),
368            "Should contain mermaid"
369        );
370    }
371
372    #[test]
373    fn test_hit_rate_percentages() {
374        // Test hit rate calculation for various scenarios
375        let scenarios = vec![
376            (100, 0, 100.0), // All hits
377            (0, 100, 0.0),   // All misses
378            (50, 50, 50.0),  // Even split
379            (75, 25, 75.0),  // 75% hit rate
380        ];
381
382        for (hits, misses, expected_rate) in scenarios {
383            let total = hits + misses;
384            let rate = if total > 0 {
385                (hits as f32 / total as f32) * 100.0
386            } else {
387                0.0
388            };
389
390            assert!(
391                (rate - expected_rate).abs() < 0.1,
392                "Hit rate calculation for ({}, {}) failed",
393                hits,
394                misses
395            );
396        }
397    }
398}