Skip to main content

lean_ctx/core/
cache.rs

1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8    crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12    std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13        .ok()
14        .and_then(|v| v.parse().ok())
15        .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19#[allow(dead_code)]
20pub struct CacheEntry {
21    pub content: String,
22    pub hash: String,
23    pub line_count: usize,
24    pub original_tokens: usize,
25    pub read_count: u32,
26    pub path: String,
27    pub last_access: Instant,
28}
29
30impl CacheEntry {
31    /// Boltzmann-inspired eviction score. Higher = more valuable = keep longer.
32    /// E = α·recency + β·frequency + γ·size_value
33    pub fn eviction_score(&self, now: Instant) -> f64 {
34        let elapsed = now.duration_since(self.last_access).as_secs_f64();
35        let recency = 1.0 / (1.0 + elapsed.sqrt());
36        let frequency = (self.read_count as f64 + 1.0).ln();
37        let size_value = (self.original_tokens as f64 + 1.0).ln();
38        recency * 0.4 + frequency * 0.3 + size_value * 0.3
39    }
40}
41
42#[derive(Debug)]
43pub struct CacheStats {
44    pub total_reads: u64,
45    pub cache_hits: u64,
46    pub total_original_tokens: u64,
47    pub total_sent_tokens: u64,
48    pub files_tracked: usize,
49}
50
51#[allow(dead_code)]
52impl CacheStats {
53    pub fn hit_rate(&self) -> f64 {
54        if self.total_reads == 0 {
55            return 0.0;
56        }
57        (self.cache_hits as f64 / self.total_reads as f64) * 100.0
58    }
59
60    pub fn tokens_saved(&self) -> u64 {
61        self.total_original_tokens
62            .saturating_sub(self.total_sent_tokens)
63    }
64
65    pub fn savings_percent(&self) -> f64 {
66        if self.total_original_tokens == 0 {
67            return 0.0;
68        }
69        (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
70    }
71}
72
73/// A block shared across multiple files, identified by its canonical source.
74#[derive(Clone, Debug)]
75pub struct SharedBlock {
76    pub canonical_path: String,
77    pub canonical_ref: String,
78    pub start_line: usize,
79    pub end_line: usize,
80    pub content: String,
81}
82
83pub struct SessionCache {
84    entries: HashMap<String, CacheEntry>,
85    file_refs: HashMap<String, String>,
86    next_ref: usize,
87    stats: CacheStats,
88    shared_blocks: Vec<SharedBlock>,
89}
90
91impl Default for SessionCache {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl SessionCache {
98    pub fn new() -> Self {
99        Self {
100            entries: HashMap::new(),
101            file_refs: HashMap::new(),
102            next_ref: 1,
103            shared_blocks: Vec::new(),
104            stats: CacheStats {
105                total_reads: 0,
106                cache_hits: 0,
107                total_original_tokens: 0,
108                total_sent_tokens: 0,
109                files_tracked: 0,
110            },
111        }
112    }
113
114    pub fn get_file_ref(&mut self, path: &str) -> String {
115        let key = normalize_key(path);
116        if let Some(r) = self.file_refs.get(&key) {
117            return r.clone();
118        }
119        let r = format!("F{}", self.next_ref);
120        self.next_ref += 1;
121        self.file_refs.insert(key, r.clone());
122        r
123    }
124
125    pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
126        self.file_refs.get(&normalize_key(path)).cloned()
127    }
128
129    pub fn get(&self, path: &str) -> Option<&CacheEntry> {
130        self.entries.get(&normalize_key(path))
131    }
132
133    pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
134        let key = normalize_key(path);
135        let ref_label = self
136            .file_refs
137            .get(&key)
138            .cloned()
139            .unwrap_or_else(|| "F?".to_string());
140        if let Some(entry) = self.entries.get_mut(&key) {
141            entry.read_count += 1;
142            entry.last_access = Instant::now();
143            self.stats.total_reads += 1;
144            self.stats.cache_hits += 1;
145            self.stats.total_original_tokens += entry.original_tokens as u64;
146            let hit_msg = format!(
147                "{ref_label} cached {}t {}L",
148                entry.read_count, entry.line_count
149            );
150            self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
151            crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
152            Some(entry)
153        } else {
154            None
155        }
156    }
157
158    pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
159        let key = normalize_key(path);
160        let hash = compute_md5(&content);
161        let line_count = content.lines().count();
162        let original_tokens = count_tokens(&content);
163        let now = Instant::now();
164
165        self.stats.total_reads += 1;
166        self.stats.total_original_tokens += original_tokens as u64;
167
168        if let Some(existing) = self.entries.get_mut(&key) {
169            existing.last_access = now;
170            if existing.hash == hash {
171                existing.read_count += 1;
172                self.stats.cache_hits += 1;
173                let hit_msg = format!(
174                    "{} cached {}t {}L",
175                    self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
176                    existing.read_count,
177                    existing.line_count,
178                );
179                let sent = count_tokens(&hit_msg) as u64;
180                self.stats.total_sent_tokens += sent;
181                return (existing.clone(), true);
182            }
183            existing.content = content;
184            existing.hash = hash.clone();
185            existing.line_count = line_count;
186            existing.original_tokens = original_tokens;
187            existing.read_count += 1;
188            self.stats.total_sent_tokens += original_tokens as u64;
189            return (existing.clone(), false);
190        }
191
192        self.evict_if_needed(original_tokens);
193        self.get_file_ref(&key);
194
195        let entry = CacheEntry {
196            content,
197            hash,
198            line_count,
199            original_tokens,
200            read_count: 1,
201            path: key.clone(),
202            last_access: now,
203        };
204
205        self.entries.insert(key, entry.clone());
206        self.stats.files_tracked += 1;
207        self.stats.total_sent_tokens += original_tokens as u64;
208        (entry, false)
209    }
210
211    pub fn total_cached_tokens(&self) -> usize {
212        self.entries.values().map(|e| e.original_tokens).sum()
213    }
214
215    /// Evict lowest-scoring entries until cache fits within token budget.
216    pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
217        let max_tokens = max_cache_tokens();
218        let current = self.total_cached_tokens();
219        if current + incoming_tokens <= max_tokens {
220            return;
221        }
222
223        let now = Instant::now();
224        let mut scored: Vec<(String, f64)> = self
225            .entries
226            .iter()
227            .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
228            .collect();
229        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
230
231        let mut freed = 0usize;
232        let target = (current + incoming_tokens).saturating_sub(max_tokens);
233        for (path, _score) in &scored {
234            if freed >= target {
235                break;
236            }
237            if let Some(entry) = self.entries.remove(path) {
238                freed += entry.original_tokens;
239                self.file_refs.remove(path);
240            }
241        }
242    }
243
244    pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
245        self.entries.iter().collect()
246    }
247
248    pub fn get_stats(&self) -> &CacheStats {
249        &self.stats
250    }
251
252    pub fn file_ref_map(&self) -> &HashMap<String, String> {
253        &self.file_refs
254    }
255
256    #[allow(dead_code)]
257    pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
258        self.shared_blocks = blocks;
259    }
260
261    #[allow(dead_code)]
262    pub fn get_shared_blocks(&self) -> &[SharedBlock] {
263        &self.shared_blocks
264    }
265
266    /// Replace shared blocks in content with cross-file references.
267    #[allow(dead_code)]
268    pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
269        if self.shared_blocks.is_empty() {
270            return None;
271        }
272        let refs: Vec<&SharedBlock> = self
273            .shared_blocks
274            .iter()
275            .filter(|b| b.canonical_path != path && content.contains(&b.content))
276            .collect();
277        if refs.is_empty() {
278            return None;
279        }
280        let mut result = content.to_string();
281        for block in refs {
282            result = result.replacen(
283                &block.content,
284                &format!(
285                    "[= {}:{}-{}]",
286                    block.canonical_ref, block.start_line, block.end_line
287                ),
288                1,
289            );
290        }
291        Some(result)
292    }
293
294    pub fn invalidate(&mut self, path: &str) -> bool {
295        self.entries.remove(&normalize_key(path)).is_some()
296    }
297
298    pub fn clear(&mut self) -> usize {
299        let count = self.entries.len();
300        self.entries.clear();
301        self.file_refs.clear();
302        self.shared_blocks.clear();
303        self.next_ref = 1;
304        self.stats = CacheStats {
305            total_reads: 0,
306            cache_hits: 0,
307            total_original_tokens: 0,
308            total_sent_tokens: 0,
309            files_tracked: 0,
310        };
311        count
312    }
313}
314
315fn compute_md5(content: &str) -> String {
316    let mut hasher = Md5::new();
317    hasher.update(content.as_bytes());
318    format!("{:x}", hasher.finalize())
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn cache_stores_and_retrieves() {
327        let mut cache = SessionCache::new();
328        let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
329        assert!(!was_hit);
330        assert_eq!(entry.line_count, 1);
331        assert!(cache.get("/test/file.rs").is_some());
332    }
333
334    #[test]
335    fn cache_hit_on_same_content() {
336        let mut cache = SessionCache::new();
337        cache.store("/test/file.rs", "content".to_string());
338        let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
339        assert!(was_hit, "same content should be a cache hit");
340    }
341
342    #[test]
343    fn cache_miss_on_changed_content() {
344        let mut cache = SessionCache::new();
345        cache.store("/test/file.rs", "old content".to_string());
346        let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
347        assert!(!was_hit, "changed content should not be a cache hit");
348    }
349
350    #[test]
351    fn file_refs_are_sequential() {
352        let mut cache = SessionCache::new();
353        assert_eq!(cache.get_file_ref("/a.rs"), "F1");
354        assert_eq!(cache.get_file_ref("/b.rs"), "F2");
355        assert_eq!(cache.get_file_ref("/a.rs"), "F1"); // stable
356    }
357
358    #[test]
359    fn cache_clear_resets_everything() {
360        let mut cache = SessionCache::new();
361        cache.store("/a.rs", "a".to_string());
362        cache.store("/b.rs", "b".to_string());
363        let count = cache.clear();
364        assert_eq!(count, 2);
365        assert!(cache.get("/a.rs").is_none());
366        assert_eq!(cache.get_file_ref("/c.rs"), "F1"); // refs reset
367    }
368
369    #[test]
370    fn cache_invalidate_removes_entry() {
371        let mut cache = SessionCache::new();
372        cache.store("/test.rs", "test".to_string());
373        assert!(cache.invalidate("/test.rs"));
374        assert!(!cache.invalidate("/nonexistent.rs"));
375    }
376
377    #[test]
378    fn cache_stats_track_correctly() {
379        let mut cache = SessionCache::new();
380        cache.store("/a.rs", "hello".to_string());
381        cache.store("/a.rs", "hello".to_string()); // hit
382        let stats = cache.get_stats();
383        assert_eq!(stats.total_reads, 2);
384        assert_eq!(stats.cache_hits, 1);
385        assert!(stats.hit_rate() > 0.0);
386    }
387
388    #[test]
389    fn md5_is_deterministic() {
390        let h1 = compute_md5("test content");
391        let h2 = compute_md5("test content");
392        assert_eq!(h1, h2);
393        assert_ne!(h1, compute_md5("different"));
394    }
395
396    #[test]
397    fn eviction_score_prefers_recent() {
398        let now = Instant::now();
399        let recent = CacheEntry {
400            content: "a".to_string(),
401            hash: "h1".to_string(),
402            line_count: 1,
403            original_tokens: 10,
404            read_count: 1,
405            path: "/a.rs".to_string(),
406            last_access: now,
407        };
408        let old = CacheEntry {
409            content: "b".to_string(),
410            hash: "h2".to_string(),
411            line_count: 1,
412            original_tokens: 10,
413            read_count: 1,
414            path: "/b.rs".to_string(),
415            last_access: now - std::time::Duration::from_secs(300),
416        };
417        assert!(
418            recent.eviction_score(now) > old.eviction_score(now),
419            "recently accessed entries should score higher"
420        );
421    }
422
423    #[test]
424    fn eviction_score_prefers_frequent() {
425        let now = Instant::now();
426        let frequent = CacheEntry {
427            content: "a".to_string(),
428            hash: "h1".to_string(),
429            line_count: 1,
430            original_tokens: 10,
431            read_count: 20,
432            path: "/a.rs".to_string(),
433            last_access: now,
434        };
435        let rare = CacheEntry {
436            content: "b".to_string(),
437            hash: "h2".to_string(),
438            line_count: 1,
439            original_tokens: 10,
440            read_count: 1,
441            path: "/b.rs".to_string(),
442            last_access: now,
443        };
444        assert!(
445            frequent.eviction_score(now) > rare.eviction_score(now),
446            "frequently accessed entries should score higher"
447        );
448    }
449
450    #[test]
451    fn evict_if_needed_removes_lowest_score() {
452        std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
453        let mut cache = SessionCache::new();
454        let big_content = "a]".repeat(30); // ~30 tokens
455        cache.store("/old.rs", big_content);
456        // /old.rs now in cache with ~30 tokens
457
458        let new_content = "b ".repeat(30); // ~30 tokens incoming
459        cache.store("/new.rs", new_content);
460        // should have evicted /old.rs to make room
461        // (total would be ~60 which exceeds 50)
462
463        // At least one should remain, total should be <= 50
464        assert!(
465            cache.total_cached_tokens() <= 60,
466            "eviction should have kicked in"
467        );
468        std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
469    }
470}