1use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13#[derive(Debug, Clone)]
15pub struct CacheConfig {
16 pub enabled: bool,
18 pub ttl_seconds: u64,
20 pub max_entries: usize,
22}
23
24impl Default for CacheConfig {
25 fn default() -> Self {
26 Self {
27 enabled: true,
28 ttl_seconds: 420, max_entries: 1000,
30 }
31 }
32}
33
34#[derive(Debug, Serialize, Deserialize)]
36struct CacheEntry<T: Clone> {
37 data: Arc<T>,
39 created_at: SystemTime,
41 ttl: Duration,
43}
44
45impl<T: Clone> CacheEntry<T> {
46 fn new(data: T, ttl: Duration) -> Self {
48 Self {
49 data: Arc::new(data),
50 created_at: SystemTime::now(),
51 ttl,
52 }
53 }
54
55 fn is_expired(&self) -> bool {
57 self.created_at.elapsed().unwrap_or(Duration::MAX) > self.ttl
58 }
59
60 fn data_arc(&self) -> &Arc<T> {
62 &self.data
63 }
64}
65
66#[derive(Debug, Clone, Hash, Eq, PartialEq)]
68pub struct QueryMemoryKey {
69 pub query: String,
70 pub domain: String,
71 pub task_type: Option<String>,
72 pub limit: usize,
73}
74
75impl QueryMemoryKey {
76 pub fn new(query: String, domain: String, task_type: Option<String>, limit: usize) -> Self {
77 Self {
78 query,
79 domain,
80 task_type,
81 limit,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Hash, Eq, PartialEq)]
88pub struct AnalyzePatternsKey {
89 pub task_type: String,
90 pub min_success_rate: u32, pub limit: usize,
92}
93
94impl AnalyzePatternsKey {
95 pub fn new(task_type: String, min_success_rate: f32, limit: usize) -> Self {
96 Self {
97 task_type,
98 min_success_rate: (min_success_rate * 100.0) as u32, limit,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Hash, Eq, PartialEq)]
106pub struct ExecuteCodeKey {
107 pub code_hash: u64, pub context_task: String,
109 pub context_input_hash: u64, }
111
112impl ExecuteCodeKey {
113 pub fn new(code: &str, context: &super::ExecutionContext) -> Self {
114 let mut hasher = std::collections::hash_map::DefaultHasher::new();
115 code.hash(&mut hasher);
116 let code_hash = hasher.finish();
117
118 let mut hasher = std::collections::hash_map::DefaultHasher::new();
119 context.input.to_string().hash(&mut hasher);
120 let context_input_hash = hasher.finish();
121
122 Self {
123 code_hash,
124 context_task: context.task.clone(),
125 context_input_hash,
126 }
127 }
128}
129
130pub struct QueryCache {
132 config: CacheConfig,
133 query_memory_cache: RwLock<HashMap<QueryMemoryKey, CacheEntry<serde_json::Value>>>,
135 analyze_patterns_cache: RwLock<HashMap<AnalyzePatternsKey, CacheEntry<serde_json::Value>>>,
137 execute_code_cache: RwLock<HashMap<ExecuteCodeKey, CacheEntry<super::ExecutionResult>>>,
139 hits: RwLock<u64>,
141 misses: RwLock<u64>,
143}
144
145impl Default for QueryCache {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl QueryCache {
152 pub fn new() -> Self {
154 Self::with_config(CacheConfig::default())
155 }
156
157 pub fn with_config(config: CacheConfig) -> Self {
159 Self {
160 config,
161 query_memory_cache: RwLock::new(HashMap::new()),
162 analyze_patterns_cache: RwLock::new(HashMap::new()),
163 execute_code_cache: RwLock::new(HashMap::new()),
164 hits: RwLock::new(0),
165 misses: RwLock::new(0),
166 }
167 }
168
169 pub fn get_query_memory(&self, key: &QueryMemoryKey) -> Option<serde_json::Value> {
171 if !self.config.enabled {
172 return None;
173 }
174
175 let cache = self.query_memory_cache.read();
176 if let Some(entry) = cache.get(key) {
177 if !entry.is_expired() {
178 *self.hits.write() += 1;
180 return Some((**entry.data_arc()).clone());
181 }
182 }
183 *self.misses.write() += 1;
184 None
185 }
186
187 pub fn put_query_memory(&self, key: QueryMemoryKey, result: serde_json::Value) {
189 if !self.config.enabled {
190 return;
191 }
192
193 let mut cache = self.query_memory_cache.write();
194 self.evict_expired_entries(&mut cache);
195
196 if cache.len() >= self.config.max_entries {
198 self.evict_oldest(&mut cache);
199 }
200
201 let ttl = Duration::from_secs(self.config.ttl_seconds);
202 cache.insert(key, CacheEntry::new(result, ttl));
203 }
204
205 pub fn get_analyze_patterns(&self, key: &AnalyzePatternsKey) -> Option<serde_json::Value> {
207 if !self.config.enabled {
208 return None;
209 }
210
211 let cache = self.analyze_patterns_cache.read();
212 if let Some(entry) = cache.get(key) {
213 if !entry.is_expired() {
214 *self.hits.write() += 1;
216 return Some((**entry.data_arc()).clone());
217 }
218 }
219 *self.misses.write() += 1;
220 None
221 }
222
223 pub fn put_analyze_patterns(&self, key: AnalyzePatternsKey, result: serde_json::Value) {
225 if !self.config.enabled {
226 return;
227 }
228
229 let mut cache = self.analyze_patterns_cache.write();
230 self.evict_expired_entries(&mut cache);
231
232 if cache.len() >= self.config.max_entries {
234 self.evict_oldest(&mut cache);
235 }
236
237 let ttl = Duration::from_secs(self.config.ttl_seconds);
238 cache.insert(key, CacheEntry::new(result, ttl));
239 }
240
241 pub fn get_execute_code(&self, key: &ExecuteCodeKey) -> Option<super::ExecutionResult> {
243 if !self.config.enabled {
244 return None;
245 }
246
247 let cache = self.execute_code_cache.read();
248 if let Some(entry) = cache.get(key) {
249 if !entry.is_expired() {
250 *self.hits.write() += 1;
252 return Some((**entry.data_arc()).clone());
253 }
254 }
255 *self.misses.write() += 1;
256 None
257 }
258
259 pub fn put_execute_code(&self, key: ExecuteCodeKey, result: super::ExecutionResult) {
261 if !self.config.enabled {
262 return;
263 }
264
265 let mut cache = self.execute_code_cache.write();
266 self.evict_expired_entries(&mut cache);
267
268 if cache.len() >= self.config.max_entries {
270 self.evict_oldest(&mut cache);
271 }
272
273 let ttl = Duration::from_secs(self.config.ttl_seconds);
274 cache.insert(key, CacheEntry::new(result, ttl));
275 }
276
277 pub fn clear(&self) {
279 self.query_memory_cache.write().clear();
280 self.analyze_patterns_cache.write().clear();
281 self.execute_code_cache.write().clear();
282 }
283
284 pub fn stats(&self) -> CacheStats {
286 let query_memory = self.query_memory_cache.read();
287 let analyze_patterns = self.analyze_patterns_cache.read();
288 let execute_code = self.execute_code_cache.read();
289
290 let hits = *self.hits.read();
291 let misses = *self.misses.read();
292 let total = hits + misses;
293 let hit_rate = if total > 0 {
294 (hits as f64 / total as f64) * 100.0
295 } else {
296 0.0
297 };
298
299 CacheStats {
300 query_memory_entries: query_memory.len(),
301 analyze_patterns_entries: analyze_patterns.len(),
302 execute_code_entries: execute_code.len(),
303 total_entries: query_memory.len() + analyze_patterns.len() + execute_code.len(),
304 max_entries: self.config.max_entries,
305 enabled: self.config.enabled,
306 ttl_seconds: self.config.ttl_seconds,
307 hits,
308 misses,
309 hit_rate,
310 }
311 }
312
313 fn evict_expired_entries<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
315 where
316 T: Eq + Hash + Clone,
317 U: Clone,
318 {
319 cache.retain(|_, entry| !entry.is_expired());
320 }
321
322 fn evict_oldest<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
324 where
325 T: Eq + Hash + Clone,
326 U: Clone,
327 {
328 if cache.is_empty() {
329 return;
330 }
331
332 let mut oldest_key = None;
334 let mut oldest_time = SystemTime::now();
335
336 for (key, entry) in cache.iter() {
337 if entry.created_at < oldest_time {
338 oldest_time = entry.created_at;
339 oldest_key = Some(key.clone());
340 }
341 }
342
343 if let Some(key) = oldest_key {
344 cache.remove(&key);
345 }
346 }
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct CacheStats {
352 pub query_memory_entries: usize,
353 pub analyze_patterns_entries: usize,
354 pub execute_code_entries: usize,
355 pub total_entries: usize,
356 pub max_entries: usize,
357 pub enabled: bool,
358 pub ttl_seconds: u64,
359 pub hits: u64,
361 pub misses: u64,
363 pub hit_rate: f64,
365}
366
367#[cfg(test)]
368mod tests;