Skip to main content

lean_ctx/core/
cache.rs

1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8    crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12    std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13        .ok()
14        .and_then(|v| v.parse().ok())
15        .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19pub struct CacheEntry {
20    pub content: String,
21    pub hash: String,
22    pub line_count: usize,
23    pub original_tokens: usize,
24    pub read_count: u32,
25    pub path: String,
26    pub last_access: Instant,
27}
28
29#[derive(Debug, Clone)]
30pub struct StoreResult {
31    pub line_count: usize,
32    pub original_tokens: usize,
33    pub read_count: u32,
34    pub was_hit: bool,
35}
36
37impl CacheEntry {
38    pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
39        let elapsed = now
40            .checked_duration_since(self.last_access)
41            .unwrap_or_default()
42            .as_secs_f64();
43        let recency = 1.0 / (1.0 + elapsed.sqrt());
44        let frequency = (self.read_count as f64 + 1.0).ln();
45        let size_value = (self.original_tokens as f64 + 1.0).ln();
46        recency * 0.4 + frequency * 0.3 + size_value * 0.3
47    }
48}
49
50const RRF_K: f64 = 60.0;
51
52/// Compute Reciprocal Rank Fusion eviction scores for a batch of cache entries.
53/// Each signal (recency, frequency, size) produces an independent ranking.
54/// The final score is the sum of `1/(k + rank)` across all signals.
55/// Higher score = more valuable = keep longer.
56pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
57    if entries.is_empty() {
58        return Vec::new();
59    }
60
61    let n = entries.len();
62
63    let mut recency_order: Vec<usize> = (0..n).collect();
64    recency_order.sort_by(|&a, &b| {
65        let elapsed_a = now
66            .checked_duration_since(entries[a].1.last_access)
67            .unwrap_or_default()
68            .as_secs_f64();
69        let elapsed_b = now
70            .checked_duration_since(entries[b].1.last_access)
71            .unwrap_or_default()
72            .as_secs_f64();
73        elapsed_a
74            .partial_cmp(&elapsed_b)
75            .unwrap_or(std::cmp::Ordering::Equal)
76    });
77
78    let mut frequency_order: Vec<usize> = (0..n).collect();
79    frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
80
81    let mut size_order: Vec<usize> = (0..n).collect();
82    size_order.sort_by(|&a, &b| {
83        entries[b]
84            .1
85            .original_tokens
86            .cmp(&entries[a].1.original_tokens)
87    });
88
89    let mut recency_ranks = vec![0usize; n];
90    let mut frequency_ranks = vec![0usize; n];
91    let mut size_ranks = vec![0usize; n];
92
93    for (rank, &idx) in recency_order.iter().enumerate() {
94        recency_ranks[idx] = rank;
95    }
96    for (rank, &idx) in frequency_order.iter().enumerate() {
97        frequency_ranks[idx] = rank;
98    }
99    for (rank, &idx) in size_order.iter().enumerate() {
100        size_ranks[idx] = rank;
101    }
102
103    entries
104        .iter()
105        .enumerate()
106        .map(|(i, (path, _))| {
107            let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
108                + 1.0 / (RRF_K + frequency_ranks[i] as f64)
109                + 1.0 / (RRF_K + size_ranks[i] as f64);
110            ((*path).clone(), score)
111        })
112        .collect()
113}
114
115#[derive(Debug)]
116pub struct CacheStats {
117    pub total_reads: u64,
118    pub cache_hits: u64,
119    pub total_original_tokens: u64,
120    pub total_sent_tokens: u64,
121    pub files_tracked: usize,
122}
123
124impl CacheStats {
125    pub fn hit_rate(&self) -> f64 {
126        if self.total_reads == 0 {
127            return 0.0;
128        }
129        (self.cache_hits as f64 / self.total_reads as f64) * 100.0
130    }
131
132    pub fn tokens_saved(&self) -> u64 {
133        self.total_original_tokens
134            .saturating_sub(self.total_sent_tokens)
135    }
136
137    pub fn savings_percent(&self) -> f64 {
138        if self.total_original_tokens == 0 {
139            return 0.0;
140        }
141        (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
142    }
143}
144
145/// A block shared across multiple files, identified by its canonical source.
146#[derive(Clone, Debug)]
147pub struct SharedBlock {
148    pub canonical_path: String,
149    pub canonical_ref: String,
150    pub start_line: usize,
151    pub end_line: usize,
152    pub content: String,
153}
154
155pub struct SessionCache {
156    entries: HashMap<String, CacheEntry>,
157    file_refs: HashMap<String, String>,
158    next_ref: usize,
159    stats: CacheStats,
160    shared_blocks: Vec<SharedBlock>,
161}
162
163impl Default for SessionCache {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169impl SessionCache {
170    pub fn new() -> Self {
171        Self {
172            entries: HashMap::new(),
173            file_refs: HashMap::new(),
174            next_ref: 1,
175            shared_blocks: Vec::new(),
176            stats: CacheStats {
177                total_reads: 0,
178                cache_hits: 0,
179                total_original_tokens: 0,
180                total_sent_tokens: 0,
181                files_tracked: 0,
182            },
183        }
184    }
185
186    pub fn get_file_ref(&mut self, path: &str) -> String {
187        let key = normalize_key(path);
188        if let Some(r) = self.file_refs.get(&key) {
189            return r.clone();
190        }
191        let r = format!("F{}", self.next_ref);
192        self.next_ref += 1;
193        self.file_refs.insert(key, r.clone());
194        r
195    }
196
197    pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
198        self.file_refs.get(&normalize_key(path)).cloned()
199    }
200
201    pub fn get(&self, path: &str) -> Option<&CacheEntry> {
202        self.entries.get(&normalize_key(path))
203    }
204
205    pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
206        let key = normalize_key(path);
207        let ref_label = self
208            .file_refs
209            .get(&key)
210            .cloned()
211            .unwrap_or_else(|| "F?".to_string());
212        if let Some(entry) = self.entries.get_mut(&key) {
213            entry.read_count += 1;
214            entry.last_access = Instant::now();
215            self.stats.total_reads += 1;
216            self.stats.cache_hits += 1;
217            self.stats.total_original_tokens += entry.original_tokens as u64;
218            let hit_msg = format!(
219                "{ref_label} cached {}t {}L",
220                entry.read_count, entry.line_count
221            );
222            self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
223            crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
224            Some(entry)
225        } else {
226            None
227        }
228    }
229
230    pub fn store(&mut self, path: &str, content: String) -> StoreResult {
231        let key = normalize_key(path);
232        let hash = compute_md5(&content);
233        let line_count = content.lines().count();
234        let original_tokens = count_tokens(&content);
235        let now = Instant::now();
236
237        self.stats.total_reads += 1;
238        self.stats.total_original_tokens += original_tokens as u64;
239
240        if let Some(existing) = self.entries.get_mut(&key) {
241            existing.last_access = now;
242            if existing.hash == hash {
243                existing.read_count += 1;
244                self.stats.cache_hits += 1;
245                let hit_msg = format!(
246                    "{} cached {}t {}L",
247                    self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
248                    existing.read_count,
249                    existing.line_count,
250                );
251                self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
252                return StoreResult {
253                    line_count: existing.line_count,
254                    original_tokens: existing.original_tokens,
255                    read_count: existing.read_count,
256                    was_hit: true,
257                };
258            }
259            existing.content = content;
260            existing.hash = hash;
261            existing.line_count = line_count;
262            existing.original_tokens = original_tokens;
263            existing.read_count += 1;
264            self.stats.total_sent_tokens += original_tokens as u64;
265            return StoreResult {
266                line_count,
267                original_tokens,
268                read_count: existing.read_count,
269                was_hit: false,
270            };
271        }
272
273        self.evict_if_needed(original_tokens);
274        self.get_file_ref(&key);
275
276        let entry = CacheEntry {
277            content,
278            hash,
279            line_count,
280            original_tokens,
281            read_count: 1,
282            path: key.clone(),
283            last_access: now,
284        };
285
286        self.entries.insert(key, entry);
287        self.stats.files_tracked += 1;
288        self.stats.total_sent_tokens += original_tokens as u64;
289        StoreResult {
290            line_count,
291            original_tokens,
292            read_count: 1,
293            was_hit: false,
294        }
295    }
296
297    pub fn total_cached_tokens(&self) -> usize {
298        self.entries.values().map(|e| e.original_tokens).sum()
299    }
300
301    /// Evict lowest-scoring entries (by RRF) until cache fits within token budget.
302    pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
303        let max_tokens = max_cache_tokens();
304        let current = self.total_cached_tokens();
305        if current + incoming_tokens <= max_tokens {
306            return;
307        }
308
309        let now = Instant::now();
310        let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
311        let mut scored = eviction_scores_rrf(&all_entries, now);
312        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
313
314        let mut freed = 0usize;
315        let target = (current + incoming_tokens).saturating_sub(max_tokens);
316        for (path, _score) in &scored {
317            if freed >= target {
318                break;
319            }
320            if let Some(entry) = self.entries.remove(path) {
321                freed += entry.original_tokens;
322                self.file_refs.remove(path);
323            }
324        }
325    }
326
327    pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
328        self.entries.iter().collect()
329    }
330
331    pub fn get_stats(&self) -> &CacheStats {
332        &self.stats
333    }
334
335    pub fn file_ref_map(&self) -> &HashMap<String, String> {
336        &self.file_refs
337    }
338
339    pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
340        self.shared_blocks = blocks;
341    }
342
343    pub fn get_shared_blocks(&self) -> &[SharedBlock] {
344        &self.shared_blocks
345    }
346
347    /// Replace shared blocks in content with cross-file references.
348    pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
349        if self.shared_blocks.is_empty() {
350            return None;
351        }
352        let refs: Vec<&SharedBlock> = self
353            .shared_blocks
354            .iter()
355            .filter(|b| b.canonical_path != path && content.contains(&b.content))
356            .collect();
357        if refs.is_empty() {
358            return None;
359        }
360        let mut result = content.to_string();
361        for block in refs {
362            result = result.replacen(
363                &block.content,
364                &format!(
365                    "[= {}:{}-{}]",
366                    block.canonical_ref, block.start_line, block.end_line
367                ),
368                1,
369            );
370        }
371        Some(result)
372    }
373
374    pub fn invalidate(&mut self, path: &str) -> bool {
375        self.entries.remove(&normalize_key(path)).is_some()
376    }
377
378    pub fn clear(&mut self) -> usize {
379        let count = self.entries.len();
380        self.entries.clear();
381        self.file_refs.clear();
382        self.shared_blocks.clear();
383        self.next_ref = 1;
384        self.stats = CacheStats {
385            total_reads: 0,
386            cache_hits: 0,
387            total_original_tokens: 0,
388            total_sent_tokens: 0,
389            files_tracked: 0,
390        };
391        count
392    }
393}
394
395fn compute_md5(content: &str) -> String {
396    let mut hasher = Md5::new();
397    hasher.update(content.as_bytes());
398    format!("{:x}", hasher.finalize())
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn cache_stores_and_retrieves() {
407        let mut cache = SessionCache::new();
408        let result = cache.store("/test/file.rs", "fn main() {}".to_string());
409        assert!(!result.was_hit);
410        assert_eq!(result.line_count, 1);
411        assert!(cache.get("/test/file.rs").is_some());
412    }
413
414    #[test]
415    fn cache_hit_on_same_content() {
416        let mut cache = SessionCache::new();
417        cache.store("/test/file.rs", "content".to_string());
418        let result = cache.store("/test/file.rs", "content".to_string());
419        assert!(result.was_hit, "same content should be a cache hit");
420    }
421
422    #[test]
423    fn cache_miss_on_changed_content() {
424        let mut cache = SessionCache::new();
425        cache.store("/test/file.rs", "old content".to_string());
426        let result = cache.store("/test/file.rs", "new content".to_string());
427        assert!(!result.was_hit, "changed content should not be a cache hit");
428    }
429
430    #[test]
431    fn file_refs_are_sequential() {
432        let mut cache = SessionCache::new();
433        assert_eq!(cache.get_file_ref("/a.rs"), "F1");
434        assert_eq!(cache.get_file_ref("/b.rs"), "F2");
435        assert_eq!(cache.get_file_ref("/a.rs"), "F1"); // stable
436    }
437
438    #[test]
439    fn cache_clear_resets_everything() {
440        let mut cache = SessionCache::new();
441        cache.store("/a.rs", "a".to_string());
442        cache.store("/b.rs", "b".to_string());
443        let count = cache.clear();
444        assert_eq!(count, 2);
445        assert!(cache.get("/a.rs").is_none());
446        assert_eq!(cache.get_file_ref("/c.rs"), "F1"); // refs reset
447    }
448
449    #[test]
450    fn cache_invalidate_removes_entry() {
451        let mut cache = SessionCache::new();
452        cache.store("/test.rs", "test".to_string());
453        assert!(cache.invalidate("/test.rs"));
454        assert!(!cache.invalidate("/nonexistent.rs"));
455    }
456
457    #[test]
458    fn cache_stats_track_correctly() {
459        let mut cache = SessionCache::new();
460        cache.store("/a.rs", "hello".to_string());
461        cache.store("/a.rs", "hello".to_string()); // hit
462        let stats = cache.get_stats();
463        assert_eq!(stats.total_reads, 2);
464        assert_eq!(stats.cache_hits, 1);
465        assert!(stats.hit_rate() > 0.0);
466    }
467
468    #[test]
469    fn md5_is_deterministic() {
470        let h1 = compute_md5("test content");
471        let h2 = compute_md5("test content");
472        assert_eq!(h1, h2);
473        assert_ne!(h1, compute_md5("different"));
474    }
475
476    #[test]
477    fn rrf_eviction_prefers_recent() {
478        let base = Instant::now();
479        std::thread::sleep(std::time::Duration::from_millis(5));
480        let now = Instant::now();
481        let key_a = "a.rs".to_string();
482        let key_b = "b.rs".to_string();
483        let recent = CacheEntry {
484            content: "a".to_string(),
485            hash: "h1".to_string(),
486            line_count: 1,
487            original_tokens: 10,
488            read_count: 1,
489            path: "/a.rs".to_string(),
490            last_access: now,
491        };
492        let old = CacheEntry {
493            content: "b".to_string(),
494            hash: "h2".to_string(),
495            line_count: 1,
496            original_tokens: 10,
497            read_count: 1,
498            path: "/b.rs".to_string(),
499            last_access: base,
500        };
501        let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
502        let scores = eviction_scores_rrf(&entries, now);
503        let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
504        let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
505        assert!(
506            score_a > score_b,
507            "recently accessed entries should score higher via RRF"
508        );
509    }
510
511    #[test]
512    fn rrf_eviction_prefers_frequent() {
513        let now = Instant::now();
514        let key_a = "a.rs".to_string();
515        let key_b = "b.rs".to_string();
516        let frequent = CacheEntry {
517            content: "a".to_string(),
518            hash: "h1".to_string(),
519            line_count: 1,
520            original_tokens: 10,
521            read_count: 20,
522            path: "/a.rs".to_string(),
523            last_access: now,
524        };
525        let rare = CacheEntry {
526            content: "b".to_string(),
527            hash: "h2".to_string(),
528            line_count: 1,
529            original_tokens: 10,
530            read_count: 1,
531            path: "/b.rs".to_string(),
532            last_access: now,
533        };
534        let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
535        let scores = eviction_scores_rrf(&entries, now);
536        let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
537        let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
538        assert!(
539            score_a > score_b,
540            "frequently accessed entries should score higher via RRF"
541        );
542    }
543
544    #[test]
545    fn evict_if_needed_removes_lowest_score() {
546        std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
547        let mut cache = SessionCache::new();
548        let big_content = "a]".repeat(30); // ~30 tokens
549        cache.store("/old.rs", big_content);
550        // /old.rs now in cache with ~30 tokens
551
552        let new_content = "b ".repeat(30); // ~30 tokens incoming
553        cache.store("/new.rs", new_content);
554        // should have evicted /old.rs to make room
555        // (total would be ~60 which exceeds 50)
556
557        // At least one should remain, total should be <= 50
558        assert!(
559            cache.total_cached_tokens() <= 60,
560            "eviction should have kicked in"
561        );
562        std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
563    }
564}