Skip to main content

lean_ctx/core/
cache.rs

1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8    crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12    std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13        .ok()
14        .and_then(|v| v.parse().ok())
15        .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19pub struct CacheEntry {
20    pub content: String,
21    pub hash: String,
22    pub line_count: usize,
23    pub original_tokens: usize,
24    pub read_count: u32,
25    pub path: String,
26    pub last_access: Instant,
27}
28
29impl CacheEntry {
30    /// Boltzmann-inspired eviction score. Higher = more valuable = keep longer.
31    /// E = α·recency + β·frequency + γ·size_value
32    pub fn eviction_score(&self, now: Instant) -> f64 {
33        let elapsed = now.duration_since(self.last_access).as_secs_f64();
34        let recency = 1.0 / (1.0 + elapsed.sqrt());
35        let frequency = (self.read_count as f64 + 1.0).ln();
36        let size_value = (self.original_tokens as f64 + 1.0).ln();
37        recency * 0.4 + frequency * 0.3 + size_value * 0.3
38    }
39}
40
41#[derive(Debug)]
42pub struct CacheStats {
43    pub total_reads: u64,
44    pub cache_hits: u64,
45    pub total_original_tokens: u64,
46    pub total_sent_tokens: u64,
47    pub files_tracked: usize,
48}
49
50impl CacheStats {
51    pub fn hit_rate(&self) -> f64 {
52        if self.total_reads == 0 {
53            return 0.0;
54        }
55        (self.cache_hits as f64 / self.total_reads as f64) * 100.0
56    }
57
58    pub fn tokens_saved(&self) -> u64 {
59        self.total_original_tokens
60            .saturating_sub(self.total_sent_tokens)
61    }
62
63    pub fn savings_percent(&self) -> f64 {
64        if self.total_original_tokens == 0 {
65            return 0.0;
66        }
67        (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
68    }
69}
70
71/// A block shared across multiple files, identified by its canonical source.
72#[derive(Clone, Debug)]
73pub struct SharedBlock {
74    pub canonical_path: String,
75    pub canonical_ref: String,
76    pub start_line: usize,
77    pub end_line: usize,
78    pub content: String,
79}
80
81pub struct SessionCache {
82    entries: HashMap<String, CacheEntry>,
83    file_refs: HashMap<String, String>,
84    next_ref: usize,
85    stats: CacheStats,
86    shared_blocks: Vec<SharedBlock>,
87}
88
89impl Default for SessionCache {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl SessionCache {
96    pub fn new() -> Self {
97        Self {
98            entries: HashMap::new(),
99            file_refs: HashMap::new(),
100            next_ref: 1,
101            shared_blocks: Vec::new(),
102            stats: CacheStats {
103                total_reads: 0,
104                cache_hits: 0,
105                total_original_tokens: 0,
106                total_sent_tokens: 0,
107                files_tracked: 0,
108            },
109        }
110    }
111
112    pub fn get_file_ref(&mut self, path: &str) -> String {
113        let key = normalize_key(path);
114        if let Some(r) = self.file_refs.get(&key) {
115            return r.clone();
116        }
117        let r = format!("F{}", self.next_ref);
118        self.next_ref += 1;
119        self.file_refs.insert(key, r.clone());
120        r
121    }
122
123    pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
124        self.file_refs.get(&normalize_key(path)).cloned()
125    }
126
127    pub fn get(&self, path: &str) -> Option<&CacheEntry> {
128        self.entries.get(&normalize_key(path))
129    }
130
131    pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
132        let key = normalize_key(path);
133        let ref_label = self
134            .file_refs
135            .get(&key)
136            .cloned()
137            .unwrap_or_else(|| "F?".to_string());
138        if let Some(entry) = self.entries.get_mut(&key) {
139            entry.read_count += 1;
140            entry.last_access = Instant::now();
141            self.stats.total_reads += 1;
142            self.stats.cache_hits += 1;
143            self.stats.total_original_tokens += entry.original_tokens as u64;
144            let hit_msg = format!(
145                "{ref_label} cached {}t {}L",
146                entry.read_count, entry.line_count
147            );
148            self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
149            crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
150            Some(entry)
151        } else {
152            None
153        }
154    }
155
156    pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
157        let key = normalize_key(path);
158        let hash = compute_md5(&content);
159        let line_count = content.lines().count();
160        let original_tokens = count_tokens(&content);
161        let now = Instant::now();
162
163        self.stats.total_reads += 1;
164        self.stats.total_original_tokens += original_tokens as u64;
165
166        if let Some(existing) = self.entries.get_mut(&key) {
167            existing.last_access = now;
168            if existing.hash == hash {
169                existing.read_count += 1;
170                self.stats.cache_hits += 1;
171                let hit_msg = format!(
172                    "{} cached {}t {}L",
173                    self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
174                    existing.read_count,
175                    existing.line_count,
176                );
177                let sent = count_tokens(&hit_msg) as u64;
178                self.stats.total_sent_tokens += sent;
179                return (existing.clone(), true);
180            }
181            existing.content = content;
182            existing.hash = hash.clone();
183            existing.line_count = line_count;
184            existing.original_tokens = original_tokens;
185            existing.read_count += 1;
186            self.stats.total_sent_tokens += original_tokens as u64;
187            return (existing.clone(), false);
188        }
189
190        self.evict_if_needed(original_tokens);
191        self.get_file_ref(&key);
192
193        let entry = CacheEntry {
194            content,
195            hash,
196            line_count,
197            original_tokens,
198            read_count: 1,
199            path: key.clone(),
200            last_access: now,
201        };
202
203        self.entries.insert(key, entry.clone());
204        self.stats.files_tracked += 1;
205        self.stats.total_sent_tokens += original_tokens as u64;
206        (entry, false)
207    }
208
209    pub fn total_cached_tokens(&self) -> usize {
210        self.entries.values().map(|e| e.original_tokens).sum()
211    }
212
213    /// Evict lowest-scoring entries until cache fits within token budget.
214    pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
215        let max_tokens = max_cache_tokens();
216        let current = self.total_cached_tokens();
217        if current + incoming_tokens <= max_tokens {
218            return;
219        }
220
221        let now = Instant::now();
222        let mut scored: Vec<(String, f64)> = self
223            .entries
224            .iter()
225            .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
226            .collect();
227        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
228
229        let mut freed = 0usize;
230        let target = (current + incoming_tokens).saturating_sub(max_tokens);
231        for (path, _score) in &scored {
232            if freed >= target {
233                break;
234            }
235            if let Some(entry) = self.entries.remove(path) {
236                freed += entry.original_tokens;
237                self.file_refs.remove(path);
238            }
239        }
240    }
241
242    pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
243        self.entries.iter().collect()
244    }
245
246    pub fn get_stats(&self) -> &CacheStats {
247        &self.stats
248    }
249
250    pub fn file_ref_map(&self) -> &HashMap<String, String> {
251        &self.file_refs
252    }
253
254    #[allow(dead_code)]
255    pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
256        self.shared_blocks = blocks;
257    }
258
259    #[allow(dead_code)]
260    pub fn get_shared_blocks(&self) -> &[SharedBlock] {
261        &self.shared_blocks
262    }
263
264    /// Replace shared blocks in content with cross-file references.
265    #[allow(dead_code)]
266    pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
267        if self.shared_blocks.is_empty() {
268            return None;
269        }
270        let refs: Vec<&SharedBlock> = self
271            .shared_blocks
272            .iter()
273            .filter(|b| b.canonical_path != path && content.contains(&b.content))
274            .collect();
275        if refs.is_empty() {
276            return None;
277        }
278        let mut result = content.to_string();
279        for block in refs {
280            result = result.replacen(
281                &block.content,
282                &format!(
283                    "[= {}:{}-{}]",
284                    block.canonical_ref, block.start_line, block.end_line
285                ),
286                1,
287            );
288        }
289        Some(result)
290    }
291
292    pub fn invalidate(&mut self, path: &str) -> bool {
293        self.entries.remove(&normalize_key(path)).is_some()
294    }
295
296    pub fn clear(&mut self) -> usize {
297        let count = self.entries.len();
298        self.entries.clear();
299        self.file_refs.clear();
300        self.shared_blocks.clear();
301        self.next_ref = 1;
302        self.stats = CacheStats {
303            total_reads: 0,
304            cache_hits: 0,
305            total_original_tokens: 0,
306            total_sent_tokens: 0,
307            files_tracked: 0,
308        };
309        count
310    }
311}
312
313fn compute_md5(content: &str) -> String {
314    let mut hasher = Md5::new();
315    hasher.update(content.as_bytes());
316    format!("{:x}", hasher.finalize())
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn cache_stores_and_retrieves() {
325        let mut cache = SessionCache::new();
326        let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
327        assert!(!was_hit);
328        assert_eq!(entry.line_count, 1);
329        assert!(cache.get("/test/file.rs").is_some());
330    }
331
332    #[test]
333    fn cache_hit_on_same_content() {
334        let mut cache = SessionCache::new();
335        cache.store("/test/file.rs", "content".to_string());
336        let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
337        assert!(was_hit, "same content should be a cache hit");
338    }
339
340    #[test]
341    fn cache_miss_on_changed_content() {
342        let mut cache = SessionCache::new();
343        cache.store("/test/file.rs", "old content".to_string());
344        let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
345        assert!(!was_hit, "changed content should not be a cache hit");
346    }
347
348    #[test]
349    fn file_refs_are_sequential() {
350        let mut cache = SessionCache::new();
351        assert_eq!(cache.get_file_ref("/a.rs"), "F1");
352        assert_eq!(cache.get_file_ref("/b.rs"), "F2");
353        assert_eq!(cache.get_file_ref("/a.rs"), "F1"); // stable
354    }
355
356    #[test]
357    fn cache_clear_resets_everything() {
358        let mut cache = SessionCache::new();
359        cache.store("/a.rs", "a".to_string());
360        cache.store("/b.rs", "b".to_string());
361        let count = cache.clear();
362        assert_eq!(count, 2);
363        assert!(cache.get("/a.rs").is_none());
364        assert_eq!(cache.get_file_ref("/c.rs"), "F1"); // refs reset
365    }
366
367    #[test]
368    fn cache_invalidate_removes_entry() {
369        let mut cache = SessionCache::new();
370        cache.store("/test.rs", "test".to_string());
371        assert!(cache.invalidate("/test.rs"));
372        assert!(!cache.invalidate("/nonexistent.rs"));
373    }
374
375    #[test]
376    fn cache_stats_track_correctly() {
377        let mut cache = SessionCache::new();
378        cache.store("/a.rs", "hello".to_string());
379        cache.store("/a.rs", "hello".to_string()); // hit
380        let stats = cache.get_stats();
381        assert_eq!(stats.total_reads, 2);
382        assert_eq!(stats.cache_hits, 1);
383        assert!(stats.hit_rate() > 0.0);
384    }
385
386    #[test]
387    fn md5_is_deterministic() {
388        let h1 = compute_md5("test content");
389        let h2 = compute_md5("test content");
390        assert_eq!(h1, h2);
391        assert_ne!(h1, compute_md5("different"));
392    }
393
394    #[test]
395    fn eviction_score_prefers_recent() {
396        let now = Instant::now();
397        let recent = CacheEntry {
398            content: "a".to_string(),
399            hash: "h1".to_string(),
400            line_count: 1,
401            original_tokens: 10,
402            read_count: 1,
403            path: "/a.rs".to_string(),
404            last_access: now,
405        };
406        let old = CacheEntry {
407            content: "b".to_string(),
408            hash: "h2".to_string(),
409            line_count: 1,
410            original_tokens: 10,
411            read_count: 1,
412            path: "/b.rs".to_string(),
413            last_access: now - std::time::Duration::from_secs(300),
414        };
415        assert!(
416            recent.eviction_score(now) > old.eviction_score(now),
417            "recently accessed entries should score higher"
418        );
419    }
420
421    #[test]
422    fn eviction_score_prefers_frequent() {
423        let now = Instant::now();
424        let frequent = CacheEntry {
425            content: "a".to_string(),
426            hash: "h1".to_string(),
427            line_count: 1,
428            original_tokens: 10,
429            read_count: 20,
430            path: "/a.rs".to_string(),
431            last_access: now,
432        };
433        let rare = CacheEntry {
434            content: "b".to_string(),
435            hash: "h2".to_string(),
436            line_count: 1,
437            original_tokens: 10,
438            read_count: 1,
439            path: "/b.rs".to_string(),
440            last_access: now,
441        };
442        assert!(
443            frequent.eviction_score(now) > rare.eviction_score(now),
444            "frequently accessed entries should score higher"
445        );
446    }
447
448    #[test]
449    fn evict_if_needed_removes_lowest_score() {
450        std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
451        let mut cache = SessionCache::new();
452        let big_content = "a]".repeat(30); // ~30 tokens
453        cache.store("/old.rs", big_content);
454        // /old.rs now in cache with ~30 tokens
455
456        let new_content = "b ".repeat(30); // ~30 tokens incoming
457        cache.store("/new.rs", new_content);
458        // should have evicted /old.rs to make room
459        // (total would be ~60 which exceeds 50)
460
461        // At least one should remain, total should be <= 50
462        assert!(
463            cache.total_cached_tokens() <= 60,
464            "eviction should have kicked in"
465        );
466        std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
467    }
468}