heliosdb_proxy/distribcache/ai/
tools.rs1use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct ToolCallKey {
14 pub tool: String,
16 pub param_hash: u64,
18}
19
20impl ToolCallKey {
21 pub fn new(tool: &str, params: &serde_json::Value) -> Self {
23 use std::collections::hash_map::DefaultHasher;
24 use std::hash::{Hash, Hasher};
25
26 let mut hasher = DefaultHasher::new();
27 params.to_string().hash(&mut hasher);
28
29 Self {
30 tool: tool.to_string(),
31 param_hash: hasher.finish(),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct ToolResult {
39 pub data: serde_json::Value,
41 pub execution_time: Duration,
43 pub timestamp: Instant,
45 pub ttl: Duration,
47}
48
49impl ToolResult {
50 pub fn new(data: serde_json::Value, execution_time: Duration) -> Self {
52 Self {
53 data,
54 execution_time,
55 timestamp: Instant::now(),
56 ttl: Duration::from_secs(300), }
58 }
59
60 pub fn with_ttl(mut self, ttl: Duration) -> Self {
62 self.ttl = ttl;
63 self
64 }
65
66 pub fn is_expired(&self) -> bool {
68 self.timestamp.elapsed() > self.ttl
69 }
70
71 pub fn size(&self) -> usize {
73 self.data.to_string().len() + 32
74 }
75}
76
77pub struct ToolResultCache {
79 cache: DashMap<ToolCallKey, ToolResult>,
81
82 deterministic_tools: HashSet<String>,
84
85 tool_ttls: DashMap<String, Duration>,
87
88 stats: ToolCacheStats,
90}
91
92#[derive(Debug, Default)]
94struct ToolCacheStats {
95 hits: AtomicU64,
96 misses: AtomicU64,
97 cached_executions: AtomicU64,
98 time_saved_ms: AtomicU64,
99}
100
101impl ToolResultCache {
102 pub fn new() -> Self {
104 let mut deterministic = HashSet::new();
106 deterministic.insert("get_weather".to_string());
107 deterministic.insert("calculate".to_string());
108 deterministic.insert("lookup_definition".to_string());
109 deterministic.insert("search_knowledge_base".to_string());
110 deterministic.insert("get_stock_price".to_string());
111 deterministic.insert("convert_units".to_string());
112 deterministic.insert("translate".to_string());
113
114 Self {
115 cache: DashMap::new(),
116 deterministic_tools: deterministic,
117 tool_ttls: DashMap::new(),
118 stats: ToolCacheStats::default(),
119 }
120 }
121
122 pub fn is_deterministic(&self, tool: &str) -> bool {
124 self.deterministic_tools.contains(tool)
125 }
126
127 pub fn mark_deterministic(&mut self, tool: impl Into<String>) {
129 self.deterministic_tools.insert(tool.into());
130 }
131
132 pub fn mark_non_deterministic(&mut self, tool: &str) {
134 self.deterministic_tools.remove(tool);
135 }
136
137 pub fn set_tool_ttl(&self, tool: impl Into<String>, ttl: Duration) {
139 self.tool_ttls.insert(tool.into(), ttl);
140 }
141
142 pub fn get(&self, key: &ToolCallKey) -> Option<ToolResult> {
144 if !self.is_deterministic(&key.tool) {
146 return None;
147 }
148
149 if let Some(result) = self.cache.get(key) {
150 if result.is_expired() {
151 drop(result);
152 self.cache.remove(key);
153 self.stats.misses.fetch_add(1, Ordering::Relaxed);
154 return None;
155 }
156
157 self.stats.hits.fetch_add(1, Ordering::Relaxed);
158 self.stats
159 .time_saved_ms
160 .fetch_add(result.execution_time.as_millis() as u64, Ordering::Relaxed);
161
162 Some(result.clone())
163 } else {
164 self.stats.misses.fetch_add(1, Ordering::Relaxed);
165 None
166 }
167 }
168
169 pub fn put(&self, key: ToolCallKey, result: ToolResult) {
171 if !self.is_deterministic(&key.tool) {
173 return;
174 }
175
176 let result = if let Some(ttl) = self.tool_ttls.get(&key.tool) {
178 result.with_ttl(*ttl)
179 } else {
180 result
181 };
182
183 self.cache.insert(key, result);
184 self.stats.cached_executions.fetch_add(1, Ordering::Relaxed);
185 }
186
187 pub async fn execute_with_cache<F, Fut>(
189 &self,
190 tool: &str,
191 params: &serde_json::Value,
192 executor: F,
193 ) -> ToolResult
194 where
195 F: FnOnce() -> Fut,
196 Fut: std::future::Future<Output = serde_json::Value>,
197 {
198 let key = ToolCallKey::new(tool, params);
199
200 if let Some(cached) = self.get(&key) {
202 return cached;
203 }
204
205 let start = Instant::now();
207 let data = executor().await;
208 let execution_time = start.elapsed();
209
210 let result = ToolResult::new(data, execution_time);
211
212 self.put(key, result.clone());
214
215 result
216 }
217
218 pub fn clear(&self) {
220 self.cache.clear();
221 }
222
223 pub fn clear_tool(&self, tool: &str) {
225 self.cache.retain(|k, _| k.tool != tool);
226 }
227
228 pub fn cleanup_expired(&self) {
230 self.cache.retain(|_, v| !v.is_expired());
231 }
232
233 pub fn stats(&self) -> ToolCacheStatsSnapshot {
235 let hits = self.stats.hits.load(Ordering::Relaxed);
236 let misses = self.stats.misses.load(Ordering::Relaxed);
237 let total = hits + misses;
238
239 ToolCacheStatsSnapshot {
240 cached_entries: self.cache.len(),
241 deterministic_tools: self.deterministic_tools.len(),
242 hits,
243 misses,
244 hit_rate: if total > 0 {
245 hits as f64 / total as f64
246 } else {
247 0.0
248 },
249 cached_executions: self.stats.cached_executions.load(Ordering::Relaxed),
250 time_saved_ms: self.stats.time_saved_ms.load(Ordering::Relaxed),
251 }
252 }
253}
254
255impl Default for ToolResultCache {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct ToolCacheStatsSnapshot {
264 pub cached_entries: usize,
265 pub deterministic_tools: usize,
266 pub hits: u64,
267 pub misses: u64,
268 pub hit_rate: f64,
269 pub cached_executions: u64,
270 pub time_saved_ms: u64,
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use serde_json::json;
277
278 #[test]
279 fn test_tool_call_key() {
280 let key1 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
281 let key2 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
282 let key3 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 3}));
283
284 assert_eq!(key1, key2);
285 assert_ne!(key1, key3);
286 }
287
288 #[test]
289 fn test_deterministic_check() {
290 let cache = ToolResultCache::new();
291
292 assert!(cache.is_deterministic("calculate"));
293 assert!(cache.is_deterministic("get_weather"));
294 assert!(!cache.is_deterministic("random_function"));
295 }
296
297 #[test]
298 fn test_cache_put_get() {
299 let cache = ToolResultCache::new();
300
301 let key = ToolCallKey::new("calculate", &json!({"expr": "2+2"}));
302 let result = ToolResult::new(json!(4), Duration::from_millis(10));
303
304 cache.put(key.clone(), result);
305
306 let cached = cache.get(&key);
307 assert!(cached.is_some());
308 assert_eq!(cached.unwrap().data, json!(4));
309 }
310
311 #[test]
312 fn test_non_deterministic_not_cached() {
313 let cache = ToolResultCache::new();
314
315 let key = ToolCallKey::new("random_tool", &json!({}));
316 let result = ToolResult::new(json!("result"), Duration::from_millis(10));
317
318 cache.put(key.clone(), result);
319
320 assert!(cache.get(&key).is_none());
322 }
323
324 #[test]
325 fn test_expired_entries() {
326 let cache = ToolResultCache::new();
327
328 let key = ToolCallKey::new("calculate", &json!({}));
329 let result =
330 ToolResult::new(json!(1), Duration::from_millis(1)).with_ttl(Duration::from_millis(1));
331
332 cache.put(key.clone(), result);
333
334 std::thread::sleep(Duration::from_millis(10));
336
337 assert!(cache.get(&key).is_none());
338 }
339
340 #[test]
341 fn test_stats() {
342 let cache = ToolResultCache::new();
343
344 let key = ToolCallKey::new("calculate", &json!({}));
345 let result = ToolResult::new(json!(1), Duration::from_millis(50));
346
347 cache.put(key.clone(), result);
348 cache.get(&key); cache.get(&key); let key2 = ToolCallKey::new("calculate", &json!({"x": 1}));
352 cache.get(&key2); let stats = cache.stats();
355 assert_eq!(stats.hits, 2);
356 assert_eq!(stats.misses, 1);
357 assert!(stats.time_saved_ms >= 100);
358 }
359
360 #[tokio::test]
361 async fn test_execute_with_cache() {
362 let cache = ToolResultCache::new();
363
364 let params = json!({"a": 5, "b": 3});
365 let mut call_count = 0;
366
367 let result1 = cache
369 .execute_with_cache("calculate", ¶ms, || {
370 call_count += 1;
371 async { json!(8) }
372 })
373 .await;
374
375 let result2 = cache
377 .execute_with_cache("calculate", ¶ms, || {
378 call_count += 1;
379 async { json!(8) }
380 })
381 .await;
382
383 assert_eq!(result1.data, json!(8));
384 assert_eq!(result2.data, json!(8));
385 }
388}