use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ToolCallKey {
pub tool: String,
pub param_hash: u64,
}
impl ToolCallKey {
pub fn new(tool: &str, params: &serde_json::Value) -> Self {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
params.to_string().hash(&mut hasher);
Self {
tool: tool.to_string(),
param_hash: hasher.finish(),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolResult {
pub data: serde_json::Value,
pub execution_time: Duration,
pub timestamp: Instant,
pub ttl: Duration,
}
impl ToolResult {
pub fn new(data: serde_json::Value, execution_time: Duration) -> Self {
Self {
data,
execution_time,
timestamp: Instant::now(),
ttl: Duration::from_secs(300), }
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn is_expired(&self) -> bool {
self.timestamp.elapsed() > self.ttl
}
pub fn size(&self) -> usize {
self.data.to_string().len() + 32
}
}
pub struct ToolResultCache {
cache: DashMap<ToolCallKey, ToolResult>,
deterministic_tools: HashSet<String>,
tool_ttls: DashMap<String, Duration>,
stats: ToolCacheStats,
}
#[derive(Debug, Default)]
struct ToolCacheStats {
hits: AtomicU64,
misses: AtomicU64,
cached_executions: AtomicU64,
time_saved_ms: AtomicU64,
}
impl ToolResultCache {
pub fn new() -> Self {
let mut deterministic = HashSet::new();
deterministic.insert("get_weather".to_string());
deterministic.insert("calculate".to_string());
deterministic.insert("lookup_definition".to_string());
deterministic.insert("search_knowledge_base".to_string());
deterministic.insert("get_stock_price".to_string());
deterministic.insert("convert_units".to_string());
deterministic.insert("translate".to_string());
Self {
cache: DashMap::new(),
deterministic_tools: deterministic,
tool_ttls: DashMap::new(),
stats: ToolCacheStats::default(),
}
}
pub fn is_deterministic(&self, tool: &str) -> bool {
self.deterministic_tools.contains(tool)
}
pub fn mark_deterministic(&mut self, tool: impl Into<String>) {
self.deterministic_tools.insert(tool.into());
}
pub fn mark_non_deterministic(&mut self, tool: &str) {
self.deterministic_tools.remove(tool);
}
pub fn set_tool_ttl(&self, tool: impl Into<String>, ttl: Duration) {
self.tool_ttls.insert(tool.into(), ttl);
}
pub fn get(&self, key: &ToolCallKey) -> Option<ToolResult> {
if !self.is_deterministic(&key.tool) {
return None;
}
if let Some(result) = self.cache.get(key) {
if result.is_expired() {
drop(result);
self.cache.remove(key);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
self.stats.hits.fetch_add(1, Ordering::Relaxed);
self.stats.time_saved_ms.fetch_add(
result.execution_time.as_millis() as u64,
Ordering::Relaxed,
);
Some(result.clone())
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn put(&self, key: ToolCallKey, result: ToolResult) {
if !self.is_deterministic(&key.tool) {
return;
}
let result = if let Some(ttl) = self.tool_ttls.get(&key.tool) {
result.with_ttl(*ttl)
} else {
result
};
self.cache.insert(key, result);
self.stats.cached_executions.fetch_add(1, Ordering::Relaxed);
}
pub async fn execute_with_cache<F, Fut>(
&self,
tool: &str,
params: &serde_json::Value,
executor: F,
) -> ToolResult
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = serde_json::Value>,
{
let key = ToolCallKey::new(tool, params);
if let Some(cached) = self.get(&key) {
return cached;
}
let start = Instant::now();
let data = executor().await;
let execution_time = start.elapsed();
let result = ToolResult::new(data, execution_time);
self.put(key, result.clone());
result
}
pub fn clear(&self) {
self.cache.clear();
}
pub fn clear_tool(&self, tool: &str) {
self.cache.retain(|k, _| k.tool != tool);
}
pub fn cleanup_expired(&self) {
self.cache.retain(|_, v| !v.is_expired());
}
pub fn stats(&self) -> ToolCacheStatsSnapshot {
let hits = self.stats.hits.load(Ordering::Relaxed);
let misses = self.stats.misses.load(Ordering::Relaxed);
let total = hits + misses;
ToolCacheStatsSnapshot {
cached_entries: self.cache.len(),
deterministic_tools: self.deterministic_tools.len(),
hits,
misses,
hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
cached_executions: self.stats.cached_executions.load(Ordering::Relaxed),
time_saved_ms: self.stats.time_saved_ms.load(Ordering::Relaxed),
}
}
}
impl Default for ToolResultCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ToolCacheStatsSnapshot {
pub cached_entries: usize,
pub deterministic_tools: usize,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
pub cached_executions: u64,
pub time_saved_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_call_key() {
let key1 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
let key2 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
let key3 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 3}));
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_deterministic_check() {
let cache = ToolResultCache::new();
assert!(cache.is_deterministic("calculate"));
assert!(cache.is_deterministic("get_weather"));
assert!(!cache.is_deterministic("random_function"));
}
#[test]
fn test_cache_put_get() {
let cache = ToolResultCache::new();
let key = ToolCallKey::new("calculate", &json!({"expr": "2+2"}));
let result = ToolResult::new(json!(4), Duration::from_millis(10));
cache.put(key.clone(), result);
let cached = cache.get(&key);
assert!(cached.is_some());
assert_eq!(cached.unwrap().data, json!(4));
}
#[test]
fn test_non_deterministic_not_cached() {
let cache = ToolResultCache::new();
let key = ToolCallKey::new("random_tool", &json!({}));
let result = ToolResult::new(json!("result"), Duration::from_millis(10));
cache.put(key.clone(), result);
assert!(cache.get(&key).is_none());
}
#[test]
fn test_expired_entries() {
let cache = ToolResultCache::new();
let key = ToolCallKey::new("calculate", &json!({}));
let result = ToolResult::new(json!(1), Duration::from_millis(1))
.with_ttl(Duration::from_millis(1));
cache.put(key.clone(), result);
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_stats() {
let cache = ToolResultCache::new();
let key = ToolCallKey::new("calculate", &json!({}));
let result = ToolResult::new(json!(1), Duration::from_millis(50));
cache.put(key.clone(), result);
cache.get(&key); cache.get(&key);
let key2 = ToolCallKey::new("calculate", &json!({"x": 1}));
cache.get(&key2);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!(stats.time_saved_ms >= 100);
}
#[tokio::test]
async fn test_execute_with_cache() {
let cache = ToolResultCache::new();
let params = json!({"a": 5, "b": 3});
let mut call_count = 0;
let result1 = cache.execute_with_cache("calculate", ¶ms, || {
call_count += 1;
async { json!(8) }
}).await;
let result2 = cache.execute_with_cache("calculate", ¶ms, || {
call_count += 1;
async { json!(8) }
}).await;
assert_eq!(result1.data, json!(8));
assert_eq!(result2.data, json!(8));
}
}