Skip to main content

lean_ctx/core/
cache.rs

1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::{Instant, SystemTime};
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8    crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12    std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13        .ok()
14        .and_then(|v| v.parse().ok())
15        .unwrap_or(500_000)
16}
17
18/// A cached file read: content, hash, token count, and access metadata.
19#[derive(Clone, Debug)]
20pub struct CacheEntry {
21    pub content: String,
22    pub hash: String,
23    pub line_count: usize,
24    pub original_tokens: usize,
25    pub read_count: u32,
26    pub path: String,
27    pub last_access: Instant,
28    pub stored_mtime: Option<SystemTime>,
29}
30
31/// Result of a cache store operation, indicating whether it was a hit or new entry.
32#[derive(Debug, Clone)]
33pub struct StoreResult {
34    pub line_count: usize,
35    pub original_tokens: usize,
36    pub read_count: u32,
37    pub was_hit: bool,
38}
39
40impl CacheEntry {
41    /// Computes a legacy eviction score blending recency, frequency, and size.
42    pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
43        let elapsed = now
44            .checked_duration_since(self.last_access)
45            .unwrap_or_default()
46            .as_secs_f64();
47        let recency = 1.0 / (1.0 + elapsed.sqrt());
48        let frequency = (self.read_count as f64 + 1.0).ln();
49        let size_value = (self.original_tokens as f64 + 1.0).ln();
50        recency * 0.4 + frequency * 0.3 + size_value * 0.3
51    }
52}
53
54const RRF_K: f64 = 60.0;
55
56/// Compute Reciprocal Rank Fusion eviction scores for a batch of cache entries.
57/// Each signal (recency, frequency, size) produces an independent ranking.
58/// The final score is the sum of `1/(k + rank)` across all signals.
59/// Higher score = more valuable = keep longer.
60pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
61    if entries.is_empty() {
62        return Vec::new();
63    }
64
65    let n = entries.len();
66
67    let mut recency_order: Vec<usize> = (0..n).collect();
68    recency_order.sort_by(|&a, &b| {
69        let elapsed_a = now
70            .checked_duration_since(entries[a].1.last_access)
71            .unwrap_or_default()
72            .as_secs_f64();
73        let elapsed_b = now
74            .checked_duration_since(entries[b].1.last_access)
75            .unwrap_or_default()
76            .as_secs_f64();
77        elapsed_a
78            .partial_cmp(&elapsed_b)
79            .unwrap_or(std::cmp::Ordering::Equal)
80    });
81
82    let mut frequency_order: Vec<usize> = (0..n).collect();
83    frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
84
85    let mut size_order: Vec<usize> = (0..n).collect();
86    size_order.sort_by(|&a, &b| {
87        entries[b]
88            .1
89            .original_tokens
90            .cmp(&entries[a].1.original_tokens)
91    });
92
93    let mut recency_ranks = vec![0usize; n];
94    let mut frequency_ranks = vec![0usize; n];
95    let mut size_ranks = vec![0usize; n];
96
97    for (rank, &idx) in recency_order.iter().enumerate() {
98        recency_ranks[idx] = rank;
99    }
100    for (rank, &idx) in frequency_order.iter().enumerate() {
101        frequency_ranks[idx] = rank;
102    }
103    for (rank, &idx) in size_order.iter().enumerate() {
104        size_ranks[idx] = rank;
105    }
106
107    entries
108        .iter()
109        .enumerate()
110        .map(|(i, (path, _))| {
111            let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
112                + 1.0 / (RRF_K + frequency_ranks[i] as f64)
113                + 1.0 / (RRF_K + size_ranks[i] as f64);
114            ((*path).clone(), score)
115        })
116        .collect()
117}
118
119/// Aggregated cache statistics: hits, reads, and token savings.
120#[derive(Debug)]
121pub struct CacheStats {
122    pub total_reads: u64,
123    pub cache_hits: u64,
124    pub total_original_tokens: u64,
125    pub total_sent_tokens: u64,
126    pub files_tracked: usize,
127}
128
129impl CacheStats {
130    /// Returns the cache hit rate as a percentage (0–100).
131    pub fn hit_rate(&self) -> f64 {
132        if self.total_reads == 0 {
133            return 0.0;
134        }
135        (self.cache_hits as f64 / self.total_reads as f64) * 100.0
136    }
137
138    /// Returns the total number of tokens saved by cache hits.
139    pub fn tokens_saved(&self) -> u64 {
140        self.total_original_tokens
141            .saturating_sub(self.total_sent_tokens)
142    }
143
144    /// Returns the savings as a percentage of total original tokens.
145    pub fn savings_percent(&self) -> f64 {
146        if self.total_original_tokens == 0 {
147            return 0.0;
148        }
149        (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
150    }
151}
152
153/// A block shared across multiple files, identified by its canonical source.
154#[derive(Clone, Debug)]
155pub struct SharedBlock {
156    pub canonical_path: String,
157    pub canonical_ref: String,
158    pub start_line: usize,
159    pub end_line: usize,
160    pub content: String,
161}
162
163/// In-memory file cache with RRF-based eviction, file references, and cross-file dedup.
164pub struct SessionCache {
165    entries: HashMap<String, CacheEntry>,
166    file_refs: HashMap<String, String>,
167    next_ref: usize,
168    stats: CacheStats,
169    shared_blocks: Vec<SharedBlock>,
170}
171
172impl Default for SessionCache {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178impl SessionCache {
179    /// Creates an empty session cache with default stats.
180    pub fn new() -> Self {
181        Self {
182            entries: HashMap::new(),
183            file_refs: HashMap::new(),
184            next_ref: 1,
185            shared_blocks: Vec::new(),
186            stats: CacheStats {
187                total_reads: 0,
188                cache_hits: 0,
189                total_original_tokens: 0,
190                total_sent_tokens: 0,
191                files_tracked: 0,
192            },
193        }
194    }
195
196    /// Returns or assigns a short file reference label (F1, F2, ...) for the given path.
197    pub fn get_file_ref(&mut self, path: &str) -> String {
198        let key = normalize_key(path);
199        if let Some(r) = self.file_refs.get(&key) {
200            return r.clone();
201        }
202        let r = format!("F{}", self.next_ref);
203        self.next_ref += 1;
204        self.file_refs.insert(key, r.clone());
205        r
206    }
207
208    /// Returns the file reference label for a path without assigning a new one.
209    pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
210        self.file_refs.get(&normalize_key(path)).cloned()
211    }
212
213    /// Looks up a cached entry by file path.
214    pub fn get(&self, path: &str) -> Option<&CacheEntry> {
215        self.entries.get(&normalize_key(path))
216    }
217
218    /// Records a cache hit, updates access stats, and emits a cache-hit event.
219    pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
220        let key = normalize_key(path);
221        let ref_label = self
222            .file_refs
223            .get(&key)
224            .cloned()
225            .unwrap_or_else(|| "F?".to_string());
226        if let Some(entry) = self.entries.get_mut(&key) {
227            entry.read_count += 1;
228            entry.last_access = Instant::now();
229            self.stats.total_reads += 1;
230            self.stats.cache_hits += 1;
231            self.stats.total_original_tokens += entry.original_tokens as u64;
232            let hit_msg = format!(
233                "{ref_label} cached {}t {}L",
234                entry.read_count, entry.line_count
235            );
236            self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
237            crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
238            Some(entry)
239        } else {
240            None
241        }
242    }
243
244    /// Stores file content in the cache; returns a hit if content hash matches.
245    pub fn store(&mut self, path: &str, content: String) -> StoreResult {
246        let key = normalize_key(path);
247        let hash = compute_md5(&content);
248        let line_count = content.lines().count();
249        let original_tokens = count_tokens(&content);
250        let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
251        let now = Instant::now();
252
253        self.stats.total_reads += 1;
254        self.stats.total_original_tokens += original_tokens as u64;
255
256        if let Some(existing) = self.entries.get_mut(&key) {
257            existing.last_access = now;
258            if stored_mtime.is_some() {
259                existing.stored_mtime = stored_mtime;
260            }
261            if existing.hash == hash {
262                existing.read_count += 1;
263                self.stats.cache_hits += 1;
264                let hit_msg = format!(
265                    "{} cached {}t {}L",
266                    self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
267                    existing.read_count,
268                    existing.line_count,
269                );
270                self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
271                return StoreResult {
272                    line_count: existing.line_count,
273                    original_tokens: existing.original_tokens,
274                    read_count: existing.read_count,
275                    was_hit: true,
276                };
277            }
278            existing.content = content;
279            existing.hash = hash;
280            existing.line_count = line_count;
281            existing.original_tokens = original_tokens;
282            existing.read_count += 1;
283            if stored_mtime.is_some() {
284                existing.stored_mtime = stored_mtime;
285            }
286            self.stats.total_sent_tokens += original_tokens as u64;
287            return StoreResult {
288                line_count,
289                original_tokens,
290                read_count: existing.read_count,
291                was_hit: false,
292            };
293        }
294
295        self.evict_if_needed(original_tokens);
296        self.get_file_ref(&key);
297
298        let entry = CacheEntry {
299            content,
300            hash,
301            line_count,
302            original_tokens,
303            read_count: 1,
304            path: key.clone(),
305            last_access: now,
306            stored_mtime,
307        };
308
309        self.entries.insert(key, entry);
310        self.stats.files_tracked += 1;
311        self.stats.total_sent_tokens += original_tokens as u64;
312        StoreResult {
313            line_count,
314            original_tokens,
315            read_count: 1,
316            was_hit: false,
317        }
318    }
319
320    /// Returns the sum of original token counts across all cached entries.
321    pub fn total_cached_tokens(&self) -> usize {
322        self.entries.values().map(|e| e.original_tokens).sum()
323    }
324
325    /// Evict lowest-scoring entries (by RRF) until cache fits within token budget.
326    pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
327        let max_tokens = max_cache_tokens();
328        let current = self.total_cached_tokens();
329        if current + incoming_tokens <= max_tokens {
330            return;
331        }
332
333        let now = Instant::now();
334        let all_entries: Vec<(&String, &CacheEntry)> = self.entries.iter().collect();
335        let mut scored = eviction_scores_rrf(&all_entries, now);
336        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
337
338        let mut freed = 0usize;
339        let target = (current + incoming_tokens).saturating_sub(max_tokens);
340        for (path, _score) in &scored {
341            if freed >= target {
342                break;
343            }
344            if let Some(entry) = self.entries.remove(path) {
345                freed += entry.original_tokens;
346                self.file_refs.remove(path);
347            }
348        }
349    }
350
351    /// Returns all cached entries as (path, entry) pairs.
352    pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
353        self.entries.iter().collect()
354    }
355
356    /// Returns a reference to the aggregated cache statistics.
357    pub fn get_stats(&self) -> &CacheStats {
358        &self.stats
359    }
360
361    /// Returns the path-to-file-ref mapping (e.g. "/src/main.rs" → "F1").
362    pub fn file_ref_map(&self) -> &HashMap<String, String> {
363        &self.file_refs
364    }
365
366    /// Replaces the cross-file shared blocks used for deduplication.
367    pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
368        self.shared_blocks = blocks;
369    }
370
371    /// Returns the current set of cross-file shared blocks.
372    pub fn get_shared_blocks(&self) -> &[SharedBlock] {
373        &self.shared_blocks
374    }
375
376    /// Replace shared blocks in content with cross-file references.
377    pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
378        if self.shared_blocks.is_empty() {
379            return None;
380        }
381        let refs: Vec<&SharedBlock> = self
382            .shared_blocks
383            .iter()
384            .filter(|b| b.canonical_path != path && content.contains(&b.content))
385            .collect();
386        if refs.is_empty() {
387            return None;
388        }
389        let mut result = content.to_string();
390        for block in refs {
391            result = result.replacen(
392                &block.content,
393                &format!(
394                    "[= {}:{}-{}]",
395                    block.canonical_ref, block.start_line, block.end_line
396                ),
397                1,
398            );
399        }
400        Some(result)
401    }
402
403    /// Removes a file from the cache, forcing a fresh read on next access.
404    pub fn invalidate(&mut self, path: &str) -> bool {
405        self.entries.remove(&normalize_key(path)).is_some()
406    }
407
408    /// Clears all cached entries, file refs, and resets stats. Returns the number of entries removed.
409    pub fn clear(&mut self) -> usize {
410        let count = self.entries.len();
411        self.entries.clear();
412        self.file_refs.clear();
413        self.shared_blocks.clear();
414        self.next_ref = 1;
415        self.stats = CacheStats {
416            total_reads: 0,
417            cache_hits: 0,
418            total_original_tokens: 0,
419            total_sent_tokens: 0,
420            files_tracked: 0,
421        };
422        count
423    }
424}
425
426pub fn file_mtime(path: &str) -> Option<SystemTime> {
427    std::fs::metadata(path).and_then(|m| m.modified()).ok()
428}
429
430pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
431    let current = file_mtime(path);
432    match (cached_mtime, current) {
433        (_, None) => false,
434        (None, Some(_)) => true,
435        (Some(cached), Some(current)) => current > cached,
436    }
437}
438
439fn compute_md5(content: &str) -> String {
440    let mut hasher = Md5::new();
441    hasher.update(content.as_bytes());
442    format!("{:x}", hasher.finalize())
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use std::time::Duration;
449
450    #[test]
451    fn cache_stores_and_retrieves() {
452        let mut cache = SessionCache::new();
453        let result = cache.store("/test/file.rs", "fn main() {}".to_string());
454        assert!(!result.was_hit);
455        assert_eq!(result.line_count, 1);
456        assert!(cache.get("/test/file.rs").is_some());
457    }
458
459    #[test]
460    fn cache_hit_on_same_content() {
461        let mut cache = SessionCache::new();
462        cache.store("/test/file.rs", "content".to_string());
463        let result = cache.store("/test/file.rs", "content".to_string());
464        assert!(result.was_hit, "same content should be a cache hit");
465    }
466
467    #[test]
468    fn cache_miss_on_changed_content() {
469        let mut cache = SessionCache::new();
470        cache.store("/test/file.rs", "old content".to_string());
471        let result = cache.store("/test/file.rs", "new content".to_string());
472        assert!(!result.was_hit, "changed content should not be a cache hit");
473    }
474
475    #[test]
476    fn file_refs_are_sequential() {
477        let mut cache = SessionCache::new();
478        assert_eq!(cache.get_file_ref("/a.rs"), "F1");
479        assert_eq!(cache.get_file_ref("/b.rs"), "F2");
480        assert_eq!(cache.get_file_ref("/a.rs"), "F1"); // stable
481    }
482
483    #[test]
484    fn cache_clear_resets_everything() {
485        let mut cache = SessionCache::new();
486        cache.store("/a.rs", "a".to_string());
487        cache.store("/b.rs", "b".to_string());
488        let count = cache.clear();
489        assert_eq!(count, 2);
490        assert!(cache.get("/a.rs").is_none());
491        assert_eq!(cache.get_file_ref("/c.rs"), "F1"); // refs reset
492    }
493
494    #[test]
495    fn cache_invalidate_removes_entry() {
496        let mut cache = SessionCache::new();
497        cache.store("/test.rs", "test".to_string());
498        assert!(cache.invalidate("/test.rs"));
499        assert!(!cache.invalidate("/nonexistent.rs"));
500    }
501
502    #[test]
503    fn cache_stats_track_correctly() {
504        let mut cache = SessionCache::new();
505        cache.store("/a.rs", "hello".to_string());
506        cache.store("/a.rs", "hello".to_string()); // hit
507        let stats = cache.get_stats();
508        assert_eq!(stats.total_reads, 2);
509        assert_eq!(stats.cache_hits, 1);
510        assert!(stats.hit_rate() > 0.0);
511    }
512
513    #[test]
514    fn md5_is_deterministic() {
515        let h1 = compute_md5("test content");
516        let h2 = compute_md5("test content");
517        assert_eq!(h1, h2);
518        assert_ne!(h1, compute_md5("different"));
519    }
520
521    #[test]
522    fn rrf_eviction_prefers_recent() {
523        let base = Instant::now();
524        std::thread::sleep(std::time::Duration::from_millis(5));
525        let now = Instant::now();
526        let key_a = "a.rs".to_string();
527        let key_b = "b.rs".to_string();
528        let recent = CacheEntry {
529            content: "a".to_string(),
530            hash: "h1".to_string(),
531            line_count: 1,
532            original_tokens: 10,
533            read_count: 1,
534            path: "/a.rs".to_string(),
535            last_access: now,
536            stored_mtime: None,
537        };
538        let old = CacheEntry {
539            content: "b".to_string(),
540            hash: "h2".to_string(),
541            line_count: 1,
542            original_tokens: 10,
543            read_count: 1,
544            path: "/b.rs".to_string(),
545            last_access: base,
546            stored_mtime: None,
547        };
548        let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
549        let scores = eviction_scores_rrf(&entries, now);
550        let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
551        let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
552        assert!(
553            score_a > score_b,
554            "recently accessed entries should score higher via RRF"
555        );
556    }
557
558    #[test]
559    fn rrf_eviction_prefers_frequent() {
560        let now = Instant::now();
561        let key_a = "a.rs".to_string();
562        let key_b = "b.rs".to_string();
563        let frequent = CacheEntry {
564            content: "a".to_string(),
565            hash: "h1".to_string(),
566            line_count: 1,
567            original_tokens: 10,
568            read_count: 20,
569            path: "/a.rs".to_string(),
570            last_access: now,
571            stored_mtime: None,
572        };
573        let rare = CacheEntry {
574            content: "b".to_string(),
575            hash: "h2".to_string(),
576            line_count: 1,
577            original_tokens: 10,
578            read_count: 1,
579            path: "/b.rs".to_string(),
580            last_access: now,
581            stored_mtime: None,
582        };
583        let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
584        let scores = eviction_scores_rrf(&entries, now);
585        let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
586        let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
587        assert!(
588            score_a > score_b,
589            "frequently accessed entries should score higher via RRF"
590        );
591    }
592
593    #[test]
594    fn evict_if_needed_removes_lowest_score() {
595        std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
596        let mut cache = SessionCache::new();
597        let big_content = "a]".repeat(30); // ~30 tokens
598        cache.store("/old.rs", big_content);
599        // /old.rs now in cache with ~30 tokens
600
601        let new_content = "b ".repeat(30); // ~30 tokens incoming
602        cache.store("/new.rs", new_content);
603        // should have evicted /old.rs to make room
604        // (total would be ~60 which exceeds 50)
605
606        // At least one should remain, total should be <= 50
607        assert!(
608            cache.total_cached_tokens() <= 60,
609            "eviction should have kicked in"
610        );
611        std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
612    }
613
614    #[test]
615    fn stale_detection_flags_newer_file() {
616        let dir = tempfile::tempdir().unwrap();
617        let path = dir.path().join("stale.txt");
618        let p = path.to_string_lossy().to_string();
619
620        std::fs::write(&path, "one").unwrap();
621        let mut cache = SessionCache::new();
622        cache.store(&p, "one".to_string());
623
624        let entry = cache.get(&p).unwrap();
625        assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
626
627        // Ensure mtime granularity differences don't make this flaky.
628        std::thread::sleep(Duration::from_secs(1));
629        std::fs::write(&path, "two").unwrap();
630
631        let entry = cache.get(&p).unwrap();
632        assert!(is_cache_entry_stale(&p, entry.stored_mtime));
633    }
634}