Skip to main content

lean_ctx/core/
cache.rs

1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn max_cache_tokens() -> usize {
8    std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
9        .ok()
10        .and_then(|v| v.parse().ok())
11        .unwrap_or(500_000)
12}
13
14#[derive(Clone, Debug)]
15#[allow(dead_code)]
16pub struct CacheEntry {
17    pub content: String,
18    pub hash: String,
19    pub line_count: usize,
20    pub original_tokens: usize,
21    pub read_count: u32,
22    pub path: String,
23    pub last_access: Instant,
24}
25
26impl CacheEntry {
27    /// Boltzmann-inspired eviction score. Higher = more valuable = keep longer.
28    /// E = α·recency + β·frequency + γ·size_value
29    pub fn eviction_score(&self, now: Instant) -> f64 {
30        let elapsed = now.duration_since(self.last_access).as_secs_f64();
31        let recency = 1.0 / (1.0 + elapsed.sqrt());
32        let frequency = (self.read_count as f64 + 1.0).ln();
33        let size_value = (self.original_tokens as f64 + 1.0).ln();
34        recency * 0.4 + frequency * 0.3 + size_value * 0.3
35    }
36}
37
38#[derive(Debug)]
39pub struct CacheStats {
40    pub total_reads: u64,
41    pub cache_hits: u64,
42    pub total_original_tokens: u64,
43    pub total_sent_tokens: u64,
44    pub files_tracked: usize,
45}
46
47#[allow(dead_code)]
48impl CacheStats {
49    pub fn hit_rate(&self) -> f64 {
50        if self.total_reads == 0 {
51            return 0.0;
52        }
53        (self.cache_hits as f64 / self.total_reads as f64) * 100.0
54    }
55
56    pub fn tokens_saved(&self) -> u64 {
57        self.total_original_tokens
58            .saturating_sub(self.total_sent_tokens)
59    }
60
61    pub fn savings_percent(&self) -> f64 {
62        if self.total_original_tokens == 0 {
63            return 0.0;
64        }
65        (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
66    }
67}
68
69/// A block shared across multiple files, identified by its canonical source.
70#[derive(Clone, Debug)]
71pub struct SharedBlock {
72    pub canonical_path: String,
73    pub canonical_ref: String,
74    pub start_line: usize,
75    pub end_line: usize,
76    pub content: String,
77}
78
79pub struct SessionCache {
80    entries: HashMap<String, CacheEntry>,
81    file_refs: HashMap<String, String>,
82    next_ref: usize,
83    stats: CacheStats,
84    shared_blocks: Vec<SharedBlock>,
85}
86
87impl Default for SessionCache {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl SessionCache {
94    pub fn new() -> Self {
95        Self {
96            entries: HashMap::new(),
97            file_refs: HashMap::new(),
98            next_ref: 1,
99            shared_blocks: Vec::new(),
100            stats: CacheStats {
101                total_reads: 0,
102                cache_hits: 0,
103                total_original_tokens: 0,
104                total_sent_tokens: 0,
105                files_tracked: 0,
106            },
107        }
108    }
109
110    pub fn get_file_ref(&mut self, path: &str) -> String {
111        if let Some(r) = self.file_refs.get(path) {
112            return r.clone();
113        }
114        let r = format!("F{}", self.next_ref);
115        self.next_ref += 1;
116        self.file_refs.insert(path.to_string(), r.clone());
117        r
118    }
119
120    pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
121        self.file_refs.get(path).cloned()
122    }
123
124    pub fn get(&self, path: &str) -> Option<&CacheEntry> {
125        self.entries.get(path)
126    }
127
128    pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
129        let ref_label = self
130            .file_refs
131            .get(path)
132            .cloned()
133            .unwrap_or_else(|| "F?".to_string());
134        if let Some(entry) = self.entries.get_mut(path) {
135            entry.read_count += 1;
136            entry.last_access = Instant::now();
137            self.stats.total_reads += 1;
138            self.stats.cache_hits += 1;
139            self.stats.total_original_tokens += entry.original_tokens as u64;
140            let hit_msg = format!(
141                "{ref_label} cached {}t {}L",
142                entry.read_count, entry.line_count
143            );
144            self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
145            Some(entry)
146        } else {
147            None
148        }
149    }
150
151    pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
152        let hash = compute_md5(&content);
153        let line_count = content.lines().count();
154        let original_tokens = count_tokens(&content);
155        let now = Instant::now();
156
157        self.stats.total_reads += 1;
158        self.stats.total_original_tokens += original_tokens as u64;
159
160        if let Some(existing) = self.entries.get_mut(path) {
161            existing.last_access = now;
162            if existing.hash == hash {
163                existing.read_count += 1;
164                self.stats.cache_hits += 1;
165                let hit_msg = format!(
166                    "{} cached {}t {}L",
167                    self.file_refs.get(path).unwrap_or(&"F?".to_string()),
168                    existing.read_count,
169                    existing.line_count,
170                );
171                let sent = count_tokens(&hit_msg) as u64;
172                self.stats.total_sent_tokens += sent;
173                return (existing.clone(), true);
174            }
175            existing.content = content;
176            existing.hash = hash.clone();
177            existing.line_count = line_count;
178            existing.original_tokens = original_tokens;
179            existing.read_count += 1;
180            self.stats.total_sent_tokens += original_tokens as u64;
181            return (existing.clone(), false);
182        }
183
184        self.evict_if_needed(original_tokens);
185        self.get_file_ref(path);
186
187        let entry = CacheEntry {
188            content,
189            hash,
190            line_count,
191            original_tokens,
192            read_count: 1,
193            path: path.to_string(),
194            last_access: now,
195        };
196
197        self.entries.insert(path.to_string(), entry.clone());
198        self.stats.files_tracked += 1;
199        self.stats.total_sent_tokens += original_tokens as u64;
200        (entry, false)
201    }
202
203    pub fn total_cached_tokens(&self) -> usize {
204        self.entries.values().map(|e| e.original_tokens).sum()
205    }
206
207    /// Evict lowest-scoring entries until cache fits within token budget.
208    pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
209        let max_tokens = max_cache_tokens();
210        let current = self.total_cached_tokens();
211        if current + incoming_tokens <= max_tokens {
212            return;
213        }
214
215        let now = Instant::now();
216        let mut scored: Vec<(String, f64)> = self
217            .entries
218            .iter()
219            .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
220            .collect();
221        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
222
223        let mut freed = 0usize;
224        let target = (current + incoming_tokens).saturating_sub(max_tokens);
225        for (path, _score) in &scored {
226            if freed >= target {
227                break;
228            }
229            if let Some(entry) = self.entries.remove(path) {
230                freed += entry.original_tokens;
231                self.file_refs.remove(path);
232            }
233        }
234    }
235
236    pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
237        self.entries.iter().collect()
238    }
239
240    pub fn get_stats(&self) -> &CacheStats {
241        &self.stats
242    }
243
244    pub fn file_ref_map(&self) -> &HashMap<String, String> {
245        &self.file_refs
246    }
247
248    #[allow(dead_code)]
249    pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
250        self.shared_blocks = blocks;
251    }
252
253    #[allow(dead_code)]
254    pub fn get_shared_blocks(&self) -> &[SharedBlock] {
255        &self.shared_blocks
256    }
257
258    /// Replace shared blocks in content with cross-file references.
259    #[allow(dead_code)]
260    pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
261        if self.shared_blocks.is_empty() {
262            return None;
263        }
264        let refs: Vec<&SharedBlock> = self
265            .shared_blocks
266            .iter()
267            .filter(|b| b.canonical_path != path && content.contains(&b.content))
268            .collect();
269        if refs.is_empty() {
270            return None;
271        }
272        let mut result = content.to_string();
273        for block in refs {
274            result = result.replacen(
275                &block.content,
276                &format!(
277                    "[= {}:{}-{}]",
278                    block.canonical_ref, block.start_line, block.end_line
279                ),
280                1,
281            );
282        }
283        Some(result)
284    }
285
286    pub fn invalidate(&mut self, path: &str) -> bool {
287        self.entries.remove(path).is_some()
288    }
289
290    pub fn clear(&mut self) -> usize {
291        let count = self.entries.len();
292        self.entries.clear();
293        self.file_refs.clear();
294        self.shared_blocks.clear();
295        self.next_ref = 1;
296        self.stats = CacheStats {
297            total_reads: 0,
298            cache_hits: 0,
299            total_original_tokens: 0,
300            total_sent_tokens: 0,
301            files_tracked: 0,
302        };
303        count
304    }
305}
306
307fn compute_md5(content: &str) -> String {
308    let mut hasher = Md5::new();
309    hasher.update(content.as_bytes());
310    format!("{:x}", hasher.finalize())
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn cache_stores_and_retrieves() {
319        let mut cache = SessionCache::new();
320        let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
321        assert!(!was_hit);
322        assert_eq!(entry.line_count, 1);
323        assert!(cache.get("/test/file.rs").is_some());
324    }
325
326    #[test]
327    fn cache_hit_on_same_content() {
328        let mut cache = SessionCache::new();
329        cache.store("/test/file.rs", "content".to_string());
330        let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
331        assert!(was_hit, "same content should be a cache hit");
332    }
333
334    #[test]
335    fn cache_miss_on_changed_content() {
336        let mut cache = SessionCache::new();
337        cache.store("/test/file.rs", "old content".to_string());
338        let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
339        assert!(!was_hit, "changed content should not be a cache hit");
340    }
341
342    #[test]
343    fn file_refs_are_sequential() {
344        let mut cache = SessionCache::new();
345        assert_eq!(cache.get_file_ref("/a.rs"), "F1");
346        assert_eq!(cache.get_file_ref("/b.rs"), "F2");
347        assert_eq!(cache.get_file_ref("/a.rs"), "F1"); // stable
348    }
349
350    #[test]
351    fn cache_clear_resets_everything() {
352        let mut cache = SessionCache::new();
353        cache.store("/a.rs", "a".to_string());
354        cache.store("/b.rs", "b".to_string());
355        let count = cache.clear();
356        assert_eq!(count, 2);
357        assert!(cache.get("/a.rs").is_none());
358        assert_eq!(cache.get_file_ref("/c.rs"), "F1"); // refs reset
359    }
360
361    #[test]
362    fn cache_invalidate_removes_entry() {
363        let mut cache = SessionCache::new();
364        cache.store("/test.rs", "test".to_string());
365        assert!(cache.invalidate("/test.rs"));
366        assert!(!cache.invalidate("/nonexistent.rs"));
367    }
368
369    #[test]
370    fn cache_stats_track_correctly() {
371        let mut cache = SessionCache::new();
372        cache.store("/a.rs", "hello".to_string());
373        cache.store("/a.rs", "hello".to_string()); // hit
374        let stats = cache.get_stats();
375        assert_eq!(stats.total_reads, 2);
376        assert_eq!(stats.cache_hits, 1);
377        assert!(stats.hit_rate() > 0.0);
378    }
379
380    #[test]
381    fn md5_is_deterministic() {
382        let h1 = compute_md5("test content");
383        let h2 = compute_md5("test content");
384        assert_eq!(h1, h2);
385        assert_ne!(h1, compute_md5("different"));
386    }
387
388    #[test]
389    fn eviction_score_prefers_recent() {
390        let now = Instant::now();
391        let recent = CacheEntry {
392            content: "a".to_string(),
393            hash: "h1".to_string(),
394            line_count: 1,
395            original_tokens: 10,
396            read_count: 1,
397            path: "/a.rs".to_string(),
398            last_access: now,
399        };
400        let old = CacheEntry {
401            content: "b".to_string(),
402            hash: "h2".to_string(),
403            line_count: 1,
404            original_tokens: 10,
405            read_count: 1,
406            path: "/b.rs".to_string(),
407            last_access: now - std::time::Duration::from_secs(300),
408        };
409        assert!(
410            recent.eviction_score(now) > old.eviction_score(now),
411            "recently accessed entries should score higher"
412        );
413    }
414
415    #[test]
416    fn eviction_score_prefers_frequent() {
417        let now = Instant::now();
418        let frequent = CacheEntry {
419            content: "a".to_string(),
420            hash: "h1".to_string(),
421            line_count: 1,
422            original_tokens: 10,
423            read_count: 20,
424            path: "/a.rs".to_string(),
425            last_access: now,
426        };
427        let rare = CacheEntry {
428            content: "b".to_string(),
429            hash: "h2".to_string(),
430            line_count: 1,
431            original_tokens: 10,
432            read_count: 1,
433            path: "/b.rs".to_string(),
434            last_access: now,
435        };
436        assert!(
437            frequent.eviction_score(now) > rare.eviction_score(now),
438            "frequently accessed entries should score higher"
439        );
440    }
441
442    #[test]
443    fn evict_if_needed_removes_lowest_score() {
444        std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
445        let mut cache = SessionCache::new();
446        let big_content = "a]".repeat(30); // ~30 tokens
447        cache.store("/old.rs", big_content);
448        // /old.rs now in cache with ~30 tokens
449
450        let new_content = "b ".repeat(30); // ~30 tokens incoming
451        cache.store("/new.rs", new_content);
452        // should have evicted /old.rs to make room
453        // (total would be ~60 which exceeds 50)
454
455        // At least one should remain, total should be <= 50
456        assert!(
457            cache.total_cached_tokens() <= 60,
458            "eviction should have kicked in"
459        );
460        std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
461    }
462}