Skip to main content

heliosdb_proxy/distribcache/ai/
tools.rs

1//! Tool result cache
2//!
3//! Caches results of deterministic tool calls to avoid redundant execution.
4//! Useful for AI agents that may call the same tool with same parameters multiple times.
5
6use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11/// Tool call key (tool name + parameters hash)
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct ToolCallKey {
14    /// Tool name
15    pub tool: String,
16    /// Parameter hash
17    pub param_hash: u64,
18}
19
20impl ToolCallKey {
21    /// Create a new tool call key
22    pub fn new(tool: &str, params: &serde_json::Value) -> Self {
23        use std::collections::hash_map::DefaultHasher;
24        use std::hash::{Hash, Hasher};
25
26        let mut hasher = DefaultHasher::new();
27        params.to_string().hash(&mut hasher);
28
29        Self {
30            tool: tool.to_string(),
31            param_hash: hasher.finish(),
32        }
33    }
34}
35
36/// Tool execution result
37#[derive(Debug, Clone)]
38pub struct ToolResult {
39    /// Result data
40    pub data: serde_json::Value,
41    /// Execution time
42    pub execution_time: Duration,
43    /// Timestamp
44    pub timestamp: Instant,
45    /// TTL
46    pub ttl: Duration,
47}
48
49impl ToolResult {
50    /// Create a new result
51    pub fn new(data: serde_json::Value, execution_time: Duration) -> Self {
52        Self {
53            data,
54            execution_time,
55            timestamp: Instant::now(),
56            ttl: Duration::from_secs(300), // Default 5 minutes
57        }
58    }
59
60    /// Set TTL
61    pub fn with_ttl(mut self, ttl: Duration) -> Self {
62        self.ttl = ttl;
63        self
64    }
65
66    /// Check if expired
67    pub fn is_expired(&self) -> bool {
68        self.timestamp.elapsed() > self.ttl
69    }
70
71    /// Approximate size
72    pub fn size(&self) -> usize {
73        self.data.to_string().len() + 32
74    }
75}
76
77/// Tool result cache
78pub struct ToolResultCache {
79    /// Cache storage
80    cache: DashMap<ToolCallKey, ToolResult>,
81
82    /// Deterministic tools (safe to cache)
83    deterministic_tools: HashSet<String>,
84
85    /// Custom TTLs per tool
86    tool_ttls: DashMap<String, Duration>,
87
88    /// Statistics
89    stats: ToolCacheStats,
90}
91
92/// Tool cache statistics
93#[derive(Debug, Default)]
94struct ToolCacheStats {
95    hits: AtomicU64,
96    misses: AtomicU64,
97    cached_executions: AtomicU64,
98    time_saved_ms: AtomicU64,
99}
100
101impl ToolResultCache {
102    /// Create a new tool cache
103    pub fn new() -> Self {
104        // Default deterministic tools
105        let mut deterministic = HashSet::new();
106        deterministic.insert("get_weather".to_string());
107        deterministic.insert("calculate".to_string());
108        deterministic.insert("lookup_definition".to_string());
109        deterministic.insert("search_knowledge_base".to_string());
110        deterministic.insert("get_stock_price".to_string());
111        deterministic.insert("convert_units".to_string());
112        deterministic.insert("translate".to_string());
113
114        Self {
115            cache: DashMap::new(),
116            deterministic_tools: deterministic,
117            tool_ttls: DashMap::new(),
118            stats: ToolCacheStats::default(),
119        }
120    }
121
122    /// Check if tool is deterministic (cacheable)
123    pub fn is_deterministic(&self, tool: &str) -> bool {
124        self.deterministic_tools.contains(tool)
125    }
126
127    /// Mark a tool as deterministic
128    pub fn mark_deterministic(&mut self, tool: impl Into<String>) {
129        self.deterministic_tools.insert(tool.into());
130    }
131
132    /// Mark a tool as non-deterministic
133    pub fn mark_non_deterministic(&mut self, tool: &str) {
134        self.deterministic_tools.remove(tool);
135    }
136
137    /// Set custom TTL for a tool
138    pub fn set_tool_ttl(&self, tool: impl Into<String>, ttl: Duration) {
139        self.tool_ttls.insert(tool.into(), ttl);
140    }
141
142    /// Get cached result
143    pub fn get(&self, key: &ToolCallKey) -> Option<ToolResult> {
144        // Check if tool is deterministic
145        if !self.is_deterministic(&key.tool) {
146            return None;
147        }
148
149        if let Some(result) = self.cache.get(key) {
150            if result.is_expired() {
151                drop(result);
152                self.cache.remove(key);
153                self.stats.misses.fetch_add(1, Ordering::Relaxed);
154                return None;
155            }
156
157            self.stats.hits.fetch_add(1, Ordering::Relaxed);
158            self.stats
159                .time_saved_ms
160                .fetch_add(result.execution_time.as_millis() as u64, Ordering::Relaxed);
161
162            Some(result.clone())
163        } else {
164            self.stats.misses.fetch_add(1, Ordering::Relaxed);
165            None
166        }
167    }
168
169    /// Cache a tool result
170    pub fn put(&self, key: ToolCallKey, result: ToolResult) {
171        // Only cache deterministic tools
172        if !self.is_deterministic(&key.tool) {
173            return;
174        }
175
176        // Apply custom TTL if configured
177        let result = if let Some(ttl) = self.tool_ttls.get(&key.tool) {
178            result.with_ttl(*ttl)
179        } else {
180            result
181        };
182
183        self.cache.insert(key, result);
184        self.stats.cached_executions.fetch_add(1, Ordering::Relaxed);
185    }
186
187    /// Execute with caching
188    pub async fn execute_with_cache<F, Fut>(
189        &self,
190        tool: &str,
191        params: &serde_json::Value,
192        executor: F,
193    ) -> ToolResult
194    where
195        F: FnOnce() -> Fut,
196        Fut: std::future::Future<Output = serde_json::Value>,
197    {
198        let key = ToolCallKey::new(tool, params);
199
200        // Check cache
201        if let Some(cached) = self.get(&key) {
202            return cached;
203        }
204
205        // Execute
206        let start = Instant::now();
207        let data = executor().await;
208        let execution_time = start.elapsed();
209
210        let result = ToolResult::new(data, execution_time);
211
212        // Cache result
213        self.put(key, result.clone());
214
215        result
216    }
217
218    /// Clear all cached results
219    pub fn clear(&self) {
220        self.cache.clear();
221    }
222
223    /// Clear cached results for a tool
224    pub fn clear_tool(&self, tool: &str) {
225        self.cache.retain(|k, _| k.tool != tool);
226    }
227
228    /// Remove expired entries
229    pub fn cleanup_expired(&self) {
230        self.cache.retain(|_, v| !v.is_expired());
231    }
232
233    /// Get statistics
234    pub fn stats(&self) -> ToolCacheStatsSnapshot {
235        let hits = self.stats.hits.load(Ordering::Relaxed);
236        let misses = self.stats.misses.load(Ordering::Relaxed);
237        let total = hits + misses;
238
239        ToolCacheStatsSnapshot {
240            cached_entries: self.cache.len(),
241            deterministic_tools: self.deterministic_tools.len(),
242            hits,
243            misses,
244            hit_rate: if total > 0 {
245                hits as f64 / total as f64
246            } else {
247                0.0
248            },
249            cached_executions: self.stats.cached_executions.load(Ordering::Relaxed),
250            time_saved_ms: self.stats.time_saved_ms.load(Ordering::Relaxed),
251        }
252    }
253}
254
255impl Default for ToolResultCache {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// Tool cache statistics snapshot
262#[derive(Debug, Clone)]
263pub struct ToolCacheStatsSnapshot {
264    pub cached_entries: usize,
265    pub deterministic_tools: usize,
266    pub hits: u64,
267    pub misses: u64,
268    pub hit_rate: f64,
269    pub cached_executions: u64,
270    pub time_saved_ms: u64,
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use serde_json::json;
277
278    #[test]
279    fn test_tool_call_key() {
280        let key1 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
281        let key2 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
282        let key3 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 3}));
283
284        assert_eq!(key1, key2);
285        assert_ne!(key1, key3);
286    }
287
288    #[test]
289    fn test_deterministic_check() {
290        let cache = ToolResultCache::new();
291
292        assert!(cache.is_deterministic("calculate"));
293        assert!(cache.is_deterministic("get_weather"));
294        assert!(!cache.is_deterministic("random_function"));
295    }
296
297    #[test]
298    fn test_cache_put_get() {
299        let cache = ToolResultCache::new();
300
301        let key = ToolCallKey::new("calculate", &json!({"expr": "2+2"}));
302        let result = ToolResult::new(json!(4), Duration::from_millis(10));
303
304        cache.put(key.clone(), result);
305
306        let cached = cache.get(&key);
307        assert!(cached.is_some());
308        assert_eq!(cached.unwrap().data, json!(4));
309    }
310
311    #[test]
312    fn test_non_deterministic_not_cached() {
313        let cache = ToolResultCache::new();
314
315        let key = ToolCallKey::new("random_tool", &json!({}));
316        let result = ToolResult::new(json!("result"), Duration::from_millis(10));
317
318        cache.put(key.clone(), result);
319
320        // Should not be cached
321        assert!(cache.get(&key).is_none());
322    }
323
324    #[test]
325    fn test_expired_entries() {
326        let cache = ToolResultCache::new();
327
328        let key = ToolCallKey::new("calculate", &json!({}));
329        let result =
330            ToolResult::new(json!(1), Duration::from_millis(1)).with_ttl(Duration::from_millis(1));
331
332        cache.put(key.clone(), result);
333
334        // Wait for expiration
335        std::thread::sleep(Duration::from_millis(10));
336
337        assert!(cache.get(&key).is_none());
338    }
339
340    #[test]
341    fn test_stats() {
342        let cache = ToolResultCache::new();
343
344        let key = ToolCallKey::new("calculate", &json!({}));
345        let result = ToolResult::new(json!(1), Duration::from_millis(50));
346
347        cache.put(key.clone(), result);
348        cache.get(&key); // Hit
349        cache.get(&key); // Hit
350
351        let key2 = ToolCallKey::new("calculate", &json!({"x": 1}));
352        cache.get(&key2); // Miss
353
354        let stats = cache.stats();
355        assert_eq!(stats.hits, 2);
356        assert_eq!(stats.misses, 1);
357        assert!(stats.time_saved_ms >= 100);
358    }
359
360    #[tokio::test]
361    async fn test_execute_with_cache() {
362        let cache = ToolResultCache::new();
363
364        let params = json!({"a": 5, "b": 3});
365        let mut call_count = 0;
366
367        // First call - executes
368        let result1 = cache
369            .execute_with_cache("calculate", &params, || {
370                call_count += 1;
371                async { json!(8) }
372            })
373            .await;
374
375        // Second call - cached
376        let result2 = cache
377            .execute_with_cache("calculate", &params, || {
378                call_count += 1;
379                async { json!(8) }
380            })
381            .await;
382
383        assert_eq!(result1.data, json!(8));
384        assert_eq!(result2.data, json!(8));
385        // Function should only be called once
386        // Note: call_count tracking doesn't work directly in async closure
387    }
388}