1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::core::pathutil::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29 pub compressed_outputs: HashMap<String, String>,
31 pub full_content_delivered: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct StoreResult {
39 pub line_count: usize,
40 pub original_tokens: usize,
41 pub read_count: u32,
42 pub was_hit: bool,
43 pub full_content_delivered: bool,
45}
46
47impl CacheEntry {
48 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
50 let elapsed = now
51 .checked_duration_since(self.last_access)
52 .unwrap_or_default()
53 .as_secs_f64();
54 let recency = 1.0 / (1.0 + elapsed.sqrt());
55 let frequency = (self.read_count as f64 + 1.0).ln();
56 let size_value = (self.original_tokens as f64 + 1.0).ln();
57 recency * 0.4 + frequency * 0.3 + size_value * 0.3
58 }
59
60 pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
61 self.compressed_outputs.get(mode_key)
62 }
63
64 pub fn set_compressed(&mut self, mode_key: &str, output: String) {
65 const MAX_COMPRESSED_VARIANTS: usize = 3;
66 if self.compressed_outputs.len() >= MAX_COMPRESSED_VARIANTS
67 && !self.compressed_outputs.contains_key(mode_key)
68 {
69 if let Some(oldest_key) = self.compressed_outputs.keys().next().cloned() {
70 self.compressed_outputs.remove(&oldest_key);
71 }
72 }
73 self.compressed_outputs.insert(mode_key.to_string(), output);
74 }
75
76 pub fn mark_full_delivered(&mut self) {
77 self.full_content_delivered = true;
78 }
79}
80
81const RRF_K: f64 = 60.0;
82
83pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
88 if entries.is_empty() {
89 return Vec::new();
90 }
91
92 let n = entries.len();
93
94 let mut recency_order: Vec<usize> = (0..n).collect();
95 recency_order.sort_by(|&a, &b| {
96 let elapsed_a = now
97 .checked_duration_since(entries[a].1.last_access)
98 .unwrap_or_default()
99 .as_secs_f64();
100 let elapsed_b = now
101 .checked_duration_since(entries[b].1.last_access)
102 .unwrap_or_default()
103 .as_secs_f64();
104 elapsed_a
105 .partial_cmp(&elapsed_b)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108
109 let mut frequency_order: Vec<usize> = (0..n).collect();
110 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
111
112 let mut size_order: Vec<usize> = (0..n).collect();
113 size_order.sort_by(|&a, &b| {
114 entries[b]
115 .1
116 .original_tokens
117 .cmp(&entries[a].1.original_tokens)
118 });
119
120 let mut recency_ranks = vec![0usize; n];
121 let mut frequency_ranks = vec![0usize; n];
122 let mut size_ranks = vec![0usize; n];
123
124 for (rank, &idx) in recency_order.iter().enumerate() {
125 recency_ranks[idx] = rank;
126 }
127 for (rank, &idx) in frequency_order.iter().enumerate() {
128 frequency_ranks[idx] = rank;
129 }
130 for (rank, &idx) in size_order.iter().enumerate() {
131 size_ranks[idx] = rank;
132 }
133
134 entries
135 .iter()
136 .enumerate()
137 .map(|(i, (path, _))| {
138 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
139 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
140 + 1.0 / (RRF_K + size_ranks[i] as f64);
141 ((*path).clone(), score)
142 })
143 .collect()
144}
145
146#[derive(Debug)]
148pub struct CacheStats {
149 pub total_reads: u64,
150 pub cache_hits: u64,
151 pub total_original_tokens: u64,
152 pub total_sent_tokens: u64,
153 pub files_tracked: usize,
154}
155
156impl CacheStats {
157 pub fn hit_rate(&self) -> f64 {
159 if self.total_reads == 0 {
160 return 0.0;
161 }
162 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
163 }
164
165 pub fn tokens_saved(&self) -> u64 {
167 self.total_original_tokens
168 .saturating_sub(self.total_sent_tokens)
169 }
170
171 pub fn savings_percent(&self) -> f64 {
173 if self.total_original_tokens == 0 {
174 return 0.0;
175 }
176 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
177 }
178}
179
180#[derive(Clone, Debug)]
182pub struct SharedBlock {
183 pub canonical_path: String,
184 pub canonical_ref: String,
185 pub start_line: usize,
186 pub end_line: usize,
187 pub content: String,
188}
189
190pub struct SessionCache {
193 entries: HashMap<String, CacheEntry>,
194 file_refs: HashMap<String, String>,
195 next_ref: usize,
196 stats: CacheStats,
197 shared_blocks: Vec<SharedBlock>,
198}
199
200impl Default for SessionCache {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl SessionCache {
207 pub fn new() -> Self {
209 Self {
210 entries: HashMap::new(),
211 file_refs: HashMap::new(),
212 next_ref: 1,
213 shared_blocks: Vec::new(),
214 stats: CacheStats {
215 total_reads: 0,
216 cache_hits: 0,
217 total_original_tokens: 0,
218 total_sent_tokens: 0,
219 files_tracked: 0,
220 },
221 }
222 }
223
224 pub fn get_file_ref(&mut self, path: &str) -> String {
226 let key = normalize_key(path);
227 if let Some(r) = self.file_refs.get(&key) {
228 return r.clone();
229 }
230 let r = format!("F{}", self.next_ref);
231 self.next_ref += 1;
232 self.file_refs.insert(key, r.clone());
233 r
234 }
235
236 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
238 self.file_refs.get(&normalize_key(path)).cloned()
239 }
240
241 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
243 self.entries.get(&normalize_key(path))
244 }
245
246 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
248 let key = normalize_key(path);
249 let ref_label = self
250 .file_refs
251 .get(&key)
252 .cloned()
253 .unwrap_or_else(|| "F?".to_string());
254 if let Some(entry) = self.entries.get_mut(&key) {
255 entry.read_count += 1;
256 entry.last_access = Instant::now();
257 self.stats.total_reads += 1;
258 self.stats.cache_hits += 1;
259 self.stats.total_original_tokens += entry.original_tokens as u64;
260 let hit_msg = format!(
261 "{ref_label} cached {}t {}L",
262 entry.read_count, entry.line_count
263 );
264 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
265 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
266 Some(entry)
267 } else {
268 None
269 }
270 }
271
272 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
274 let key = normalize_key(path);
275 let hash = compute_md5(&content);
276 let line_count = content.lines().count();
277 let original_tokens = count_tokens(&content);
278 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
279 let now = Instant::now();
280
281 self.stats.total_reads += 1;
282 self.stats.total_original_tokens += original_tokens as u64;
283
284 if let Some(existing) = self.entries.get_mut(&key) {
285 existing.last_access = now;
286 if stored_mtime.is_some() {
287 existing.stored_mtime = stored_mtime;
288 }
289 if existing.hash == hash {
290 existing.read_count += 1;
291 self.stats.cache_hits += 1;
292 let hit_msg = format!(
293 "{} cached {}t {}L",
294 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
295 existing.read_count,
296 existing.line_count,
297 );
298 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
299 return StoreResult {
300 line_count: existing.line_count,
301 original_tokens: existing.original_tokens,
302 read_count: existing.read_count,
303 was_hit: true,
304 full_content_delivered: existing.full_content_delivered,
305 };
306 }
307 existing.compressed_outputs.clear();
308 existing.content = content;
309 existing.hash = hash;
310 existing.line_count = line_count;
311 existing.original_tokens = original_tokens;
312 existing.read_count += 1;
313 existing.full_content_delivered = false;
314 if stored_mtime.is_some() {
315 existing.stored_mtime = stored_mtime;
316 }
317 self.stats.total_sent_tokens += original_tokens as u64;
318 return StoreResult {
319 line_count,
320 original_tokens,
321 read_count: existing.read_count,
322 was_hit: false,
323 full_content_delivered: false,
324 };
325 }
326
327 self.evict_if_needed(original_tokens);
328 self.get_file_ref(&key);
329
330 let entry = CacheEntry {
331 content,
332 hash,
333 line_count,
334 original_tokens,
335 read_count: 1,
336 path: key.clone(),
337 last_access: now,
338 stored_mtime,
339 compressed_outputs: HashMap::new(),
340 full_content_delivered: false,
341 };
342
343 self.entries.insert(key, entry);
344 self.stats.files_tracked += 1;
345 self.stats.total_sent_tokens += original_tokens as u64;
346 StoreResult {
347 line_count,
348 original_tokens,
349 read_count: 1,
350 was_hit: false,
351 full_content_delivered: false,
352 }
353 }
354
355 pub fn total_cached_tokens(&self) -> usize {
357 self.entries.values().map(|e| e.original_tokens).sum()
358 }
359
360 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
364 let max_tokens = max_cache_tokens();
365 let current = self.total_cached_tokens();
366 if current + incoming_tokens <= max_tokens {
367 return;
368 }
369
370 let mut freed = 0usize;
371 let target = (current + incoming_tokens).saturating_sub(max_tokens);
372
373 let mut probationary: Vec<(String, Instant)> = self
374 .entries
375 .iter()
376 .filter(|(_, e)| e.read_count <= 1)
377 .map(|(p, e)| (p.clone(), e.last_access))
378 .collect();
379 probationary.sort_by_key(|(_, t)| *t);
380
381 let mut protected: Vec<(String, Instant)> = self
382 .entries
383 .iter()
384 .filter(|(_, e)| e.read_count > 1)
385 .map(|(p, e)| (p.clone(), e.last_access))
386 .collect();
387 protected.sort_by_key(|(_, t)| *t);
388
389 for (path, _) in probationary.into_iter().chain(protected) {
390 if freed >= target {
391 break;
392 }
393 if let Some(entry) = self.entries.remove(&path) {
394 freed += entry.original_tokens;
395 self.file_refs.remove(&path);
396 }
397 }
398 }
399
400 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
402 self.entries.iter().collect()
403 }
404
405 pub fn get_stats(&self) -> &CacheStats {
407 &self.stats
408 }
409
410 pub fn file_ref_map(&self) -> &HashMap<String, String> {
412 &self.file_refs
413 }
414
415 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
417 self.shared_blocks = blocks;
418 }
419
420 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
422 &self.shared_blocks
423 }
424
425 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
427 if self.shared_blocks.is_empty() {
428 return None;
429 }
430 let refs: Vec<&SharedBlock> = self
431 .shared_blocks
432 .iter()
433 .filter(|b| b.canonical_path != path && content.contains(&b.content))
434 .collect();
435 if refs.is_empty() {
436 return None;
437 }
438 let mut result = content.to_string();
439 for block in refs {
440 result = result.replacen(
441 &block.content,
442 &format!(
443 "[= {}:{}-{}]",
444 block.canonical_ref, block.start_line, block.end_line
445 ),
446 1,
447 );
448 }
449 Some(result)
450 }
451
452 pub fn invalidate(&mut self, path: &str) -> bool {
454 self.entries.remove(&normalize_key(path)).is_some()
455 }
456
457 pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
459 self.entries
460 .get(&normalize_key(path))?
461 .get_compressed(mode_key)
462 }
463
464 pub fn mark_full_delivered(&mut self, path: &str) {
466 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
467 entry.mark_full_delivered();
468 }
469 }
470
471 pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
473 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
474 entry.set_compressed(mode_key, output);
475 }
476 }
477
478 pub fn clear(&mut self) -> usize {
480 let count = self.entries.len();
481 self.entries.clear();
482 self.file_refs.clear();
483 self.shared_blocks.clear();
484 self.next_ref = 1;
485 self.stats = CacheStats {
486 total_reads: 0,
487 cache_hits: 0,
488 total_original_tokens: 0,
489 total_sent_tokens: 0,
490 files_tracked: 0,
491 };
492 count
493 }
494}
495
496pub fn file_mtime(path: &str) -> Option<SystemTime> {
497 std::fs::metadata(path).and_then(|m| m.modified()).ok()
498}
499
500pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
501 let current = file_mtime(path);
502 match (cached_mtime, current) {
503 (_, None) => false,
504 (None, Some(_)) => true,
505 (Some(cached), Some(current)) => current > cached,
506 }
507}
508
509fn compute_md5(content: &str) -> String {
510 let mut hasher = Md5::new();
511 hasher.update(content.as_bytes());
512 format!("{:x}", hasher.finalize())
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use std::time::Duration;
519
520 #[test]
521 fn cache_stores_and_retrieves() {
522 let mut cache = SessionCache::new();
523 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
524 assert!(!result.was_hit);
525 assert_eq!(result.line_count, 1);
526 assert!(cache.get("/test/file.rs").is_some());
527 }
528
529 #[test]
530 fn cache_hit_on_same_content() {
531 let mut cache = SessionCache::new();
532 cache.store("/test/file.rs", "content".to_string());
533 let result = cache.store("/test/file.rs", "content".to_string());
534 assert!(result.was_hit, "same content should be a cache hit");
535 }
536
537 #[test]
538 fn cache_miss_on_changed_content() {
539 let mut cache = SessionCache::new();
540 cache.store("/test/file.rs", "old content".to_string());
541 let result = cache.store("/test/file.rs", "new content".to_string());
542 assert!(!result.was_hit, "changed content should not be a cache hit");
543 }
544
545 #[test]
546 fn file_refs_are_sequential() {
547 let mut cache = SessionCache::new();
548 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
549 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
550 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
552
553 #[test]
554 fn cache_clear_resets_everything() {
555 let mut cache = SessionCache::new();
556 cache.store("/a.rs", "a".to_string());
557 cache.store("/b.rs", "b".to_string());
558 let count = cache.clear();
559 assert_eq!(count, 2);
560 assert!(cache.get("/a.rs").is_none());
561 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
563
564 #[test]
565 fn cache_invalidate_removes_entry() {
566 let mut cache = SessionCache::new();
567 cache.store("/test.rs", "test".to_string());
568 assert!(cache.invalidate("/test.rs"));
569 assert!(!cache.invalidate("/nonexistent.rs"));
570 }
571
572 #[test]
573 fn cache_stats_track_correctly() {
574 let mut cache = SessionCache::new();
575 cache.store("/a.rs", "hello".to_string());
576 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
578 assert_eq!(stats.total_reads, 2);
579 assert_eq!(stats.cache_hits, 1);
580 assert!(stats.hit_rate() > 0.0);
581 }
582
583 #[test]
584 fn md5_is_deterministic() {
585 let h1 = compute_md5("test content");
586 let h2 = compute_md5("test content");
587 assert_eq!(h1, h2);
588 assert_ne!(h1, compute_md5("different"));
589 }
590
591 #[test]
592 fn rrf_eviction_prefers_recent() {
593 let base = Instant::now();
594 std::thread::sleep(std::time::Duration::from_millis(5));
595 let now = Instant::now();
596 let key_a = "a.rs".to_string();
597 let key_b = "b.rs".to_string();
598 let recent = CacheEntry {
599 content: "a".to_string(),
600 hash: "h1".to_string(),
601 line_count: 1,
602 original_tokens: 10,
603 read_count: 1,
604 path: "/a.rs".to_string(),
605 last_access: now,
606 stored_mtime: None,
607 compressed_outputs: HashMap::new(),
608 full_content_delivered: false,
609 };
610 let old = CacheEntry {
611 content: "b".to_string(),
612 hash: "h2".to_string(),
613 line_count: 1,
614 original_tokens: 10,
615 read_count: 1,
616 path: "/b.rs".to_string(),
617 last_access: base,
618 stored_mtime: None,
619 compressed_outputs: HashMap::new(),
620 full_content_delivered: false,
621 };
622 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
623 let scores = eviction_scores_rrf(&entries, now);
624 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
625 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
626 assert!(
627 score_a > score_b,
628 "recently accessed entries should score higher via RRF"
629 );
630 }
631
632 #[test]
633 fn rrf_eviction_prefers_frequent() {
634 let now = Instant::now();
635 let key_a = "a.rs".to_string();
636 let key_b = "b.rs".to_string();
637 let frequent = CacheEntry {
638 content: "a".to_string(),
639 hash: "h1".to_string(),
640 line_count: 1,
641 original_tokens: 10,
642 read_count: 20,
643 path: "/a.rs".to_string(),
644 last_access: now,
645 stored_mtime: None,
646 compressed_outputs: HashMap::new(),
647 full_content_delivered: false,
648 };
649 let rare = CacheEntry {
650 content: "b".to_string(),
651 hash: "h2".to_string(),
652 line_count: 1,
653 original_tokens: 10,
654 read_count: 1,
655 path: "/b.rs".to_string(),
656 last_access: now,
657 stored_mtime: None,
658 compressed_outputs: HashMap::new(),
659 full_content_delivered: false,
660 };
661 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
662 let scores = eviction_scores_rrf(&entries, now);
663 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
664 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
665 assert!(
666 score_a > score_b,
667 "frequently accessed entries should score higher via RRF"
668 );
669 }
670
671 #[test]
672 fn evict_if_needed_removes_lowest_score() {
673 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
674 let mut cache = SessionCache::new();
675 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
677 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
681 assert!(
686 cache.total_cached_tokens() <= 60,
687 "eviction should have kicked in"
688 );
689 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
690 }
691
692 #[test]
693 fn stale_detection_flags_newer_file() {
694 let dir = tempfile::tempdir().unwrap();
695 let path = dir.path().join("stale.txt");
696 let p = path.to_string_lossy().to_string();
697
698 std::fs::write(&path, "one").unwrap();
699 let mut cache = SessionCache::new();
700 cache.store(&p, "one".to_string());
701
702 let entry = cache.get(&p).unwrap();
703 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
704
705 std::thread::sleep(Duration::from_secs(1));
707 std::fs::write(&path, "two").unwrap();
708
709 let entry = cache.get(&p).unwrap();
710 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
711 }
712
713 #[test]
714 fn compressed_outputs_cached_and_retrieved() {
715 let mut cache = SessionCache::new();
716 cache.store("/test.rs", "fn main() {}".to_string());
717 cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
718 assert_eq!(
719 cache.get_compressed("/test.rs", "map"),
720 Some(&"compressed map output".to_string())
721 );
722 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
723 }
724
725 #[test]
726 fn compressed_outputs_cleared_on_content_change() {
727 let mut cache = SessionCache::new();
728 cache.store("/test.rs", "old content".to_string());
729 cache.set_compressed("/test.rs", "map", "old map".to_string());
730 assert!(cache.get_compressed("/test.rs", "map").is_some());
731
732 cache.store("/test.rs", "new content".to_string());
733 assert_eq!(cache.get_compressed("/test.rs", "map"), None);
734 }
735
736 #[test]
737 fn compressed_outputs_survive_same_content_store() {
738 let mut cache = SessionCache::new();
739 cache.store("/test.rs", "content".to_string());
740 cache.set_compressed("/test.rs", "map", "cached map".to_string());
741
742 let result = cache.store("/test.rs", "content".to_string());
743 assert!(result.was_hit);
744 assert_eq!(
745 cache.get_compressed("/test.rs", "map"),
746 Some(&"cached map".to_string())
747 );
748 }
749
750 #[test]
751 fn compressed_outputs_cleared_on_invalidate() {
752 let mut cache = SessionCache::new();
753 cache.store("/test.rs", "content".to_string());
754 cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
755 cache.invalidate("/test.rs");
756 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
757 }
758
759 #[test]
760 fn compressed_outputs_cleared_on_clear() {
761 let mut cache = SessionCache::new();
762 cache.store("/a.rs", "a".to_string());
763 cache.set_compressed("/a.rs", "map", "map_a".to_string());
764 cache.clear();
765 assert_eq!(cache.get_compressed("/a.rs", "map"), None);
766 }
767}