1use serde_json::Value;
7use std::collections::hash_map::DefaultHasher;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14struct CacheEntry {
16 result: Value,
17 inserted_at: Instant,
18 ttl: Duration,
19}
20
21impl CacheEntry {
22 fn is_expired(&self) -> bool {
23 self.inserted_at.elapsed() > self.ttl
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct CacheStats {
30 pub hits: u64,
31 pub misses: u64,
32 pub entries: usize,
33}
34
35pub struct ResultCache {
37 entries: Mutex<HashMap<String, CacheEntry>>,
38 tool_ttls: Mutex<HashMap<String, Duration>>,
40 hits: AtomicU64,
41 misses: AtomicU64,
42}
43
44impl ResultCache {
45 pub fn new() -> Self {
47 Self {
48 entries: Mutex::new(HashMap::new()),
49 tool_ttls: Mutex::new(HashMap::new()),
50 hits: AtomicU64::new(0),
51 misses: AtomicU64::new(0),
52 }
53 }
54
55 pub async fn enable_caching(&self, tool: &str, ttl_secs: u64) {
57 let mut ttls = self.tool_ttls.lock().await;
58 ttls.insert(tool.to_string(), Duration::from_secs(ttl_secs));
59 }
60
61 pub async fn get(&self, tool: &str, params: &Value) -> Option<Value> {
63 let ttls = self.tool_ttls.lock().await;
64 if !ttls.contains_key(tool) {
65 return None;
66 }
67 drop(ttls);
68
69 let key = cache_key(tool, params);
70 let mut entries = self.entries.lock().await;
71
72 if let Some(entry) = entries.get(&key) {
73 if entry.is_expired() {
74 entries.remove(&key);
75 self.misses.fetch_add(1, Ordering::Relaxed);
76 None
77 } else {
78 self.hits.fetch_add(1, Ordering::Relaxed);
79 Some(entry.result.clone())
80 }
81 } else {
82 self.misses.fetch_add(1, Ordering::Relaxed);
83 None
84 }
85 }
86
87 pub async fn put(&self, tool: &str, params: &Value, result: Value) {
89 let ttls = self.tool_ttls.lock().await;
90 let ttl = match ttls.get(tool) {
91 Some(ttl) => *ttl,
92 None => return,
93 };
94 drop(ttls);
95
96 let key = cache_key(tool, params);
97 let mut entries = self.entries.lock().await;
98 entries.insert(
99 key,
100 CacheEntry {
101 result,
102 inserted_at: Instant::now(),
103 ttl,
104 },
105 );
106 }
107
108 pub async fn invalidate(&self, tool: &str) {
110 let prefix = format!("{}:", tool);
111 let mut entries = self.entries.lock().await;
112 entries.retain(|k, _| !k.starts_with(&prefix));
113 }
114
115 pub async fn invalidate_all(&self) {
117 let mut entries = self.entries.lock().await;
118 entries.clear();
119 }
120
121 pub async fn stats(&self) -> CacheStats {
123 let entries = self.entries.lock().await;
124 CacheStats {
125 hits: self.hits.load(Ordering::Relaxed),
126 misses: self.misses.load(Ordering::Relaxed),
127 entries: entries.len(),
128 }
129 }
130}
131
132impl Default for ResultCache {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138fn cache_key(tool: &str, params: &Value) -> String {
140 let serialized = serde_json::to_string(params).unwrap_or_default();
141 let mut hasher = DefaultHasher::new();
142 serialized.hash(&mut hasher);
143 let hash = hasher.finish();
144 format!("{}:{:x}", tool, hash)
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use serde_json::json;
151
152 #[tokio::test]
153 async fn test_cache_hit_returns_stored_result() {
154 let cache = ResultCache::new();
155 cache.enable_caching("add", 60).await;
156
157 let params = json!({"a": 1, "b": 2});
158 let result = json!(3);
159
160 cache.put("add", ¶ms, result.clone()).await;
161
162 let cached = cache.get("add", ¶ms).await;
163 assert_eq!(cached, Some(result));
164
165 let stats = cache.stats().await;
166 assert_eq!(stats.hits, 1);
167 assert_eq!(stats.misses, 0);
168 assert_eq!(stats.entries, 1);
169 }
170
171 #[tokio::test]
172 async fn test_expired_entries_return_none() {
173 let cache = ResultCache::new();
174 cache.enable_caching("add", 0).await;
176
177 let params = json!({"a": 1, "b": 2});
178 cache.put("add", ¶ms, json!(3)).await;
179
180 let cached = cache.get("add", ¶ms).await;
182 assert_eq!(cached, None);
183
184 let stats = cache.stats().await;
185 assert_eq!(stats.hits, 0);
186 assert_eq!(stats.misses, 1);
187 }
188
189 #[tokio::test]
190 async fn test_different_params_produce_different_keys() {
191 let cache = ResultCache::new();
192 cache.enable_caching("add", 60).await;
193
194 let params_a = json!({"a": 1, "b": 2});
195 let params_b = json!({"a": 3, "b": 4});
196
197 cache.put("add", ¶ms_a, json!(3)).await;
198 cache.put("add", ¶ms_b, json!(7)).await;
199
200 assert_eq!(cache.get("add", ¶ms_a).await, Some(json!(3)));
201 assert_eq!(cache.get("add", ¶ms_b).await, Some(json!(7)));
202
203 let stats = cache.stats().await;
204 assert_eq!(stats.entries, 2);
205 }
206
207 #[tokio::test]
208 async fn test_invalidate_clears_tool_entries() {
209 let cache = ResultCache::new();
210 cache.enable_caching("add", 60).await;
211 cache.enable_caching("echo", 60).await;
212
213 cache.put("add", &json!({"a": 1}), json!(1)).await;
214 cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
215
216 cache.invalidate("add").await;
217
218 assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
219 assert_eq!(
220 cache.get("echo", &json!({"msg": "hi"})).await,
221 Some(json!("hi"))
222 );
223 }
224
225 #[tokio::test]
226 async fn test_invalidate_all_clears_everything() {
227 let cache = ResultCache::new();
228 cache.enable_caching("add", 60).await;
229 cache.enable_caching("echo", 60).await;
230
231 cache.put("add", &json!({"a": 1}), json!(1)).await;
232 cache.put("echo", &json!({"msg": "hi"}), json!("hi")).await;
233
234 cache.invalidate_all().await;
235
236 let stats = cache.stats().await;
237 assert_eq!(stats.entries, 0);
238 }
239
240 #[tokio::test]
241 async fn test_uncacheable_tool_returns_none() {
242 let cache = ResultCache::new();
243 cache.put("add", &json!({"a": 1}), json!(1)).await;
245 assert_eq!(cache.get("add", &json!({"a": 1})).await, None);
246 }
247}