1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19#[allow(dead_code)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28}
29
30impl CacheEntry {
31 pub fn eviction_score(&self, now: Instant) -> f64 {
34 let elapsed = now.duration_since(self.last_access).as_secs_f64();
35 let recency = 1.0 / (1.0 + elapsed.sqrt());
36 let frequency = (self.read_count as f64 + 1.0).ln();
37 let size_value = (self.original_tokens as f64 + 1.0).ln();
38 recency * 0.4 + frequency * 0.3 + size_value * 0.3
39 }
40}
41
42#[derive(Debug)]
43pub struct CacheStats {
44 pub total_reads: u64,
45 pub cache_hits: u64,
46 pub total_original_tokens: u64,
47 pub total_sent_tokens: u64,
48 pub files_tracked: usize,
49}
50
51#[allow(dead_code)]
52impl CacheStats {
53 pub fn hit_rate(&self) -> f64 {
54 if self.total_reads == 0 {
55 return 0.0;
56 }
57 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
58 }
59
60 pub fn tokens_saved(&self) -> u64 {
61 self.total_original_tokens
62 .saturating_sub(self.total_sent_tokens)
63 }
64
65 pub fn savings_percent(&self) -> f64 {
66 if self.total_original_tokens == 0 {
67 return 0.0;
68 }
69 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
70 }
71}
72
73#[derive(Clone, Debug)]
75pub struct SharedBlock {
76 pub canonical_path: String,
77 pub canonical_ref: String,
78 pub start_line: usize,
79 pub end_line: usize,
80 pub content: String,
81}
82
83pub struct SessionCache {
84 entries: HashMap<String, CacheEntry>,
85 file_refs: HashMap<String, String>,
86 next_ref: usize,
87 stats: CacheStats,
88 shared_blocks: Vec<SharedBlock>,
89}
90
91impl Default for SessionCache {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl SessionCache {
98 pub fn new() -> Self {
99 Self {
100 entries: HashMap::new(),
101 file_refs: HashMap::new(),
102 next_ref: 1,
103 shared_blocks: Vec::new(),
104 stats: CacheStats {
105 total_reads: 0,
106 cache_hits: 0,
107 total_original_tokens: 0,
108 total_sent_tokens: 0,
109 files_tracked: 0,
110 },
111 }
112 }
113
114 pub fn get_file_ref(&mut self, path: &str) -> String {
115 let key = normalize_key(path);
116 if let Some(r) = self.file_refs.get(&key) {
117 return r.clone();
118 }
119 let r = format!("F{}", self.next_ref);
120 self.next_ref += 1;
121 self.file_refs.insert(key, r.clone());
122 r
123 }
124
125 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
126 self.file_refs.get(&normalize_key(path)).cloned()
127 }
128
129 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
130 self.entries.get(&normalize_key(path))
131 }
132
133 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
134 let key = normalize_key(path);
135 let ref_label = self
136 .file_refs
137 .get(&key)
138 .cloned()
139 .unwrap_or_else(|| "F?".to_string());
140 if let Some(entry) = self.entries.get_mut(&key) {
141 entry.read_count += 1;
142 entry.last_access = Instant::now();
143 self.stats.total_reads += 1;
144 self.stats.cache_hits += 1;
145 self.stats.total_original_tokens += entry.original_tokens as u64;
146 let hit_msg = format!(
147 "{ref_label} cached {}t {}L",
148 entry.read_count, entry.line_count
149 );
150 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
151 crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
152 Some(entry)
153 } else {
154 None
155 }
156 }
157
158 pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
159 let key = normalize_key(path);
160 let hash = compute_md5(&content);
161 let line_count = content.lines().count();
162 let original_tokens = count_tokens(&content);
163 let now = Instant::now();
164
165 self.stats.total_reads += 1;
166 self.stats.total_original_tokens += original_tokens as u64;
167
168 if let Some(existing) = self.entries.get_mut(&key) {
169 existing.last_access = now;
170 if existing.hash == hash {
171 existing.read_count += 1;
172 self.stats.cache_hits += 1;
173 let hit_msg = format!(
174 "{} cached {}t {}L",
175 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
176 existing.read_count,
177 existing.line_count,
178 );
179 let sent = count_tokens(&hit_msg) as u64;
180 self.stats.total_sent_tokens += sent;
181 return (existing.clone(), true);
182 }
183 existing.content = content;
184 existing.hash = hash.clone();
185 existing.line_count = line_count;
186 existing.original_tokens = original_tokens;
187 existing.read_count += 1;
188 self.stats.total_sent_tokens += original_tokens as u64;
189 return (existing.clone(), false);
190 }
191
192 self.evict_if_needed(original_tokens);
193 self.get_file_ref(&key);
194
195 let entry = CacheEntry {
196 content,
197 hash,
198 line_count,
199 original_tokens,
200 read_count: 1,
201 path: key.clone(),
202 last_access: now,
203 };
204
205 self.entries.insert(key, entry.clone());
206 self.stats.files_tracked += 1;
207 self.stats.total_sent_tokens += original_tokens as u64;
208 (entry, false)
209 }
210
211 pub fn total_cached_tokens(&self) -> usize {
212 self.entries.values().map(|e| e.original_tokens).sum()
213 }
214
215 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
217 let max_tokens = max_cache_tokens();
218 let current = self.total_cached_tokens();
219 if current + incoming_tokens <= max_tokens {
220 return;
221 }
222
223 let now = Instant::now();
224 let mut scored: Vec<(String, f64)> = self
225 .entries
226 .iter()
227 .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
228 .collect();
229 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
230
231 let mut freed = 0usize;
232 let target = (current + incoming_tokens).saturating_sub(max_tokens);
233 for (path, _score) in &scored {
234 if freed >= target {
235 break;
236 }
237 if let Some(entry) = self.entries.remove(path) {
238 freed += entry.original_tokens;
239 self.file_refs.remove(path);
240 }
241 }
242 }
243
244 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
245 self.entries.iter().collect()
246 }
247
248 pub fn get_stats(&self) -> &CacheStats {
249 &self.stats
250 }
251
252 pub fn file_ref_map(&self) -> &HashMap<String, String> {
253 &self.file_refs
254 }
255
256 #[allow(dead_code)]
257 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
258 self.shared_blocks = blocks;
259 }
260
261 #[allow(dead_code)]
262 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
263 &self.shared_blocks
264 }
265
266 #[allow(dead_code)]
268 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
269 if self.shared_blocks.is_empty() {
270 return None;
271 }
272 let refs: Vec<&SharedBlock> = self
273 .shared_blocks
274 .iter()
275 .filter(|b| b.canonical_path != path && content.contains(&b.content))
276 .collect();
277 if refs.is_empty() {
278 return None;
279 }
280 let mut result = content.to_string();
281 for block in refs {
282 result = result.replacen(
283 &block.content,
284 &format!(
285 "[= {}:{}-{}]",
286 block.canonical_ref, block.start_line, block.end_line
287 ),
288 1,
289 );
290 }
291 Some(result)
292 }
293
294 pub fn invalidate(&mut self, path: &str) -> bool {
295 self.entries.remove(&normalize_key(path)).is_some()
296 }
297
298 pub fn clear(&mut self) -> usize {
299 let count = self.entries.len();
300 self.entries.clear();
301 self.file_refs.clear();
302 self.shared_blocks.clear();
303 self.next_ref = 1;
304 self.stats = CacheStats {
305 total_reads: 0,
306 cache_hits: 0,
307 total_original_tokens: 0,
308 total_sent_tokens: 0,
309 files_tracked: 0,
310 };
311 count
312 }
313}
314
315fn compute_md5(content: &str) -> String {
316 let mut hasher = Md5::new();
317 hasher.update(content.as_bytes());
318 format!("{:x}", hasher.finalize())
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn cache_stores_and_retrieves() {
327 let mut cache = SessionCache::new();
328 let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
329 assert!(!was_hit);
330 assert_eq!(entry.line_count, 1);
331 assert!(cache.get("/test/file.rs").is_some());
332 }
333
334 #[test]
335 fn cache_hit_on_same_content() {
336 let mut cache = SessionCache::new();
337 cache.store("/test/file.rs", "content".to_string());
338 let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
339 assert!(was_hit, "same content should be a cache hit");
340 }
341
342 #[test]
343 fn cache_miss_on_changed_content() {
344 let mut cache = SessionCache::new();
345 cache.store("/test/file.rs", "old content".to_string());
346 let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
347 assert!(!was_hit, "changed content should not be a cache hit");
348 }
349
350 #[test]
351 fn file_refs_are_sequential() {
352 let mut cache = SessionCache::new();
353 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
354 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
355 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
357
358 #[test]
359 fn cache_clear_resets_everything() {
360 let mut cache = SessionCache::new();
361 cache.store("/a.rs", "a".to_string());
362 cache.store("/b.rs", "b".to_string());
363 let count = cache.clear();
364 assert_eq!(count, 2);
365 assert!(cache.get("/a.rs").is_none());
366 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
368
369 #[test]
370 fn cache_invalidate_removes_entry() {
371 let mut cache = SessionCache::new();
372 cache.store("/test.rs", "test".to_string());
373 assert!(cache.invalidate("/test.rs"));
374 assert!(!cache.invalidate("/nonexistent.rs"));
375 }
376
377 #[test]
378 fn cache_stats_track_correctly() {
379 let mut cache = SessionCache::new();
380 cache.store("/a.rs", "hello".to_string());
381 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
383 assert_eq!(stats.total_reads, 2);
384 assert_eq!(stats.cache_hits, 1);
385 assert!(stats.hit_rate() > 0.0);
386 }
387
388 #[test]
389 fn md5_is_deterministic() {
390 let h1 = compute_md5("test content");
391 let h2 = compute_md5("test content");
392 assert_eq!(h1, h2);
393 assert_ne!(h1, compute_md5("different"));
394 }
395
396 #[test]
397 fn eviction_score_prefers_recent() {
398 let now = Instant::now();
399 let recent = CacheEntry {
400 content: "a".to_string(),
401 hash: "h1".to_string(),
402 line_count: 1,
403 original_tokens: 10,
404 read_count: 1,
405 path: "/a.rs".to_string(),
406 last_access: now,
407 };
408 let old = CacheEntry {
409 content: "b".to_string(),
410 hash: "h2".to_string(),
411 line_count: 1,
412 original_tokens: 10,
413 read_count: 1,
414 path: "/b.rs".to_string(),
415 last_access: now - std::time::Duration::from_secs(300),
416 };
417 assert!(
418 recent.eviction_score(now) > old.eviction_score(now),
419 "recently accessed entries should score higher"
420 );
421 }
422
423 #[test]
424 fn eviction_score_prefers_frequent() {
425 let now = Instant::now();
426 let frequent = CacheEntry {
427 content: "a".to_string(),
428 hash: "h1".to_string(),
429 line_count: 1,
430 original_tokens: 10,
431 read_count: 20,
432 path: "/a.rs".to_string(),
433 last_access: now,
434 };
435 let rare = CacheEntry {
436 content: "b".to_string(),
437 hash: "h2".to_string(),
438 line_count: 1,
439 original_tokens: 10,
440 read_count: 1,
441 path: "/b.rs".to_string(),
442 last_access: now,
443 };
444 assert!(
445 frequent.eviction_score(now) > rare.eviction_score(now),
446 "frequently accessed entries should score higher"
447 );
448 }
449
450 #[test]
451 fn evict_if_needed_removes_lowest_score() {
452 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
453 let mut cache = SessionCache::new();
454 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
456 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
460 assert!(
465 cache.total_cached_tokens() <= 60,
466 "eviction should have kicked in"
467 );
468 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
469 }
470}