use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::LazyLock;
use std::time::{Duration, Instant};
use zeph_common::ToolName;
use crate::executor::ToolOutput;
static NON_CACHEABLE_TOOLS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"bash", "memory_save", "memory_search", "scheduler", "write", ])
});
#[must_use]
pub fn is_cacheable(tool_name: &str) -> bool {
if tool_name.starts_with("mcp_") {
return false;
}
!NON_CACHEABLE_TOOLS.contains(tool_name)
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub tool_name: ToolName,
pub args_hash: u64,
}
impl CacheKey {
#[must_use]
pub fn new(tool_name: impl Into<ToolName>, args_hash: u64) -> Self {
Self {
tool_name: tool_name.into(),
args_hash,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub output: ToolOutput,
pub inserted_at: Instant,
}
impl CacheEntry {
fn is_expired(&self, ttl: Duration) -> bool {
self.inserted_at.elapsed() > ttl
}
}
const MAX_CACHE_ENTRIES: usize = 512;
#[derive(Debug)]
pub struct ToolResultCache {
entries: HashMap<CacheKey, CacheEntry>,
insertion_order: VecDeque<CacheKey>,
ttl: Option<Duration>,
enabled: bool,
hits: u64,
misses: u64,
}
impl ToolResultCache {
#[must_use]
pub fn new(enabled: bool, ttl: Option<Duration>) -> Self {
Self {
entries: HashMap::new(),
insertion_order: VecDeque::new(),
ttl,
enabled,
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, key: &CacheKey) -> Option<ToolOutput> {
if !self.enabled {
return None;
}
if let Some(entry) = self.entries.get(key) {
if self.ttl.is_some_and(|ttl| entry.is_expired(ttl)) {
self.entries.remove(key);
return None;
}
let output = entry.output.clone();
self.hits += 1;
return Some(output);
}
self.misses += 1;
None
}
pub fn put(&mut self, key: CacheKey, output: ToolOutput) {
if !self.enabled {
return;
}
if self.entries.len() >= MAX_CACHE_ENTRIES
&& let Some(oldest_key) = self.insertion_order.pop_front()
{
self.entries.remove(&oldest_key);
tracing::debug!(
tool = %oldest_key.tool_name,
args_hash = oldest_key.args_hash,
"tool cache: evicted oldest entry (LRU cap {})",
MAX_CACHE_ENTRIES
);
}
self.insertion_order.push_back(key.clone());
self.entries.insert(
key,
CacheEntry {
output,
inserted_at: Instant::now(),
},
);
}
pub fn clear(&mut self) {
self.entries.clear();
self.insertion_order.clear();
self.hits = 0;
self.misses = 0;
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.enabled
}
#[must_use]
pub fn ttl_secs(&self) -> u64 {
self.ttl.map_or(0, |d| d.as_secs())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ToolName;
fn make_output(summary: &str) -> ToolOutput {
ToolOutput {
tool_name: ToolName::new("test"),
summary: summary.to_owned(),
blocks_executed: 1,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: None,
}
}
fn key(name: &str, hash: u64) -> CacheKey {
CacheKey::new(name, hash)
}
#[test]
fn miss_on_empty_cache() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
assert!(cache.get(&key("read", 1)).is_none());
assert_eq!(cache.misses(), 1);
assert_eq!(cache.hits(), 0);
}
#[test]
fn put_then_get_returns_cached() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
let out = make_output("file contents");
cache.put(key("read", 42), out.clone());
let result = cache.get(&key("read", 42));
assert!(result.is_some());
assert_eq!(result.unwrap().summary, "file contents");
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 0);
}
#[test]
fn different_hash_is_miss() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
cache.put(key("read", 1), make_output("a"));
assert!(cache.get(&key("read", 2)).is_none());
}
#[test]
fn different_tool_name_is_miss() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
cache.put(key("read", 1), make_output("a"));
assert!(cache.get(&key("write", 1)).is_none());
}
#[test]
fn ttl_none_never_expires() {
let mut cache = ToolResultCache::new(true, None);
cache.put(key("read", 1), make_output("content"));
assert!(cache.get(&key("read", 1)).is_some());
assert_eq!(cache.hits(), 1);
}
#[test]
fn ttl_zero_duration_expires_immediately() {
let mut cache = ToolResultCache::new(true, Some(Duration::ZERO));
cache.put(key("read", 1), make_output("content"));
let result = cache.get(&key("read", 1));
assert!(
result.is_none(),
"Duration::ZERO entry must expire on first get()"
);
assert_eq!(cache.len(), 0, "expired entry must be removed from map");
}
#[test]
fn ttl_expired_returns_none() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_millis(1)));
cache.put(key("read", 1), make_output("content"));
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get(&key("read", 1)).is_none());
assert_eq!(cache.len(), 0);
}
#[test]
fn clear_removes_all_and_resets_counters() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
cache.put(key("read", 1), make_output("a"));
cache.put(key("web_scrape", 2), make_output("b"));
cache.get(&key("read", 1));
cache.get(&key("missing", 99));
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.hits(), 0);
assert_eq!(cache.misses(), 0);
assert!(cache.get(&key("read", 1)).is_none());
}
#[test]
fn disabled_cache_always_misses() {
let mut cache = ToolResultCache::new(false, Some(Duration::from_mins(5)));
cache.put(key("read", 1), make_output("content"));
assert!(cache.get(&key("read", 1)).is_none());
assert_eq!(cache.len(), 0);
assert_eq!(cache.misses(), 0);
}
#[test]
fn is_cacheable_returns_false_for_deny_list() {
assert!(!is_cacheable("bash"));
assert!(!is_cacheable("memory_save"));
assert!(!is_cacheable("memory_search"));
assert!(!is_cacheable("scheduler"));
assert!(!is_cacheable("write"));
}
#[test]
fn is_cacheable_returns_false_for_mcp_prefix() {
assert!(!is_cacheable("mcp_github_list_issues"));
assert!(!is_cacheable("mcp_send_email"));
assert!(!is_cacheable("mcp_"));
}
#[test]
fn is_cacheable_returns_true_for_read_only_tools() {
assert!(is_cacheable("read"));
assert!(is_cacheable("web_scrape"));
assert!(is_cacheable("search_code"));
assert!(is_cacheable("load_skill"));
assert!(is_cacheable("diagnostics"));
}
#[test]
fn counter_increments_correctly() {
let mut cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
cache.put(key("read", 1), make_output("a"));
cache.put(key("read", 2), make_output("b"));
cache.get(&key("read", 1)); cache.get(&key("read", 1)); cache.get(&key("read", 99));
assert_eq!(cache.hits(), 2);
assert_eq!(cache.misses(), 1);
}
#[test]
fn ttl_secs_returns_zero_for_none() {
let cache = ToolResultCache::new(true, None);
assert_eq!(cache.ttl_secs(), 0);
}
#[test]
fn ttl_secs_returns_seconds_for_some() {
let cache = ToolResultCache::new(true, Some(Duration::from_mins(5)));
assert_eq!(cache.ttl_secs(), 300);
}
#[test]
fn lru_eviction_at_capacity() {
let mut cache = ToolResultCache::new(true, None);
for i in 0..MAX_CACHE_ENTRIES {
cache.put(key("read", i as u64), make_output("v"));
}
assert_eq!(cache.len(), MAX_CACHE_ENTRIES);
cache.put(key("read", MAX_CACHE_ENTRIES as u64), make_output("new"));
assert_eq!(cache.len(), MAX_CACHE_ENTRIES, "size must stay at cap");
assert!(
cache.get(&key("read", 0)).is_none(),
"oldest entry must be evicted"
);
assert!(
cache.get(&key("read", MAX_CACHE_ENTRIES as u64)).is_some(),
"new entry must be present"
);
}
}