1use anyhow::Result;
2use directories::ProjectDirs;
3use rustc_hash::FxHashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use super::file_cache::FileCache;
8use super::types::{CacheKey, CachedTokens};
9use crate::utils::lock_arc_mutex_safe;
10
11#[derive(Debug)]
13pub struct CacheManager {
14 file_cache: Arc<FileCache>,
15 memory_cache: Arc<Mutex<MemoryCache>>,
16 cache_dir: PathBuf,
17}
18
19#[derive(Debug, Default)]
21struct MemoryCache {
22 tokens: FxHashMap<CacheKey, CachedTokens>,
23 hits: usize,
24 misses: usize,
25}
26
27impl CacheManager {
28 pub fn new() -> Result<Self> {
30 let cache_dir = if let Some(proj_dirs) = ProjectDirs::from("", "", "mermaid") {
32 proj_dirs.cache_dir().to_path_buf()
33 } else {
34 let home = std::env::var("HOME")?;
36 PathBuf::from(home).join(".cache").join("mermaid")
37 };
38
39 let file_cache = Arc::new(FileCache::new(cache_dir.clone())?);
40 let memory_cache = Arc::new(Mutex::new(MemoryCache::default()));
41
42 Ok(Self {
43 file_cache,
44 memory_cache,
45 cache_dir,
46 })
47 }
48
49 pub fn get_or_compute_tokens(
51 &self,
52 path: &Path,
53 content: &str,
54 model_name: &str,
55 ) -> Result<usize> {
56 let key = FileCache::generate_key(path)?;
58
59 {
61 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
62 if let Some(cached) = mem_cache.tokens.get(&key).cloned() {
63 if cached.model_name == model_name {
64 mem_cache.hits += 1;
65 return Ok(cached.count);
66 }
67 }
68 }
69
70 if let Some(cached) = self.file_cache.load::<CachedTokens>(&key)? {
75 if cached.model_name == model_name {
76 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
78 mem_cache.tokens.insert(key.clone(), cached.clone());
79 mem_cache.hits += 1;
80 return Ok(cached.count);
81 }
82 }
83
84 {
86 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
87 mem_cache.misses += 1;
88 }
89
90 let tokenizer = crate::utils::Tokenizer::new(model_name);
92 let count = tokenizer.count_tokens(content)?;
93
94 let cached = CachedTokens {
96 count,
97 model_name: model_name.to_string(),
98 };
99
100 self.file_cache.save(&key, &cached)?;
102
103 {
105 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
106 mem_cache.tokens.insert(key, cached);
107 }
108
109 Ok(count)
110 }
111
112 pub fn invalidate(&self, path: &Path) -> Result<()> {
114 let key = FileCache::generate_key(path)?;
115
116 {
118 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
119 mem_cache.tokens.remove(&key);
120 }
121
122 self.file_cache.remove(&key)?;
124
125 Ok(())
126 }
127
128 pub fn clear_all(&self) -> Result<()> {
130 {
132 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
133 mem_cache.tokens.clear();
134 mem_cache.hits = 0;
135 mem_cache.misses = 0;
136 }
137
138 if self.cache_dir.exists() {
140 std::fs::remove_dir_all(&self.cache_dir)?;
141 std::fs::create_dir_all(&self.cache_dir)?;
142 }
143
144 Ok(())
145 }
146
147 pub fn get_stats(&self) -> Result<CacheStats> {
149 let file_stats = self.file_cache.get_stats()?;
150
151 let (memory_entries, hits, misses, hit_rate) = {
152 let mem_cache = lock_arc_mutex_safe(&self.memory_cache);
153 let total_requests = mem_cache.hits + mem_cache.misses;
154 let hit_rate = if total_requests > 0 {
155 (mem_cache.hits as f32 / total_requests as f32) * 100.0
156 } else {
157 0.0
158 };
159 (
160 mem_cache.tokens.len(),
161 mem_cache.hits,
162 mem_cache.misses,
163 hit_rate,
164 )
165 };
166
167 Ok(CacheStats {
168 file_cache_entries: file_stats.total_entries,
169 memory_cache_entries: memory_entries,
170 total_size: file_stats.total_size,
171 compressed_size: file_stats.total_compressed_size,
172 compression_ratio: file_stats.compression_ratio,
173 cache_hits: hits,
174 cache_misses: misses,
175 hit_rate,
176 cache_directory: self.cache_dir.clone(),
177 })
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct CacheStats {
184 pub file_cache_entries: usize,
185 pub memory_cache_entries: usize,
186 pub total_size: usize,
187 pub compressed_size: usize,
188 pub compression_ratio: f32,
189 pub cache_hits: usize,
190 pub cache_misses: usize,
191 pub hit_rate: f32,
192 pub cache_directory: PathBuf,
193}
194
195impl CacheStats {
196 pub fn format(&self) -> String {
198 format!(
199 "Cache Statistics:\n\
200 Directory: {}\n\
201 File Cache: {} entries\n\
202 Memory Cache: {} entries\n\
203 Total Size: {:.2} MB\n\
204 Compressed: {:.2} MB (ratio: {:.1}x)\n\
205 Hit Rate: {:.1}% ({} hits, {} misses)",
206 self.cache_directory.display(),
207 self.file_cache_entries,
208 self.memory_cache_entries,
209 self.total_size as f64 / 1_048_576.0,
210 self.compressed_size as f64 / 1_048_576.0,
211 self.compression_ratio,
212 self.hit_rate,
213 self.cache_hits,
214 self.cache_misses
215 )
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
226 fn test_cache_manager_new() {
227 let result = CacheManager::new();
229 assert!(result.is_ok() || result.is_err(), "Should return a Result");
230 }
231
232 #[test]
233 fn test_cache_stats_hit_rate_calculation() {
234 let stats = CacheStats {
236 file_cache_entries: 10,
237 memory_cache_entries: 5,
238 total_size: 1_000_000,
239 compressed_size: 500_000,
240 compression_ratio: 2.0,
241 cache_hits: 100,
242 cache_misses: 20,
243 hit_rate: 83.33,
244 cache_directory: PathBuf::from("/cache"),
245 };
246
247 let expected_hit_rate = (100.0 / 120.0) * 100.0;
249 assert!(
250 (stats.hit_rate - expected_hit_rate).abs() < 0.1,
251 "Hit rate should be ~83.33%"
252 );
253 }
254
255 #[test]
256 fn test_cache_stats_compression_ratio() {
257 let stats = CacheStats {
259 file_cache_entries: 5,
260 memory_cache_entries: 3,
261 total_size: 1000,
262 compressed_size: 400,
263 compression_ratio: 2.5,
264 cache_hits: 50,
265 cache_misses: 10,
266 hit_rate: 83.33,
267 cache_directory: PathBuf::from("/cache"),
268 };
269
270 assert_eq!(
271 stats.compression_ratio, 2.5,
272 "Compression ratio should be 2.5"
273 );
274 assert_eq!(stats.total_size, 1000, "Total size should be 1000");
275 }
276
277 #[test]
278 fn test_cache_stats_format_display() {
279 let stats = CacheStats {
281 file_cache_entries: 10,
282 memory_cache_entries: 5,
283 total_size: 1_048_576,
284 compressed_size: 524_288,
285 compression_ratio: 2.0,
286 cache_hits: 100,
287 cache_misses: 20,
288 hit_rate: 83.33,
289 cache_directory: PathBuf::from("/cache"),
290 };
291
292 let formatted = stats.format();
293 assert!(
294 formatted.contains("Cache Statistics"),
295 "Should include header"
296 );
297 assert!(
298 formatted.contains("/cache"),
299 "Should include cache directory"
300 );
301 assert!(
302 formatted.contains("File Cache: 10"),
303 "Should include file cache entries"
304 );
305 assert!(
306 formatted.contains("Memory Cache: 5"),
307 "Should include memory cache entries"
308 );
309 }
310
311 #[test]
312 fn test_memory_cache_default() {
313 let mem_cache = MemoryCache::default();
315 assert_eq!(mem_cache.hits, 0, "Initial hits should be 0");
316 assert_eq!(mem_cache.misses, 0, "Initial misses should be 0");
317 assert!(
318 mem_cache.tokens.is_empty(),
319 "Initial tokens should be empty"
320 );
321 }
322
323 #[test]
324 fn test_cache_key_components() {
325 let path = PathBuf::from("src/main.rs");
327 let file_hash = "abc123def456".to_string();
328
329 let key = CacheKey {
330 file_path: path.clone(),
331 file_hash: file_hash.clone(),
332 };
333
334 assert_eq!(key.file_path, path, "File path should match");
335 assert_eq!(key.file_hash, file_hash, "File hash should match");
336 }
337
338 #[test]
339 fn test_cached_tokens_structure() {
340 let cached = CachedTokens {
342 count: 1000,
343 model_name: "ollama/tinyllama".to_string(),
344 };
345
346 assert_eq!(cached.count, 1000, "Token count should be 1000");
347 assert_eq!(
348 cached.model_name, "ollama/tinyllama",
349 "Model name should match"
350 );
351 }
352
353 #[test]
354 fn test_cache_directory_structure() {
355 #[cfg(windows)]
357 let cache_dir = PathBuf::from("C:\\Users\\user\\AppData\\Local\\mermaid");
358
359 #[cfg(not(windows))]
360 let cache_dir = PathBuf::from("/home/user/.cache/mermaid");
361
362 assert!(
364 cache_dir.is_absolute(),
365 "Cache directory should be absolute"
366 );
367 assert!(
368 cache_dir.to_string_lossy().contains("mermaid"),
369 "Should contain mermaid"
370 );
371 }
372
373 #[test]
374 fn test_hit_rate_percentages() {
375 let scenarios = vec![
377 (100, 0, 100.0), (0, 100, 0.0), (50, 50, 50.0), (75, 25, 75.0), ];
382
383 for (hits, misses, expected_rate) in scenarios {
384 let total = hits + misses;
385 let rate = if total > 0 {
386 (hits as f32 / total as f32) * 100.0
387 } else {
388 0.0
389 };
390
391 assert!(
392 (rate - expected_rate).abs() < 0.1,
393 "Hit rate calculation for ({}, {}) failed",
394 hits,
395 misses
396 );
397 }
398 }
399}