use md5::{Digest, Md5};
use std::collections::HashMap;
use std::time::{Instant, SystemTime};
use super::tokens::count_tokens;
fn normalize_key(path: &str) -> String {
crate::core::pathutil::normalize_tool_path(path)
}
fn max_cache_tokens() -> usize {
std::env::var("LEAN_CTX_CACHE_MAX_TOKENS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(500_000)
}
#[derive(Clone, Debug)]
pub struct CacheEntry {
compressed_content: Vec<u8>,
pub hash: String,
pub line_count: usize,
pub original_tokens: usize,
pub read_count: u32,
pub path: String,
pub last_access: Instant,
pub stored_mtime: Option<SystemTime>,
pub compressed_outputs: HashMap<String, String>,
pub full_content_delivered: bool,
}
const ZSTD_LEVEL: i32 = 3;
fn zstd_compress(data: &str) -> Vec<u8> {
zstd::encode_all(data.as_bytes(), ZSTD_LEVEL).unwrap_or_else(|_| data.as_bytes().to_vec())
}
fn zstd_decompress(data: &[u8]) -> String {
zstd::decode_all(data)
.ok()
.and_then(|v| String::from_utf8(v).ok())
.unwrap_or_default()
}
impl CacheEntry {
pub fn new(
content: &str,
hash: String,
line_count: usize,
original_tokens: usize,
path: String,
stored_mtime: Option<SystemTime>,
) -> Self {
let compressed_content = zstd_compress(content);
Self {
compressed_content,
hash,
line_count,
original_tokens,
read_count: 1,
path,
last_access: Instant::now(),
stored_mtime,
compressed_outputs: HashMap::new(),
full_content_delivered: false,
}
}
pub fn content(&self) -> String {
zstd_decompress(&self.compressed_content)
}
pub fn set_content(&mut self, content: &str) {
self.compressed_content = zstd_compress(content);
}
pub fn compressed_size(&self) -> usize {
self.compressed_content.len()
}
}
#[derive(Debug, Clone)]
pub struct StoreResult {
pub line_count: usize,
pub original_tokens: usize,
pub read_count: u32,
pub was_hit: bool,
pub full_content_delivered: bool,
}
impl CacheEntry {
pub fn eviction_score_legacy(&self, now: Instant) -> f64 {
let elapsed = now
.checked_duration_since(self.last_access)
.unwrap_or_default()
.as_secs_f64();
let recency = 1.0 / (1.0 + elapsed.sqrt());
let frequency = (self.read_count as f64 + 1.0).ln();
let size_value = (self.original_tokens as f64 + 1.0).ln();
recency * 0.4 + frequency * 0.3 + size_value * 0.3
}
pub fn get_compressed(&self, mode_key: &str) -> Option<&String> {
self.compressed_outputs.get(mode_key)
}
pub fn set_compressed(&mut self, mode_key: &str, output: String) {
const MAX_COMPRESSED_VARIANTS: usize = 3;
if self.compressed_outputs.len() >= MAX_COMPRESSED_VARIANTS
&& !self.compressed_outputs.contains_key(mode_key)
{
if let Some(oldest_key) = self.compressed_outputs.keys().next().cloned() {
self.compressed_outputs.remove(&oldest_key);
}
}
self.compressed_outputs.insert(mode_key.to_string(), output);
}
pub fn mark_full_delivered(&mut self) {
self.full_content_delivered = true;
}
}
const RRF_K: f64 = 60.0;
pub fn eviction_scores_rrf(entries: &[(&String, &CacheEntry)], now: Instant) -> Vec<(String, f64)> {
if entries.is_empty() {
return Vec::new();
}
let n = entries.len();
let mut recency_order: Vec<usize> = (0..n).collect();
recency_order.sort_by(|&a, &b| {
let elapsed_a = now
.checked_duration_since(entries[a].1.last_access)
.unwrap_or_default()
.as_secs_f64();
let elapsed_b = now
.checked_duration_since(entries[b].1.last_access)
.unwrap_or_default()
.as_secs_f64();
elapsed_a
.partial_cmp(&elapsed_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut frequency_order: Vec<usize> = (0..n).collect();
frequency_order.sort_by(|&a, &b| entries[b].1.read_count.cmp(&entries[a].1.read_count));
let mut size_order: Vec<usize> = (0..n).collect();
size_order.sort_by(|&a, &b| {
entries[b]
.1
.original_tokens
.cmp(&entries[a].1.original_tokens)
});
let mut recency_ranks = vec![0usize; n];
let mut frequency_ranks = vec![0usize; n];
let mut size_ranks = vec![0usize; n];
for (rank, &idx) in recency_order.iter().enumerate() {
recency_ranks[idx] = rank;
}
for (rank, &idx) in frequency_order.iter().enumerate() {
frequency_ranks[idx] = rank;
}
for (rank, &idx) in size_order.iter().enumerate() {
size_ranks[idx] = rank;
}
entries
.iter()
.enumerate()
.map(|(i, (path, _))| {
let score = 1.0 / (RRF_K + recency_ranks[i] as f64)
+ 1.0 / (RRF_K + frequency_ranks[i] as f64)
+ 1.0 / (RRF_K + size_ranks[i] as f64);
((*path).clone(), score)
})
.collect()
}
#[derive(Debug)]
pub struct CacheStats {
pub total_reads: u64,
pub cache_hits: u64,
pub total_original_tokens: u64,
pub total_sent_tokens: u64,
pub files_tracked: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
if self.total_reads == 0 {
return 0.0;
}
(self.cache_hits as f64 / self.total_reads as f64) * 100.0
}
pub fn tokens_saved(&self) -> u64 {
self.total_original_tokens
.saturating_sub(self.total_sent_tokens)
}
pub fn savings_percent(&self) -> f64 {
if self.total_original_tokens == 0 {
return 0.0;
}
(self.tokens_saved() as f64 / self.total_original_tokens as f64) * 100.0
}
}
#[derive(Clone, Debug)]
pub struct SharedBlock {
pub canonical_path: String,
pub canonical_ref: String,
pub start_line: usize,
pub end_line: usize,
pub content: String,
}
pub struct SessionCache {
entries: HashMap<String, CacheEntry>,
file_refs: HashMap<String, String>,
next_ref: usize,
stats: CacheStats,
shared_blocks: Vec<SharedBlock>,
}
impl Default for SessionCache {
fn default() -> Self {
Self::new()
}
}
impl SessionCache {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
file_refs: HashMap::new(),
next_ref: 1,
shared_blocks: Vec::new(),
stats: CacheStats {
total_reads: 0,
cache_hits: 0,
total_original_tokens: 0,
total_sent_tokens: 0,
files_tracked: 0,
},
}
}
pub fn get_file_ref(&mut self, path: &str) -> String {
let key = normalize_key(path);
if let Some(r) = self.file_refs.get(&key) {
return r.clone();
}
let r = format!("F{}", self.next_ref);
self.next_ref += 1;
self.file_refs.insert(key, r.clone());
r
}
pub fn get_file_ref_readonly(&self, path: &str) -> Option<String> {
self.file_refs.get(&normalize_key(path)).cloned()
}
pub fn get(&self, path: &str) -> Option<&CacheEntry> {
self.entries.get(&normalize_key(path))
}
pub fn get_full_content(&self, path: &str) -> Option<String> {
self.entries
.get(&normalize_key(path))
.map(CacheEntry::content)
}
pub fn record_cache_hit(&mut self, path: &str) -> Option<&CacheEntry> {
let key = normalize_key(path);
let ref_label = self
.file_refs
.get(&key)
.cloned()
.unwrap_or_else(|| "F?".to_string());
if let Some(entry) = self.entries.get_mut(&key) {
entry.read_count += 1;
entry.last_access = Instant::now();
self.stats.total_reads += 1;
self.stats.cache_hits += 1;
self.stats.total_original_tokens += entry.original_tokens as u64;
let hit_msg = format!(
"{ref_label} cached {}t {}L",
entry.read_count, entry.line_count
);
self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
crate::core::events::emit_cache_hit(path, entry.original_tokens as u64);
Some(entry)
} else {
None
}
}
pub fn store(&mut self, path: &str, content: &str) -> StoreResult {
let key = normalize_key(path);
let hash = compute_md5(content);
let line_count = content.lines().count();
let original_tokens = count_tokens(content);
let stored_mtime = std::fs::metadata(path).and_then(|m| m.modified()).ok();
let now = Instant::now();
self.stats.total_reads += 1;
self.stats.total_original_tokens += original_tokens as u64;
if let Some(existing) = self.entries.get_mut(&key) {
existing.last_access = now;
if stored_mtime.is_some() {
existing.stored_mtime = stored_mtime;
}
if existing.hash == hash {
existing.read_count += 1;
self.stats.cache_hits += 1;
let hit_msg = format!(
"{} cached {}t {}L",
self.file_refs.get(&key).unwrap_or(&"F?".to_string()),
existing.read_count,
existing.line_count,
);
self.stats.total_sent_tokens += count_tokens(&hit_msg) as u64;
return StoreResult {
line_count: existing.line_count,
original_tokens: existing.original_tokens,
read_count: existing.read_count,
was_hit: true,
full_content_delivered: existing.full_content_delivered,
};
}
existing.compressed_outputs.clear();
existing.set_content(content);
existing.hash = hash;
existing.line_count = line_count;
existing.original_tokens = original_tokens;
existing.read_count += 1;
existing.full_content_delivered = false;
if stored_mtime.is_some() {
existing.stored_mtime = stored_mtime;
}
self.stats.total_sent_tokens += original_tokens as u64;
return StoreResult {
line_count,
original_tokens,
read_count: existing.read_count,
was_hit: false,
full_content_delivered: false,
};
}
self.evict_if_needed(original_tokens);
self.get_file_ref(&key);
let entry = CacheEntry::new(
content,
hash,
line_count,
original_tokens,
key.clone(),
stored_mtime,
);
self.entries.insert(key, entry);
self.stats.files_tracked += 1;
self.stats.total_sent_tokens += original_tokens as u64;
StoreResult {
line_count,
original_tokens,
read_count: 1,
was_hit: false,
full_content_delivered: false,
}
}
pub fn total_cached_tokens(&self) -> usize {
self.entries.values().map(|e| e.original_tokens).sum()
}
pub fn evict_if_needed(&mut self, incoming_tokens: usize) {
let max_tokens = max_cache_tokens();
let current = self.total_cached_tokens();
if current + incoming_tokens <= max_tokens {
return;
}
let mut freed = 0usize;
let target = (current + incoming_tokens).saturating_sub(max_tokens);
let mut probationary: Vec<(String, Instant)> = self
.entries
.iter()
.filter(|(_, e)| e.read_count <= 1)
.map(|(p, e)| (p.clone(), e.last_access))
.collect();
probationary.sort_by_key(|(_, t)| *t);
let mut protected: Vec<(String, Instant)> = self
.entries
.iter()
.filter(|(_, e)| e.read_count > 1)
.map(|(p, e)| (p.clone(), e.last_access))
.collect();
protected.sort_by_key(|(_, t)| *t);
for (path, _) in probationary.into_iter().chain(protected) {
if freed >= target {
break;
}
if let Some(entry) = self.entries.remove(&path) {
freed += entry.original_tokens;
self.file_refs.remove(&path);
}
}
}
pub fn get_all_entries(&self) -> Vec<(&String, &CacheEntry)> {
self.entries.iter().collect()
}
pub fn get_stats(&self) -> &CacheStats {
&self.stats
}
pub fn file_ref_map(&self) -> &HashMap<String, String> {
&self.file_refs
}
pub fn set_shared_blocks(&mut self, blocks: Vec<SharedBlock>) {
self.shared_blocks = blocks;
}
pub fn get_shared_blocks(&self) -> &[SharedBlock] {
&self.shared_blocks
}
pub fn apply_dedup(&self, path: &str, content: &str) -> Option<String> {
if self.shared_blocks.is_empty() {
return None;
}
let refs: Vec<&SharedBlock> = self
.shared_blocks
.iter()
.filter(|b| b.canonical_path != path && content.contains(&b.content))
.collect();
if refs.is_empty() {
return None;
}
let mut result = content.to_string();
for block in refs {
result = result.replacen(
&block.content,
&format!(
"[= {}:{}-{}]",
block.canonical_ref, block.start_line, block.end_line
),
1,
);
}
Some(result)
}
pub fn invalidate(&mut self, path: &str) -> bool {
self.entries.remove(&normalize_key(path)).is_some()
}
pub fn get_compressed(&self, path: &str, mode_key: &str) -> Option<&String> {
self.entries
.get(&normalize_key(path))?
.get_compressed(mode_key)
}
pub fn mark_full_delivered(&mut self, path: &str) {
if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
entry.mark_full_delivered();
}
}
pub fn set_compressed(&mut self, path: &str, mode_key: &str, output: String) {
if let Some(entry) = self.entries.get_mut(&normalize_key(path)) {
entry.set_compressed(mode_key, output);
}
}
pub fn clear(&mut self) -> usize {
let count = self.entries.len();
self.entries.clear();
self.file_refs.clear();
self.shared_blocks.clear();
self.next_ref = 1;
self.stats = CacheStats {
total_reads: 0,
cache_hits: 0,
total_original_tokens: 0,
total_sent_tokens: 0,
files_tracked: 0,
};
count
}
}
pub fn file_mtime(path: &str) -> Option<SystemTime> {
std::fs::metadata(path).and_then(|m| m.modified()).ok()
}
pub fn is_cache_entry_stale(path: &str, cached_mtime: Option<SystemTime>) -> bool {
let current = file_mtime(path);
match (cached_mtime, current) {
(_, None) => false,
(None, Some(_)) => true,
(Some(cached), Some(current)) => current > cached,
}
}
fn compute_md5(content: &str) -> String {
let mut hasher = Md5::new();
hasher.update(content.as_bytes());
format!("{:x}", hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn cache_stores_and_retrieves() {
let mut cache = SessionCache::new();
let result = cache.store("/test/file.rs", "fn main() {}");
assert!(!result.was_hit);
assert_eq!(result.line_count, 1);
assert!(cache.get("/test/file.rs").is_some());
}
#[test]
fn cache_hit_on_same_content() {
let mut cache = SessionCache::new();
cache.store("/test/file.rs", "content");
let result = cache.store("/test/file.rs", "content");
assert!(result.was_hit, "same content should be a cache hit");
}
#[test]
fn cache_miss_on_changed_content() {
let mut cache = SessionCache::new();
cache.store("/test/file.rs", "old content");
let result = cache.store("/test/file.rs", "new content");
assert!(!result.was_hit, "changed content should not be a cache hit");
}
#[test]
fn file_refs_are_sequential() {
let mut cache = SessionCache::new();
assert_eq!(cache.get_file_ref("/a.rs"), "F1");
assert_eq!(cache.get_file_ref("/b.rs"), "F2");
assert_eq!(cache.get_file_ref("/a.rs"), "F1"); }
#[test]
fn cache_clear_resets_everything() {
let mut cache = SessionCache::new();
cache.store("/a.rs", "a");
cache.store("/b.rs", "b");
let count = cache.clear();
assert_eq!(count, 2);
assert!(cache.get("/a.rs").is_none());
assert_eq!(cache.get_file_ref("/c.rs"), "F1"); }
#[test]
fn cache_invalidate_removes_entry() {
let mut cache = SessionCache::new();
cache.store("/test.rs", "test");
assert!(cache.invalidate("/test.rs"));
assert!(!cache.invalidate("/nonexistent.rs"));
}
#[test]
fn cache_stats_track_correctly() {
let mut cache = SessionCache::new();
cache.store("/a.rs", "hello");
cache.store("/a.rs", "hello"); let stats = cache.get_stats();
assert_eq!(stats.total_reads, 2);
assert_eq!(stats.cache_hits, 1);
assert!(stats.hit_rate() > 0.0);
}
#[test]
fn md5_is_deterministic() {
let h1 = compute_md5("test content");
let h2 = compute_md5("test content");
assert_eq!(h1, h2);
assert_ne!(h1, compute_md5("different"));
}
#[test]
fn rrf_eviction_prefers_recent() {
let base = Instant::now();
std::thread::sleep(std::time::Duration::from_millis(5));
let now = Instant::now();
let key_a = "a.rs".to_string();
let key_b = "b.rs".to_string();
let recent = CacheEntry::new("a", "h1".to_string(), 1, 10, "/a.rs".to_string(), None);
let old = {
let mut e = CacheEntry::new("b", "h2".to_string(), 1, 10, "/b.rs".to_string(), None);
e.last_access = base;
e
};
let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &recent), (&key_b, &old)];
let scores = eviction_scores_rrf(&entries, now);
let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
assert!(
score_a > score_b,
"recently accessed entries should score higher via RRF"
);
}
#[test]
fn rrf_eviction_prefers_frequent() {
let now = Instant::now();
let key_a = "a.rs".to_string();
let key_b = "b.rs".to_string();
let frequent = {
let mut e = CacheEntry::new("a", "h1".to_string(), 1, 10, "/a.rs".to_string(), None);
e.read_count = 20;
e
};
let rare = CacheEntry::new("b", "h2".to_string(), 1, 10, "/b.rs".to_string(), None);
let entries: Vec<(&String, &CacheEntry)> = vec![(&key_a, &frequent), (&key_b, &rare)];
let scores = eviction_scores_rrf(&entries, now);
let score_a = scores.iter().find(|(p, _)| p == "a.rs").unwrap().1;
let score_b = scores.iter().find(|(p, _)| p == "b.rs").unwrap().1;
assert!(
score_a > score_b,
"frequently accessed entries should score higher via RRF"
);
}
#[test]
fn evict_if_needed_removes_lowest_score() {
std::env::set_var("LEAN_CTX_CACHE_MAX_TOKENS", "50");
let mut cache = SessionCache::new();
let big_content = "a]".repeat(30); cache.store("/old.rs", &big_content);
let new_content = "b ".repeat(30); cache.store("/new.rs", &new_content);
assert!(
cache.total_cached_tokens() <= 60,
"eviction should have kicked in"
);
std::env::remove_var("LEAN_CTX_CACHE_MAX_TOKENS");
}
#[test]
fn stale_detection_flags_newer_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("stale.txt");
let p = path.to_string_lossy().to_string();
std::fs::write(&path, "one").unwrap();
let mut cache = SessionCache::new();
cache.store(&p, "one");
let entry = cache.get(&p).unwrap();
assert!(!is_cache_entry_stale(&p, entry.stored_mtime));
std::thread::sleep(Duration::from_secs(1));
std::fs::write(&path, "two").unwrap();
let entry = cache.get(&p).unwrap();
assert!(is_cache_entry_stale(&p, entry.stored_mtime));
}
#[test]
fn compressed_outputs_cached_and_retrieved() {
let mut cache = SessionCache::new();
cache.store("/test.rs", "fn main() {}");
cache.set_compressed("/test.rs", "map", "compressed map output".to_string());
assert_eq!(
cache.get_compressed("/test.rs", "map"),
Some(&"compressed map output".to_string())
);
assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
}
#[test]
fn compressed_outputs_cleared_on_content_change() {
let mut cache = SessionCache::new();
cache.store("/test.rs", "old content");
cache.set_compressed("/test.rs", "map", "old map".to_string());
assert!(cache.get_compressed("/test.rs", "map").is_some());
cache.store("/test.rs", "new content");
assert_eq!(cache.get_compressed("/test.rs", "map"), None);
}
#[test]
fn compressed_outputs_survive_same_content_store() {
let mut cache = SessionCache::new();
cache.store("/test.rs", "content");
cache.set_compressed("/test.rs", "map", "cached map".to_string());
let result = cache.store("/test.rs", "content");
assert!(result.was_hit);
assert_eq!(
cache.get_compressed("/test.rs", "map"),
Some(&"cached map".to_string())
);
}
#[test]
fn compressed_outputs_cleared_on_invalidate() {
let mut cache = SessionCache::new();
cache.store("/test.rs", "content");
cache.set_compressed("/test.rs", "signatures", "cached sigs".to_string());
cache.invalidate("/test.rs");
assert_eq!(cache.get_compressed("/test.rs", "signatures"), None);
}
#[test]
fn compressed_outputs_cleared_on_clear() {
let mut cache = SessionCache::new();
cache.store("/a.rs", "a");
cache.set_compressed("/a.rs", "map", "map_a".to_string());
cache.clear();
assert_eq!(cache.get_compressed("/a.rs", "map"), None);
}
}