1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::core::pathutil::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29 pub compressed_outputs: HashMap<String, String>,
31 pub full_content_delivered: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct StoreResult {
39 pub line_count: usize,
40 pub original_tokens: usize,
41 pub read_count: u32,
42 pub was_hit: bool,
43 pub full_content_delivered: bool,
45}
46
47impl CacheEntry {
48 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
50 let elapsed = now
51 .checked_duration_since(self.last_access)
52 .unwrap_or_default()
53 .as_secs_f64();
54 let recency = 1.0 / (1.0 + elapsed.sqrt());
55 let frequency = (self.read_count as f64 + 1.0).ln();
56 let size_value = (self.original_tokens as f64 + 1.0).ln();
57 recency * 0.4 + frequency * 0.3 + size_value * 0.3
58 }
59
60 pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
61 self.compressed_outputs.get(mode_key)
62 }
63
64 pub fn set_compressed(&mut self, mode_key: &str, output: String) {
65 const MAX_COMPRESSED_VARIANTS: usize = 3;
66 if self.compressed_outputs.len() >= MAX_COMPRESSED_VARIANTS
67 && !self.compressed_outputs.contains_key(mode_key)
68 {
69 if let Some(oldest_key) = self.compressed_outputs.keys().next().cloned() {
70 self.compressed_outputs.remove(&oldest_key);
71 }
72 }
73 self.compressed_outputs.insert(mode_key.to_string(), output);
74 }
75
76 pub fn mark_full_delivered(&mut self) {
77 self.full_content_delivered = true;
78 }
79}
80
81const RRF_K: f64 = 60.0;
82
83pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
88 if entries.is_empty() {
89 return Vec::new();
90 }
91
92 let n = entries.len();
93
94 let mut recency_order: Vec<usize> = (0..n).collect();
95 recency_order.sort_by(|&a, &b| {
96 let elapsed_a = now
97 .checked_duration_since(entries[a].1.last_access)
98 .unwrap_or_default()
99 .as_secs_f64();
100 let elapsed_b = now
101 .checked_duration_since(entries[b].1.last_access)
102 .unwrap_or_default()
103 .as_secs_f64();
104 elapsed_a
105 .partial_cmp(&elapsed_b)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108
109 let mut frequency_order: Vec<usize> = (0..n).collect();
110 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
111
112 let mut size_order: Vec<usize> = (0..n).collect();
113 size_order.sort_by(|&a, &b| {
114 entries[b]
115 .1
116 .original_tokens
117 .cmp(&entries[a].1.original_tokens)
118 });
119
120 let mut recency_ranks = vec![0usize; n];
121 let mut frequency_ranks = vec![0usize; n];
122 let mut size_ranks = vec![0usize; n];
123
124 for (rank, &idx) in recency_order.iter().enumerate() {
125 recency_ranks[idx] = rank;
126 }
127 for (rank, &idx) in frequency_order.iter().enumerate() {
128 frequency_ranks[idx] = rank;
129 }
130 for (rank, &idx) in size_order.iter().enumerate() {
131 size_ranks[idx] = rank;
132 }
133
134 entries
135 .iter()
136 .enumerate()
137 .map(|(i, (path, _))| {
138 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
139 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
140 + 1.0 / (RRF_K + size_ranks[i] as f64);
141 ((*path).clone(), score)
142 })
143 .collect()
144}
145
146#[derive(Debug)]
148pub struct CacheStats {
149 pub total_reads: u64,
150 pub cache_hits: u64,
151 pub total_original_tokens: u64,
152 pub total_sent_tokens: u64,
153 pub files_tracked: usize,
154}
155
156impl CacheStats {
157 pub fn hit_rate(&self) -> f64 {
159 if self.total_reads == 0 {
160 return 0.0;
161 }
162 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
163 }
164
165 pub fn tokens_saved(&self) -> u64 {
167 self.total_original_tokens
168 .saturating_sub(self.total_sent_tokens)
169 }
170
171 pub fn savings_percent(&self) -> f64 {
173 if self.total_original_tokens == 0 {
174 return 0.0;
175 }
176 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
177 }
178}
179
180#[derive(Clone, Debug)]
182pub struct SharedBlock {
183 pub canonical_path: String,
184 pub canonical_ref: String,
185 pub start_line: usize,
186 pub end_line: usize,
187 pub content: String,
188}
189
190pub struct SessionCache {
193 entries: HashMap<String, CacheEntry>,
194 file_refs: HashMap<String, String>,
195 next_ref: usize,
196 stats: CacheStats,
197 shared_blocks: Vec<SharedBlock>,
198}
199
200impl Default for SessionCache {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl SessionCache {
207 pub fn new() -> Self {
209 Self {
210 entries: HashMap::new(),
211 file_refs: HashMap::new(),
212 next_ref: 1,
213 shared_blocks: Vec::new(),
214 stats: CacheStats {
215 total_reads: 0,
216 cache_hits: 0,
217 total_original_tokens: 0,
218 total_sent_tokens: 0,
219 files_tracked: 0,
220 },
221 }
222 }
223
224 pub fn get_file_ref(&mut self, path: &str) -> String {
226 let key = normalize_key(path);
227 if let Some(r) = self.file_refs.get(&key) {
228 return r.clone();
229 }
230 let r = format!("F{}", self.next_ref);
231 self.next_ref += 1;
232 self.file_refs.insert(key, r.clone());
233 r
234 }
235
236 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
238 self.file_refs.get(&normalize_key(path)).cloned()
239 }
240
241 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
243 self.entries.get(&normalize_key(path))
244 }
245
246 pub fn get_full_content(&self, path: &str) -> Option<String> {
249 self.entries
250 .get(&normalize_key(path))
251 .map(|e| e.content.clone())
252 }
253
254 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
256 let key = normalize_key(path);
257 let ref_label = self
258 .file_refs
259 .get(&key)
260 .cloned()
261 .unwrap_or_else(|| "F?".to_string());
262 if let Some(entry) = self.entries.get_mut(&key) {
263 entry.read_count += 1;
264 entry.last_access = Instant::now();
265 self.stats.total_reads += 1;
266 self.stats.cache_hits += 1;
267 self.stats.total_original_tokens += entry.original_tokens as u64;
268 let hit_msg = format!(
269 "{ref_label} cached {}t {}L",
270 entry.read_count, entry.line_count
271 );
272 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
273 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
274 Some(entry)
275 } else {
276 None
277 }
278 }
279
280 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
282 let key = normalize_key(path);
283 let hash = compute_md5(&content);
284 let line_count = content.lines().count();
285 let original_tokens = count_tokens(&content);
286 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
287 let now = Instant::now();
288
289 self.stats.total_reads += 1;
290 self.stats.total_original_tokens += original_tokens as u64;
291
292 if let Some(existing) = self.entries.get_mut(&key) {
293 existing.last_access = now;
294 if stored_mtime.is_some() {
295 existing.stored_mtime = stored_mtime;
296 }
297 if existing.hash == hash {
298 existing.read_count += 1;
299 self.stats.cache_hits += 1;
300 let hit_msg = format!(
301 "{} cached {}t {}L",
302 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
303 existing.read_count,
304 existing.line_count,
305 );
306 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
307 return StoreResult {
308 line_count: existing.line_count,
309 original_tokens: existing.original_tokens,
310 read_count: existing.read_count,
311 was_hit: true,
312 full_content_delivered: existing.full_content_delivered,
313 };
314 }
315 existing.compressed_outputs.clear();
316 existing.content = content;
317 existing.hash = hash;
318 existing.line_count = line_count;
319 existing.original_tokens = original_tokens;
320 existing.read_count += 1;
321 existing.full_content_delivered = false;
322 if stored_mtime.is_some() {
323 existing.stored_mtime = stored_mtime;
324 }
325 self.stats.total_sent_tokens += original_tokens as u64;
326 return StoreResult {
327 line_count,
328 original_tokens,
329 read_count: existing.read_count,
330 was_hit: false,
331 full_content_delivered: false,
332 };
333 }
334
335 self.evict_if_needed(original_tokens);
336 self.get_file_ref(&key);
337
338 let entry = CacheEntry {
339 content,
340 hash,
341 line_count,
342 original_tokens,
343 read_count: 1,
344 path: key.clone(),
345 last_access: now,
346 stored_mtime,
347 compressed_outputs: HashMap::new(),
348 full_content_delivered: false,
349 };
350
351 self.entries.insert(key, entry);
352 self.stats.files_tracked += 1;
353 self.stats.total_sent_tokens += original_tokens as u64;
354 StoreResult {
355 line_count,
356 original_tokens,
357 read_count: 1,
358 was_hit: false,
359 full_content_delivered: false,
360 }
361 }
362
363 pub fn total_cached_tokens(&self) -> usize {
365 self.entries.values().map(|e| e.original_tokens).sum()
366 }
367
368 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
372 let max_tokens = max_cache_tokens();
373 let current = self.total_cached_tokens();
374 if current + incoming_tokens <= max_tokens {
375 return;
376 }
377
378 let mut freed = 0usize;
379 let target = (current + incoming_tokens).saturating_sub(max_tokens);
380
381 let mut probationary: Vec<(String, Instant)> = self
382 .entries
383 .iter()
384 .filter(|(_, e)| e.read_count <= 1)
385 .map(|(p, e)| (p.clone(), e.last_access))
386 .collect();
387 probationary.sort_by_key(|(_, t)| *t);
388
389 let mut protected: Vec<(String, Instant)> = self
390 .entries
391 .iter()
392 .filter(|(_, e)| e.read_count > 1)
393 .map(|(p, e)| (p.clone(), e.last_access))
394 .collect();
395 protected.sort_by_key(|(_, t)| *t);
396
397 for (path, _) in probationary.into_iter().chain(protected) {
398 if freed >= target {
399 break;
400 }
401 if let Some(entry) = self.entries.remove(&path) {
402 freed += entry.original_tokens;
403 self.file_refs.remove(&path);
404 }
405 }
406 }
407
408 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
410 self.entries.iter().collect()
411 }
412
413 pub fn get_stats(&self) -> &CacheStats {
415 &self.stats
416 }
417
418 pub fn file_ref_map(&self) -> &HashMap<String, String> {
420 &self.file_refs
421 }
422
423 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
425 self.shared_blocks = blocks;
426 }
427
428 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
430 &self.shared_blocks
431 }
432
433 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
435 if self.shared_blocks.is_empty() {
436 return None;
437 }
438 let refs: Vec<&SharedBlock> = self
439 .shared_blocks
440 .iter()
441 .filter(|b| b.canonical_path != path && content.contains(&b.content))
442 .collect();
443 if refs.is_empty() {
444 return None;
445 }
446 let mut result = content.to_string();
447 for block in refs {
448 result = result.replacen(
449 &block.content,
450 &format!(
451 "[= {}:{}-{}]",
452 block.canonical_ref, block.start_line, block.end_line
453 ),
454 1,
455 );
456 }
457 Some(result)
458 }
459
460 pub fn invalidate(&mut self, path: &str) -> bool {
462 self.entries.remove(&normalize_key(path)).is_some()
463 }
464
465 pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
467 self.entries
468 .get(&normalize_key(path))?
469 .get_compressed(mode_key)
470 }
471
472 pub fn mark_full_delivered(&mut self, path: &str) {
474 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
475 entry.mark_full_delivered();
476 }
477 }
478
479 pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
481 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
482 entry.set_compressed(mode_key, output);
483 }
484 }
485
486 pub fn clear(&mut self) -> usize {
488 let count = self.entries.len();
489 self.entries.clear();
490 self.file_refs.clear();
491 self.shared_blocks.clear();
492 self.next_ref = 1;
493 self.stats = CacheStats {
494 total_reads: 0,
495 cache_hits: 0,
496 total_original_tokens: 0,
497 total_sent_tokens: 0,
498 files_tracked: 0,
499 };
500 count
501 }
502}
503
504pub fn file_mtime(path: &str) -> Option<SystemTime> {
505 std::fs::metadata(path).and_then(|m| m.modified()).ok()
506}
507
508pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
509 let current = file_mtime(path);
510 match (cached_mtime, current) {
511 (_, None) => false,
512 (None, Some(_)) => true,
513 (Some(cached), Some(current)) => current > cached,
514 }
515}
516
517fn compute_md5(content: &str) -> String {
518 let mut hasher = Md5::new();
519 hasher.update(content.as_bytes());
520 format!("{:x}", hasher.finalize())
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use std::time::Duration;
527
528 #[test]
529 fn cache_stores_and_retrieves() {
530 let mut cache = SessionCache::new();
531 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
532 assert!(!result.was_hit);
533 assert_eq!(result.line_count, 1);
534 assert!(cache.get("/test/file.rs").is_some());
535 }
536
537 #[test]
538 fn cache_hit_on_same_content() {
539 let mut cache = SessionCache::new();
540 cache.store("/test/file.rs", "content".to_string());
541 let result = cache.store("/test/file.rs", "content".to_string());
542 assert!(result.was_hit, "same content should be a cache hit");
543 }
544
545 #[test]
546 fn cache_miss_on_changed_content() {
547 let mut cache = SessionCache::new();
548 cache.store("/test/file.rs", "old content".to_string());
549 let result = cache.store("/test/file.rs", "new content".to_string());
550 assert!(!result.was_hit, "changed content should not be a cache hit");
551 }
552
553 #[test]
554 fn file_refs_are_sequential() {
555 let mut cache = SessionCache::new();
556 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
557 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
558 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
560
561 #[test]
562 fn cache_clear_resets_everything() {
563 let mut cache = SessionCache::new();
564 cache.store("/a.rs", "a".to_string());
565 cache.store("/b.rs", "b".to_string());
566 let count = cache.clear();
567 assert_eq!(count, 2);
568 assert!(cache.get("/a.rs").is_none());
569 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
571
572 #[test]
573 fn cache_invalidate_removes_entry() {
574 let mut cache = SessionCache::new();
575 cache.store("/test.rs", "test".to_string());
576 assert!(cache.invalidate("/test.rs"));
577 assert!(!cache.invalidate("/nonexistent.rs"));
578 }
579
580 #[test]
581 fn cache_stats_track_correctly() {
582 let mut cache = SessionCache::new();
583 cache.store("/a.rs", "hello".to_string());
584 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
586 assert_eq!(stats.total_reads, 2);
587 assert_eq!(stats.cache_hits, 1);
588 assert!(stats.hit_rate() > 0.0);
589 }
590
591 #[test]
592 fn md5_is_deterministic() {
593 let h1 = compute_md5("test content");
594 let h2 = compute_md5("test content");
595 assert_eq!(h1, h2);
596 assert_ne!(h1, compute_md5("different"));
597 }
598
599 #[test]
600 fn rrf_eviction_prefers_recent() {
601 let base = Instant::now();
602 std::thread::sleep(std::time::Duration::from_millis(5));
603 let now = Instant::now();
604 let key_a = "a.rs".to_string();
605 let key_b = "b.rs".to_string();
606 let recent = CacheEntry {
607 content: "a".to_string(),
608 hash: "h1".to_string(),
609 line_count: 1,
610 original_tokens: 10,
611 read_count: 1,
612 path: "/a.rs".to_string(),
613 last_access: now,
614 stored_mtime: None,
615 compressed_outputs: HashMap::new(),
616 full_content_delivered: false,
617 };
618 let old = CacheEntry {
619 content: "b".to_string(),
620 hash: "h2".to_string(),
621 line_count: 1,
622 original_tokens: 10,
623 read_count: 1,
624 path: "/b.rs".to_string(),
625 last_access: base,
626 stored_mtime: None,
627 compressed_outputs: HashMap::new(),
628 full_content_delivered: false,
629 };
630 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
631 let scores = eviction_scores_rrf(&entries, now);
632 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
633 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
634 assert!(
635 score_a > score_b,
636 "recently accessed entries should score higher via RRF"
637 );
638 }
639
640 #[test]
641 fn rrf_eviction_prefers_frequent() {
642 let now = Instant::now();
643 let key_a = "a.rs".to_string();
644 let key_b = "b.rs".to_string();
645 let frequent = CacheEntry {
646 content: "a".to_string(),
647 hash: "h1".to_string(),
648 line_count: 1,
649 original_tokens: 10,
650 read_count: 20,
651 path: "/a.rs".to_string(),
652 last_access: now,
653 stored_mtime: None,
654 compressed_outputs: HashMap::new(),
655 full_content_delivered: false,
656 };
657 let rare = CacheEntry {
658 content: "b".to_string(),
659 hash: "h2".to_string(),
660 line_count: 1,
661 original_tokens: 10,
662 read_count: 1,
663 path: "/b.rs".to_string(),
664 last_access: now,
665 stored_mtime: None,
666 compressed_outputs: HashMap::new(),
667 full_content_delivered: false,
668 };
669 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
670 let scores = eviction_scores_rrf(&entries, now);
671 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
672 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
673 assert!(
674 score_a > score_b,
675 "frequently accessed entries should score higher via RRF"
676 );
677 }
678
679 #[test]
680 fn evict_if_needed_removes_lowest_score() {
681 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
682 let mut cache = SessionCache::new();
683 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
685 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
689 assert!(
694 cache.total_cached_tokens() <= 60,
695 "eviction should have kicked in"
696 );
697 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
698 }
699
700 #[test]
701 fn stale_detection_flags_newer_file() {
702 let dir = tempfile::tempdir().unwrap();
703 let path = dir.path().join("stale.txt");
704 let p = path.to_string_lossy().to_string();
705
706 std::fs::write(&path, "one").unwrap();
707 let mut cache = SessionCache::new();
708 cache.store(&p, "one".to_string());
709
710 let entry = cache.get(&p).unwrap();
711 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
712
713 std::thread::sleep(Duration::from_secs(1));
715 std::fs::write(&path, "two").unwrap();
716
717 let entry = cache.get(&p).unwrap();
718 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
719 }
720
721 #[test]
722 fn compressed_outputs_cached_and_retrieved() {
723 let mut cache = SessionCache::new();
724 cache.store("/test.rs", "fn main() {}".to_string());
725 cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
726 assert_eq!(
727 cache.get_compressed("/test.rs", "map"),
728 Some(&"compressed map output".to_string())
729 );
730 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
731 }
732
733 #[test]
734 fn compressed_outputs_cleared_on_content_change() {
735 let mut cache = SessionCache::new();
736 cache.store("/test.rs", "old content".to_string());
737 cache.set_compressed("/test.rs", "map", "old map".to_string());
738 assert!(cache.get_compressed("/test.rs", "map").is_some());
739
740 cache.store("/test.rs", "new content".to_string());
741 assert_eq!(cache.get_compressed("/test.rs", "map"), None);
742 }
743
744 #[test]
745 fn compressed_outputs_survive_same_content_store() {
746 let mut cache = SessionCache::new();
747 cache.store("/test.rs", "content".to_string());
748 cache.set_compressed("/test.rs", "map", "cached map".to_string());
749
750 let result = cache.store("/test.rs", "content".to_string());
751 assert!(result.was_hit);
752 assert_eq!(
753 cache.get_compressed("/test.rs", "map"),
754 Some(&"cached map".to_string())
755 );
756 }
757
758 #[test]
759 fn compressed_outputs_cleared_on_invalidate() {
760 let mut cache = SessionCache::new();
761 cache.store("/test.rs", "content".to_string());
762 cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
763 cache.invalidate("/test.rs");
764 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
765 }
766
767 #[test]
768 fn compressed_outputs_cleared_on_clear() {
769 let mut cache = SessionCache::new();
770 cache.store("/a.rs", "a".to_string());
771 cache.set_compressed("/a.rs", "map", "map_a".to_string());
772 cache.clear();
773 assert_eq!(cache.get_compressed("/a.rs", "map"), None);
774 }
775}