1use anyhow::Result;
2use directories::ProjectDirs;
3use rustc_hash::FxHashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use super::file_cache::FileCache;
8use super::types::{CacheKey, CachedTokens};
9use crate::utils::lock_arc_mutex_safe;
10
11#[derive(Debug)]
13pub struct CacheManager {
14 file_cache: Arc<FileCache>,
15 memory_cache: Arc<Mutex<MemoryCache>>,
16 cache_dir: PathBuf,
17}
18
19#[derive(Debug, Default)]
21struct MemoryCache {
22 tokens: FxHashMap<CacheKey, CachedTokens>,
23 hits: usize,
24 misses: usize,
25}
26
27impl CacheManager {
28 pub fn new() -> Result<Self> {
30 let cache_dir = if let Some(proj_dirs) = ProjectDirs::from("", "", "mermaid") {
32 proj_dirs.cache_dir().to_path_buf()
33 } else {
34 let home = std::env::var("HOME")?;
36 PathBuf::from(home).join(".cache").join("mermaid")
37 };
38
39 let file_cache = Arc::new(FileCache::new(cache_dir.clone())?);
40 let memory_cache = Arc::new(Mutex::new(MemoryCache::default()));
41
42 Ok(Self {
43 file_cache,
44 memory_cache,
45 cache_dir,
46 })
47 }
48
49 pub fn get_or_compute_tokens(
51 &self,
52 path: &Path,
53 content: &str,
54 model_name: &str,
55 ) -> Result<usize> {
56 let key = FileCache::generate_key(path)?;
58
59 {
61 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
62 if let Some(cached) = mem_cache.tokens.get(&key).cloned() {
63 if cached.model_name == model_name {
64 mem_cache.hits += 1;
65 return Ok(cached.count);
66 }
67 }
68 }
69
70 if let Some(cached) = self.file_cache.load::<CachedTokens>(&key)? {
72 if cached.model_name == model_name {
73 if self.file_cache.is_valid(&key)? {
75 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
77 mem_cache.tokens.insert(key.clone(), cached.clone());
78 mem_cache.hits += 1;
79 return Ok(cached.count);
80 } else {
81 self.file_cache.remove(&key)?;
83 }
84 }
85 }
86
87 {
89 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
90 mem_cache.misses += 1;
91 }
92
93 let tokenizer = crate::utils::Tokenizer::new(model_name);
95 let count = tokenizer.count_tokens(content)?;
96
97 let cached = CachedTokens {
99 count,
100 model_name: model_name.to_string(),
101 };
102
103 self.file_cache.save(&key, &cached)?;
105
106 {
108 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
109 mem_cache.tokens.insert(key, cached);
110 }
111
112 Ok(count)
113 }
114
115 pub fn invalidate(&self, path: &Path) -> Result<()> {
117 let key = FileCache::generate_key(path)?;
118
119 {
121 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
122 mem_cache.tokens.remove(&key);
123 }
124
125 self.file_cache.remove(&key)?;
127
128 Ok(())
129 }
130
131 pub fn clear_all(&self) -> Result<()> {
133 {
135 let mut mem_cache = lock_arc_mutex_safe(&self.memory_cache);
136 mem_cache.tokens.clear();
137 mem_cache.hits = 0;
138 mem_cache.misses = 0;
139 }
140
141 if self.cache_dir.exists() {
143 std::fs::remove_dir_all(&self.cache_dir)?;
144 std::fs::create_dir_all(&self.cache_dir)?;
145 }
146
147 Ok(())
148 }
149
150 pub fn get_stats(&self) -> Result<CacheStats> {
152 let file_stats = self.file_cache.get_stats()?;
153
154 let (memory_entries, hits, misses, hit_rate) = {
155 let mem_cache = lock_arc_mutex_safe(&self.memory_cache);
156 let total_requests = mem_cache.hits + mem_cache.misses;
157 let hit_rate = if total_requests > 0 {
158 (mem_cache.hits as f32 / total_requests as f32) * 100.0
159 } else {
160 0.0
161 };
162 (
163 mem_cache.tokens.len(),
164 mem_cache.hits,
165 mem_cache.misses,
166 hit_rate,
167 )
168 };
169
170 Ok(CacheStats {
171 file_cache_entries: file_stats.total_entries,
172 memory_cache_entries: memory_entries,
173 total_size: file_stats.total_size,
174 compressed_size: file_stats.total_compressed_size,
175 compression_ratio: file_stats.compression_ratio,
176 cache_hits: hits,
177 cache_misses: misses,
178 hit_rate,
179 cache_directory: self.cache_dir.clone(),
180 })
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct CacheStats {
187 pub file_cache_entries: usize,
188 pub memory_cache_entries: usize,
189 pub total_size: usize,
190 pub compressed_size: usize,
191 pub compression_ratio: f32,
192 pub cache_hits: usize,
193 pub cache_misses: usize,
194 pub hit_rate: f32,
195 pub cache_directory: PathBuf,
196}
197
198impl CacheStats {
199 pub fn format(&self) -> String {
201 format!(
202 "Cache Statistics:\n\
203 Directory: {}\n\
204 File Cache: {} entries\n\
205 Memory Cache: {} entries\n\
206 Total Size: {:.2} MB\n\
207 Compressed: {:.2} MB (ratio: {:.1}x)\n\
208 Hit Rate: {:.1}% ({} hits, {} misses)",
209 self.cache_directory.display(),
210 self.file_cache_entries,
211 self.memory_cache_entries,
212 self.total_size as f64 / 1_048_576.0,
213 self.compressed_size as f64 / 1_048_576.0,
214 self.compression_ratio,
215 self.hit_rate,
216 self.cache_hits,
217 self.cache_misses
218 )
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
229 fn test_cache_manager_new() {
230 let result = CacheManager::new();
232 assert!(result.is_ok() || result.is_err(), "Should return a Result");
233 }
234
235 #[test]
236 fn test_cache_stats_hit_rate_calculation() {
237 let stats = CacheStats {
239 file_cache_entries: 10,
240 memory_cache_entries: 5,
241 total_size: 1_000_000,
242 compressed_size: 500_000,
243 compression_ratio: 2.0,
244 cache_hits: 100,
245 cache_misses: 20,
246 hit_rate: 83.33,
247 cache_directory: PathBuf::from("/cache"),
248 };
249
250 let expected_hit_rate = (100.0 / 120.0) * 100.0;
252 assert!(
253 (stats.hit_rate - expected_hit_rate).abs() < 0.1,
254 "Hit rate should be ~83.33%"
255 );
256 }
257
258 #[test]
259 fn test_cache_stats_compression_ratio() {
260 let stats = CacheStats {
262 file_cache_entries: 5,
263 memory_cache_entries: 3,
264 total_size: 1000,
265 compressed_size: 400,
266 compression_ratio: 2.5,
267 cache_hits: 50,
268 cache_misses: 10,
269 hit_rate: 83.33,
270 cache_directory: PathBuf::from("/cache"),
271 };
272
273 assert_eq!(
274 stats.compression_ratio, 2.5,
275 "Compression ratio should be 2.5"
276 );
277 assert_eq!(stats.total_size, 1000, "Total size should be 1000");
278 }
279
280 #[test]
281 fn test_cache_stats_format_display() {
282 let stats = CacheStats {
284 file_cache_entries: 10,
285 memory_cache_entries: 5,
286 total_size: 1_048_576,
287 compressed_size: 524_288,
288 compression_ratio: 2.0,
289 cache_hits: 100,
290 cache_misses: 20,
291 hit_rate: 83.33,
292 cache_directory: PathBuf::from("/cache"),
293 };
294
295 let formatted = stats.format();
296 assert!(
297 formatted.contains("Cache Statistics"),
298 "Should include header"
299 );
300 assert!(
301 formatted.contains("/cache"),
302 "Should include cache directory"
303 );
304 assert!(
305 formatted.contains("File Cache: 10"),
306 "Should include file cache entries"
307 );
308 assert!(
309 formatted.contains("Memory Cache: 5"),
310 "Should include memory cache entries"
311 );
312 }
313
314 #[test]
315 fn test_memory_cache_default() {
316 let mem_cache = MemoryCache::default();
318 assert_eq!(mem_cache.hits, 0, "Initial hits should be 0");
319 assert_eq!(mem_cache.misses, 0, "Initial misses should be 0");
320 assert!(
321 mem_cache.tokens.is_empty(),
322 "Initial tokens should be empty"
323 );
324 }
325
326 #[test]
327 fn test_cache_key_components() {
328 let path = PathBuf::from("src/main.rs");
330 let file_hash = "abc123def456".to_string();
331
332 let key = CacheKey {
333 file_path: path.clone(),
334 file_hash: file_hash.clone(),
335 };
336
337 assert_eq!(key.file_path, path, "File path should match");
338 assert_eq!(key.file_hash, file_hash, "File hash should match");
339 }
340
341 #[test]
342 fn test_cached_tokens_structure() {
343 let cached = CachedTokens {
345 count: 1000,
346 model_name: "ollama/tinyllama".to_string(),
347 };
348
349 assert_eq!(cached.count, 1000, "Token count should be 1000");
350 assert_eq!(
351 cached.model_name, "ollama/tinyllama",
352 "Model name should match"
353 );
354 }
355
356 #[test]
357 fn test_cache_directory_structure() {
358 let cache_dir = PathBuf::from("/home/user/.cache/mermaid");
360
361 assert!(
363 cache_dir.is_absolute(),
364 "Cache directory should be absolute"
365 );
366 assert!(
367 cache_dir.to_string_lossy().contains("mermaid"),
368 "Should contain mermaid"
369 );
370 }
371
372 #[test]
373 fn test_hit_rate_percentages() {
374 let scenarios = vec![
376 (100, 0, 100.0), (0, 100, 0.0), (50, 50, 50.0), (75, 25, 75.0), ];
381
382 for (hits, misses, expected_rate) in scenarios {
383 let total = hits + misses;
384 let rate = if total > 0 {
385 (hits as f32 / total as f32) * 100.0
386 } else {
387 0.0
388 };
389
390 assert!(
391 (rate - expected_rate).abs() < 0.1,
392 "Hit rate calculation for ({}, {}) failed",
393 hits,
394 misses
395 );
396 }
397 }
398}