Skip to main content

heliosdb_proxy/distribcache/ai/
tools.rs

1//! Tool result cache
2//!
3//! Caches results of deterministic tool calls to avoid redundant execution.
4//! Useful for AI agents that may call the same tool with same parameters multiple times.
5
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11/// Tool call key (tool name + parameters hash)
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct ToolCallKey {
14    /// Tool name
15    pub tool: String,
16    /// Parameter hash
17    pub param_hash: u64,
18}
19
20impl ToolCallKey {
21    /// Create a new tool call key
22    pub fn new(tool: &str, params: &serde_json::Value) -> Self {
23        use std::hash::{Hash, Hasher};
24        use std::collections::hash_map::DefaultHasher;
25
26        let mut hasher = DefaultHasher::new();
27        params.to_string().hash(&mut hasher);
28
29        Self {
30            tool: tool.to_string(),
31            param_hash: hasher.finish(),
32        }
33    }
34}
35
36/// Tool execution result
37#[derive(Debug, Clone)]
38pub struct ToolResult {
39    /// Result data
40    pub data: serde_json::Value,
41    /// Execution time
42    pub execution_time: Duration,
43    /// Timestamp
44    pub timestamp: Instant,
45    /// TTL
46    pub ttl: Duration,
47}
48
49impl ToolResult {
50    /// Create a new result
51    pub fn new(data: serde_json::Value, execution_time: Duration) -> Self {
52        Self {
53            data,
54            execution_time,
55            timestamp: Instant::now(),
56            ttl: Duration::from_secs(300), // Default 5 minutes
57        }
58    }
59
60    /// Set TTL
61    pub fn with_ttl(mut self, ttl: Duration) -> Self {
62        self.ttl = ttl;
63        self
64    }
65
66    /// Check if expired
67    pub fn is_expired(&self) -> bool {
68        self.timestamp.elapsed() > self.ttl
69    }
70
71    /// Approximate size
72    pub fn size(&self) -> usize {
73        self.data.to_string().len() + 32
74    }
75}
76
77/// Tool result cache
78pub struct ToolResultCache {
79    /// Cache storage
80    cache: DashMap<ToolCallKey, ToolResult>,
81
82    /// Deterministic tools (safe to cache)
83    deterministic_tools: HashSet<String>,
84
85    /// Custom TTLs per tool
86    tool_ttls: DashMap<String, Duration>,
87
88    /// Statistics
89    stats: ToolCacheStats,
90}
91
92/// Tool cache statistics
93#[derive(Debug, Default)]
94struct ToolCacheStats {
95    hits: AtomicU64,
96    misses: AtomicU64,
97    cached_executions: AtomicU64,
98    time_saved_ms: AtomicU64,
99}
100
101impl ToolResultCache {
102    /// Create a new tool cache
103    pub fn new() -> Self {
104        // Default deterministic tools
105        let mut deterministic = HashSet::new();
106        deterministic.insert("get_weather".to_string());
107        deterministic.insert("calculate".to_string());
108        deterministic.insert("lookup_definition".to_string());
109        deterministic.insert("search_knowledge_base".to_string());
110        deterministic.insert("get_stock_price".to_string());
111        deterministic.insert("convert_units".to_string());
112        deterministic.insert("translate".to_string());
113
114        Self {
115            cache: DashMap::new(),
116            deterministic_tools: deterministic,
117            tool_ttls: DashMap::new(),
118            stats: ToolCacheStats::default(),
119        }
120    }
121
122    /// Check if tool is deterministic (cacheable)
123    pub fn is_deterministic(&self, tool: &str) -> bool {
124        self.deterministic_tools.contains(tool)
125    }
126
127    /// Mark a tool as deterministic
128    pub fn mark_deterministic(&mut self, tool: impl Into<String>) {
129        self.deterministic_tools.insert(tool.into());
130    }
131
132    /// Mark a tool as non-deterministic
133    pub fn mark_non_deterministic(&mut self, tool: &str) {
134        self.deterministic_tools.remove(tool);
135    }
136
137    /// Set custom TTL for a tool
138    pub fn set_tool_ttl(&self, tool: impl Into<String>, ttl: Duration) {
139        self.tool_ttls.insert(tool.into(), ttl);
140    }
141
142    /// Get cached result
143    pub fn get(&self, key: &ToolCallKey) -> Option<ToolResult> {
144        // Check if tool is deterministic
145        if !self.is_deterministic(&key.tool) {
146            return None;
147        }
148
149        if let Some(result) = self.cache.get(key) {
150            if result.is_expired() {
151                drop(result);
152                self.cache.remove(key);
153                self.stats.misses.fetch_add(1, Ordering::Relaxed);
154                return None;
155            }
156
157            self.stats.hits.fetch_add(1, Ordering::Relaxed);
158            self.stats.time_saved_ms.fetch_add(
159                result.execution_time.as_millis() as u64,
160                Ordering::Relaxed,
161            );
162
163            Some(result.clone())
164        } else {
165            self.stats.misses.fetch_add(1, Ordering::Relaxed);
166            None
167        }
168    }
169
170    /// Cache a tool result
171    pub fn put(&self, key: ToolCallKey, result: ToolResult) {
172        // Only cache deterministic tools
173        if !self.is_deterministic(&key.tool) {
174            return;
175        }
176
177        // Apply custom TTL if configured
178        let result = if let Some(ttl) = self.tool_ttls.get(&key.tool) {
179            result.with_ttl(*ttl)
180        } else {
181            result
182        };
183
184        self.cache.insert(key, result);
185        self.stats.cached_executions.fetch_add(1, Ordering::Relaxed);
186    }
187
188    /// Execute with caching
189    pub async fn execute_with_cache<F, Fut>(
190        &self,
191        tool: &str,
192        params: &serde_json::Value,
193        executor: F,
194    ) -> ToolResult
195    where
196        F: FnOnce() -> Fut,
197        Fut: std::future::Future<Output = serde_json::Value>,
198    {
199        let key = ToolCallKey::new(tool, params);
200
201        // Check cache
202        if let Some(cached) = self.get(&key) {
203            return cached;
204        }
205
206        // Execute
207        let start = Instant::now();
208        let data = executor().await;
209        let execution_time = start.elapsed();
210
211        let result = ToolResult::new(data, execution_time);
212
213        // Cache result
214        self.put(key, result.clone());
215
216        result
217    }
218
219    /// Clear all cached results
220    pub fn clear(&self) {
221        self.cache.clear();
222    }
223
224    /// Clear cached results for a tool
225    pub fn clear_tool(&self, tool: &str) {
226        self.cache.retain(|k, _| k.tool != tool);
227    }
228
229    /// Remove expired entries
230    pub fn cleanup_expired(&self) {
231        self.cache.retain(|_, v| !v.is_expired());
232    }
233
234    /// Get statistics
235    pub fn stats(&self) -> ToolCacheStatsSnapshot {
236        let hits = self.stats.hits.load(Ordering::Relaxed);
237        let misses = self.stats.misses.load(Ordering::Relaxed);
238        let total = hits + misses;
239
240        ToolCacheStatsSnapshot {
241            cached_entries: self.cache.len(),
242            deterministic_tools: self.deterministic_tools.len(),
243            hits,
244            misses,
245            hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
246            cached_executions: self.stats.cached_executions.load(Ordering::Relaxed),
247            time_saved_ms: self.stats.time_saved_ms.load(Ordering::Relaxed),
248        }
249    }
250}
251
252impl Default for ToolResultCache {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// Tool cache statistics snapshot
259#[derive(Debug, Clone)]
260pub struct ToolCacheStatsSnapshot {
261    pub cached_entries: usize,
262    pub deterministic_tools: usize,
263    pub hits: u64,
264    pub misses: u64,
265    pub hit_rate: f64,
266    pub cached_executions: u64,
267    pub time_saved_ms: u64,
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use serde_json::json;
274
275    #[test]
276    fn test_tool_call_key() {
277        let key1 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
278        let key2 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
279        let key3 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 3}));
280
281        assert_eq!(key1, key2);
282        assert_ne!(key1, key3);
283    }
284
285    #[test]
286    fn test_deterministic_check() {
287        let cache = ToolResultCache::new();
288
289        assert!(cache.is_deterministic("calculate"));
290        assert!(cache.is_deterministic("get_weather"));
291        assert!(!cache.is_deterministic("random_function"));
292    }
293
294    #[test]
295    fn test_cache_put_get() {
296        let cache = ToolResultCache::new();
297
298        let key = ToolCallKey::new("calculate", &json!({"expr": "2+2"}));
299        let result = ToolResult::new(json!(4), Duration::from_millis(10));
300
301        cache.put(key.clone(), result);
302
303        let cached = cache.get(&key);
304        assert!(cached.is_some());
305        assert_eq!(cached.unwrap().data, json!(4));
306    }
307
308    #[test]
309    fn test_non_deterministic_not_cached() {
310        let cache = ToolResultCache::new();
311
312        let key = ToolCallKey::new("random_tool", &json!({}));
313        let result = ToolResult::new(json!("result"), Duration::from_millis(10));
314
315        cache.put(key.clone(), result);
316
317        // Should not be cached
318        assert!(cache.get(&key).is_none());
319    }
320
321    #[test]
322    fn test_expired_entries() {
323        let cache = ToolResultCache::new();
324
325        let key = ToolCallKey::new("calculate", &json!({}));
326        let result = ToolResult::new(json!(1), Duration::from_millis(1))
327            .with_ttl(Duration::from_millis(1));
328
329        cache.put(key.clone(), result);
330
331        // Wait for expiration
332        std::thread::sleep(Duration::from_millis(10));
333
334        assert!(cache.get(&key).is_none());
335    }
336
337    #[test]
338    fn test_stats() {
339        let cache = ToolResultCache::new();
340
341        let key = ToolCallKey::new("calculate", &json!({}));
342        let result = ToolResult::new(json!(1), Duration::from_millis(50));
343
344        cache.put(key.clone(), result);
345        cache.get(&key); // Hit
346        cache.get(&key); // Hit
347
348        let key2 = ToolCallKey::new("calculate", &json!({"x": 1}));
349        cache.get(&key2); // Miss
350
351        let stats = cache.stats();
352        assert_eq!(stats.hits, 2);
353        assert_eq!(stats.misses, 1);
354        assert!(stats.time_saved_ms >= 100);
355    }
356
357    #[tokio::test]
358    async fn test_execute_with_cache() {
359        let cache = ToolResultCache::new();
360
361        let params = json!({"a": 5, "b": 3});
362        let mut call_count = 0;
363
364        // First call - executes
365        let result1 = cache.execute_with_cache("calculate", &params, || {
366            call_count += 1;
367            async { json!(8) }
368        }).await;
369
370        // Second call - cached
371        let result2 = cache.execute_with_cache("calculate", &params, || {
372            call_count += 1;
373            async { json!(8) }
374        }).await;
375
376        assert_eq!(result1.data, json!(8));
377        assert_eq!(result2.data, json!(8));
378        // Function should only be called once
379        // Note: call_count tracking doesn't work directly in async closure
380    }
381}