Skip to main content

do_memory_mcp/
cache.rs

1//! Query result caching for MCP operations
2//!
3//! This module provides caching functionality for expensive MCP operations
4//! to improve performance and reduce redundant computations.
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13/// Configuration for query caching
14#[derive(Debug, Clone)]
15pub struct CacheConfig {
16    /// Enable caching
17    pub enabled: bool,
18    /// Cache TTL in seconds (default: 7 minutes = 420 seconds)
19    pub ttl_seconds: u64,
20    /// Maximum number of cached entries
21    pub max_entries: usize,
22}
23
24impl Default for CacheConfig {
25    fn default() -> Self {
26        Self {
27            enabled: true,
28            ttl_seconds: 420, // 7 minutes
29            max_entries: 1000,
30        }
31    }
32}
33
34/// Cache entry with timestamp and data (uses Arc for zero-copy sharing)
35#[derive(Debug, Serialize, Deserialize)]
36struct CacheEntry<T: Clone> {
37    /// Cached data (wrapped in Arc for cheap cloning on cache hit)
38    data: Arc<T>,
39    /// Timestamp when entry was created
40    created_at: SystemTime,
41    /// TTL for this entry
42    ttl: Duration,
43}
44
45impl<T: Clone> CacheEntry<T> {
46    /// Create a new cache entry
47    fn new(data: T, ttl: Duration) -> Self {
48        Self {
49            data: Arc::new(data),
50            created_at: SystemTime::now(),
51            ttl,
52        }
53    }
54
55    /// Check if entry is expired
56    fn is_expired(&self) -> bool {
57        self.created_at.elapsed().unwrap_or(Duration::MAX) > self.ttl
58    }
59
60    /// Get data as Arc for cheap sharing across threads
61    fn data_arc(&self) -> &Arc<T> {
62        &self.data
63    }
64}
65
66/// Cache key for query_memory operations
67#[derive(Debug, Clone, Hash, Eq, PartialEq)]
68pub struct QueryMemoryKey {
69    pub query: String,
70    pub domain: String,
71    pub task_type: Option<String>,
72    pub limit: usize,
73}
74
75impl QueryMemoryKey {
76    pub fn new(query: String, domain: String, task_type: Option<String>, limit: usize) -> Self {
77        Self {
78            query,
79            domain,
80            task_type,
81            limit,
82        }
83    }
84}
85
86/// Cache key for analyze_patterns operations
87#[derive(Debug, Clone, Hash, Eq, PartialEq)]
88pub struct AnalyzePatternsKey {
89    pub task_type: String,
90    pub min_success_rate: u32, // Store as integer for hashing
91    pub limit: usize,
92}
93
94impl AnalyzePatternsKey {
95    pub fn new(task_type: String, min_success_rate: f32, limit: usize) -> Self {
96        Self {
97            task_type,
98            min_success_rate: (min_success_rate * 100.0) as u32, // Convert to integer for hashing
99            limit,
100        }
101    }
102}
103
104/// Cache key for execute_agent_code operations
105#[derive(Debug, Clone, Hash, Eq, PartialEq)]
106pub struct ExecuteCodeKey {
107    pub code_hash: u64, // Hash of the code for caching
108    pub context_task: String,
109    pub context_input_hash: u64, // Hash of input JSON
110}
111
112impl ExecuteCodeKey {
113    pub fn new(code: &str, context: &super::ExecutionContext) -> Self {
114        let mut hasher = std::collections::hash_map::DefaultHasher::new();
115        code.hash(&mut hasher);
116        let code_hash = hasher.finish();
117
118        let mut hasher = std::collections::hash_map::DefaultHasher::new();
119        context.input.to_string().hash(&mut hasher);
120        let context_input_hash = hasher.finish();
121
122        Self {
123            code_hash,
124            context_task: context.task.clone(),
125            context_input_hash,
126        }
127    }
128}
129
130/// Query result cache for MCP operations
131pub struct QueryCache {
132    config: CacheConfig,
133    /// Cache for query_memory results
134    query_memory_cache: RwLock<HashMap<QueryMemoryKey, CacheEntry<serde_json::Value>>>,
135    /// Cache for analyze_patterns results
136    analyze_patterns_cache: RwLock<HashMap<AnalyzePatternsKey, CacheEntry<serde_json::Value>>>,
137    /// Cache for execute_agent_code results
138    execute_code_cache: RwLock<HashMap<ExecuteCodeKey, CacheEntry<super::ExecutionResult>>>,
139    /// Cache hit count
140    hits: RwLock<u64>,
141    /// Cache miss count
142    misses: RwLock<u64>,
143}
144
145impl Default for QueryCache {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151impl QueryCache {
152    /// Create a new query cache with default configuration
153    pub fn new() -> Self {
154        Self::with_config(CacheConfig::default())
155    }
156
157    /// Create a new query cache with custom configuration
158    pub fn with_config(config: CacheConfig) -> Self {
159        Self {
160            config,
161            query_memory_cache: RwLock::new(HashMap::new()),
162            analyze_patterns_cache: RwLock::new(HashMap::new()),
163            execute_code_cache: RwLock::new(HashMap::new()),
164            hits: RwLock::new(0),
165            misses: RwLock::new(0),
166        }
167    }
168
169    /// Get cached query_memory result
170    pub fn get_query_memory(&self, key: &QueryMemoryKey) -> Option<serde_json::Value> {
171        if !self.config.enabled {
172            return None;
173        }
174
175        let cache = self.query_memory_cache.read();
176        if let Some(entry) = cache.get(key) {
177            if !entry.is_expired() {
178                // Clone from Arc (cheaper than deep clone for shared entries)
179                *self.hits.write() += 1;
180                return Some((**entry.data_arc()).clone());
181            }
182        }
183        *self.misses.write() += 1;
184        None
185    }
186
187    /// Cache query_memory result
188    pub fn put_query_memory(&self, key: QueryMemoryKey, result: serde_json::Value) {
189        if !self.config.enabled {
190            return;
191        }
192
193        let mut cache = self.query_memory_cache.write();
194        self.evict_expired_entries(&mut cache);
195
196        // Evict oldest entries if at capacity
197        if cache.len() >= self.config.max_entries {
198            self.evict_oldest(&mut cache);
199        }
200
201        let ttl = Duration::from_secs(self.config.ttl_seconds);
202        cache.insert(key, CacheEntry::new(result, ttl));
203    }
204
205    /// Get cached analyze_patterns result
206    pub fn get_analyze_patterns(&self, key: &AnalyzePatternsKey) -> Option<serde_json::Value> {
207        if !self.config.enabled {
208            return None;
209        }
210
211        let cache = self.analyze_patterns_cache.read();
212        if let Some(entry) = cache.get(key) {
213            if !entry.is_expired() {
214                // Clone from Arc (cheaper than deep clone for shared entries)
215                *self.hits.write() += 1;
216                return Some((**entry.data_arc()).clone());
217            }
218        }
219        *self.misses.write() += 1;
220        None
221    }
222
223    /// Cache analyze_patterns result
224    pub fn put_analyze_patterns(&self, key: AnalyzePatternsKey, result: serde_json::Value) {
225        if !self.config.enabled {
226            return;
227        }
228
229        let mut cache = self.analyze_patterns_cache.write();
230        self.evict_expired_entries(&mut cache);
231
232        // Evict oldest entries if at capacity
233        if cache.len() >= self.config.max_entries {
234            self.evict_oldest(&mut cache);
235        }
236
237        let ttl = Duration::from_secs(self.config.ttl_seconds);
238        cache.insert(key, CacheEntry::new(result, ttl));
239    }
240
241    /// Get cached execute_agent_code result
242    pub fn get_execute_code(&self, key: &ExecuteCodeKey) -> Option<super::ExecutionResult> {
243        if !self.config.enabled {
244            return None;
245        }
246
247        let cache = self.execute_code_cache.read();
248        if let Some(entry) = cache.get(key) {
249            if !entry.is_expired() {
250                // Clone from Arc (cheaper than deep clone for shared entries)
251                *self.hits.write() += 1;
252                return Some((**entry.data_arc()).clone());
253            }
254        }
255        *self.misses.write() += 1;
256        None
257    }
258
259    /// Cache execute_agent_code result
260    pub fn put_execute_code(&self, key: ExecuteCodeKey, result: super::ExecutionResult) {
261        if !self.config.enabled {
262            return;
263        }
264
265        let mut cache = self.execute_code_cache.write();
266        self.evict_expired_entries(&mut cache);
267
268        // Evict oldest entries if at capacity
269        if cache.len() >= self.config.max_entries {
270            self.evict_oldest(&mut cache);
271        }
272
273        let ttl = Duration::from_secs(self.config.ttl_seconds);
274        cache.insert(key, CacheEntry::new(result, ttl));
275    }
276
277    /// Clear all cached entries
278    pub fn clear(&self) {
279        self.query_memory_cache.write().clear();
280        self.analyze_patterns_cache.write().clear();
281        self.execute_code_cache.write().clear();
282    }
283
284    /// Get cache statistics
285    pub fn stats(&self) -> CacheStats {
286        let query_memory = self.query_memory_cache.read();
287        let analyze_patterns = self.analyze_patterns_cache.read();
288        let execute_code = self.execute_code_cache.read();
289
290        let hits = *self.hits.read();
291        let misses = *self.misses.read();
292        let total = hits + misses;
293        let hit_rate = if total > 0 {
294            (hits as f64 / total as f64) * 100.0
295        } else {
296            0.0
297        };
298
299        CacheStats {
300            query_memory_entries: query_memory.len(),
301            analyze_patterns_entries: analyze_patterns.len(),
302            execute_code_entries: execute_code.len(),
303            total_entries: query_memory.len() + analyze_patterns.len() + execute_code.len(),
304            max_entries: self.config.max_entries,
305            enabled: self.config.enabled,
306            ttl_seconds: self.config.ttl_seconds,
307            hits,
308            misses,
309            hit_rate,
310        }
311    }
312
313    /// Evict expired entries from a cache
314    fn evict_expired_entries<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
315    where
316        T: Eq + Hash + Clone,
317        U: Clone,
318    {
319        cache.retain(|_, entry| !entry.is_expired());
320    }
321
322    /// Evict oldest entries when at capacity (LRU-style)
323    fn evict_oldest<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
324    where
325        T: Eq + Hash + Clone,
326        U: Clone,
327    {
328        if cache.is_empty() {
329            return;
330        }
331
332        // Find the oldest entry
333        let mut oldest_key = None;
334        let mut oldest_time = SystemTime::now();
335
336        for (key, entry) in cache.iter() {
337            if entry.created_at < oldest_time {
338                oldest_time = entry.created_at;
339                oldest_key = Some(key.clone());
340            }
341        }
342
343        if let Some(key) = oldest_key {
344            cache.remove(&key);
345        }
346    }
347}
348
349/// Cache statistics
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct CacheStats {
352    pub query_memory_entries: usize,
353    pub analyze_patterns_entries: usize,
354    pub execute_code_entries: usize,
355    pub total_entries: usize,
356    pub max_entries: usize,
357    pub enabled: bool,
358    pub ttl_seconds: u64,
359    /// Total cache hits
360    pub hits: u64,
361    /// Total cache misses
362    pub misses: u64,
363    /// Cache hit rate (percentage)
364    pub hit_rate: f64,
365}
366
367#[cfg(test)]
368mod tests;