1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn max_cache_tokens() -> usize {
8 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
9 .ok()
10 .and_then(|v| v.parse().ok())
11 .unwrap_or(500_000)
12}
13
14#[derive(Clone, Debug)]
15#[allow(dead_code)]
16pub struct CacheEntry {
17 pub content: String,
18 pub hash: String,
19 pub line_count: usize,
20 pub original_tokens: usize,
21 pub read_count: u32,
22 pub path: String,
23 pub last_access: Instant,
24}
25
26impl CacheEntry {
27 pub fn eviction_score(&self, now: Instant) -> f64 {
30 let elapsed = now.duration_since(self.last_access).as_secs_f64();
31 let recency = 1.0 / (1.0 + elapsed.sqrt());
32 let frequency = (self.read_count as f64 + 1.0).ln();
33 let size_value = (self.original_tokens as f64 + 1.0).ln();
34 recency * 0.4 + frequency * 0.3 + size_value * 0.3
35 }
36}
37
38#[derive(Debug)]
39pub struct CacheStats {
40 pub total_reads: u64,
41 pub cache_hits: u64,
42 pub total_original_tokens: u64,
43 pub total_sent_tokens: u64,
44 pub files_tracked: usize,
45}
46
47#[allow(dead_code)]
48impl CacheStats {
49 pub fn hit_rate(&self) -> f64 {
50 if self.total_reads == 0 {
51 return 0.0;
52 }
53 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
54 }
55
56 pub fn tokens_saved(&self) -> u64 {
57 self.total_original_tokens
58 .saturating_sub(self.total_sent_tokens)
59 }
60
61 pub fn savings_percent(&self) -> f64 {
62 if self.total_original_tokens == 0 {
63 return 0.0;
64 }
65 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
66 }
67}
68
69#[derive(Clone, Debug)]
71pub struct SharedBlock {
72 pub canonical_path: String,
73 pub canonical_ref: String,
74 pub start_line: usize,
75 pub end_line: usize,
76 pub content: String,
77}
78
79pub struct SessionCache {
80 entries: HashMap<String, CacheEntry>,
81 file_refs: HashMap<String, String>,
82 next_ref: usize,
83 stats: CacheStats,
84 shared_blocks: Vec<SharedBlock>,
85}
86
87impl Default for SessionCache {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl SessionCache {
94 pub fn new() -> Self {
95 Self {
96 entries: HashMap::new(),
97 file_refs: HashMap::new(),
98 next_ref: 1,
99 shared_blocks: Vec::new(),
100 stats: CacheStats {
101 total_reads: 0,
102 cache_hits: 0,
103 total_original_tokens: 0,
104 total_sent_tokens: 0,
105 files_tracked: 0,
106 },
107 }
108 }
109
110 pub fn get_file_ref(&mut self, path: &str) -> String {
111 if let Some(r) = self.file_refs.get(path) {
112 return r.clone();
113 }
114 let r = format!("F{}", self.next_ref);
115 self.next_ref += 1;
116 self.file_refs.insert(path.to_string(), r.clone());
117 r
118 }
119
120 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
121 self.file_refs.get(path).cloned()
122 }
123
124 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
125 self.entries.get(path)
126 }
127
128 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
129 let ref_label = self
130 .file_refs
131 .get(path)
132 .cloned()
133 .unwrap_or_else(|| "F?".to_string());
134 if let Some(entry) = self.entries.get_mut(path) {
135 entry.read_count += 1;
136 entry.last_access = Instant::now();
137 self.stats.total_reads += 1;
138 self.stats.cache_hits += 1;
139 self.stats.total_original_tokens += entry.original_tokens as u64;
140 let hit_msg = format!(
141 "{ref_label} cached {}t {}L",
142 entry.read_count, entry.line_count
143 );
144 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
145 Some(entry)
146 } else {
147 None
148 }
149 }
150
151 pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
152 let hash = compute_md5(&content);
153 let line_count = content.lines().count();
154 let original_tokens = count_tokens(&content);
155 let now = Instant::now();
156
157 self.stats.total_reads += 1;
158 self.stats.total_original_tokens += original_tokens as u64;
159
160 if let Some(existing) = self.entries.get_mut(path) {
161 existing.last_access = now;
162 if existing.hash == hash {
163 existing.read_count += 1;
164 self.stats.cache_hits += 1;
165 let hit_msg = format!(
166 "{} cached {}t {}L",
167 self.file_refs.get(path).unwrap_or(&"F?".to_string()),
168 existing.read_count,
169 existing.line_count,
170 );
171 let sent = count_tokens(&hit_msg) as u64;
172 self.stats.total_sent_tokens += sent;
173 return (existing.clone(), true);
174 }
175 existing.content = content;
176 existing.hash = hash.clone();
177 existing.line_count = line_count;
178 existing.original_tokens = original_tokens;
179 existing.read_count += 1;
180 self.stats.total_sent_tokens += original_tokens as u64;
181 return (existing.clone(), false);
182 }
183
184 self.evict_if_needed(original_tokens);
185 self.get_file_ref(path);
186
187 let entry = CacheEntry {
188 content,
189 hash,
190 line_count,
191 original_tokens,
192 read_count: 1,
193 path: path.to_string(),
194 last_access: now,
195 };
196
197 self.entries.insert(path.to_string(), entry.clone());
198 self.stats.files_tracked += 1;
199 self.stats.total_sent_tokens += original_tokens as u64;
200 (entry, false)
201 }
202
203 pub fn total_cached_tokens(&self) -> usize {
204 self.entries.values().map(|e| e.original_tokens).sum()
205 }
206
207 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
209 let max_tokens = max_cache_tokens();
210 let current = self.total_cached_tokens();
211 if current + incoming_tokens <= max_tokens {
212 return;
213 }
214
215 let now = Instant::now();
216 let mut scored: Vec<(String, f64)> = self
217 .entries
218 .iter()
219 .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
220 .collect();
221 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
222
223 let mut freed = 0usize;
224 let target = (current + incoming_tokens).saturating_sub(max_tokens);
225 for (path, _score) in &scored {
226 if freed >= target {
227 break;
228 }
229 if let Some(entry) = self.entries.remove(path) {
230 freed += entry.original_tokens;
231 self.file_refs.remove(path);
232 }
233 }
234 }
235
236 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
237 self.entries.iter().collect()
238 }
239
240 pub fn get_stats(&self) -> &CacheStats {
241 &self.stats
242 }
243
244 pub fn file_ref_map(&self) -> &HashMap<String, String> {
245 &self.file_refs
246 }
247
248 #[allow(dead_code)]
249 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
250 self.shared_blocks = blocks;
251 }
252
253 #[allow(dead_code)]
254 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
255 &self.shared_blocks
256 }
257
258 #[allow(dead_code)]
260 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
261 if self.shared_blocks.is_empty() {
262 return None;
263 }
264 let refs: Vec<&SharedBlock> = self
265 .shared_blocks
266 .iter()
267 .filter(|b| b.canonical_path != path && content.contains(&b.content))
268 .collect();
269 if refs.is_empty() {
270 return None;
271 }
272 let mut result = content.to_string();
273 for block in refs {
274 result = result.replacen(
275 &block.content,
276 &format!(
277 "[= {}:{}-{}]",
278 block.canonical_ref, block.start_line, block.end_line
279 ),
280 1,
281 );
282 }
283 Some(result)
284 }
285
286 pub fn invalidate(&mut self, path: &str) -> bool {
287 self.entries.remove(path).is_some()
288 }
289
290 pub fn clear(&mut self) -> usize {
291 let count = self.entries.len();
292 self.entries.clear();
293 self.file_refs.clear();
294 self.shared_blocks.clear();
295 self.next_ref = 1;
296 self.stats = CacheStats {
297 total_reads: 0,
298 cache_hits: 0,
299 total_original_tokens: 0,
300 total_sent_tokens: 0,
301 files_tracked: 0,
302 };
303 count
304 }
305}
306
307fn compute_md5(content: &str) -> String {
308 let mut hasher = Md5::new();
309 hasher.update(content.as_bytes());
310 format!("{:x}", hasher.finalize())
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn cache_stores_and_retrieves() {
319 let mut cache = SessionCache::new();
320 let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
321 assert!(!was_hit);
322 assert_eq!(entry.line_count, 1);
323 assert!(cache.get("/test/file.rs").is_some());
324 }
325
326 #[test]
327 fn cache_hit_on_same_content() {
328 let mut cache = SessionCache::new();
329 cache.store("/test/file.rs", "content".to_string());
330 let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
331 assert!(was_hit, "same content should be a cache hit");
332 }
333
334 #[test]
335 fn cache_miss_on_changed_content() {
336 let mut cache = SessionCache::new();
337 cache.store("/test/file.rs", "old content".to_string());
338 let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
339 assert!(!was_hit, "changed content should not be a cache hit");
340 }
341
342 #[test]
343 fn file_refs_are_sequential() {
344 let mut cache = SessionCache::new();
345 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
346 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
347 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
349
350 #[test]
351 fn cache_clear_resets_everything() {
352 let mut cache = SessionCache::new();
353 cache.store("/a.rs", "a".to_string());
354 cache.store("/b.rs", "b".to_string());
355 let count = cache.clear();
356 assert_eq!(count, 2);
357 assert!(cache.get("/a.rs").is_none());
358 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
360
361 #[test]
362 fn cache_invalidate_removes_entry() {
363 let mut cache = SessionCache::new();
364 cache.store("/test.rs", "test".to_string());
365 assert!(cache.invalidate("/test.rs"));
366 assert!(!cache.invalidate("/nonexistent.rs"));
367 }
368
369 #[test]
370 fn cache_stats_track_correctly() {
371 let mut cache = SessionCache::new();
372 cache.store("/a.rs", "hello".to_string());
373 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
375 assert_eq!(stats.total_reads, 2);
376 assert_eq!(stats.cache_hits, 1);
377 assert!(stats.hit_rate() > 0.0);
378 }
379
380 #[test]
381 fn md5_is_deterministic() {
382 let h1 = compute_md5("test content");
383 let h2 = compute_md5("test content");
384 assert_eq!(h1, h2);
385 assert_ne!(h1, compute_md5("different"));
386 }
387
388 #[test]
389 fn eviction_score_prefers_recent() {
390 let now = Instant::now();
391 let recent = CacheEntry {
392 content: "a".to_string(),
393 hash: "h1".to_string(),
394 line_count: 1,
395 original_tokens: 10,
396 read_count: 1,
397 path: "/a.rs".to_string(),
398 last_access: now,
399 };
400 let old = CacheEntry {
401 content: "b".to_string(),
402 hash: "h2".to_string(),
403 line_count: 1,
404 original_tokens: 10,
405 read_count: 1,
406 path: "/b.rs".to_string(),
407 last_access: now - std::time::Duration::from_secs(300),
408 };
409 assert!(
410 recent.eviction_score(now) > old.eviction_score(now),
411 "recently accessed entries should score higher"
412 );
413 }
414
415 #[test]
416 fn eviction_score_prefers_frequent() {
417 let now = Instant::now();
418 let frequent = CacheEntry {
419 content: "a".to_string(),
420 hash: "h1".to_string(),
421 line_count: 1,
422 original_tokens: 10,
423 read_count: 20,
424 path: "/a.rs".to_string(),
425 last_access: now,
426 };
427 let rare = CacheEntry {
428 content: "b".to_string(),
429 hash: "h2".to_string(),
430 line_count: 1,
431 original_tokens: 10,
432 read_count: 1,
433 path: "/b.rs".to_string(),
434 last_access: now,
435 };
436 assert!(
437 frequent.eviction_score(now) > rare.eviction_score(now),
438 "frequently accessed entries should score higher"
439 );
440 }
441
442 #[test]
443 fn evict_if_needed_removes_lowest_score() {
444 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
445 let mut cache = SessionCache::new();
446 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
448 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
452 assert!(
457 cache.total_cached_tokens() <= 60,
458 "eviction should have kicked in"
459 );
460 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
461 }
462}