Skip to main content

car_engine/
cache.rs

1//! Cross-proposal result cache for tool call results.
2//!
3//! Allows caching tool call results across proposals so that repeated calls
4//! with identical parameters can skip re-execution within a configurable TTL.
5
6use serde_json::Value;
7use std::collections::hash_map::DefaultHasher;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14/// Cached tool call result with expiration.
15struct CacheEntry {
16    result: Value,
17    inserted_at: Instant,
18    ttl: Duration,
19}
20
21impl CacheEntry {
22    fn is_expired(&self) -> bool {
23        self.inserted_at.elapsed() > self.ttl
24    }
25}
26
27/// Statistics for cache usage.
28#[derive(Debug, Clone)]
29pub struct CacheStats {
30    pub hits: u64,
31    pub misses: u64,
32    pub entries: usize,
33}
34
35/// Cross-proposal cache for tool call results.
36pub struct ResultCache {
37    entries: Mutex<HashMap<String, CacheEntry>>,
38    /// Per-tool TTL configuration. Tools not listed are not cached.
39    tool_ttls: Mutex<HashMap<String, Duration>>,
40    hits: AtomicU64,
41    misses: AtomicU64,
42}
43
44impl ResultCache {
45    /// Create an empty cache with no tools enabled.
46    pub fn new() -> Self {
47        Self {
48            entries: Mutex::new(HashMap::new()),
49            tool_ttls: Mutex::new(HashMap::new()),
50            hits: AtomicU64::new(0),
51            misses: AtomicU64::new(0),
52        }
53    }
54
55    /// Mark a tool as cacheable with a given TTL in seconds.
56    pub async fn enable_caching(&self, tool: &str, ttl_secs: u64) {
57        let mut ttls = self.tool_ttls.lock().await;
58        ttls.insert(tool.to_string(), Duration::from_secs(ttl_secs));
59    }
60
61    /// Return a cached result if the tool is cacheable and the entry is fresh.
62    pub async fn get(&self, tool: &str, params: &Value) -> Option<Value> {
63        let ttls = self.tool_ttls.lock().await;
64        if !ttls.contains_key(tool) {
65            return None;
66        }
67        drop(ttls);
68
69        let key = cache_key(tool, params);
70        let mut entries = self.entries.lock().await;
71
72        if let Some(entry) = entries.get(&key) {
73            if entry.is_expired() {
74                entries.remove(&key);
75                self.misses.fetch_add(1, Ordering::Relaxed);
76                None
77            } else {
78                self.hits.fetch_add(1, Ordering::Relaxed);
79                Some(entry.result.clone())
80            }
81        } else {
82            self.misses.fetch_add(1, Ordering::Relaxed);
83            None
84        }
85    }
86
87    /// Store a result in the cache. Only stores if the tool has caching enabled.
88    pub async fn put(&self, tool: &str, params: &Value, result: Value) {
89        let ttls = self.tool_ttls.lock().await;
90        let ttl = match ttls.get(tool) {
91            Some(ttl) => *ttl,
92            None => return,
93        };
94        drop(ttls);
95
96        let key = cache_key(tool, params);
97        let mut entries = self.entries.lock().await;
98        entries.insert(
99            key,
100            CacheEntry {
101                result,
102                inserted_at: Instant::now(),
103                ttl,
104            },
105        );
106    }
107
108    /// Clear all cached entries for a specific tool.
109    pub async fn invalidate(&self, tool: &str) {
110        let prefix = format!("{}:", tool);
111        let mut entries = self.entries.lock().await;
112        entries.retain(|k, _| !k.starts_with(&prefix));
113    }
114
115    /// Clear all cached entries.
116    pub async fn invalidate_all(&self) {
117        let mut entries = self.entries.lock().await;
118        entries.clear();
119    }
120
121    /// Return hit/miss counts and current entry count.
122    pub async fn stats(&self) -> CacheStats {
123        let entries = self.entries.lock().await;
124        CacheStats {
125            hits: self.hits.load(Ordering::Relaxed),
126            misses: self.misses.load(Ordering::Relaxed),
127            entries: entries.len(),
128        }
129    }
130}
131
132impl Default for ResultCache {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Build a deterministic cache key from tool name and parameters.
139fn cache_key(tool: &str, params: &Value) -> String {
140    let serialized = serde_json::to_string(params).unwrap_or_default();
141    let mut hasher = DefaultHasher::new();
142    serialized.hash(&mut hasher);
143    let hash = hasher.finish();
144    format!("{}:{:x}", tool, hash)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use serde_json::json;
151
152    #[tokio::test]
153    async fn test_cache_hit_returns_stored_result() {
154        let cache = ResultCache::new();
155        cache.enable_caching("add", 60).await;
156
157        let params = json!({"a": 1, "b": 2});
158        let result = json!(3);
159
160        cache.put("add", &params, result.clone()).await;
161
162        let cached = cache.get("add", &params).await;
163        assert_eq!(cached, Some(result));
164
165        let stats = cache.stats().await;
166        assert_eq!(stats.hits, 1);
167        assert_eq!(stats.misses, 0);
168        assert_eq!(stats.entries, 1);
169    }
170
171    #[tokio::test]
172    async fn test_expired_entries_return_none() {
173        let cache = ResultCache::new();
174        // Use 0-second TTL so entries expire immediately.
175        cache.enable_caching("add", 0).await;
176
177        let params = json!({"a": 1, "b": 2});
178        cache.put("add", &params, json!(3)).await;
179
180        // Entry should be expired immediately (TTL = 0s means already past).
181        let cached = cache.get("add", &params).await;
182        assert_eq!(cached, None);
183
184        let stats = cache.stats().await;
185        assert_eq!(stats.hits, 0);
186        assert_eq!(stats.misses, 1);
187    }
188
189    #[tokio::test]
190    async fn test_different_params_produce_different_keys() {
191        let cache = ResultCache::new();
192        cache.enable_caching("add", 60).await;
193
194        let params_a = json!({"a": 1, "b": 2});
195        let params_b = json!({"a": 3, "b": 4});
196
197        cache.put("add", &params_a, json!(3)).await;
198        cache.put("add", &params_b, json!(7)).await;
199
200        assert_eq!(cache.get("add", &params_a).await, Some(json!(3)));
201        assert_eq!(cache.get("add", &params_b).await, Some(json!(7)));
202
203        let stats = cache.stats().await;
204        assert_eq!(stats.entries, 2);
205    }
206
207    #[tokio::test]
208    async fn test_invalidate_clears_tool_entries() {
209        let cache = ResultCache::new();
210        cache.enable_caching("add", 60).await;
211        cache.enable_caching("echo", 60).await;
212
213        cache.put("add", &json!({"a": 1}), json!(1)).await;
214        cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
215
216        cache.invalidate("add").await;
217
218        assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
219        assert_eq!(
220            cache.get("echo", &json!({"msg": "hi"})).await,
221            Some(json!("hi"))
222        );
223    }
224
225    #[tokio::test]
226    async fn test_invalidate_all_clears_everything() {
227        let cache = ResultCache::new();
228        cache.enable_caching("add", 60).await;
229        cache.enable_caching("echo", 60).await;
230
231        cache.put("add", &json!({"a": 1}), json!(1)).await;
232        cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
233
234        cache.invalidate_all().await;
235
236        let stats = cache.stats().await;
237        assert_eq!(stats.entries, 0);
238    }
239
240    #[tokio::test]
241    async fn test_uncacheable_tool_returns_none() {
242        let cache = ResultCache::new();
243        // "add" not enabled for caching
244        cache.put("add", &json!({"a": 1}), json!(1)).await;
245        assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
246    }
247}