1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29 pub compressed_outputs: HashMap<String, String>,
31}
32
33#[derive(Debug, Clone)]
35pub struct StoreResult {
36 pub line_count: usize,
37 pub original_tokens: usize,
38 pub read_count: u32,
39 pub was_hit: bool,
40}
41
42impl CacheEntry {
43 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
45 let elapsed = now
46 .checked_duration_since(self.last_access)
47 .unwrap_or_default()
48 .as_secs_f64();
49 let recency = 1.0 / (1.0 + elapsed.sqrt());
50 let frequency = (self.read_count as f64 + 1.0).ln();
51 let size_value = (self.original_tokens as f64 + 1.0).ln();
52 recency * 0.4 + frequency * 0.3 + size_value * 0.3
53 }
54
55 pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
56 self.compressed_outputs.get(mode_key)
57 }
58
59 pub fn set_compressed(&mut self, mode_key: &str, output: String) {
60 self.compressed_outputs.insert(mode_key.to_string(), output);
61 }
62}
63
64const RRF_K: f64 = 60.0;
65
66pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
71 if entries.is_empty() {
72 return Vec::new();
73 }
74
75 let n = entries.len();
76
77 let mut recency_order: Vec<usize> = (0..n).collect();
78 recency_order.sort_by(|&a, &b| {
79 let elapsed_a = now
80 .checked_duration_since(entries[a].1.last_access)
81 .unwrap_or_default()
82 .as_secs_f64();
83 let elapsed_b = now
84 .checked_duration_since(entries[b].1.last_access)
85 .unwrap_or_default()
86 .as_secs_f64();
87 elapsed_a
88 .partial_cmp(&elapsed_b)
89 .unwrap_or(std::cmp::Ordering::Equal)
90 });
91
92 let mut frequency_order: Vec<usize> = (0..n).collect();
93 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
94
95 let mut size_order: Vec<usize> = (0..n).collect();
96 size_order.sort_by(|&a, &b| {
97 entries[b]
98 .1
99 .original_tokens
100 .cmp(&entries[a].1.original_tokens)
101 });
102
103 let mut recency_ranks = vec![0usize; n];
104 let mut frequency_ranks = vec![0usize; n];
105 let mut size_ranks = vec![0usize; n];
106
107 for (rank, &idx) in recency_order.iter().enumerate() {
108 recency_ranks[idx] = rank;
109 }
110 for (rank, &idx) in frequency_order.iter().enumerate() {
111 frequency_ranks[idx] = rank;
112 }
113 for (rank, &idx) in size_order.iter().enumerate() {
114 size_ranks[idx] = rank;
115 }
116
117 entries
118 .iter()
119 .enumerate()
120 .map(|(i, (path, _))| {
121 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
122 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
123 + 1.0 / (RRF_K + size_ranks[i] as f64);
124 ((*path).clone(), score)
125 })
126 .collect()
127}
128
129#[derive(Debug)]
131pub struct CacheStats {
132 pub total_reads: u64,
133 pub cache_hits: u64,
134 pub total_original_tokens: u64,
135 pub total_sent_tokens: u64,
136 pub files_tracked: usize,
137}
138
139impl CacheStats {
140 pub fn hit_rate(&self) -> f64 {
142 if self.total_reads == 0 {
143 return 0.0;
144 }
145 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
146 }
147
148 pub fn tokens_saved(&self) -> u64 {
150 self.total_original_tokens
151 .saturating_sub(self.total_sent_tokens)
152 }
153
154 pub fn savings_percent(&self) -> f64 {
156 if self.total_original_tokens == 0 {
157 return 0.0;
158 }
159 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
160 }
161}
162
163#[derive(Clone, Debug)]
165pub struct SharedBlock {
166 pub canonical_path: String,
167 pub canonical_ref: String,
168 pub start_line: usize,
169 pub end_line: usize,
170 pub content: String,
171}
172
173pub struct SessionCache {
175 entries: HashMap<String, CacheEntry>,
176 file_refs: HashMap<String, String>,
177 next_ref: usize,
178 stats: CacheStats,
179 shared_blocks: Vec<SharedBlock>,
180}
181
182impl Default for SessionCache {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl SessionCache {
189 pub fn new() -> Self {
191 Self {
192 entries: HashMap::new(),
193 file_refs: HashMap::new(),
194 next_ref: 1,
195 shared_blocks: Vec::new(),
196 stats: CacheStats {
197 total_reads: 0,
198 cache_hits: 0,
199 total_original_tokens: 0,
200 total_sent_tokens: 0,
201 files_tracked: 0,
202 },
203 }
204 }
205
206 pub fn get_file_ref(&mut self, path: &str) -> String {
208 let key = normalize_key(path);
209 if let Some(r) = self.file_refs.get(&key) {
210 return r.clone();
211 }
212 let r = format!("F{}", self.next_ref);
213 self.next_ref += 1;
214 self.file_refs.insert(key, r.clone());
215 r
216 }
217
218 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
220 self.file_refs.get(&normalize_key(path)).cloned()
221 }
222
223 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
225 self.entries.get(&normalize_key(path))
226 }
227
228 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
230 let key = normalize_key(path);
231 let ref_label = self
232 .file_refs
233 .get(&key)
234 .cloned()
235 .unwrap_or_else(|| "F?".to_string());
236 if let Some(entry) = self.entries.get_mut(&key) {
237 entry.read_count += 1;
238 entry.last_access = Instant::now();
239 self.stats.total_reads += 1;
240 self.stats.cache_hits += 1;
241 self.stats.total_original_tokens += entry.original_tokens as u64;
242 let hit_msg = format!(
243 "{ref_label} cached {}t {}L",
244 entry.read_count, entry.line_count
245 );
246 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
247 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
248 Some(entry)
249 } else {
250 None
251 }
252 }
253
254 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
256 let key = normalize_key(path);
257 let hash = compute_md5(&content);
258 let line_count = content.lines().count();
259 let original_tokens = count_tokens(&content);
260 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
261 let now = Instant::now();
262
263 self.stats.total_reads += 1;
264 self.stats.total_original_tokens += original_tokens as u64;
265
266 if let Some(existing) = self.entries.get_mut(&key) {
267 existing.last_access = now;
268 if stored_mtime.is_some() {
269 existing.stored_mtime = stored_mtime;
270 }
271 if existing.hash == hash {
272 existing.read_count += 1;
273 self.stats.cache_hits += 1;
274 let hit_msg = format!(
275 "{} cached {}t {}L",
276 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
277 existing.read_count,
278 existing.line_count,
279 );
280 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
281 return StoreResult {
282 line_count: existing.line_count,
283 original_tokens: existing.original_tokens,
284 read_count: existing.read_count,
285 was_hit: true,
286 };
287 }
288 existing.compressed_outputs.clear();
289 existing.content = content;
290 existing.hash = hash;
291 existing.line_count = line_count;
292 existing.original_tokens = original_tokens;
293 existing.read_count += 1;
294 if stored_mtime.is_some() {
295 existing.stored_mtime = stored_mtime;
296 }
297 self.stats.total_sent_tokens += original_tokens as u64;
298 return StoreResult {
299 line_count,
300 original_tokens,
301 read_count: existing.read_count,
302 was_hit: false,
303 };
304 }
305
306 self.evict_if_needed(original_tokens);
307 self.get_file_ref(&key);
308
309 let entry = CacheEntry {
310 content,
311 hash,
312 line_count,
313 original_tokens,
314 read_count: 1,
315 path: key.clone(),
316 last_access: now,
317 stored_mtime,
318 compressed_outputs: HashMap::new(),
319 };
320
321 self.entries.insert(key, entry);
322 self.stats.files_tracked += 1;
323 self.stats.total_sent_tokens += original_tokens as u64;
324 StoreResult {
325 line_count,
326 original_tokens,
327 read_count: 1,
328 was_hit: false,
329 }
330 }
331
332 pub fn total_cached_tokens(&self) -> usize {
334 self.entries.values().map(|e| e.original_tokens).sum()
335 }
336
337 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
339 let max_tokens = max_cache_tokens();
340 let current = self.total_cached_tokens();
341 if current + incoming_tokens <= max_tokens {
342 return;
343 }
344
345 let now = Instant::now();
346 let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
347 let mut scored = eviction_scores_rrf(&all_entries, now);
348 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
349
350 let mut freed = 0usize;
351 let target = (current + incoming_tokens).saturating_sub(max_tokens);
352 for (path, _score) in &scored {
353 if freed >= target {
354 break;
355 }
356 if let Some(entry) = self.entries.remove(path) {
357 freed += entry.original_tokens;
358 self.file_refs.remove(path);
359 }
360 }
361 }
362
363 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
365 self.entries.iter().collect()
366 }
367
368 pub fn get_stats(&self) -> &CacheStats {
370 &self.stats
371 }
372
373 pub fn file_ref_map(&self) -> &HashMap<String, String> {
375 &self.file_refs
376 }
377
378 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
380 self.shared_blocks = blocks;
381 }
382
383 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
385 &self.shared_blocks
386 }
387
388 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
390 if self.shared_blocks.is_empty() {
391 return None;
392 }
393 let refs: Vec<&SharedBlock> = self
394 .shared_blocks
395 .iter()
396 .filter(|b| b.canonical_path != path && content.contains(&b.content))
397 .collect();
398 if refs.is_empty() {
399 return None;
400 }
401 let mut result = content.to_string();
402 for block in refs {
403 result = result.replacen(
404 &block.content,
405 &format!(
406 "[= {}:{}-{}]",
407 block.canonical_ref, block.start_line, block.end_line
408 ),
409 1,
410 );
411 }
412 Some(result)
413 }
414
415 pub fn invalidate(&mut self, path: &str) -> bool {
417 self.entries.remove(&normalize_key(path)).is_some()
418 }
419
420 pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
422 self.entries
423 .get(&normalize_key(path))?
424 .get_compressed(mode_key)
425 }
426
427 pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
429 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
430 entry.set_compressed(mode_key, output);
431 }
432 }
433
434 pub fn clear(&mut self) -> usize {
436 let count = self.entries.len();
437 self.entries.clear();
438 self.file_refs.clear();
439 self.shared_blocks.clear();
440 self.next_ref = 1;
441 self.stats = CacheStats {
442 total_reads: 0,
443 cache_hits: 0,
444 total_original_tokens: 0,
445 total_sent_tokens: 0,
446 files_tracked: 0,
447 };
448 count
449 }
450}
451
452pub fn file_mtime(path: &str) -> Option<SystemTime> {
453 std::fs::metadata(path).and_then(|m| m.modified()).ok()
454}
455
456pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
457 let current = file_mtime(path);
458 match (cached_mtime, current) {
459 (_, None) => false,
460 (None, Some(_)) => true,
461 (Some(cached), Some(current)) => current > cached,
462 }
463}
464
465fn compute_md5(content: &str) -> String {
466 let mut hasher = Md5::new();
467 hasher.update(content.as_bytes());
468 format!("{:x}", hasher.finalize())
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use std::time::Duration;
475
476 #[test]
477 fn cache_stores_and_retrieves() {
478 let mut cache = SessionCache::new();
479 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
480 assert!(!result.was_hit);
481 assert_eq!(result.line_count, 1);
482 assert!(cache.get("/test/file.rs").is_some());
483 }
484
485 #[test]
486 fn cache_hit_on_same_content() {
487 let mut cache = SessionCache::new();
488 cache.store("/test/file.rs", "content".to_string());
489 let result = cache.store("/test/file.rs", "content".to_string());
490 assert!(result.was_hit, "same content should be a cache hit");
491 }
492
493 #[test]
494 fn cache_miss_on_changed_content() {
495 let mut cache = SessionCache::new();
496 cache.store("/test/file.rs", "old content".to_string());
497 let result = cache.store("/test/file.rs", "new content".to_string());
498 assert!(!result.was_hit, "changed content should not be a cache hit");
499 }
500
501 #[test]
502 fn file_refs_are_sequential() {
503 let mut cache = SessionCache::new();
504 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
505 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
506 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
508
509 #[test]
510 fn cache_clear_resets_everything() {
511 let mut cache = SessionCache::new();
512 cache.store("/a.rs", "a".to_string());
513 cache.store("/b.rs", "b".to_string());
514 let count = cache.clear();
515 assert_eq!(count, 2);
516 assert!(cache.get("/a.rs").is_none());
517 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
519
520 #[test]
521 fn cache_invalidate_removes_entry() {
522 let mut cache = SessionCache::new();
523 cache.store("/test.rs", "test".to_string());
524 assert!(cache.invalidate("/test.rs"));
525 assert!(!cache.invalidate("/nonexistent.rs"));
526 }
527
528 #[test]
529 fn cache_stats_track_correctly() {
530 let mut cache = SessionCache::new();
531 cache.store("/a.rs", "hello".to_string());
532 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
534 assert_eq!(stats.total_reads, 2);
535 assert_eq!(stats.cache_hits, 1);
536 assert!(stats.hit_rate() > 0.0);
537 }
538
539 #[test]
540 fn md5_is_deterministic() {
541 let h1 = compute_md5("test content");
542 let h2 = compute_md5("test content");
543 assert_eq!(h1, h2);
544 assert_ne!(h1, compute_md5("different"));
545 }
546
547 #[test]
548 fn rrf_eviction_prefers_recent() {
549 let base = Instant::now();
550 std::thread::sleep(std::time::Duration::from_millis(5));
551 let now = Instant::now();
552 let key_a = "a.rs".to_string();
553 let key_b = "b.rs".to_string();
554 let recent = CacheEntry {
555 content: "a".to_string(),
556 hash: "h1".to_string(),
557 line_count: 1,
558 original_tokens: 10,
559 read_count: 1,
560 path: "/a.rs".to_string(),
561 last_access: now,
562 stored_mtime: None,
563 compressed_outputs: HashMap::new(),
564 };
565 let old = CacheEntry {
566 content: "b".to_string(),
567 hash: "h2".to_string(),
568 line_count: 1,
569 original_tokens: 10,
570 read_count: 1,
571 path: "/b.rs".to_string(),
572 last_access: base,
573 stored_mtime: None,
574 compressed_outputs: HashMap::new(),
575 };
576 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
577 let scores = eviction_scores_rrf(&entries, now);
578 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
579 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
580 assert!(
581 score_a > score_b,
582 "recently accessed entries should score higher via RRF"
583 );
584 }
585
586 #[test]
587 fn rrf_eviction_prefers_frequent() {
588 let now = Instant::now();
589 let key_a = "a.rs".to_string();
590 let key_b = "b.rs".to_string();
591 let frequent = CacheEntry {
592 content: "a".to_string(),
593 hash: "h1".to_string(),
594 line_count: 1,
595 original_tokens: 10,
596 read_count: 20,
597 path: "/a.rs".to_string(),
598 last_access: now,
599 stored_mtime: None,
600 compressed_outputs: HashMap::new(),
601 };
602 let rare = CacheEntry {
603 content: "b".to_string(),
604 hash: "h2".to_string(),
605 line_count: 1,
606 original_tokens: 10,
607 read_count: 1,
608 path: "/b.rs".to_string(),
609 last_access: now,
610 stored_mtime: None,
611 compressed_outputs: HashMap::new(),
612 };
613 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
614 let scores = eviction_scores_rrf(&entries, now);
615 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
616 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
617 assert!(
618 score_a > score_b,
619 "frequently accessed entries should score higher via RRF"
620 );
621 }
622
623 #[test]
624 fn evict_if_needed_removes_lowest_score() {
625 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
626 let mut cache = SessionCache::new();
627 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
629 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
633 assert!(
638 cache.total_cached_tokens() <= 60,
639 "eviction should have kicked in"
640 );
641 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
642 }
643
644 #[test]
645 fn stale_detection_flags_newer_file() {
646 let dir = tempfile::tempdir().unwrap();
647 let path = dir.path().join("stale.txt");
648 let p = path.to_string_lossy().to_string();
649
650 std::fs::write(&path, "one").unwrap();
651 let mut cache = SessionCache::new();
652 cache.store(&p, "one".to_string());
653
654 let entry = cache.get(&p).unwrap();
655 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
656
657 std::thread::sleep(Duration::from_secs(1));
659 std::fs::write(&path, "two").unwrap();
660
661 let entry = cache.get(&p).unwrap();
662 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
663 }
664
665 #[test]
666 fn compressed_outputs_cached_and_retrieved() {
667 let mut cache = SessionCache::new();
668 cache.store("/test.rs", "fn main() {}".to_string());
669 cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
670 assert_eq!(
671 cache.get_compressed("/test.rs", "map"),
672 Some(&"compressed map output".to_string())
673 );
674 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
675 }
676
677 #[test]
678 fn compressed_outputs_cleared_on_content_change() {
679 let mut cache = SessionCache::new();
680 cache.store("/test.rs", "old content".to_string());
681 cache.set_compressed("/test.rs", "map", "old map".to_string());
682 assert!(cache.get_compressed("/test.rs", "map").is_some());
683
684 cache.store("/test.rs", "new content".to_string());
685 assert_eq!(cache.get_compressed("/test.rs", "map"), None);
686 }
687
688 #[test]
689 fn compressed_outputs_survive_same_content_store() {
690 let mut cache = SessionCache::new();
691 cache.store("/test.rs", "content".to_string());
692 cache.set_compressed("/test.rs", "map", "cached map".to_string());
693
694 let result = cache.store("/test.rs", "content".to_string());
695 assert!(result.was_hit);
696 assert_eq!(
697 cache.get_compressed("/test.rs", "map"),
698 Some(&"cached map".to_string())
699 );
700 }
701
702 #[test]
703 fn compressed_outputs_cleared_on_invalidate() {
704 let mut cache = SessionCache::new();
705 cache.store("/test.rs", "content".to_string());
706 cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
707 cache.invalidate("/test.rs");
708 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
709 }
710
711 #[test]
712 fn compressed_outputs_cleared_on_clear() {
713 let mut cache = SessionCache::new();
714 cache.store("/a.rs", "a".to_string());
715 cache.set_compressed("/a.rs", "map", "map_a".to_string());
716 cache.clear();
717 assert_eq!(cache.get_compressed("/a.rs", "map"), None);
718 }
719}