1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19pub struct CacheEntry {
20 pub content: String,
21 pub hash: String,
22 pub line_count: usize,
23 pub original_tokens: usize,
24 pub read_count: u32,
25 pub path: String,
26 pub last_access: Instant,
27}
28
29impl CacheEntry {
30 pub fn eviction_score(&self, now: Instant) -> f64 {
33 let elapsed = now.duration_since(self.last_access).as_secs_f64();
34 let recency = 1.0 / (1.0 + elapsed.sqrt());
35 let frequency = (self.read_count as f64 + 1.0).ln();
36 let size_value = (self.original_tokens as f64 + 1.0).ln();
37 recency * 0.4 + frequency * 0.3 + size_value * 0.3
38 }
39}
40
41#[derive(Debug)]
42pub struct CacheStats {
43 pub total_reads: u64,
44 pub cache_hits: u64,
45 pub total_original_tokens: u64,
46 pub total_sent_tokens: u64,
47 pub files_tracked: usize,
48}
49
50impl CacheStats {
51 pub fn hit_rate(&self) -> f64 {
52 if self.total_reads == 0 {
53 return 0.0;
54 }
55 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
56 }
57
58 pub fn tokens_saved(&self) -> u64 {
59 self.total_original_tokens
60 .saturating_sub(self.total_sent_tokens)
61 }
62
63 pub fn savings_percent(&self) -> f64 {
64 if self.total_original_tokens == 0 {
65 return 0.0;
66 }
67 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
68 }
69}
70
71#[derive(Clone, Debug)]
73pub struct SharedBlock {
74 pub canonical_path: String,
75 pub canonical_ref: String,
76 pub start_line: usize,
77 pub end_line: usize,
78 pub content: String,
79}
80
81pub struct SessionCache {
82 entries: HashMap<String, CacheEntry>,
83 file_refs: HashMap<String, String>,
84 next_ref: usize,
85 stats: CacheStats,
86 shared_blocks: Vec<SharedBlock>,
87}
88
89impl Default for SessionCache {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl SessionCache {
96 pub fn new() -> Self {
97 Self {
98 entries: HashMap::new(),
99 file_refs: HashMap::new(),
100 next_ref: 1,
101 shared_blocks: Vec::new(),
102 stats: CacheStats {
103 total_reads: 0,
104 cache_hits: 0,
105 total_original_tokens: 0,
106 total_sent_tokens: 0,
107 files_tracked: 0,
108 },
109 }
110 }
111
112 pub fn get_file_ref(&mut self, path: &str) -> String {
113 let key = normalize_key(path);
114 if let Some(r) = self.file_refs.get(&key) {
115 return r.clone();
116 }
117 let r = format!("F{}", self.next_ref);
118 self.next_ref += 1;
119 self.file_refs.insert(key, r.clone());
120 r
121 }
122
123 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
124 self.file_refs.get(&normalize_key(path)).cloned()
125 }
126
127 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
128 self.entries.get(&normalize_key(path))
129 }
130
131 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
132 let key = normalize_key(path);
133 let ref_label = self
134 .file_refs
135 .get(&key)
136 .cloned()
137 .unwrap_or_else(|| "F?".to_string());
138 if let Some(entry) = self.entries.get_mut(&key) {
139 entry.read_count += 1;
140 entry.last_access = Instant::now();
141 self.stats.total_reads += 1;
142 self.stats.cache_hits += 1;
143 self.stats.total_original_tokens += entry.original_tokens as u64;
144 let hit_msg = format!(
145 "{ref_label} cached {}t {}L",
146 entry.read_count, entry.line_count
147 );
148 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
149 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
150 Some(entry)
151 } else {
152 None
153 }
154 }
155
156 pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
157 let key = normalize_key(path);
158 let hash = compute_md5(&content);
159 let line_count = content.lines().count();
160 let original_tokens = count_tokens(&content);
161 let now = Instant::now();
162
163 self.stats.total_reads += 1;
164 self.stats.total_original_tokens += original_tokens as u64;
165
166 if let Some(existing) = self.entries.get_mut(&key) {
167 existing.last_access = now;
168 if existing.hash == hash {
169 existing.read_count += 1;
170 self.stats.cache_hits += 1;
171 let hit_msg = format!(
172 "{} cached {}t {}L",
173 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
174 existing.read_count,
175 existing.line_count,
176 );
177 let sent = count_tokens(&hit_msg) as u64;
178 self.stats.total_sent_tokens += sent;
179 return (existing.clone(), true);
180 }
181 existing.content = content;
182 existing.hash = hash.clone();
183 existing.line_count = line_count;
184 existing.original_tokens = original_tokens;
185 existing.read_count += 1;
186 self.stats.total_sent_tokens += original_tokens as u64;
187 return (existing.clone(), false);
188 }
189
190 self.evict_if_needed(original_tokens);
191 self.get_file_ref(&key);
192
193 let entry = CacheEntry {
194 content,
195 hash,
196 line_count,
197 original_tokens,
198 read_count: 1,
199 path: key.clone(),
200 last_access: now,
201 };
202
203 self.entries.insert(key, entry.clone());
204 self.stats.files_tracked += 1;
205 self.stats.total_sent_tokens += original_tokens as u64;
206 (entry, false)
207 }
208
209 pub fn total_cached_tokens(&self) -> usize {
210 self.entries.values().map(|e| e.original_tokens).sum()
211 }
212
213 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
215 let max_tokens = max_cache_tokens();
216 let current = self.total_cached_tokens();
217 if current + incoming_tokens <= max_tokens {
218 return;
219 }
220
221 let now = Instant::now();
222 let mut scored: Vec<(String, f64)> = self
223 .entries
224 .iter()
225 .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
226 .collect();
227 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
228
229 let mut freed = 0usize;
230 let target = (current + incoming_tokens).saturating_sub(max_tokens);
231 for (path, _score) in &scored {
232 if freed >= target {
233 break;
234 }
235 if let Some(entry) = self.entries.remove(path) {
236 freed += entry.original_tokens;
237 self.file_refs.remove(path);
238 }
239 }
240 }
241
242 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
243 self.entries.iter().collect()
244 }
245
246 pub fn get_stats(&self) -> &CacheStats {
247 &self.stats
248 }
249
250 pub fn file_ref_map(&self) -> &HashMap<String, String> {
251 &self.file_refs
252 }
253
254 #[allow(dead_code)]
255 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
256 self.shared_blocks = blocks;
257 }
258
259 #[allow(dead_code)]
260 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
261 &self.shared_blocks
262 }
263
264 #[allow(dead_code)]
266 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
267 if self.shared_blocks.is_empty() {
268 return None;
269 }
270 let refs: Vec<&SharedBlock> = self
271 .shared_blocks
272 .iter()
273 .filter(|b| b.canonical_path != path && content.contains(&b.content))
274 .collect();
275 if refs.is_empty() {
276 return None;
277 }
278 let mut result = content.to_string();
279 for block in refs {
280 result = result.replacen(
281 &block.content,
282 &format!(
283 "[= {}:{}-{}]",
284 block.canonical_ref, block.start_line, block.end_line
285 ),
286 1,
287 );
288 }
289 Some(result)
290 }
291
292 pub fn invalidate(&mut self, path: &str) -> bool {
293 self.entries.remove(&normalize_key(path)).is_some()
294 }
295
296 pub fn clear(&mut self) -> usize {
297 let count = self.entries.len();
298 self.entries.clear();
299 self.file_refs.clear();
300 self.shared_blocks.clear();
301 self.next_ref = 1;
302 self.stats = CacheStats {
303 total_reads: 0,
304 cache_hits: 0,
305 total_original_tokens: 0,
306 total_sent_tokens: 0,
307 files_tracked: 0,
308 };
309 count
310 }
311}
312
313fn compute_md5(content: &str) -> String {
314 let mut hasher = Md5::new();
315 hasher.update(content.as_bytes());
316 format!("{:x}", hasher.finalize())
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn cache_stores_and_retrieves() {
325 let mut cache = SessionCache::new();
326 let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
327 assert!(!was_hit);
328 assert_eq!(entry.line_count, 1);
329 assert!(cache.get("/test/file.rs").is_some());
330 }
331
332 #[test]
333 fn cache_hit_on_same_content() {
334 let mut cache = SessionCache::new();
335 cache.store("/test/file.rs", "content".to_string());
336 let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
337 assert!(was_hit, "same content should be a cache hit");
338 }
339
340 #[test]
341 fn cache_miss_on_changed_content() {
342 let mut cache = SessionCache::new();
343 cache.store("/test/file.rs", "old content".to_string());
344 let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
345 assert!(!was_hit, "changed content should not be a cache hit");
346 }
347
348 #[test]
349 fn file_refs_are_sequential() {
350 let mut cache = SessionCache::new();
351 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
352 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
353 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
355
356 #[test]
357 fn cache_clear_resets_everything() {
358 let mut cache = SessionCache::new();
359 cache.store("/a.rs", "a".to_string());
360 cache.store("/b.rs", "b".to_string());
361 let count = cache.clear();
362 assert_eq!(count, 2);
363 assert!(cache.get("/a.rs").is_none());
364 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
366
367 #[test]
368 fn cache_invalidate_removes_entry() {
369 let mut cache = SessionCache::new();
370 cache.store("/test.rs", "test".to_string());
371 assert!(cache.invalidate("/test.rs"));
372 assert!(!cache.invalidate("/nonexistent.rs"));
373 }
374
375 #[test]
376 fn cache_stats_track_correctly() {
377 let mut cache = SessionCache::new();
378 cache.store("/a.rs", "hello".to_string());
379 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
381 assert_eq!(stats.total_reads, 2);
382 assert_eq!(stats.cache_hits, 1);
383 assert!(stats.hit_rate() > 0.0);
384 }
385
386 #[test]
387 fn md5_is_deterministic() {
388 let h1 = compute_md5("test content");
389 let h2 = compute_md5("test content");
390 assert_eq!(h1, h2);
391 assert_ne!(h1, compute_md5("different"));
392 }
393
394 #[test]
395 fn eviction_score_prefers_recent() {
396 let now = Instant::now();
397 let recent = CacheEntry {
398 content: "a".to_string(),
399 hash: "h1".to_string(),
400 line_count: 1,
401 original_tokens: 10,
402 read_count: 1,
403 path: "/a.rs".to_string(),
404 last_access: now,
405 };
406 let old = CacheEntry {
407 content: "b".to_string(),
408 hash: "h2".to_string(),
409 line_count: 1,
410 original_tokens: 10,
411 read_count: 1,
412 path: "/b.rs".to_string(),
413 last_access: now - std::time::Duration::from_secs(300),
414 };
415 assert!(
416 recent.eviction_score(now) > old.eviction_score(now),
417 "recently accessed entries should score higher"
418 );
419 }
420
421 #[test]
422 fn eviction_score_prefers_frequent() {
423 let now = Instant::now();
424 let frequent = CacheEntry {
425 content: "a".to_string(),
426 hash: "h1".to_string(),
427 line_count: 1,
428 original_tokens: 10,
429 read_count: 20,
430 path: "/a.rs".to_string(),
431 last_access: now,
432 };
433 let rare = CacheEntry {
434 content: "b".to_string(),
435 hash: "h2".to_string(),
436 line_count: 1,
437 original_tokens: 10,
438 read_count: 1,
439 path: "/b.rs".to_string(),
440 last_access: now,
441 };
442 assert!(
443 frequent.eviction_score(now) > rare.eviction_score(now),
444 "frequently accessed entries should score higher"
445 );
446 }
447
448 #[test]
449 fn evict_if_needed_removes_lowest_score() {
450 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
451 let mut cache = SessionCache::new();
452 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
454 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
458 assert!(
463 cache.total_cached_tokens() <= 60,
464 "eviction should have kicked in"
465 );
466 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
467 }
468}