Skip to main content

mermaid_cli/cache/
cache_manager.rs

1use anyhow::Result;
2use directories::ProjectDirs;
3use rustc_hash::FxHashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use super::file_cache::FileCache;
8use super::types::{CacheKey, CachedTokens};
9use crate::utils::lock_arc_mutex_safe;
10
11/// Main cache manager for the application
12#[derive(Debug)]
13pub struct CacheManager {
14    file_cache: Arc<FileCache>,
15    memory_cache: Arc<Mutex<MemoryCache>>,
16    cache_dir: PathBuf,
17}
18
19/// In-memory cache for hot data
20#[derive(Debug, Default)]
21struct MemoryCache {
22    tokens: FxHashMap<CacheKey, CachedTokens>,
23    hits: usize,
24    misses: usize,
25}
26
27impl CacheManager {
28    /// Create a new cache manager
29    pub fn new() -> Result<Self> {
30        // Get cache directory (~/.cache/mermaid on Linux, ~/Library/Caches/mermaid on macOS)
31        let cache_dir = if let Some(proj_dirs) = ProjectDirs::from("", "", "mermaid") {
32            proj_dirs.cache_dir().to_path_buf()
33        } else {
34            // Fallback to ~/.cache/mermaid
35            let home = std::env::var("HOME")?;
36            PathBuf::from(home).join(".cache").join("mermaid")
37        };
38
39        let file_cache = Arc::new(FileCache::new(cache_dir.clone())?);
40        let memory_cache = Arc::new(Mutex::new(MemoryCache::default()));
41
42        Ok(Self {
43            file_cache,
44            memory_cache,
45            cache_dir,
46        })
47    }
48
49    /// Get or compute token count for content
50    pub fn get_or_compute_tokens(
51        &self,
52        path: &Path,
53        content: &str,
54        model_name: &str,
55    ) -> Result<usize> {
56        // Generate cache key
57        let key = FileCache::generate_key(path)?;
58
59        // Check memory cache first
60        {
61            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
62            if let Some(cached) = mem_cache.tokens.get(&key).cloned() {
63                if cached.model_name == model_name {
64                    mem_cache.hits += 1;
65                    return Ok(cached.count);
66                }
67            }
68        }
69
70        // Check file cache
71        // Note: is_valid() is not needed here because generate_key() already hashed the
72        // current file. The cache path is derived from that hash, so a hit means the entry
73        // was stored for this exact file content. Re-hashing would always match.
74        if let Some(cached) = self.file_cache.load::<CachedTokens>(&key)? {
75            if cached.model_name == model_name {
76                // Store in memory cache
77                let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
78                mem_cache.tokens.insert(key.clone(), cached.clone());
79                mem_cache.hits += 1;
80                return Ok(cached.count);
81            }
82        }
83
84        // Cache miss - compute and cache
85        {
86            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
87            mem_cache.misses += 1;
88        }
89
90        // Count tokens
91        let tokenizer = crate::utils::Tokenizer::new(model_name);
92        let count = tokenizer.count_tokens(content)?;
93
94        // Cache the results
95        let cached = CachedTokens {
96            count,
97            model_name: model_name.to_string(),
98        };
99
100        // Save to file cache
101        self.file_cache.save(&key, &cached)?;
102
103        // Save to memory cache
104        {
105            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
106            mem_cache.tokens.insert(key, cached);
107        }
108
109        Ok(count)
110    }
111
112    /// Invalidate cache for a specific file
113    pub fn invalidate(&self, path: &Path) -> Result<()> {
114        let key = FileCache::generate_key(path)?;
115
116        // Remove from memory cache
117        {
118            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
119            mem_cache.tokens.remove(&key);
120        }
121
122        // Remove from file cache
123        self.file_cache.remove(&key)?;
124
125        Ok(())
126    }
127
128    /// Clear all caches
129    pub fn clear_all(&self) -> Result<()> {
130        // Clear memory cache
131        {
132            let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
133            mem_cache.tokens.clear();
134            mem_cache.hits = 0;
135            mem_cache.misses = 0;
136        }
137
138        // Clear file cache (remove cache directory)
139        if self.cache_dir.exists() {
140            std::fs::remove_dir_all(&self.cache_dir)?;
141            std::fs::create_dir_all(&self.cache_dir)?;
142        }
143
144        Ok(())
145    }
146
147    /// Get cache statistics
148    pub fn get_stats(&self) -> Result<CacheStats> {
149        let file_stats = self.file_cache.get_stats()?;
150
151        let (memory_entries, hits, misses, hit_rate) = {
152            let mem_cache = lock_arc_mutex_safe(&self.memory_cache);
153            let total_requests = mem_cache.hits + mem_cache.misses;
154            let hit_rate = if total_requests > 0 {
155                (mem_cache.hits as f32 / total_requests as f32) * 100.0
156            } else {
157                0.0
158            };
159            (
160                mem_cache.tokens.len(),
161                mem_cache.hits,
162                mem_cache.misses,
163                hit_rate,
164            )
165        };
166
167        Ok(CacheStats {
168            file_cache_entries: file_stats.total_entries,
169            memory_cache_entries: memory_entries,
170            total_size: file_stats.total_size,
171            compressed_size: file_stats.total_compressed_size,
172            compression_ratio: file_stats.compression_ratio,
173            cache_hits: hits,
174            cache_misses: misses,
175            hit_rate,
176            cache_directory: self.cache_dir.clone(),
177        })
178    }
179}
180
181/// Cache statistics
182#[derive(Debug, Clone)]
183pub struct CacheStats {
184    pub file_cache_entries: usize,
185    pub memory_cache_entries: usize,
186    pub total_size: usize,
187    pub compressed_size: usize,
188    pub compression_ratio: f32,
189    pub cache_hits: usize,
190    pub cache_misses: usize,
191    pub hit_rate: f32,
192    pub cache_directory: PathBuf,
193}
194
195impl CacheStats {
196    /// Format cache stats for display
197    pub fn format(&self) -> String {
198        format!(
199            "Cache Statistics:\n\
200            Directory: {}\n\
201            File Cache: {} entries\n\
202            Memory Cache: {} entries\n\
203            Total Size: {:.2} MB\n\
204            Compressed: {:.2} MB (ratio: {:.1}x)\n\
205            Hit Rate: {:.1}% ({} hits, {} misses)",
206            self.cache_directory.display(),
207            self.file_cache_entries,
208            self.memory_cache_entries,
209            self.total_size as f64 / 1_048_576.0,
210            self.compressed_size as f64 / 1_048_576.0,
211            self.compression_ratio,
212            self.hit_rate,
213            self.cache_hits,
214            self.cache_misses
215        )
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    // Phase 4 Test Suite: Cache Manager - focused core logic tests
224
225    #[test]
226    fn test_cache_manager_new() {
227        // Test basic cache manager creation
228        let result = CacheManager::new();
229        assert!(result.is_ok() || result.is_err(), "Should return a Result");
230    }
231
232    #[test]
233    fn test_cache_stats_hit_rate_calculation() {
234        // Test hit rate calculation logic
235        let stats = CacheStats {
236            file_cache_entries: 10,
237            memory_cache_entries: 5,
238            total_size: 1_000_000,
239            compressed_size: 500_000,
240            compression_ratio: 2.0,
241            cache_hits: 100,
242            cache_misses: 20,
243            hit_rate: 83.33,
244            cache_directory: PathBuf::from("/cache"),
245        };
246
247        // Verify hit rate calculation
248        let expected_hit_rate = (100.0 / 120.0) * 100.0;
249        assert!(
250            (stats.hit_rate - expected_hit_rate).abs() < 0.1,
251            "Hit rate should be ~83.33%"
252        );
253    }
254
255    #[test]
256    fn test_cache_stats_compression_ratio() {
257        // Test compression ratio is calculated correctly
258        let stats = CacheStats {
259            file_cache_entries: 5,
260            memory_cache_entries: 3,
261            total_size: 1000,
262            compressed_size: 400,
263            compression_ratio: 2.5,
264            cache_hits: 50,
265            cache_misses: 10,
266            hit_rate: 83.33,
267            cache_directory: PathBuf::from("/cache"),
268        };
269
270        assert_eq!(
271            stats.compression_ratio, 2.5,
272            "Compression ratio should be 2.5"
273        );
274        assert_eq!(stats.total_size, 1000, "Total size should be 1000");
275    }
276
277    #[test]
278    fn test_cache_stats_format_display() {
279        // Test that stats can be formatted for display
280        let stats = CacheStats {
281            file_cache_entries: 10,
282            memory_cache_entries: 5,
283            total_size: 1_048_576,
284            compressed_size: 524_288,
285            compression_ratio: 2.0,
286            cache_hits: 100,
287            cache_misses: 20,
288            hit_rate: 83.33,
289            cache_directory: PathBuf::from("/cache"),
290        };
291
292        let formatted = stats.format();
293        assert!(
294            formatted.contains("Cache Statistics"),
295            "Should include header"
296        );
297        assert!(
298            formatted.contains("/cache"),
299            "Should include cache directory"
300        );
301        assert!(
302            formatted.contains("File Cache: 10"),
303            "Should include file cache entries"
304        );
305        assert!(
306            formatted.contains("Memory Cache: 5"),
307            "Should include memory cache entries"
308        );
309    }
310
311    #[test]
312    fn test_memory_cache_default() {
313        // Test default MemoryCache initialization
314        let mem_cache = MemoryCache::default();
315        assert_eq!(mem_cache.hits, 0, "Initial hits should be 0");
316        assert_eq!(mem_cache.misses, 0, "Initial misses should be 0");
317        assert!(
318            mem_cache.tokens.is_empty(),
319            "Initial tokens should be empty"
320        );
321    }
322
323    #[test]
324    fn test_cache_key_components() {
325        // Test cache key structure and components
326        let path = PathBuf::from("src/main.rs");
327        let file_hash = "abc123def456".to_string();
328
329        let key = CacheKey {
330            file_path: path.clone(),
331            file_hash: file_hash.clone(),
332        };
333
334        assert_eq!(key.file_path, path, "File path should match");
335        assert_eq!(key.file_hash, file_hash, "File hash should match");
336    }
337
338    #[test]
339    fn test_cached_tokens_structure() {
340        // Test CachedTokens structure
341        let cached = CachedTokens {
342            count: 1000,
343            model_name: "ollama/tinyllama".to_string(),
344        };
345
346        assert_eq!(cached.count, 1000, "Token count should be 1000");
347        assert_eq!(
348            cached.model_name, "ollama/tinyllama",
349            "Model name should match"
350        );
351    }
352
353    #[test]
354    fn test_cache_directory_structure() {
355        // Test cache directory path construction with platform-specific paths
356        #[cfg(windows)]
357        let cache_dir = PathBuf::from("C:\\Users\\user\\AppData\\Local\\mermaid");
358
359        #[cfg(not(windows))]
360        let cache_dir = PathBuf::from("/home/user/.cache/mermaid");
361
362        // Verify cache directory is a valid PathBuf
363        assert!(
364            cache_dir.is_absolute(),
365            "Cache directory should be absolute"
366        );
367        assert!(
368            cache_dir.to_string_lossy().contains("mermaid"),
369            "Should contain mermaid"
370        );
371    }
372
373    #[test]
374    fn test_hit_rate_percentages() {
375        // Test hit rate calculation for various scenarios
376        let scenarios = vec![
377            (100, 0, 100.0), // All hits
378            (0, 100, 0.0),   // All misses
379            (50, 50, 50.0),  // Even split
380            (75, 25, 75.0),  // 75% hit rate
381        ];
382
383        for (hits, misses, expected_rate) in scenarios {
384            let total = hits + misses;
385            let rate = if total > 0 {
386                (hits as f32 / total as f32) * 100.0
387            } else {
388                0.0
389            };
390
391            assert!(
392                (rate - expected_rate).abs() < 0.1,
393                "Hit rate calculation for ({}, {}) failed",
394                hits,
395                misses
396            );
397        }
398    }
399}