heliosdb_proxy/distribcache/ai/
tools.rs1use dashmap::DashMap;
7use std::collections::HashSet;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct ToolCallKey {
14 pub tool: String,
16 pub param_hash: u64,
18}
19
20impl ToolCallKey {
21 pub fn new(tool: &str, params: &serde_json::Value) -> Self {
23 use std::hash::{Hash, Hasher};
24 use std::collections::hash_map::DefaultHasher;
25
26 let mut hasher = DefaultHasher::new();
27 params.to_string().hash(&mut hasher);
28
29 Self {
30 tool: tool.to_string(),
31 param_hash: hasher.finish(),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct ToolResult {
39 pub data: serde_json::Value,
41 pub execution_time: Duration,
43 pub timestamp: Instant,
45 pub ttl: Duration,
47}
48
49impl ToolResult {
50 pub fn new(data: serde_json::Value, execution_time: Duration) -> Self {
52 Self {
53 data,
54 execution_time,
55 timestamp: Instant::now(),
56 ttl: Duration::from_secs(300), }
58 }
59
60 pub fn with_ttl(mut self, ttl: Duration) -> Self {
62 self.ttl = ttl;
63 self
64 }
65
66 pub fn is_expired(&self) -> bool {
68 self.timestamp.elapsed() > self.ttl
69 }
70
71 pub fn size(&self) -> usize {
73 self.data.to_string().len() + 32
74 }
75}
76
77pub struct ToolResultCache {
79 cache: DashMap<ToolCallKey, ToolResult>,
81
82 deterministic_tools: HashSet<String>,
84
85 tool_ttls: DashMap<String, Duration>,
87
88 stats: ToolCacheStats,
90}
91
92#[derive(Debug, Default)]
94struct ToolCacheStats {
95 hits: AtomicU64,
96 misses: AtomicU64,
97 cached_executions: AtomicU64,
98 time_saved_ms: AtomicU64,
99}
100
101impl ToolResultCache {
102 pub fn new() -> Self {
104 let mut deterministic = HashSet::new();
106 deterministic.insert("get_weather".to_string());
107 deterministic.insert("calculate".to_string());
108 deterministic.insert("lookup_definition".to_string());
109 deterministic.insert("search_knowledge_base".to_string());
110 deterministic.insert("get_stock_price".to_string());
111 deterministic.insert("convert_units".to_string());
112 deterministic.insert("translate".to_string());
113
114 Self {
115 cache: DashMap::new(),
116 deterministic_tools: deterministic,
117 tool_ttls: DashMap::new(),
118 stats: ToolCacheStats::default(),
119 }
120 }
121
122 pub fn is_deterministic(&self, tool: &str) -> bool {
124 self.deterministic_tools.contains(tool)
125 }
126
127 pub fn mark_deterministic(&mut self, tool: impl Into<String>) {
129 self.deterministic_tools.insert(tool.into());
130 }
131
132 pub fn mark_non_deterministic(&mut self, tool: &str) {
134 self.deterministic_tools.remove(tool);
135 }
136
137 pub fn set_tool_ttl(&self, tool: impl Into<String>, ttl: Duration) {
139 self.tool_ttls.insert(tool.into(), ttl);
140 }
141
142 pub fn get(&self, key: &ToolCallKey) -> Option<ToolResult> {
144 if !self.is_deterministic(&key.tool) {
146 return None;
147 }
148
149 if let Some(result) = self.cache.get(key) {
150 if result.is_expired() {
151 drop(result);
152 self.cache.remove(key);
153 self.stats.misses.fetch_add(1, Ordering::Relaxed);
154 return None;
155 }
156
157 self.stats.hits.fetch_add(1, Ordering::Relaxed);
158 self.stats.time_saved_ms.fetch_add(
159 result.execution_time.as_millis() as u64,
160 Ordering::Relaxed,
161 );
162
163 Some(result.clone())
164 } else {
165 self.stats.misses.fetch_add(1, Ordering::Relaxed);
166 None
167 }
168 }
169
170 pub fn put(&self, key: ToolCallKey, result: ToolResult) {
172 if !self.is_deterministic(&key.tool) {
174 return;
175 }
176
177 let result = if let Some(ttl) = self.tool_ttls.get(&key.tool) {
179 result.with_ttl(*ttl)
180 } else {
181 result
182 };
183
184 self.cache.insert(key, result);
185 self.stats.cached_executions.fetch_add(1, Ordering::Relaxed);
186 }
187
188 pub async fn execute_with_cache<F, Fut>(
190 &self,
191 tool: &str,
192 params: &serde_json::Value,
193 executor: F,
194 ) -> ToolResult
195 where
196 F: FnOnce() -> Fut,
197 Fut: std::future::Future<Output = serde_json::Value>,
198 {
199 let key = ToolCallKey::new(tool, params);
200
201 if let Some(cached) = self.get(&key) {
203 return cached;
204 }
205
206 let start = Instant::now();
208 let data = executor().await;
209 let execution_time = start.elapsed();
210
211 let result = ToolResult::new(data, execution_time);
212
213 self.put(key, result.clone());
215
216 result
217 }
218
219 pub fn clear(&self) {
221 self.cache.clear();
222 }
223
224 pub fn clear_tool(&self, tool: &str) {
226 self.cache.retain(|k, _| k.tool != tool);
227 }
228
229 pub fn cleanup_expired(&self) {
231 self.cache.retain(|_, v| !v.is_expired());
232 }
233
234 pub fn stats(&self) -> ToolCacheStatsSnapshot {
236 let hits = self.stats.hits.load(Ordering::Relaxed);
237 let misses = self.stats.misses.load(Ordering::Relaxed);
238 let total = hits + misses;
239
240 ToolCacheStatsSnapshot {
241 cached_entries: self.cache.len(),
242 deterministic_tools: self.deterministic_tools.len(),
243 hits,
244 misses,
245 hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
246 cached_executions: self.stats.cached_executions.load(Ordering::Relaxed),
247 time_saved_ms: self.stats.time_saved_ms.load(Ordering::Relaxed),
248 }
249 }
250}
251
252impl Default for ToolResultCache {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct ToolCacheStatsSnapshot {
261 pub cached_entries: usize,
262 pub deterministic_tools: usize,
263 pub hits: u64,
264 pub misses: u64,
265 pub hit_rate: f64,
266 pub cached_executions: u64,
267 pub time_saved_ms: u64,
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use serde_json::json;
274
275 #[test]
276 fn test_tool_call_key() {
277 let key1 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
278 let key2 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 2}));
279 let key3 = ToolCallKey::new("calculate", &json!({"a": 1, "b": 3}));
280
281 assert_eq!(key1, key2);
282 assert_ne!(key1, key3);
283 }
284
285 #[test]
286 fn test_deterministic_check() {
287 let cache = ToolResultCache::new();
288
289 assert!(cache.is_deterministic("calculate"));
290 assert!(cache.is_deterministic("get_weather"));
291 assert!(!cache.is_deterministic("random_function"));
292 }
293
294 #[test]
295 fn test_cache_put_get() {
296 let cache = ToolResultCache::new();
297
298 let key = ToolCallKey::new("calculate", &json!({"expr": "2+2"}));
299 let result = ToolResult::new(json!(4), Duration::from_millis(10));
300
301 cache.put(key.clone(), result);
302
303 let cached = cache.get(&key);
304 assert!(cached.is_some());
305 assert_eq!(cached.unwrap().data, json!(4));
306 }
307
308 #[test]
309 fn test_non_deterministic_not_cached() {
310 let cache = ToolResultCache::new();
311
312 let key = ToolCallKey::new("random_tool", &json!({}));
313 let result = ToolResult::new(json!("result"), Duration::from_millis(10));
314
315 cache.put(key.clone(), result);
316
317 assert!(cache.get(&key).is_none());
319 }
320
321 #[test]
322 fn test_expired_entries() {
323 let cache = ToolResultCache::new();
324
325 let key = ToolCallKey::new("calculate", &json!({}));
326 let result = ToolResult::new(json!(1), Duration::from_millis(1))
327 .with_ttl(Duration::from_millis(1));
328
329 cache.put(key.clone(), result);
330
331 std::thread::sleep(Duration::from_millis(10));
333
334 assert!(cache.get(&key).is_none());
335 }
336
337 #[test]
338 fn test_stats() {
339 let cache = ToolResultCache::new();
340
341 let key = ToolCallKey::new("calculate", &json!({}));
342 let result = ToolResult::new(json!(1), Duration::from_millis(50));
343
344 cache.put(key.clone(), result);
345 cache.get(&key); cache.get(&key); let key2 = ToolCallKey::new("calculate", &json!({"x": 1}));
349 cache.get(&key2); let stats = cache.stats();
352 assert_eq!(stats.hits, 2);
353 assert_eq!(stats.misses, 1);
354 assert!(stats.time_saved_ms >= 100);
355 }
356
357 #[tokio::test]
358 async fn test_execute_with_cache() {
359 let cache = ToolResultCache::new();
360
361 let params = json!({"a": 5, "b": 3});
362 let mut call_count = 0;
363
364 let result1 = cache.execute_with_cache("calculate", ¶ms, || {
366 call_count += 1;
367 async { json!(8) }
368 }).await;
369
370 let result2 = cache.execute_with_cache("calculate", ¶ms, || {
372 call_count += 1;
373 async { json!(8) }
374 }).await;
375
376 assert_eq!(result1.data, json!(8));
377 assert_eq!(result2.data, json!(8));
378 }
381}