1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29 pub compressed_outputs: HashMap<String, String>,
31 pub full_content_delivered: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct StoreResult {
39 pub line_count: usize,
40 pub original_tokens: usize,
41 pub read_count: u32,
42 pub was_hit: bool,
43 pub full_content_delivered: bool,
45}
46
47impl CacheEntry {
48 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
50 let elapsed = now
51 .checked_duration_since(self.last_access)
52 .unwrap_or_default()
53 .as_secs_f64();
54 let recency = 1.0 / (1.0 + elapsed.sqrt());
55 let frequency = (self.read_count as f64 + 1.0).ln();
56 let size_value = (self.original_tokens as f64 + 1.0).ln();
57 recency * 0.4 + frequency * 0.3 + size_value * 0.3
58 }
59
60 pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
61 self.compressed_outputs.get(mode_key)
62 }
63
64 pub fn set_compressed(&mut self, mode_key: &str, output: String) {
65 const MAX_COMPRESSED_VARIANTS: usize = 3;
66 if self.compressed_outputs.len() >= MAX_COMPRESSED_VARIANTS
67 && !self.compressed_outputs.contains_key(mode_key)
68 {
69 if let Some(oldest_key) = self.compressed_outputs.keys().next().cloned() {
70 self.compressed_outputs.remove(&oldest_key);
71 }
72 }
73 self.compressed_outputs.insert(mode_key.to_string(), output);
74 }
75
76 pub fn mark_full_delivered(&mut self) {
77 self.full_content_delivered = true;
78 }
79}
80
81const RRF_K: f64 = 60.0;
82
83pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
88 if entries.is_empty() {
89 return Vec::new();
90 }
91
92 let n = entries.len();
93
94 let mut recency_order: Vec<usize> = (0..n).collect();
95 recency_order.sort_by(|&a, &b| {
96 let elapsed_a = now
97 .checked_duration_since(entries[a].1.last_access)
98 .unwrap_or_default()
99 .as_secs_f64();
100 let elapsed_b = now
101 .checked_duration_since(entries[b].1.last_access)
102 .unwrap_or_default()
103 .as_secs_f64();
104 elapsed_a
105 .partial_cmp(&elapsed_b)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108
109 let mut frequency_order: Vec<usize> = (0..n).collect();
110 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
111
112 let mut size_order: Vec<usize> = (0..n).collect();
113 size_order.sort_by(|&a, &b| {
114 entries[b]
115 .1
116 .original_tokens
117 .cmp(&entries[a].1.original_tokens)
118 });
119
120 let mut recency_ranks = vec![0usize; n];
121 let mut frequency_ranks = vec![0usize; n];
122 let mut size_ranks = vec![0usize; n];
123
124 for (rank, &idx) in recency_order.iter().enumerate() {
125 recency_ranks[idx] = rank;
126 }
127 for (rank, &idx) in frequency_order.iter().enumerate() {
128 frequency_ranks[idx] = rank;
129 }
130 for (rank, &idx) in size_order.iter().enumerate() {
131 size_ranks[idx] = rank;
132 }
133
134 entries
135 .iter()
136 .enumerate()
137 .map(|(i, (path, _))| {
138 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
139 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
140 + 1.0 / (RRF_K + size_ranks[i] as f64);
141 ((*path).clone(), score)
142 })
143 .collect()
144}
145
146#[derive(Debug)]
148pub struct CacheStats {
149 pub total_reads: u64,
150 pub cache_hits: u64,
151 pub total_original_tokens: u64,
152 pub total_sent_tokens: u64,
153 pub files_tracked: usize,
154}
155
156impl CacheStats {
157 pub fn hit_rate(&self) -> f64 {
159 if self.total_reads == 0 {
160 return 0.0;
161 }
162 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
163 }
164
165 pub fn tokens_saved(&self) -> u64 {
167 self.total_original_tokens
168 .saturating_sub(self.total_sent_tokens)
169 }
170
171 pub fn savings_percent(&self) -> f64 {
173 if self.total_original_tokens == 0 {
174 return 0.0;
175 }
176 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
177 }
178}
179
180#[derive(Clone, Debug)]
182pub struct SharedBlock {
183 pub canonical_path: String,
184 pub canonical_ref: String,
185 pub start_line: usize,
186 pub end_line: usize,
187 pub content: String,
188}
189
190pub struct SessionCache {
192 entries: HashMap<String, CacheEntry>,
193 file_refs: HashMap<String, String>,
194 next_ref: usize,
195 stats: CacheStats,
196 shared_blocks: Vec<SharedBlock>,
197}
198
199impl Default for SessionCache {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl SessionCache {
206 pub fn new() -> Self {
208 Self {
209 entries: HashMap::new(),
210 file_refs: HashMap::new(),
211 next_ref: 1,
212 shared_blocks: Vec::new(),
213 stats: CacheStats {
214 total_reads: 0,
215 cache_hits: 0,
216 total_original_tokens: 0,
217 total_sent_tokens: 0,
218 files_tracked: 0,
219 },
220 }
221 }
222
223 pub fn get_file_ref(&mut self, path: &str) -> String {
225 let key = normalize_key(path);
226 if let Some(r) = self.file_refs.get(&key) {
227 return r.clone();
228 }
229 let r = format!("F{}", self.next_ref);
230 self.next_ref += 1;
231 self.file_refs.insert(key, r.clone());
232 r
233 }
234
235 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
237 self.file_refs.get(&normalize_key(path)).cloned()
238 }
239
240 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
242 self.entries.get(&normalize_key(path))
243 }
244
245 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
247 let key = normalize_key(path);
248 let ref_label = self
249 .file_refs
250 .get(&key)
251 .cloned()
252 .unwrap_or_else(|| "F?".to_string());
253 if let Some(entry) = self.entries.get_mut(&key) {
254 entry.read_count += 1;
255 entry.last_access = Instant::now();
256 self.stats.total_reads += 1;
257 self.stats.cache_hits += 1;
258 self.stats.total_original_tokens += entry.original_tokens as u64;
259 let hit_msg = format!(
260 "{ref_label} cached {}t {}L",
261 entry.read_count, entry.line_count
262 );
263 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
264 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
265 Some(entry)
266 } else {
267 None
268 }
269 }
270
271 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
273 let key = normalize_key(path);
274 let hash = compute_md5(&content);
275 let line_count = content.lines().count();
276 let original_tokens = count_tokens(&content);
277 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
278 let now = Instant::now();
279
280 self.stats.total_reads += 1;
281 self.stats.total_original_tokens += original_tokens as u64;
282
283 if let Some(existing) = self.entries.get_mut(&key) {
284 existing.last_access = now;
285 if stored_mtime.is_some() {
286 existing.stored_mtime = stored_mtime;
287 }
288 if existing.hash == hash {
289 existing.read_count += 1;
290 self.stats.cache_hits += 1;
291 let hit_msg = format!(
292 "{} cached {}t {}L",
293 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
294 existing.read_count,
295 existing.line_count,
296 );
297 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
298 return StoreResult {
299 line_count: existing.line_count,
300 original_tokens: existing.original_tokens,
301 read_count: existing.read_count,
302 was_hit: true,
303 full_content_delivered: existing.full_content_delivered,
304 };
305 }
306 existing.compressed_outputs.clear();
307 existing.content = content;
308 existing.hash = hash;
309 existing.line_count = line_count;
310 existing.original_tokens = original_tokens;
311 existing.read_count += 1;
312 existing.full_content_delivered = false;
313 if stored_mtime.is_some() {
314 existing.stored_mtime = stored_mtime;
315 }
316 self.stats.total_sent_tokens += original_tokens as u64;
317 return StoreResult {
318 line_count,
319 original_tokens,
320 read_count: existing.read_count,
321 was_hit: false,
322 full_content_delivered: false,
323 };
324 }
325
326 self.evict_if_needed(original_tokens);
327 self.get_file_ref(&key);
328
329 let entry = CacheEntry {
330 content,
331 hash,
332 line_count,
333 original_tokens,
334 read_count: 1,
335 path: key.clone(),
336 last_access: now,
337 stored_mtime,
338 compressed_outputs: HashMap::new(),
339 full_content_delivered: false,
340 };
341
342 self.entries.insert(key, entry);
343 self.stats.files_tracked += 1;
344 self.stats.total_sent_tokens += original_tokens as u64;
345 StoreResult {
346 line_count,
347 original_tokens,
348 read_count: 1,
349 was_hit: false,
350 full_content_delivered: false,
351 }
352 }
353
354 pub fn total_cached_tokens(&self) -> usize {
356 self.entries.values().map(|e| e.original_tokens).sum()
357 }
358
359 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
361 let max_tokens = max_cache_tokens();
362 let current = self.total_cached_tokens();
363 if current + incoming_tokens <= max_tokens {
364 return;
365 }
366
367 let now = Instant::now();
368 let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
369 let mut scored = eviction_scores_rrf(&all_entries, now);
370 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
371
372 let mut freed = 0usize;
373 let target = (current + incoming_tokens).saturating_sub(max_tokens);
374 for (path, _score) in &scored {
375 if freed >= target {
376 break;
377 }
378 if let Some(entry) = self.entries.remove(path) {
379 freed += entry.original_tokens;
380 self.file_refs.remove(path);
381 }
382 }
383 }
384
385 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
387 self.entries.iter().collect()
388 }
389
390 pub fn get_stats(&self) -> &CacheStats {
392 &self.stats
393 }
394
395 pub fn file_ref_map(&self) -> &HashMap<String, String> {
397 &self.file_refs
398 }
399
400 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
402 self.shared_blocks = blocks;
403 }
404
405 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
407 &self.shared_blocks
408 }
409
410 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
412 if self.shared_blocks.is_empty() {
413 return None;
414 }
415 let refs: Vec<&SharedBlock> = self
416 .shared_blocks
417 .iter()
418 .filter(|b| b.canonical_path != path && content.contains(&b.content))
419 .collect();
420 if refs.is_empty() {
421 return None;
422 }
423 let mut result = content.to_string();
424 for block in refs {
425 result = result.replacen(
426 &block.content,
427 &format!(
428 "[= {}:{}-{}]",
429 block.canonical_ref, block.start_line, block.end_line
430 ),
431 1,
432 );
433 }
434 Some(result)
435 }
436
437 pub fn invalidate(&mut self, path: &str) -> bool {
439 self.entries.remove(&normalize_key(path)).is_some()
440 }
441
442 pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
444 self.entries
445 .get(&normalize_key(path))?
446 .get_compressed(mode_key)
447 }
448
449 pub fn mark_full_delivered(&mut self, path: &str) {
451 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
452 entry.mark_full_delivered();
453 }
454 }
455
456 pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
458 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
459 entry.set_compressed(mode_key, output);
460 }
461 }
462
463 pub fn clear(&mut self) -> usize {
465 let count = self.entries.len();
466 self.entries.clear();
467 self.file_refs.clear();
468 self.shared_blocks.clear();
469 self.next_ref = 1;
470 self.stats = CacheStats {
471 total_reads: 0,
472 cache_hits: 0,
473 total_original_tokens: 0,
474 total_sent_tokens: 0,
475 files_tracked: 0,
476 };
477 count
478 }
479}
480
481pub fn file_mtime(path: &str) -> Option<SystemTime> {
482 std::fs::metadata(path).and_then(|m| m.modified()).ok()
483}
484
485pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
486 let current = file_mtime(path);
487 match (cached_mtime, current) {
488 (_, None) => false,
489 (None, Some(_)) => true,
490 (Some(cached), Some(current)) => current > cached,
491 }
492}
493
494fn compute_md5(content: &str) -> String {
495 let mut hasher = Md5::new();
496 hasher.update(content.as_bytes());
497 format!("{:x}", hasher.finalize())
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use std::time::Duration;
504
505 #[test]
506 fn cache_stores_and_retrieves() {
507 let mut cache = SessionCache::new();
508 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
509 assert!(!result.was_hit);
510 assert_eq!(result.line_count, 1);
511 assert!(cache.get("/test/file.rs").is_some());
512 }
513
514 #[test]
515 fn cache_hit_on_same_content() {
516 let mut cache = SessionCache::new();
517 cache.store("/test/file.rs", "content".to_string());
518 let result = cache.store("/test/file.rs", "content".to_string());
519 assert!(result.was_hit, "same content should be a cache hit");
520 }
521
522 #[test]
523 fn cache_miss_on_changed_content() {
524 let mut cache = SessionCache::new();
525 cache.store("/test/file.rs", "old content".to_string());
526 let result = cache.store("/test/file.rs", "new content".to_string());
527 assert!(!result.was_hit, "changed content should not be a cache hit");
528 }
529
530 #[test]
531 fn file_refs_are_sequential() {
532 let mut cache = SessionCache::new();
533 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
534 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
535 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
537
538 #[test]
539 fn cache_clear_resets_everything() {
540 let mut cache = SessionCache::new();
541 cache.store("/a.rs", "a".to_string());
542 cache.store("/b.rs", "b".to_string());
543 let count = cache.clear();
544 assert_eq!(count, 2);
545 assert!(cache.get("/a.rs").is_none());
546 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
548
549 #[test]
550 fn cache_invalidate_removes_entry() {
551 let mut cache = SessionCache::new();
552 cache.store("/test.rs", "test".to_string());
553 assert!(cache.invalidate("/test.rs"));
554 assert!(!cache.invalidate("/nonexistent.rs"));
555 }
556
557 #[test]
558 fn cache_stats_track_correctly() {
559 let mut cache = SessionCache::new();
560 cache.store("/a.rs", "hello".to_string());
561 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
563 assert_eq!(stats.total_reads, 2);
564 assert_eq!(stats.cache_hits, 1);
565 assert!(stats.hit_rate() > 0.0);
566 }
567
568 #[test]
569 fn md5_is_deterministic() {
570 let h1 = compute_md5("test content");
571 let h2 = compute_md5("test content");
572 assert_eq!(h1, h2);
573 assert_ne!(h1, compute_md5("different"));
574 }
575
576 #[test]
577 fn rrf_eviction_prefers_recent() {
578 let base = Instant::now();
579 std::thread::sleep(std::time::Duration::from_millis(5));
580 let now = Instant::now();
581 let key_a = "a.rs".to_string();
582 let key_b = "b.rs".to_string();
583 let recent = CacheEntry {
584 content: "a".to_string(),
585 hash: "h1".to_string(),
586 line_count: 1,
587 original_tokens: 10,
588 read_count: 1,
589 path: "/a.rs".to_string(),
590 last_access: now,
591 stored_mtime: None,
592 compressed_outputs: HashMap::new(),
593 full_content_delivered: false,
594 };
595 let old = CacheEntry {
596 content: "b".to_string(),
597 hash: "h2".to_string(),
598 line_count: 1,
599 original_tokens: 10,
600 read_count: 1,
601 path: "/b.rs".to_string(),
602 last_access: base,
603 stored_mtime: None,
604 compressed_outputs: HashMap::new(),
605 full_content_delivered: false,
606 };
607 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
608 let scores = eviction_scores_rrf(&entries, now);
609 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
610 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
611 assert!(
612 score_a > score_b,
613 "recently accessed entries should score higher via RRF"
614 );
615 }
616
617 #[test]
618 fn rrf_eviction_prefers_frequent() {
619 let now = Instant::now();
620 let key_a = "a.rs".to_string();
621 let key_b = "b.rs".to_string();
622 let frequent = CacheEntry {
623 content: "a".to_string(),
624 hash: "h1".to_string(),
625 line_count: 1,
626 original_tokens: 10,
627 read_count: 20,
628 path: "/a.rs".to_string(),
629 last_access: now,
630 stored_mtime: None,
631 compressed_outputs: HashMap::new(),
632 full_content_delivered: false,
633 };
634 let rare = CacheEntry {
635 content: "b".to_string(),
636 hash: "h2".to_string(),
637 line_count: 1,
638 original_tokens: 10,
639 read_count: 1,
640 path: "/b.rs".to_string(),
641 last_access: now,
642 stored_mtime: None,
643 compressed_outputs: HashMap::new(),
644 full_content_delivered: false,
645 };
646 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
647 let scores = eviction_scores_rrf(&entries, now);
648 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
649 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
650 assert!(
651 score_a > score_b,
652 "frequently accessed entries should score higher via RRF"
653 );
654 }
655
656 #[test]
657 fn evict_if_needed_removes_lowest_score() {
658 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
659 let mut cache = SessionCache::new();
660 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
662 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
666 assert!(
671 cache.total_cached_tokens() <= 60,
672 "eviction should have kicked in"
673 );
674 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
675 }
676
677 #[test]
678 fn stale_detection_flags_newer_file() {
679 let dir = tempfile::tempdir().unwrap();
680 let path = dir.path().join("stale.txt");
681 let p = path.to_string_lossy().to_string();
682
683 std::fs::write(&path, "one").unwrap();
684 let mut cache = SessionCache::new();
685 cache.store(&p, "one".to_string());
686
687 let entry = cache.get(&p).unwrap();
688 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
689
690 std::thread::sleep(Duration::from_secs(1));
692 std::fs::write(&path, "two").unwrap();
693
694 let entry = cache.get(&p).unwrap();
695 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
696 }
697
698 #[test]
699 fn compressed_outputs_cached_and_retrieved() {
700 let mut cache = SessionCache::new();
701 cache.store("/test.rs", "fn main() {}".to_string());
702 cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
703 assert_eq!(
704 cache.get_compressed("/test.rs", "map"),
705 Some(&"compressed map output".to_string())
706 );
707 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
708 }
709
710 #[test]
711 fn compressed_outputs_cleared_on_content_change() {
712 let mut cache = SessionCache::new();
713 cache.store("/test.rs", "old content".to_string());
714 cache.set_compressed("/test.rs", "map", "old map".to_string());
715 assert!(cache.get_compressed("/test.rs", "map").is_some());
716
717 cache.store("/test.rs", "new content".to_string());
718 assert_eq!(cache.get_compressed("/test.rs", "map"), None);
719 }
720
721 #[test]
722 fn compressed_outputs_survive_same_content_store() {
723 let mut cache = SessionCache::new();
724 cache.store("/test.rs", "content".to_string());
725 cache.set_compressed("/test.rs", "map", "cached map".to_string());
726
727 let result = cache.store("/test.rs", "content".to_string());
728 assert!(result.was_hit);
729 assert_eq!(
730 cache.get_compressed("/test.rs", "map"),
731 Some(&"cached map".to_string())
732 );
733 }
734
735 #[test]
736 fn compressed_outputs_cleared_on_invalidate() {
737 let mut cache = SessionCache::new();
738 cache.store("/test.rs", "content".to_string());
739 cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
740 cache.invalidate("/test.rs");
741 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
742 }
743
744 #[test]
745 fn compressed_outputs_cleared_on_clear() {
746 let mut cache = SessionCache::new();
747 cache.store("/a.rs", "a".to_string());
748 cache.set_compressed("/a.rs", "map", "map_a".to_string());
749 cache.clear();
750 assert_eq!(cache.get_compressed("/a.rs", "map"), None);
751 }
752}