1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19pub struct CacheEntry {
20 pub content: String,
21 pub hash: String,
22 pub line_count: usize,
23 pub original_tokens: usize,
24 pub read_count: u32,
25 pub path: String,
26 pub last_access: Instant,
27}
28
29#[derive(Debug, Clone)]
30pub struct StoreResult {
31 pub line_count: usize,
32 pub original_tokens: usize,
33 pub read_count: u32,
34 pub was_hit: bool,
35}
36
37impl CacheEntry {
38 pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
39 let elapsed = now
40 .checked_duration_since(self.last_access)
41 .unwrap_or_default()
42 .as_secs_f64();
43 let recency = 1.0 / (1.0 + elapsed.sqrt());
44 let frequency = (self.read_count as f64 + 1.0).ln();
45 let size_value = (self.original_tokens as f64 + 1.0).ln();
46 recency * 0.4 + frequency * 0.3 + size_value * 0.3
47 }
48}
49
50const RRF_K: f64 = 60.0;
51
52pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
57 if entries.is_empty() {
58 return Vec::new();
59 }
60
61 let n = entries.len();
62
63 let mut recency_order: Vec<usize> = (0..n).collect();
64 recency_order.sort_by(|&a, &b| {
65 let elapsed_a = now
66 .checked_duration_since(entries[a].1.last_access)
67 .unwrap_or_default()
68 .as_secs_f64();
69 let elapsed_b = now
70 .checked_duration_since(entries[b].1.last_access)
71 .unwrap_or_default()
72 .as_secs_f64();
73 elapsed_a
74 .partial_cmp(&elapsed_b)
75 .unwrap_or(std::cmp::Ordering::Equal)
76 });
77
78 let mut frequency_order: Vec<usize> = (0..n).collect();
79 frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
80
81 let mut size_order: Vec<usize> = (0..n).collect();
82 size_order.sort_by(|&a, &b| {
83 entries[b]
84 .1
85 .original_tokens
86 .cmp(&entries[a].1.original_tokens)
87 });
88
89 let mut recency_ranks = vec![0usize; n];
90 let mut frequency_ranks = vec![0usize; n];
91 let mut size_ranks = vec![0usize; n];
92
93 for (rank, &idx) in recency_order.iter().enumerate() {
94 recency_ranks[idx] = rank;
95 }
96 for (rank, &idx) in frequency_order.iter().enumerate() {
97 frequency_ranks[idx] = rank;
98 }
99 for (rank, &idx) in size_order.iter().enumerate() {
100 size_ranks[idx] = rank;
101 }
102
103 entries
104 .iter()
105 .enumerate()
106 .map(|(i, (path, _))| {
107 let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
108 + 1.0 / (RRF_K + frequency_ranks[i] as f64)
109 + 1.0 / (RRF_K + size_ranks[i] as f64);
110 ((*path).clone(), score)
111 })
112 .collect()
113}
114
115#[derive(Debug)]
116pub struct CacheStats {
117 pub total_reads: u64,
118 pub cache_hits: u64,
119 pub total_original_tokens: u64,
120 pub total_sent_tokens: u64,
121 pub files_tracked: usize,
122}
123
124impl CacheStats {
125 pub fn hit_rate(&self) -> f64 {
126 if self.total_reads == 0 {
127 return 0.0;
128 }
129 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
130 }
131
132 pub fn tokens_saved(&self) -> u64 {
133 self.total_original_tokens
134 .saturating_sub(self.total_sent_tokens)
135 }
136
137 pub fn savings_percent(&self) -> f64 {
138 if self.total_original_tokens == 0 {
139 return 0.0;
140 }
141 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
142 }
143}
144
145#[derive(Clone, Debug)]
147pub struct SharedBlock {
148 pub canonical_path: String,
149 pub canonical_ref: String,
150 pub start_line: usize,
151 pub end_line: usize,
152 pub content: String,
153}
154
155pub struct SessionCache {
156 entries: HashMap<String, CacheEntry>,
157 file_refs: HashMap<String, String>,
158 next_ref: usize,
159 stats: CacheStats,
160 shared_blocks: Vec<SharedBlock>,
161}
162
163impl Default for SessionCache {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169impl SessionCache {
170 pub fn new() -> Self {
171 Self {
172 entries: HashMap::new(),
173 file_refs: HashMap::new(),
174 next_ref: 1,
175 shared_blocks: Vec::new(),
176 stats: CacheStats {
177 total_reads: 0,
178 cache_hits: 0,
179 total_original_tokens: 0,
180 total_sent_tokens: 0,
181 files_tracked: 0,
182 },
183 }
184 }
185
186 pub fn get_file_ref(&mut self, path: &str) -> String {
187 let key = normalize_key(path);
188 if let Some(r) = self.file_refs.get(&key) {
189 return r.clone();
190 }
191 let r = format!("F{}", self.next_ref);
192 self.next_ref += 1;
193 self.file_refs.insert(key, r.clone());
194 r
195 }
196
197 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
198 self.file_refs.get(&normalize_key(path)).cloned()
199 }
200
201 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
202 self.entries.get(&normalize_key(path))
203 }
204
205 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
206 let key = normalize_key(path);
207 let ref_label = self
208 .file_refs
209 .get(&key)
210 .cloned()
211 .unwrap_or_else(|| "F?".to_string());
212 if let Some(entry) = self.entries.get_mut(&key) {
213 entry.read_count += 1;
214 entry.last_access = Instant::now();
215 self.stats.total_reads += 1;
216 self.stats.cache_hits += 1;
217 self.stats.total_original_tokens += entry.original_tokens as u64;
218 let hit_msg = format!(
219 "{ref_label} cached {}t {}L",
220 entry.read_count, entry.line_count
221 );
222 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
223 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
224 Some(entry)
225 } else {
226 None
227 }
228 }
229
230 pub fn store(&mut self, path: &str, content: String) -> StoreResult {
231 let key = normalize_key(path);
232 let hash = compute_md5(&content);
233 let line_count = content.lines().count();
234 let original_tokens = count_tokens(&content);
235 let now = Instant::now();
236
237 self.stats.total_reads += 1;
238 self.stats.total_original_tokens += original_tokens as u64;
239
240 if let Some(existing) = self.entries.get_mut(&key) {
241 existing.last_access = now;
242 if existing.hash == hash {
243 existing.read_count += 1;
244 self.stats.cache_hits += 1;
245 let hit_msg = format!(
246 "{} cached {}t {}L",
247 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
248 existing.read_count,
249 existing.line_count,
250 );
251 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
252 return StoreResult {
253 line_count: existing.line_count,
254 original_tokens: existing.original_tokens,
255 read_count: existing.read_count,
256 was_hit: true,
257 };
258 }
259 existing.content = content;
260 existing.hash = hash;
261 existing.line_count = line_count;
262 existing.original_tokens = original_tokens;
263 existing.read_count += 1;
264 self.stats.total_sent_tokens += original_tokens as u64;
265 return StoreResult {
266 line_count,
267 original_tokens,
268 read_count: existing.read_count,
269 was_hit: false,
270 };
271 }
272
273 self.evict_if_needed(original_tokens);
274 self.get_file_ref(&key);
275
276 let entry = CacheEntry {
277 content,
278 hash,
279 line_count,
280 original_tokens,
281 read_count: 1,
282 path: key.clone(),
283 last_access: now,
284 };
285
286 self.entries.insert(key, entry);
287 self.stats.files_tracked += 1;
288 self.stats.total_sent_tokens += original_tokens as u64;
289 StoreResult {
290 line_count,
291 original_tokens,
292 read_count: 1,
293 was_hit: false,
294 }
295 }
296
297 pub fn total_cached_tokens(&self) -> usize {
298 self.entries.values().map(|e| e.original_tokens).sum()
299 }
300
301 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
303 let max_tokens = max_cache_tokens();
304 let current = self.total_cached_tokens();
305 if current + incoming_tokens <= max_tokens {
306 return;
307 }
308
309 let now = Instant::now();
310 let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
311 let mut scored = eviction_scores_rrf(&all_entries, now);
312 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
313
314 let mut freed = 0usize;
315 let target = (current + incoming_tokens).saturating_sub(max_tokens);
316 for (path, _score) in &scored {
317 if freed >= target {
318 break;
319 }
320 if let Some(entry) = self.entries.remove(path) {
321 freed += entry.original_tokens;
322 self.file_refs.remove(path);
323 }
324 }
325 }
326
327 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
328 self.entries.iter().collect()
329 }
330
331 pub fn get_stats(&self) -> &CacheStats {
332 &self.stats
333 }
334
335 pub fn file_ref_map(&self) -> &HashMap<String, String> {
336 &self.file_refs
337 }
338
339 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
340 self.shared_blocks = blocks;
341 }
342
343 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
344 &self.shared_blocks
345 }
346
347 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
349 if self.shared_blocks.is_empty() {
350 return None;
351 }
352 let refs: Vec<&SharedBlock> = self
353 .shared_blocks
354 .iter()
355 .filter(|b| b.canonical_path != path && content.contains(&b.content))
356 .collect();
357 if refs.is_empty() {
358 return None;
359 }
360 let mut result = content.to_string();
361 for block in refs {
362 result = result.replacen(
363 &block.content,
364 &format!(
365 "[= {}:{}-{}]",
366 block.canonical_ref, block.start_line, block.end_line
367 ),
368 1,
369 );
370 }
371 Some(result)
372 }
373
374 pub fn invalidate(&mut self, path: &str) -> bool {
375 self.entries.remove(&normalize_key(path)).is_some()
376 }
377
378 pub fn clear(&mut self) -> usize {
379 let count = self.entries.len();
380 self.entries.clear();
381 self.file_refs.clear();
382 self.shared_blocks.clear();
383 self.next_ref = 1;
384 self.stats = CacheStats {
385 total_reads: 0,
386 cache_hits: 0,
387 total_original_tokens: 0,
388 total_sent_tokens: 0,
389 files_tracked: 0,
390 };
391 count
392 }
393}
394
395fn compute_md5(content: &str) -> String {
396 let mut hasher = Md5::new();
397 hasher.update(content.as_bytes());
398 format!("{:x}", hasher.finalize())
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn cache_stores_and_retrieves() {
407 let mut cache = SessionCache::new();
408 let result = cache.store("/test/file.rs", "fn main() {}".to_string());
409 assert!(!result.was_hit);
410 assert_eq!(result.line_count, 1);
411 assert!(cache.get("/test/file.rs").is_some());
412 }
413
414 #[test]
415 fn cache_hit_on_same_content() {
416 let mut cache = SessionCache::new();
417 cache.store("/test/file.rs", "content".to_string());
418 let result = cache.store("/test/file.rs", "content".to_string());
419 assert!(result.was_hit, "same content should be a cache hit");
420 }
421
422 #[test]
423 fn cache_miss_on_changed_content() {
424 let mut cache = SessionCache::new();
425 cache.store("/test/file.rs", "old content".to_string());
426 let result = cache.store("/test/file.rs", "new content".to_string());
427 assert!(!result.was_hit, "changed content should not be a cache hit");
428 }
429
430 #[test]
431 fn file_refs_are_sequential() {
432 let mut cache = SessionCache::new();
433 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
434 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
435 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
437
438 #[test]
439 fn cache_clear_resets_everything() {
440 let mut cache = SessionCache::new();
441 cache.store("/a.rs", "a".to_string());
442 cache.store("/b.rs", "b".to_string());
443 let count = cache.clear();
444 assert_eq!(count, 2);
445 assert!(cache.get("/a.rs").is_none());
446 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
448
449 #[test]
450 fn cache_invalidate_removes_entry() {
451 let mut cache = SessionCache::new();
452 cache.store("/test.rs", "test".to_string());
453 assert!(cache.invalidate("/test.rs"));
454 assert!(!cache.invalidate("/nonexistent.rs"));
455 }
456
457 #[test]
458 fn cache_stats_track_correctly() {
459 let mut cache = SessionCache::new();
460 cache.store("/a.rs", "hello".to_string());
461 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
463 assert_eq!(stats.total_reads, 2);
464 assert_eq!(stats.cache_hits, 1);
465 assert!(stats.hit_rate() > 0.0);
466 }
467
468 #[test]
469 fn md5_is_deterministic() {
470 let h1 = compute_md5("test content");
471 let h2 = compute_md5("test content");
472 assert_eq!(h1, h2);
473 assert_ne!(h1, compute_md5("different"));
474 }
475
476 #[test]
477 fn rrf_eviction_prefers_recent() {
478 let base = Instant::now();
479 std::thread::sleep(std::time::Duration::from_millis(5));
480 let now = Instant::now();
481 let key_a = "a.rs".to_string();
482 let key_b = "b.rs".to_string();
483 let recent = CacheEntry {
484 content: "a".to_string(),
485 hash: "h1".to_string(),
486 line_count: 1,
487 original_tokens: 10,
488 read_count: 1,
489 path: "/a.rs".to_string(),
490 last_access: now,
491 };
492 let old = CacheEntry {
493 content: "b".to_string(),
494 hash: "h2".to_string(),
495 line_count: 1,
496 original_tokens: 10,
497 read_count: 1,
498 path: "/b.rs".to_string(),
499 last_access: base,
500 };
501 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
502 let scores = eviction_scores_rrf(&entries, now);
503 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
504 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
505 assert!(
506 score_a > score_b,
507 "recently accessed entries should score higher via RRF"
508 );
509 }
510
511 #[test]
512 fn rrf_eviction_prefers_frequent() {
513 let now = Instant::now();
514 let key_a = "a.rs".to_string();
515 let key_b = "b.rs".to_string();
516 let frequent = CacheEntry {
517 content: "a".to_string(),
518 hash: "h1".to_string(),
519 line_count: 1,
520 original_tokens: 10,
521 read_count: 20,
522 path: "/a.rs".to_string(),
523 last_access: now,
524 };
525 let rare = CacheEntry {
526 content: "b".to_string(),
527 hash: "h2".to_string(),
528 line_count: 1,
529 original_tokens: 10,
530 read_count: 1,
531 path: "/b.rs".to_string(),
532 last_access: now,
533 };
534 let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
535 let scores = eviction_scores_rrf(&entries, now);
536 let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
537 let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
538 assert!(
539 score_a > score_b,
540 "frequently accessed entries should score higher via RRF"
541 );
542 }
543
544 #[test]
545 fn evict_if_needed_removes_lowest_score() {
546 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
547 let mut cache = SessionCache::new();
548 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
550 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
554 assert!(
559 cache.total_cached_tokens() <= 60,
560 "eviction should have kicked in"
561 );
562 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
563 }
564}