use serde_json::Value;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
struct CacheEntry {
result: Value,
inserted_at: Instant,
ttl: Duration,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.inserted_at.elapsed() > self.ttl
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
}
pub struct ResultCache {
entries: Mutex<HashMap<String, CacheEntry>>,
tool_ttls: Mutex<HashMap<String, Duration>>,
hits: AtomicU64,
misses: AtomicU64,
}
impl ResultCache {
pub fn new() -> Self {
Self {
entries: Mutex::new(HashMap::new()),
tool_ttls: Mutex::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub async fn enable_caching(&self, tool: &str, ttl_secs: u64) {
let mut ttls = self.tool_ttls.lock().await;
ttls.insert(tool.to_string(), Duration::from_secs(ttl_secs));
}
pub async fn get(&self, tool: &str, params: &Value) -> Option<Value> {
let ttls = self.tool_ttls.lock().await;
if !ttls.contains_key(tool) {
return None;
}
drop(ttls);
let key = cache_key(tool, params);
let mut entries = self.entries.lock().await;
if let Some(entry) = entries.get(&key) {
if entry.is_expired() {
entries.remove(&key);
self.misses.fetch_add(1, Ordering::Relaxed);
None
} else {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(entry.result.clone())
}
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub async fn put(&self, tool: &str, params: &Value, result: Value) {
let ttls = self.tool_ttls.lock().await;
let ttl = match ttls.get(tool) {
Some(ttl) => *ttl,
None => return,
};
drop(ttls);
let key = cache_key(tool, params);
let mut entries = self.entries.lock().await;
entries.insert(
key,
CacheEntry {
result,
inserted_at: Instant::now(),
ttl,
},
);
}
pub async fn invalidate(&self, tool: &str) {
let prefix = format!("{}:", tool);
let mut entries = self.entries.lock().await;
entries.retain(|k, _| !k.starts_with(&prefix));
}
pub async fn invalidate_all(&self) {
let mut entries = self.entries.lock().await;
entries.clear();
}
pub async fn stats(&self) -> CacheStats {
let entries = self.entries.lock().await;
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
entries: entries.len(),
}
}
}
impl Default for ResultCache {
fn default() -> Self {
Self::new()
}
}
fn cache_key(tool: &str, params: &Value) -> String {
let serialized = serde_json::to_string(params).unwrap_or_default();
let mut hasher = DefaultHasher::new();
serialized.hash(&mut hasher);
let hash = hasher.finish();
format!("{}:{:x}", tool, hash)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_cache_hit_returns_stored_result() {
let cache = ResultCache::new();
cache.enable_caching("add", 60).await;
let params = json!({"a": 1, "b": 2});
let result = json!(3);
cache.put("add", ¶ms, result.clone()).await;
let cached = cache.get("add", ¶ms).await;
assert_eq!(cached, Some(result));
let stats = cache.stats().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 0);
assert_eq!(stats.entries, 1);
}
#[tokio::test]
async fn test_expired_entries_return_none() {
let cache = ResultCache::new();
cache.enable_caching("add", 0).await;
let params = json!({"a": 1, "b": 2});
cache.put("add", ¶ms, json!(3)).await;
let cached = cache.get("add", ¶ms).await;
assert_eq!(cached, None);
let stats = cache.stats().await;
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_different_params_produce_different_keys() {
let cache = ResultCache::new();
cache.enable_caching("add", 60).await;
let params_a = json!({"a": 1, "b": 2});
let params_b = json!({"a": 3, "b": 4});
cache.put("add", ¶ms_a, json!(3)).await;
cache.put("add", ¶ms_b, json!(7)).await;
assert_eq!(cache.get("add", ¶ms_a).await, Some(json!(3)));
assert_eq!(cache.get("add", ¶ms_b).await, Some(json!(7)));
let stats = cache.stats().await;
assert_eq!(stats.entries, 2);
}
#[tokio::test]
async fn test_invalidate_clears_tool_entries() {
let cache = ResultCache::new();
cache.enable_caching("add", 60).await;
cache.enable_caching("echo", 60).await;
cache.put("add", &json!({"a": 1}), json!(1)).await;
cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
cache.invalidate("add").await;
assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
assert_eq!(
cache.get("echo", &json!({"msg": "hi"})).await,
Some(json!("hi"))
);
}
#[tokio::test]
async fn test_invalidate_all_clears_everything() {
let cache = ResultCache::new();
cache.enable_caching("add", 60).await;
cache.enable_caching("echo", 60).await;
cache.put("add", &json!({"a": 1}), json!(1)).await;
cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
cache.invalidate_all().await;
let stats = cache.stats().await;
assert_eq!(stats.entries, 0);
}
#[tokio::test]
async fn test_uncacheable_tool_returns_none() {
let cache = ResultCache::new();
cache.put("add", &json!({"a": 1}), json!(1)).await;
assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
}
}