1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::core::pathutil::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 compressed_content: Vec<u8>,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29 pub compressed_outputs: HashMap<String, String>,
31 pub full_content_delivered: bool,
34}
35
36const ZSTD_LEVEL: i32 = 3;
37
38fn zstd_compress(data: &str) -> Vec<u8> {
39 zstd::encode_all(data.as_bytes(), ZSTD_LEVEL).unwrap_or_else(|_| data.as_bytes().to_vec())
40}
41
42fn zstd_decompress(data: &[u8]) -> String {
43 zstd::decode_all(data)
44 .ok()
45 .and_then(|v| String::from_utf8(v).ok())
46 .unwrap_or_default()
47}
48
49impl CacheEntry {
50 pub fn new(
52 content: &str,
53 hash: String,
54 line_count: usize,
55 original_tokens: usize,
56 path: String,
57 stored_mtime: Option<SystemTime>,
58 ) -> Self {
59 let compressed_content = zstd_compress(content);
60 Self {
61 compressed_content,
62 hash,
63 line_count,
64 original_tokens,
65 read_count: 1,
66 path,
67 last_access: Instant::now(),
68 stored_mtime,
69 compressed_outputs: HashMap::new(),
70 full_content_delivered: false,
71 }
72 }
73
74 pub fn content(&self) -> String {
76 zstd_decompress(&self.compressed_content)
77 }
78
79 pub fn set_content(&mut self, content: &str) {
81 self.compressed_content = zstd_compress(content);
82 }
83
84 pub fn compressed_size(&self) -> usize {
86 self.compressed_content.len()
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct StoreResult {
93 pub line_count: usize,
94 pub original_tokens: usize,
95 pub read_count: u32,
96 pub was_hit: bool,
97 pub full_content_delivered: bool,
99}
100
101impl CacheEntry {
102 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
104 let elapsed = now
105 .checked_duration_since(self.last_access)
106 .unwrap_or_default()
107 .as_secs_f64();
108 let recency = 1.0 / (1.0 + elapsed.sqrt());
109 let frequency = (self.read_count as f64 + 1.0).ln();
110 let size_value = (self.original_tokens as f64 + 1.0).ln();
111 recency * 0.4 + frequency * 0.3 + size_value * 0.3
112 }
113
114 pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
115 self.compressed_outputs.get(mode_key)
116 }
117
118 pub fn set_compressed(&mut self, mode_key: &str, output: String) {
119 const MAX_COMPRESSED_VARIANTS: usize = 3;
120 if self.compressed_outputs.len() >= MAX_COMPRESSED_VARIANTS
121 && !self.compressed_outputs.contains_key(mode_key)
122 {
123 if let Some(oldest_key) = self.compressed_outputs.keys().next().cloned() {
124 self.compressed_outputs.remove(&oldest_key);
125 }
126 }
127 self.compressed_outputs.insert(mode_key.to_string(), output);
128 }
129
130 pub fn mark_full_delivered(&mut self) {
131 self.full_content_delivered = true;
132 }
133}
134
135const RRF_K: f64 = 60.0;
136
137pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
142 if entries.is_empty() {
143 return Vec::new();
144 }
145
146 let n = entries.len();
147
148 let mut recency_order: Vec<usize> = (0..n).collect();
149 recency_order.sort_by(|&a, &b| {
150 let elapsed_a = now
151 .checked_duration_since(entries[a].1.last_access)
152 .unwrap_or_default()
153 .as_secs_f64();
154 let elapsed_b = now
155 .checked_duration_since(entries[b].1.last_access)
156 .unwrap_or_default()
157 .as_secs_f64();
158 elapsed_a
159 .partial_cmp(&elapsed_b)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162
163 let mut frequency_order: Vec<usize> = (0..n).collect();
164 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
165
166 let mut size_order: Vec<usize> = (0..n).collect();
167 size_order.sort_by(|&a, &b| {
168 entries[b]
169 .1
170 .original_tokens
171 .cmp(&entries[a].1.original_tokens)
172 });
173
174 let mut recency_ranks = vec![0usize; n];
175 let mut frequency_ranks = vec![0usize; n];
176 let mut size_ranks = vec![0usize; n];
177
178 for (rank, &idx) in recency_order.iter().enumerate() {
179 recency_ranks[idx] = rank;
180 }
181 for (rank, &idx) in frequency_order.iter().enumerate() {
182 frequency_ranks[idx] = rank;
183 }
184 for (rank, &idx) in size_order.iter().enumerate() {
185 size_ranks[idx] = rank;
186 }
187
188 entries
189 .iter()
190 .enumerate()
191 .map(|(i, (path, _))| {
192 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
193 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
194 + 1.0 / (RRF_K + size_ranks[i] as f64);
195 ((*path).clone(), score)
196 })
197 .collect()
198}
199
200#[derive(Debug)]
202pub struct CacheStats {
203 pub total_reads: u64,
204 pub cache_hits: u64,
205 pub total_original_tokens: u64,
206 pub total_sent_tokens: u64,
207 pub files_tracked: usize,
208}
209
210impl CacheStats {
211 pub fn hit_rate(&self) -> f64 {
213 if self.total_reads == 0 {
214 return 0.0;
215 }
216 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
217 }
218
219 pub fn tokens_saved(&self) -> u64 {
221 self.total_original_tokens
222 .saturating_sub(self.total_sent_tokens)
223 }
224
225 pub fn savings_percent(&self) -> f64 {
227 if self.total_original_tokens == 0 {
228 return 0.0;
229 }
230 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
231 }
232}
233
234#[derive(Clone, Debug)]
236pub struct SharedBlock {
237 pub canonical_path: String,
238 pub canonical_ref: String,
239 pub start_line: usize,
240 pub end_line: usize,
241 pub content: String,
242}
243
244pub struct SessionCache {
247 entries: HashMap<String, CacheEntry>,
248 file_refs: HashMap<String, String>,
249 next_ref: usize,
250 stats: CacheStats,
251 shared_blocks: Vec<SharedBlock>,
252}
253
254impl Default for SessionCache {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl SessionCache {
261 pub fn new() -> Self {
263 Self {
264 entries: HashMap::new(),
265 file_refs: HashMap::new(),
266 next_ref: 1,
267 shared_blocks: Vec::new(),
268 stats: CacheStats {
269 total_reads: 0,
270 cache_hits: 0,
271 total_original_tokens: 0,
272 total_sent_tokens: 0,
273 files_tracked: 0,
274 },
275 }
276 }
277
278 pub fn get_file_ref(&mut self, path: &str) -> String {
280 let key = normalize_key(path);
281 if let Some(r) = self.file_refs.get(&key) {
282 return r.clone();
283 }
284 let r = format!("F{}", self.next_ref);
285 self.next_ref += 1;
286 self.file_refs.insert(key, r.clone());
287 r
288 }
289
290 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
292 self.file_refs.get(&normalize_key(path)).cloned()
293 }
294
295 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
297 self.entries.get(&normalize_key(path))
298 }
299
300 pub fn get_full_content(&self, path: &str) -> Option<String> {
303 self.entries
304 .get(&normalize_key(path))
305 .map(CacheEntry::content)
306 }
307
308 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
310 let key = normalize_key(path);
311 let ref_label = self
312 .file_refs
313 .get(&key)
314 .cloned()
315 .unwrap_or_else(|| "F?".to_string());
316 if let Some(entry) = self.entries.get_mut(&key) {
317 entry.read_count += 1;
318 entry.last_access = Instant::now();
319 self.stats.total_reads += 1;
320 self.stats.cache_hits += 1;
321 self.stats.total_original_tokens += entry.original_tokens as u64;
322 let hit_msg = format!(
323 "{ref_label} cached {}t {}L",
324 entry.read_count, entry.line_count
325 );
326 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
327 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
328 Some(entry)
329 } else {
330 None
331 }
332 }
333
334 pub fn store(&mut self, path: &str, content: &str) -> StoreResult {
336 let key = normalize_key(path);
337 let hash = compute_md5(content);
338 let line_count = content.lines().count();
339 let original_tokens = count_tokens(content);
340 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
341 let now = Instant::now();
342
343 self.stats.total_reads += 1;
344 self.stats.total_original_tokens += original_tokens as u64;
345
346 if let Some(existing) = self.entries.get_mut(&key) {
347 existing.last_access = now;
348 if stored_mtime.is_some() {
349 existing.stored_mtime = stored_mtime;
350 }
351 if existing.hash == hash {
352 existing.read_count += 1;
353 self.stats.cache_hits += 1;
354 let hit_msg = format!(
355 "{} cached {}t {}L",
356 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
357 existing.read_count,
358 existing.line_count,
359 );
360 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
361 return StoreResult {
362 line_count: existing.line_count,
363 original_tokens: existing.original_tokens,
364 read_count: existing.read_count,
365 was_hit: true,
366 full_content_delivered: existing.full_content_delivered,
367 };
368 }
369 existing.compressed_outputs.clear();
370 existing.set_content(content);
371 existing.hash = hash;
372 existing.line_count = line_count;
373 existing.original_tokens = original_tokens;
374 existing.read_count += 1;
375 existing.full_content_delivered = false;
376 if stored_mtime.is_some() {
377 existing.stored_mtime = stored_mtime;
378 }
379 self.stats.total_sent_tokens += original_tokens as u64;
380 return StoreResult {
381 line_count,
382 original_tokens,
383 read_count: existing.read_count,
384 was_hit: false,
385 full_content_delivered: false,
386 };
387 }
388
389 self.evict_if_needed(original_tokens);
390 self.get_file_ref(&key);
391
392 let entry = CacheEntry::new(
393 content,
394 hash,
395 line_count,
396 original_tokens,
397 key.clone(),
398 stored_mtime,
399 );
400
401 self.entries.insert(key, entry);
402 self.stats.files_tracked += 1;
403 self.stats.total_sent_tokens += original_tokens as u64;
404 StoreResult {
405 line_count,
406 original_tokens,
407 read_count: 1,
408 was_hit: false,
409 full_content_delivered: false,
410 }
411 }
412
413 pub fn total_cached_tokens(&self) -> usize {
415 self.entries.values().map(|e| e.original_tokens).sum()
416 }
417
418 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
422 let max_tokens = max_cache_tokens();
423 let current = self.total_cached_tokens();
424 if current + incoming_tokens <= max_tokens {
425 return;
426 }
427
428 let mut freed = 0usize;
429 let target = (current + incoming_tokens).saturating_sub(max_tokens);
430
431 let mut probationary: Vec<(String, Instant)> = self
432 .entries
433 .iter()
434 .filter(|(_, e)| e.read_count <= 1)
435 .map(|(p, e)| (p.clone(), e.last_access))
436 .collect();
437 probationary.sort_by_key(|(_, t)| *t);
438
439 let mut protected: Vec<(String, Instant)> = self
440 .entries
441 .iter()
442 .filter(|(_, e)| e.read_count > 1)
443 .map(|(p, e)| (p.clone(), e.last_access))
444 .collect();
445 protected.sort_by_key(|(_, t)| *t);
446
447 for (path, _) in probationary.into_iter().chain(protected) {
448 if freed >= target {
449 break;
450 }
451 if let Some(entry) = self.entries.remove(&path) {
452 freed += entry.original_tokens;
453 self.file_refs.remove(&path);
454 }
455 }
456 }
457
458 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
460 self.entries.iter().collect()
461 }
462
463 pub fn get_stats(&self) -> &CacheStats {
465 &self.stats
466 }
467
468 pub fn file_ref_map(&self) -> &HashMap<String, String> {
470 &self.file_refs
471 }
472
473 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
475 self.shared_blocks = blocks;
476 }
477
478 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
480 &self.shared_blocks
481 }
482
483 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
485 if self.shared_blocks.is_empty() {
486 return None;
487 }
488 let refs: Vec<&SharedBlock> = self
489 .shared_blocks
490 .iter()
491 .filter(|b| b.canonical_path != path && content.contains(&b.content))
492 .collect();
493 if refs.is_empty() {
494 return None;
495 }
496 let mut result = content.to_string();
497 for block in refs {
498 result = result.replacen(
499 &block.content,
500 &format!(
501 "[= {}:{}-{}]",
502 block.canonical_ref, block.start_line, block.end_line
503 ),
504 1,
505 );
506 }
507 Some(result)
508 }
509
510 pub fn invalidate(&mut self, path: &str) -> bool {
512 self.entries.remove(&normalize_key(path)).is_some()
513 }
514
515 pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
517 self.entries
518 .get(&normalize_key(path))?
519 .get_compressed(mode_key)
520 }
521
522 pub fn mark_full_delivered(&mut self, path: &str) {
524 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
525 entry.mark_full_delivered();
526 }
527 }
528
529 pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
531 if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
532 entry.set_compressed(mode_key, output);
533 }
534 }
535
536 pub fn clear(&mut self) -> usize {
538 let count = self.entries.len();
539 self.entries.clear();
540 self.file_refs.clear();
541 self.shared_blocks.clear();
542 self.next_ref = 1;
543 self.stats = CacheStats {
544 total_reads: 0,
545 cache_hits: 0,
546 total_original_tokens: 0,
547 total_sent_tokens: 0,
548 files_tracked: 0,
549 };
550 count
551 }
552}
553
554pub fn file_mtime(path: &str) -> Option<SystemTime> {
555 std::fs::metadata(path).and_then(|m| m.modified()).ok()
556}
557
558pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
559 let current = file_mtime(path);
560 match (cached_mtime, current) {
561 (_, None) => false,
562 (None, Some(_)) => true,
563 (Some(cached), Some(current)) => current > cached,
564 }
565}
566
567fn compute_md5(content: &str) -> String {
568 let mut hasher = Md5::new();
569 hasher.update(content.as_bytes());
570 format!("{:x}", hasher.finalize())
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use std::time::Duration;
577
578 #[test]
579 fn cache_stores_and_retrieves() {
580 let mut cache = SessionCache::new();
581 let result = cache.store("/test/file.rs", "fn main() {}");
582 assert!(!result.was_hit);
583 assert_eq!(result.line_count, 1);
584 assert!(cache.get("/test/file.rs").is_some());
585 }
586
587 #[test]
588 fn cache_hit_on_same_content() {
589 let mut cache = SessionCache::new();
590 cache.store("/test/file.rs", "content");
591 let result = cache.store("/test/file.rs", "content");
592 assert!(result.was_hit, "same content should be a cache hit");
593 }
594
595 #[test]
596 fn cache_miss_on_changed_content() {
597 let mut cache = SessionCache::new();
598 cache.store("/test/file.rs", "old content");
599 let result = cache.store("/test/file.rs", "new content");
600 assert!(!result.was_hit, "changed content should not be a cache hit");
601 }
602
603 #[test]
604 fn file_refs_are_sequential() {
605 let mut cache = SessionCache::new();
606 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
607 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
608 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
610
611 #[test]
612 fn cache_clear_resets_everything() {
613 let mut cache = SessionCache::new();
614 cache.store("/a.rs", "a");
615 cache.store("/b.rs", "b");
616 let count = cache.clear();
617 assert_eq!(count, 2);
618 assert!(cache.get("/a.rs").is_none());
619 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
621
622 #[test]
623 fn cache_invalidate_removes_entry() {
624 let mut cache = SessionCache::new();
625 cache.store("/test.rs", "test");
626 assert!(cache.invalidate("/test.rs"));
627 assert!(!cache.invalidate("/nonexistent.rs"));
628 }
629
630 #[test]
631 fn cache_stats_track_correctly() {
632 let mut cache = SessionCache::new();
633 cache.store("/a.rs", "hello");
634 cache.store("/a.rs", "hello"); let stats = cache.get_stats();
636 assert_eq!(stats.total_reads, 2);
637 assert_eq!(stats.cache_hits, 1);
638 assert!(stats.hit_rate() > 0.0);
639 }
640
641 #[test]
642 fn md5_is_deterministic() {
643 let h1 = compute_md5("test content");
644 let h2 = compute_md5("test content");
645 assert_eq!(h1, h2);
646 assert_ne!(h1, compute_md5("different"));
647 }
648
649 #[test]
650 fn rrf_eviction_prefers_recent() {
651 let base = Instant::now();
652 std::thread::sleep(std::time::Duration::from_millis(5));
653 let now = Instant::now();
654 let key_a = "a.rs".to_string();
655 let key_b = "b.rs".to_string();
656 let recent = CacheEntry::new("a", "h1".to_string(), 1, 10, "/a.rs".to_string(), None);
657 let old = {
658 let mut e = CacheEntry::new("b", "h2".to_string(), 1, 10, "/b.rs".to_string(), None);
659 e.last_access = base;
660 e
661 };
662 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
663 let scores = eviction_scores_rrf(&entries, now);
664 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
665 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
666 assert!(
667 score_a > score_b,
668 "recently accessed entries should score higher via RRF"
669 );
670 }
671
672 #[test]
673 fn rrf_eviction_prefers_frequent() {
674 let now = Instant::now();
675 let key_a = "a.rs".to_string();
676 let key_b = "b.rs".to_string();
677 let frequent = {
678 let mut e = CacheEntry::new("a", "h1".to_string(), 1, 10, "/a.rs".to_string(), None);
679 e.read_count = 20;
680 e
681 };
682 let rare = CacheEntry::new("b", "h2".to_string(), 1, 10, "/b.rs".to_string(), None);
683 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
684 let scores = eviction_scores_rrf(&entries, now);
685 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
686 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
687 assert!(
688 score_a > score_b,
689 "frequently accessed entries should score higher via RRF"
690 );
691 }
692
693 #[test]
694 fn evict_if_needed_removes_lowest_score() {
695 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
696 let mut cache = SessionCache::new();
697 let big_content = "a]".repeat(30); cache.store("/old.rs", &big_content);
699 let new_content = "b ".repeat(30); cache.store("/new.rs", &new_content);
703 assert!(
708 cache.total_cached_tokens() <= 60,
709 "eviction should have kicked in"
710 );
711 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
712 }
713
714 #[test]
715 fn stale_detection_flags_newer_file() {
716 let dir = tempfile::tempdir().unwrap();
717 let path = dir.path().join("stale.txt");
718 let p = path.to_string_lossy().to_string();
719
720 std::fs::write(&path, "one").unwrap();
721 let mut cache = SessionCache::new();
722 cache.store(&p, "one");
723
724 let entry = cache.get(&p).unwrap();
725 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
726
727 std::thread::sleep(Duration::from_secs(1));
729 std::fs::write(&path, "two").unwrap();
730
731 let entry = cache.get(&p).unwrap();
732 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
733 }
734
735 #[test]
736 fn compressed_outputs_cached_and_retrieved() {
737 let mut cache = SessionCache::new();
738 cache.store("/test.rs", "fn main() {}");
739 cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
740 assert_eq!(
741 cache.get_compressed("/test.rs", "map"),
742 Some(&"compressed map output".to_string())
743 );
744 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
745 }
746
747 #[test]
748 fn compressed_outputs_cleared_on_content_change() {
749 let mut cache = SessionCache::new();
750 cache.store("/test.rs", "old content");
751 cache.set_compressed("/test.rs", "map", "old map".to_string());
752 assert!(cache.get_compressed("/test.rs", "map").is_some());
753
754 cache.store("/test.rs", "new content");
755 assert_eq!(cache.get_compressed("/test.rs", "map"), None);
756 }
757
758 #[test]
759 fn compressed_outputs_survive_same_content_store() {
760 let mut cache = SessionCache::new();
761 cache.store("/test.rs", "content");
762 cache.set_compressed("/test.rs", "map", "cached map".to_string());
763
764 let result = cache.store("/test.rs", "content");
765 assert!(result.was_hit);
766 assert_eq!(
767 cache.get_compressed("/test.rs", "map"),
768 Some(&"cached map".to_string())
769 );
770 }
771
772 #[test]
773 fn compressed_outputs_cleared_on_invalidate() {
774 let mut cache = SessionCache::new();
775 cache.store("/test.rs", "content");
776 cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
777 cache.invalidate("/test.rs");
778 assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
779 }
780
781 #[test]
782 fn compressed_outputs_cleared_on_clear() {
783 let mut cache = SessionCache::new();
784 cache.store("/a.rs", "a");
785 cache.set_compressed("/a.rs", "map", "map_a".to_string());
786 cache.clear();
787 assert_eq!(cache.get_compressed("/a.rs", "map"), None);
788 }
789}