1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28 pub stored_mtime: Option<SystemTime>,
29}
30
31#[derive(Debug, Clone)]
33pub struct StoreResult {
34 pub line_count: usize,
35 pub original_tokens: usize,
36 pub read_count: u32,
37 pub was_hit: bool,
38}
39
40impl CacheEntry {
41 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
43 let elapsed = now
44 .checked_duration_since(self.last_access)
45 .unwrap_or_default()
46 .as_secs_f64();
47 let recency = 1.0 / (1.0 + elapsed.sqrt());
48 let frequency = (self.read_count as f64 + 1.0).ln();
49 let size_value = (self.original_tokens as f64 + 1.0).ln();
50 recency * 0.4 + frequency * 0.3 + size_value * 0.3
51 }
52}
53
54const RRF_K: f64 = 60.0;
55
56pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
61 if entries.is_empty() {
62 return Vec::new();
63 }
64
65 let n = entries.len();
66
67 let mut recency_order: Vec<usize> = (0..n).collect();
68 recency_order.sort_by(|&a, &b| {
69 let elapsed_a = now
70 .checked_duration_since(entries[a].1.last_access)
71 .unwrap_or_default()
72 .as_secs_f64();
73 let elapsed_b = now
74 .checked_duration_since(entries[b].1.last_access)
75 .unwrap_or_default()
76 .as_secs_f64();
77 elapsed_a
78 .partial_cmp(&elapsed_b)
79 .unwrap_or(std::cmp::Ordering::Equal)
80 });
81
82 let mut frequency_order: Vec<usize> = (0..n).collect();
83 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
84
85 let mut size_order: Vec<usize> = (0..n).collect();
86 size_order.sort_by(|&a, &b| {
87 entries[b]
88 .1
89 .original_tokens
90 .cmp(&entries[a].1.original_tokens)
91 });
92
93 let mut recency_ranks = vec![0usize; n];
94 let mut frequency_ranks = vec![0usize; n];
95 let mut size_ranks = vec![0usize; n];
96
97 for (rank, &idx) in recency_order.iter().enumerate() {
98 recency_ranks[idx] = rank;
99 }
100 for (rank, &idx) in frequency_order.iter().enumerate() {
101 frequency_ranks[idx] = rank;
102 }
103 for (rank, &idx) in size_order.iter().enumerate() {
104 size_ranks[idx] = rank;
105 }
106
107 entries
108 .iter()
109 .enumerate()
110 .map(|(i, (path, _))| {
111 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
112 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
113 + 1.0 / (RRF_K + size_ranks[i] as f64);
114 ((*path).clone(), score)
115 })
116 .collect()
117}
118
119#[derive(Debug)]
121pub struct CacheStats {
122 pub total_reads: u64,
123 pub cache_hits: u64,
124 pub total_original_tokens: u64,
125 pub total_sent_tokens: u64,
126 pub files_tracked: usize,
127}
128
129impl CacheStats {
130 pub fn hit_rate(&self) -> f64 {
132 if self.total_reads == 0 {
133 return 0.0;
134 }
135 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
136 }
137
138 pub fn tokens_saved(&self) -> u64 {
140 self.total_original_tokens
141 .saturating_sub(self.total_sent_tokens)
142 }
143
144 pub fn savings_percent(&self) -> f64 {
146 if self.total_original_tokens == 0 {
147 return 0.0;
148 }
149 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
150 }
151}
152
153#[derive(Clone, Debug)]
155pub struct SharedBlock {
156 pub canonical_path: String,
157 pub canonical_ref: String,
158 pub start_line: usize,
159 pub end_line: usize,
160 pub content: String,
161}
162
163pub struct SessionCache {
165 entries: HashMap<String, CacheEntry>,
166 file_refs: HashMap<String, String>,
167 next_ref: usize,
168 stats: CacheStats,
169 shared_blocks: Vec<SharedBlock>,
170}
171
172impl Default for SessionCache {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178impl SessionCache {
179 pub fn new() -> Self {
181 Self {
182 entries: HashMap::new(),
183 file_refs: HashMap::new(),
184 next_ref: 1,
185 shared_blocks: Vec::new(),
186 stats: CacheStats {
187 total_reads: 0,
188 cache_hits: 0,
189 total_original_tokens: 0,
190 total_sent_tokens: 0,
191 files_tracked: 0,
192 },
193 }
194 }
195
196 pub fn get_file_ref(&mut self, path: &str) -> String {
198 let key = normalize_key(path);
199 if let Some(r) = self.file_refs.get(&key) {
200 return r.clone();
201 }
202 let r = format!("F{}", self.next_ref);
203 self.next_ref += 1;
204 self.file_refs.insert(key, r.clone());
205 r
206 }
207
208 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
210 self.file_refs.get(&normalize_key(path)).cloned()
211 }
212
213 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
215 self.entries.get(&normalize_key(path))
216 }
217
218 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
220 let key = normalize_key(path);
221 let ref_label = self
222 .file_refs
223 .get(&key)
224 .cloned()
225 .unwrap_or_else(|| "F?".to_string());
226 if let Some(entry) = self.entries.get_mut(&key) {
227 entry.read_count += 1;
228 entry.last_access = Instant::now();
229 self.stats.total_reads += 1;
230 self.stats.cache_hits += 1;
231 self.stats.total_original_tokens += entry.original_tokens as u64;
232 let hit_msg = format!(
233 "{ref_label} cached {}t {}L",
234 entry.read_count, entry.line_count
235 );
236 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
237 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
238 Some(entry)
239 } else {
240 None
241 }
242 }
243
244 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
246 let key = normalize_key(path);
247 let hash = compute_md5(&content);
248 let line_count = content.lines().count();
249 let original_tokens = count_tokens(&content);
250 let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
251 let now = Instant::now();
252
253 self.stats.total_reads += 1;
254 self.stats.total_original_tokens += original_tokens as u64;
255
256 if let Some(existing) = self.entries.get_mut(&key) {
257 existing.last_access = now;
258 if stored_mtime.is_some() {
259 existing.stored_mtime = stored_mtime;
260 }
261 if existing.hash == hash {
262 existing.read_count += 1;
263 self.stats.cache_hits += 1;
264 let hit_msg = format!(
265 "{} cached {}t {}L",
266 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
267 existing.read_count,
268 existing.line_count,
269 );
270 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
271 return StoreResult {
272 line_count: existing.line_count,
273 original_tokens: existing.original_tokens,
274 read_count: existing.read_count,
275 was_hit: true,
276 };
277 }
278 existing.content = content;
279 existing.hash = hash;
280 existing.line_count = line_count;
281 existing.original_tokens = original_tokens;
282 existing.read_count += 1;
283 if stored_mtime.is_some() {
284 existing.stored_mtime = stored_mtime;
285 }
286 self.stats.total_sent_tokens += original_tokens as u64;
287 return StoreResult {
288 line_count,
289 original_tokens,
290 read_count: existing.read_count,
291 was_hit: false,
292 };
293 }
294
295 self.evict_if_needed(original_tokens);
296 self.get_file_ref(&key);
297
298 let entry = CacheEntry {
299 content,
300 hash,
301 line_count,
302 original_tokens,
303 read_count: 1,
304 path: key.clone(),
305 last_access: now,
306 stored_mtime,
307 };
308
309 self.entries.insert(key, entry);
310 self.stats.files_tracked += 1;
311 self.stats.total_sent_tokens += original_tokens as u64;
312 StoreResult {
313 line_count,
314 original_tokens,
315 read_count: 1,
316 was_hit: false,
317 }
318 }
319
320 pub fn total_cached_tokens(&self) -> usize {
322 self.entries.values().map(|e| e.original_tokens).sum()
323 }
324
325 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
327 let max_tokens = max_cache_tokens();
328 let current = self.total_cached_tokens();
329 if current + incoming_tokens <= max_tokens {
330 return;
331 }
332
333 let now = Instant::now();
334 let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
335 let mut scored = eviction_scores_rrf(&all_entries, now);
336 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
337
338 let mut freed = 0usize;
339 let target = (current + incoming_tokens).saturating_sub(max_tokens);
340 for (path, _score) in &scored {
341 if freed >= target {
342 break;
343 }
344 if let Some(entry) = self.entries.remove(path) {
345 freed += entry.original_tokens;
346 self.file_refs.remove(path);
347 }
348 }
349 }
350
351 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
353 self.entries.iter().collect()
354 }
355
356 pub fn get_stats(&self) -> &CacheStats {
358 &self.stats
359 }
360
361 pub fn file_ref_map(&self) -> &HashMap<String, String> {
363 &self.file_refs
364 }
365
366 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
368 self.shared_blocks = blocks;
369 }
370
371 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
373 &self.shared_blocks
374 }
375
376 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
378 if self.shared_blocks.is_empty() {
379 return None;
380 }
381 let refs: Vec<&SharedBlock> = self
382 .shared_blocks
383 .iter()
384 .filter(|b| b.canonical_path != path && content.contains(&b.content))
385 .collect();
386 if refs.is_empty() {
387 return None;
388 }
389 let mut result = content.to_string();
390 for block in refs {
391 result = result.replacen(
392 &block.content,
393 &format!(
394 "[= {}:{}-{}]",
395 block.canonical_ref, block.start_line, block.end_line
396 ),
397 1,
398 );
399 }
400 Some(result)
401 }
402
403 pub fn invalidate(&mut self, path: &str) -> bool {
405 self.entries.remove(&normalize_key(path)).is_some()
406 }
407
408 pub fn clear(&mut self) -> usize {
410 let count = self.entries.len();
411 self.entries.clear();
412 self.file_refs.clear();
413 self.shared_blocks.clear();
414 self.next_ref = 1;
415 self.stats = CacheStats {
416 total_reads: 0,
417 cache_hits: 0,
418 total_original_tokens: 0,
419 total_sent_tokens: 0,
420 files_tracked: 0,
421 };
422 count
423 }
424}
425
426pub fn file_mtime(path: &str) -> Option<SystemTime> {
427 std::fs::metadata(path).and_then(|m| m.modified()).ok()
428}
429
430pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
431 let current = file_mtime(path);
432 match (cached_mtime, current) {
433 (_, None) => false,
434 (None, Some(_)) => true,
435 (Some(cached), Some(current)) => current > cached,
436 }
437}
438
439fn compute_md5(content: &str) -> String {
440 let mut hasher = Md5::new();
441 hasher.update(content.as_bytes());
442 format!("{:x}", hasher.finalize())
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use std::time::Duration;
449
450 #[test]
451 fn cache_stores_and_retrieves() {
452 let mut cache = SessionCache::new();
453 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
454 assert!(!result.was_hit);
455 assert_eq!(result.line_count, 1);
456 assert!(cache.get("/test/file.rs").is_some());
457 }
458
459 #[test]
460 fn cache_hit_on_same_content() {
461 let mut cache = SessionCache::new();
462 cache.store("/test/file.rs", "content".to_string());
463 let result = cache.store("/test/file.rs", "content".to_string());
464 assert!(result.was_hit, "same content should be a cache hit");
465 }
466
467 #[test]
468 fn cache_miss_on_changed_content() {
469 let mut cache = SessionCache::new();
470 cache.store("/test/file.rs", "old content".to_string());
471 let result = cache.store("/test/file.rs", "new content".to_string());
472 assert!(!result.was_hit, "changed content should not be a cache hit");
473 }
474
475 #[test]
476 fn file_refs_are_sequential() {
477 let mut cache = SessionCache::new();
478 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
479 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
480 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
482
483 #[test]
484 fn cache_clear_resets_everything() {
485 let mut cache = SessionCache::new();
486 cache.store("/a.rs", "a".to_string());
487 cache.store("/b.rs", "b".to_string());
488 let count = cache.clear();
489 assert_eq!(count, 2);
490 assert!(cache.get("/a.rs").is_none());
491 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
493
494 #[test]
495 fn cache_invalidate_removes_entry() {
496 let mut cache = SessionCache::new();
497 cache.store("/test.rs", "test".to_string());
498 assert!(cache.invalidate("/test.rs"));
499 assert!(!cache.invalidate("/nonexistent.rs"));
500 }
501
502 #[test]
503 fn cache_stats_track_correctly() {
504 let mut cache = SessionCache::new();
505 cache.store("/a.rs", "hello".to_string());
506 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
508 assert_eq!(stats.total_reads, 2);
509 assert_eq!(stats.cache_hits, 1);
510 assert!(stats.hit_rate() > 0.0);
511 }
512
513 #[test]
514 fn md5_is_deterministic() {
515 let h1 = compute_md5("test content");
516 let h2 = compute_md5("test content");
517 assert_eq!(h1, h2);
518 assert_ne!(h1, compute_md5("different"));
519 }
520
521 #[test]
522 fn rrf_eviction_prefers_recent() {
523 let base = Instant::now();
524 std::thread::sleep(std::time::Duration::from_millis(5));
525 let now = Instant::now();
526 let key_a = "a.rs".to_string();
527 let key_b = "b.rs".to_string();
528 let recent = CacheEntry {
529 content: "a".to_string(),
530 hash: "h1".to_string(),
531 line_count: 1,
532 original_tokens: 10,
533 read_count: 1,
534 path: "/a.rs".to_string(),
535 last_access: now,
536 stored_mtime: None,
537 };
538 let old = CacheEntry {
539 content: "b".to_string(),
540 hash: "h2".to_string(),
541 line_count: 1,
542 original_tokens: 10,
543 read_count: 1,
544 path: "/b.rs".to_string(),
545 last_access: base,
546 stored_mtime: None,
547 };
548 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
549 let scores = eviction_scores_rrf(&entries, now);
550 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
551 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
552 assert!(
553 score_a > score_b,
554 "recently accessed entries should score higher via RRF"
555 );
556 }
557
558 #[test]
559 fn rrf_eviction_prefers_frequent() {
560 let now = Instant::now();
561 let key_a = "a.rs".to_string();
562 let key_b = "b.rs".to_string();
563 let frequent = CacheEntry {
564 content: "a".to_string(),
565 hash: "h1".to_string(),
566 line_count: 1,
567 original_tokens: 10,
568 read_count: 20,
569 path: "/a.rs".to_string(),
570 last_access: now,
571 stored_mtime: None,
572 };
573 let rare = CacheEntry {
574 content: "b".to_string(),
575 hash: "h2".to_string(),
576 line_count: 1,
577 original_tokens: 10,
578 read_count: 1,
579 path: "/b.rs".to_string(),
580 last_access: now,
581 stored_mtime: None,
582 };
583 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
584 let scores = eviction_scores_rrf(&entries, now);
585 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
586 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
587 assert!(
588 score_a > score_b,
589 "frequently accessed entries should score higher via RRF"
590 );
591 }
592
593 #[test]
594 fn evict_if_needed_removes_lowest_score() {
595 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
596 let mut cache = SessionCache::new();
597 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
599 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
603 assert!(
608 cache.total_cached_tokens() <= 60,
609 "eviction should have kicked in"
610 );
611 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
612 }
613
614 #[test]
615 fn stale_detection_flags_newer_file() {
616 let dir = tempfile::tempdir().unwrap();
617 let path = dir.path().join("stale.txt");
618 let p = path.to_string_lossy().to_string();
619
620 std::fs::write(&path, "one").unwrap();
621 let mut cache = SessionCache::new();
622 cache.store(&p, "one".to_string());
623
624 let entry = cache.get(&p).unwrap();
625 assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
626
627 std::thread::sleep(Duration::from_secs(1));
629 std::fs::write(&path, "two").unwrap();
630
631 let entry = cache.get(&p).unwrap();
632 assert!(is_cache_entry_stale(&p, entry.stored_mtime));
633 }
634}