1use md5::{Digest, Md5};
2use std::collections::HashMap;
3use std::time::Instant;
4
5use super::tokens::count_tokens;
6
7fn normalize_key(path: &str) -> String {
8 crate::hooks::normalize_tool_path(path)
9}
10
11fn max_cache_tokens() -> usize {
12 std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
13 .ok()
14 .and_then(|v| v.parse().ok())
15 .unwrap_or(500_000)
16}
17
18#[derive(Clone, Debug)]
19#[allow(dead_code)]
20pub struct CacheEntry {
21 pub content: String,
22 pub hash: String,
23 pub line_count: usize,
24 pub original_tokens: usize,
25 pub read_count: u32,
26 pub path: String,
27 pub last_access: Instant,
28}
29
30impl CacheEntry {
31 pub fn eviction_score(&self, now: Instant) -> f64 {
34 let elapsed = now.duration_since(self.last_access).as_secs_f64();
35 let recency = 1.0 / (1.0 + elapsed.sqrt());
36 let frequency = (self.read_count as f64 + 1.0).ln();
37 let size_value = (self.original_tokens as f64 + 1.0).ln();
38 recency * 0.4 + frequency * 0.3 + size_value * 0.3
39 }
40}
41
42#[derive(Debug)]
43pub struct CacheStats {
44 pub total_reads: u64,
45 pub cache_hits: u64,
46 pub total_original_tokens: u64,
47 pub total_sent_tokens: u64,
48 pub files_tracked: usize,
49}
50
51#[allow(dead_code)]
52impl CacheStats {
53 pub fn hit_rate(&self) -> f64 {
54 if self.total_reads == 0 {
55 return 0.0;
56 }
57 (self.cache_hits as f64 / self.total_reads as f64) * 100.0
58 }
59
60 pub fn tokens_saved(&self) -> u64 {
61 self.total_original_tokens
62 .saturating_sub(self.total_sent_tokens)
63 }
64
65 pub fn savings_percent(&self) -> f64 {
66 if self.total_original_tokens == 0 {
67 return 0.0;
68 }
69 (self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
70 }
71}
72
73#[derive(Clone, Debug)]
75pub struct SharedBlock {
76 pub canonical_path: String,
77 pub canonical_ref: String,
78 pub start_line: usize,
79 pub end_line: usize,
80 pub content: String,
81}
82
83pub struct SessionCache {
84 entries: HashMap<String, CacheEntry>,
85 file_refs: HashMap<String, String>,
86 next_ref: usize,
87 stats: CacheStats,
88 shared_blocks: Vec<SharedBlock>,
89}
90
91impl Default for SessionCache {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl SessionCache {
98 pub fn new() -> Self {
99 Self {
100 entries: HashMap::new(),
101 file_refs: HashMap::new(),
102 next_ref: 1,
103 shared_blocks: Vec::new(),
104 stats: CacheStats {
105 total_reads: 0,
106 cache_hits: 0,
107 total_original_tokens: 0,
108 total_sent_tokens: 0,
109 files_tracked: 0,
110 },
111 }
112 }
113
114 pub fn get_file_ref(&mut self, path: &str) -> String {
115 let key = normalize_key(path);
116 if let Some(r) = self.file_refs.get(&key) {
117 return r.clone();
118 }
119 let r = format!("F{}", self.next_ref);
120 self.next_ref += 1;
121 self.file_refs.insert(key, r.clone());
122 r
123 }
124
125 pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
126 self.file_refs.get(&normalize_key(path)).cloned()
127 }
128
129 pub fn get(&self, path: &str) -> Option<&CacheEntry> {
130 self.entries.get(&normalize_key(path))
131 }
132
133 pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
134 let key = normalize_key(path);
135 let ref_label = self
136 .file_refs
137 .get(&key)
138 .cloned()
139 .unwrap_or_else(|| "F?".to_string());
140 if let Some(entry) = self.entries.get_mut(&key) {
141 entry.read_count += 1;
142 entry.last_access = Instant::now();
143 self.stats.total_reads += 1;
144 self.stats.cache_hits += 1;
145 self.stats.total_original_tokens += entry.original_tokens as u64;
146 let hit_msg = format!(
147 "{ref_label} cached {}t {}L",
148 entry.read_count, entry.line_count
149 );
150 self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
151 Some(entry)
152 } else {
153 None
154 }
155 }
156
157 pub fn store(&mut self, path: &str, content: String) -> (CacheEntry, bool) {
158 let key = normalize_key(path);
159 let hash = compute_md5(&content);
160 let line_count = content.lines().count();
161 let original_tokens = count_tokens(&content);
162 let now = Instant::now();
163
164 self.stats.total_reads += 1;
165 self.stats.total_original_tokens += original_tokens as u64;
166
167 if let Some(existing) = self.entries.get_mut(&key) {
168 existing.last_access = now;
169 if existing.hash == hash {
170 existing.read_count += 1;
171 self.stats.cache_hits += 1;
172 let hit_msg = format!(
173 "{} cached {}t {}L",
174 self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
175 existing.read_count,
176 existing.line_count,
177 );
178 let sent = count_tokens(&hit_msg) as u64;
179 self.stats.total_sent_tokens += sent;
180 return (existing.clone(), true);
181 }
182 existing.content = content;
183 existing.hash = hash.clone();
184 existing.line_count = line_count;
185 existing.original_tokens = original_tokens;
186 existing.read_count += 1;
187 self.stats.total_sent_tokens += original_tokens as u64;
188 return (existing.clone(), false);
189 }
190
191 self.evict_if_needed(original_tokens);
192 self.get_file_ref(&key);
193
194 let entry = CacheEntry {
195 content,
196 hash,
197 line_count,
198 original_tokens,
199 read_count: 1,
200 path: key.clone(),
201 last_access: now,
202 };
203
204 self.entries.insert(key, entry.clone());
205 self.stats.files_tracked += 1;
206 self.stats.total_sent_tokens += original_tokens as u64;
207 (entry, false)
208 }
209
210 pub fn total_cached_tokens(&self) -> usize {
211 self.entries.values().map(|e| e.original_tokens).sum()
212 }
213
214 pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
216 let max_tokens = max_cache_tokens();
217 let current = self.total_cached_tokens();
218 if current + incoming_tokens <= max_tokens {
219 return;
220 }
221
222 let now = Instant::now();
223 let mut scored: Vec<(String, f64)> = self
224 .entries
225 .iter()
226 .map(|(path, entry)| (path.clone(), entry.eviction_score(now)))
227 .collect();
228 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
229
230 let mut freed = 0usize;
231 let target = (current + incoming_tokens).saturating_sub(max_tokens);
232 for (path, _score) in &scored {
233 if freed >= target {
234 break;
235 }
236 if let Some(entry) = self.entries.remove(path) {
237 freed += entry.original_tokens;
238 self.file_refs.remove(path);
239 }
240 }
241 }
242
243 pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
244 self.entries.iter().collect()
245 }
246
247 pub fn get_stats(&self) -> &CacheStats {
248 &self.stats
249 }
250
251 pub fn file_ref_map(&self) -> &HashMap<String, String> {
252 &self.file_refs
253 }
254
255 #[allow(dead_code)]
256 pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
257 self.shared_blocks = blocks;
258 }
259
260 #[allow(dead_code)]
261 pub fn get_shared_blocks(&self) -> &[SharedBlock] {
262 &self.shared_blocks
263 }
264
265 #[allow(dead_code)]
267 pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
268 if self.shared_blocks.is_empty() {
269 return None;
270 }
271 let refs: Vec<&SharedBlock> = self
272 .shared_blocks
273 .iter()
274 .filter(|b| b.canonical_path != path && content.contains(&b.content))
275 .collect();
276 if refs.is_empty() {
277 return None;
278 }
279 let mut result = content.to_string();
280 for block in refs {
281 result = result.replacen(
282 &block.content,
283 &format!(
284 "[= {}:{}-{}]",
285 block.canonical_ref, block.start_line, block.end_line
286 ),
287 1,
288 );
289 }
290 Some(result)
291 }
292
293 pub fn invalidate(&mut self, path: &str) -> bool {
294 self.entries.remove(&normalize_key(path)).is_some()
295 }
296
297 pub fn clear(&mut self) -> usize {
298 let count = self.entries.len();
299 self.entries.clear();
300 self.file_refs.clear();
301 self.shared_blocks.clear();
302 self.next_ref = 1;
303 self.stats = CacheStats {
304 total_reads: 0,
305 cache_hits: 0,
306 total_original_tokens: 0,
307 total_sent_tokens: 0,
308 files_tracked: 0,
309 };
310 count
311 }
312}
313
314fn compute_md5(content: &str) -> String {
315 let mut hasher = Md5::new();
316 hasher.update(content.as_bytes());
317 format!("{:x}", hasher.finalize())
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn cache_stores_and_retrieves() {
326 let mut cache = SessionCache::new();
327 let (entry, was_hit) = cache.store("/test/file.rs", "fn main() {}".to_string());
328 assert!(!was_hit);
329 assert_eq!(entry.line_count, 1);
330 assert!(cache.get("/test/file.rs").is_some());
331 }
332
333 #[test]
334 fn cache_hit_on_same_content() {
335 let mut cache = SessionCache::new();
336 cache.store("/test/file.rs", "content".to_string());
337 let (_, was_hit) = cache.store("/test/file.rs", "content".to_string());
338 assert!(was_hit, "same content should be a cache hit");
339 }
340
341 #[test]
342 fn cache_miss_on_changed_content() {
343 let mut cache = SessionCache::new();
344 cache.store("/test/file.rs", "old content".to_string());
345 let (_, was_hit) = cache.store("/test/file.rs", "new content".to_string());
346 assert!(!was_hit, "changed content should not be a cache hit");
347 }
348
349 #[test]
350 fn file_refs_are_sequential() {
351 let mut cache = SessionCache::new();
352 assert_eq!(cache.get_file_ref("/a.rs"), "F1");
353 assert_eq!(cache.get_file_ref("/b.rs"), "F2");
354 assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
356
357 #[test]
358 fn cache_clear_resets_everything() {
359 let mut cache = SessionCache::new();
360 cache.store("/a.rs", "a".to_string());
361 cache.store("/b.rs", "b".to_string());
362 let count = cache.clear();
363 assert_eq!(count, 2);
364 assert!(cache.get("/a.rs").is_none());
365 assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
367
368 #[test]
369 fn cache_invalidate_removes_entry() {
370 let mut cache = SessionCache::new();
371 cache.store("/test.rs", "test".to_string());
372 assert!(cache.invalidate("/test.rs"));
373 assert!(!cache.invalidate("/nonexistent.rs"));
374 }
375
376 #[test]
377 fn cache_stats_track_correctly() {
378 let mut cache = SessionCache::new();
379 cache.store("/a.rs", "hello".to_string());
380 cache.store("/a.rs", "hello".to_string()); let stats = cache.get_stats();
382 assert_eq!(stats.total_reads, 2);
383 assert_eq!(stats.cache_hits, 1);
384 assert!(stats.hit_rate() > 0.0);
385 }
386
387 #[test]
388 fn md5_is_deterministic() {
389 let h1 = compute_md5("test content");
390 let h2 = compute_md5("test content");
391 assert_eq!(h1, h2);
392 assert_ne!(h1, compute_md5("different"));
393 }
394
395 #[test]
396 fn eviction_score_prefers_recent() {
397 let now = Instant::now();
398 let recent = CacheEntry {
399 content: "a".to_string(),
400 hash: "h1".to_string(),
401 line_count: 1,
402 original_tokens: 10,
403 read_count: 1,
404 path: "/a.rs".to_string(),
405 last_access: now,
406 };
407 let old = CacheEntry {
408 content: "b".to_string(),
409 hash: "h2".to_string(),
410 line_count: 1,
411 original_tokens: 10,
412 read_count: 1,
413 path: "/b.rs".to_string(),
414 last_access: now - std::time::Duration::from_secs(300),
415 };
416 assert!(
417 recent.eviction_score(now) > old.eviction_score(now),
418 "recently accessed entries should score higher"
419 );
420 }
421
422 #[test]
423 fn eviction_score_prefers_frequent() {
424 let now = Instant::now();
425 let frequent = CacheEntry {
426 content: "a".to_string(),
427 hash: "h1".to_string(),
428 line_count: 1,
429 original_tokens: 10,
430 read_count: 20,
431 path: "/a.rs".to_string(),
432 last_access: now,
433 };
434 let rare = CacheEntry {
435 content: "b".to_string(),
436 hash: "h2".to_string(),
437 line_count: 1,
438 original_tokens: 10,
439 read_count: 1,
440 path: "/b.rs".to_string(),
441 last_access: now,
442 };
443 assert!(
444 frequent.eviction_score(now) > rare.eviction_score(now),
445 "frequently accessed entries should score higher"
446 );
447 }
448
449 #[test]
450 fn evict_if_needed_removes_lowest_score() {
451 std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
452 let mut cache = SessionCache::new();
453 let big_content = "a]".repeat(30); cache.store("/old.rs", big_content);
455 let new_content = "b ".repeat(30); cache.store("/new.rs", new_content);
459 assert!(
464 cache.total_cached_tokens() <= 60,
465 "eviction should have kicked in"
466 );
467 std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
468 }
469}