Skip to main content

mur_core/
cache.rs

1//! Three-layer cache system for MUR Commander.
2//!
3//! - L1: Struct cache — pre-loaded workflows, patterns, config, constitution
4//! - L2: Query cache — search results and injection results
5//! - L3: LLM response cache — prompt→response pairs to save API costs
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::Duration;
12
13/// Cache entry with TTL tracking.
14#[derive(Debug, Clone)]
15pub struct CacheEntry<T> {
16    pub value: T,
17    pub created_at: DateTime<Utc>,
18    pub ttl: Duration,
19    pub hits: u64,
20}
21
22impl<T: Clone> CacheEntry<T> {
23    pub fn new(value: T, ttl: Duration) -> Self {
24        Self {
25            value,
26            created_at: Utc::now(),
27            ttl,
28            hits: 0,
29        }
30    }
31
32    pub fn is_expired(&self) -> bool {
33        let elapsed = Utc::now()
34            .signed_duration_since(self.created_at)
35            .to_std()
36            .unwrap_or(Duration::ZERO);
37        elapsed > self.ttl
38    }
39}
40
41/// L1: Struct cache for frequently accessed data.
42#[derive(Debug, Clone, Default)]
43pub struct L1StructCache {
44    entries: Arc<RwLock<HashMap<String, CacheEntry<String>>>>,
45}
46
47impl L1StructCache {
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Store a value with a TTL.
53    pub fn set(&self, key: &str, value: String, ttl: Duration) {
54        if let Ok(mut entries) = self.entries.write() {
55            entries.insert(key.to_string(), CacheEntry::new(value, ttl));
56        }
57    }
58
59    /// Get a value if it exists and hasn't expired.
60    pub fn get(&self, key: &str) -> Option<String> {
61        if let Ok(mut entries) = self.entries.write() {
62            if let Some(entry) = entries.get_mut(key) {
63                if entry.is_expired() {
64                    entries.remove(key);
65                    return None;
66                }
67                entry.hits += 1;
68                return Some(entry.value.clone());
69            }
70        }
71        None
72    }
73
74    /// Invalidate a specific key.
75    pub fn invalidate(&self, key: &str) {
76        if let Ok(mut entries) = self.entries.write() {
77            entries.remove(key);
78        }
79    }
80
81    /// Invalidate all entries.
82    pub fn clear(&self) {
83        if let Ok(mut entries) = self.entries.write() {
84            entries.clear();
85        }
86    }
87
88    /// Get cache statistics.
89    pub fn stats(&self) -> CacheStats {
90        let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
91        let total = entries.len();
92        let expired = entries.values().filter(|e| e.is_expired()).count();
93        let total_hits: u64 = entries.values().map(|e| e.hits).sum();
94        CacheStats {
95            entries: total,
96            expired,
97            hits: total_hits,
98        }
99    }
100}
101
102/// L2: Query result cache.
103#[derive(Debug, Clone, Default)]
104pub struct L2QueryCache {
105    entries: Arc<RwLock<HashMap<String, CacheEntry<String>>>>,
106}
107
108impl L2QueryCache {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    pub fn set(&self, query: &str, result: String, ttl: Duration) {
114        if let Ok(mut entries) = self.entries.write() {
115            entries.insert(query.to_string(), CacheEntry::new(result, ttl));
116        }
117    }
118
119    pub fn get(&self, query: &str) -> Option<String> {
120        if let Ok(mut entries) = self.entries.write() {
121            if let Some(entry) = entries.get_mut(query) {
122                if entry.is_expired() {
123                    entries.remove(query);
124                    return None;
125                }
126                entry.hits += 1;
127                return Some(entry.value.clone());
128            }
129        }
130        None
131    }
132
133    pub fn clear(&self) {
134        if let Ok(mut entries) = self.entries.write() {
135            entries.clear();
136        }
137    }
138
139    pub fn stats(&self) -> CacheStats {
140        let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
141        CacheStats {
142            entries: entries.len(),
143            expired: entries.values().filter(|e| e.is_expired()).count(),
144            hits: entries.values().map(|e| e.hits).sum(),
145        }
146    }
147}
148
149/// L3: LLM response cache — exact prompt→response to avoid duplicate API calls.
150#[derive(Debug, Clone, Default)]
151pub struct L3LlmCache {
152    entries: Arc<RwLock<HashMap<String, CacheEntry<LlmCacheEntry>>>>,
153}
154
155/// Cached LLM response.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct LlmCacheEntry {
158    pub response: String,
159    pub model: String,
160    pub cost_saved: f64,
161}
162
163impl L3LlmCache {
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    /// Cache a prompt→response pair.
169    pub fn set(&self, prompt_hash: &str, entry: LlmCacheEntry, ttl: Duration) {
170        if let Ok(mut entries) = self.entries.write() {
171            entries.insert(prompt_hash.to_string(), CacheEntry::new(entry, ttl));
172        }
173    }
174
175    /// Look up a cached response by prompt hash.
176    pub fn get(&self, prompt_hash: &str) -> Option<LlmCacheEntry> {
177        if let Ok(mut entries) = self.entries.write() {
178            if let Some(entry) = entries.get_mut(prompt_hash) {
179                if entry.is_expired() {
180                    entries.remove(prompt_hash);
181                    return None;
182                }
183                entry.hits += 1;
184                return Some(entry.value.clone());
185            }
186        }
187        None
188    }
189
190    /// Get total cost saved by cache hits.
191    pub fn total_cost_saved(&self) -> f64 {
192        self.entries
193            .read()
194            .map(|entries| {
195                entries
196                    .values()
197                    .filter(|e| e.hits > 0)
198                    .map(|e| e.value.cost_saved * e.hits as f64)
199                    .sum()
200            })
201            .unwrap_or(0.0)
202    }
203
204    pub fn clear(&self) {
205        if let Ok(mut entries) = self.entries.write() {
206            entries.clear();
207        }
208    }
209
210    pub fn stats(&self) -> CacheStats {
211        let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
212        CacheStats {
213            entries: entries.len(),
214            expired: entries.values().filter(|e| e.is_expired()).count(),
215            hits: entries.values().map(|e| e.hits).sum(),
216        }
217    }
218}
219
220/// Cache statistics.
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CacheStats {
223    pub entries: usize,
224    pub expired: usize,
225    pub hits: u64,
226}
227
228/// Combined three-layer cache.
229#[derive(Debug, Clone)]
230pub struct CacheSystem {
231    pub l1: L1StructCache,
232    pub l2: L2QueryCache,
233    pub l3: L3LlmCache,
234}
235
236impl CacheSystem {
237    pub fn new() -> Self {
238        Self {
239            l1: L1StructCache::new(),
240            l2: L2QueryCache::new(),
241            l3: L3LlmCache::new(),
242        }
243    }
244
245    /// Get combined stats for all layers.
246    pub fn all_stats(&self) -> (CacheStats, CacheStats, CacheStats) {
247        (self.l1.stats(), self.l2.stats(), self.l3.stats())
248    }
249
250    /// Clear all cache layers.
251    pub fn clear_all(&self) {
252        self.l1.clear();
253        self.l2.clear();
254        self.l3.clear();
255    }
256}
257
258impl Default for CacheSystem {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264/// Compute a simple hash of a prompt for L3 cache keys.
265pub fn prompt_hash(prompt: &str) -> String {
266    use sha2::{Digest, Sha256};
267    let mut hasher = Sha256::new();
268    hasher.update(prompt.as_bytes());
269    format!("{:x}", hasher.finalize())
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_l1_set_get() {
278        let cache = L1StructCache::new();
279        cache.set("key1", "value1".into(), Duration::from_secs(60));
280        assert_eq!(cache.get("key1"), Some("value1".into()));
281        assert_eq!(cache.get("missing"), None);
282    }
283
284    #[test]
285    fn test_l1_expiration() {
286        let cache = L1StructCache::new();
287        cache.set("key1", "value1".into(), Duration::from_secs(0));
288        // Expired immediately
289        std::thread::sleep(Duration::from_millis(10));
290        assert_eq!(cache.get("key1"), None);
291    }
292
293    #[test]
294    fn test_l1_invalidate() {
295        let cache = L1StructCache::new();
296        cache.set("key1", "value1".into(), Duration::from_secs(60));
297        cache.invalidate("key1");
298        assert_eq!(cache.get("key1"), None);
299    }
300
301    #[test]
302    fn test_l1_stats() {
303        let cache = L1StructCache::new();
304        cache.set("k1", "v1".into(), Duration::from_secs(60));
305        cache.set("k2", "v2".into(), Duration::from_secs(60));
306        cache.get("k1");
307        cache.get("k1");
308        let stats = cache.stats();
309        assert_eq!(stats.entries, 2);
310        assert_eq!(stats.hits, 2);
311    }
312
313    #[test]
314    fn test_l3_cost_saved() {
315        let cache = L3LlmCache::new();
316        cache.set(
317            "hash1",
318            LlmCacheEntry {
319                response: "cached response".into(),
320                model: "sonnet".into(),
321                cost_saved: 0.05,
322            },
323            Duration::from_secs(3600),
324        );
325        // Hit it twice
326        cache.get("hash1");
327        cache.get("hash1");
328        let saved = cache.total_cost_saved();
329        assert!((saved - 0.10).abs() < 0.001);
330    }
331
332    #[test]
333    fn test_prompt_hash() {
334        let h1 = prompt_hash("hello");
335        let h2 = prompt_hash("hello");
336        let h3 = prompt_hash("world");
337        assert_eq!(h1, h2);
338        assert_ne!(h1, h3);
339        assert_eq!(h1.len(), 64); // SHA-256 hex
340    }
341
342    #[test]
343    fn test_cache_system() {
344        let system = CacheSystem::new();
345        system.l1.set("k", "v".into(), Duration::from_secs(60));
346        let (l1, l2, l3) = system.all_stats();
347        assert_eq!(l1.entries, 1);
348        assert_eq!(l2.entries, 0);
349        assert_eq!(l3.entries, 0);
350
351        system.clear_all();
352        let (l1, _, _) = system.all_stats();
353        assert_eq!(l1.entries, 0);
354    }
355}