1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
15pub struct CacheEntry<T> {
16 pub value: T,
17 pub created_at: DateTime<Utc>,
18 pub ttl: Duration,
19 pub hits: u64,
20}
21
22impl<T: Clone> CacheEntry<T> {
23 pub fn new(value: T, ttl: Duration) -> Self {
24 Self {
25 value,
26 created_at: Utc::now(),
27 ttl,
28 hits: 0,
29 }
30 }
31
32 pub fn is_expired(&self) -> bool {
33 let elapsed = Utc::now()
34 .signed_duration_since(self.created_at)
35 .to_std()
36 .unwrap_or(Duration::ZERO);
37 elapsed > self.ttl
38 }
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct L1StructCache {
44 entries: Arc<RwLock<HashMap<String, CacheEntry<String>>>>,
45}
46
47impl L1StructCache {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn set(&self, key: &str, value: String, ttl: Duration) {
54 if let Ok(mut entries) = self.entries.write() {
55 entries.insert(key.to_string(), CacheEntry::new(value, ttl));
56 }
57 }
58
59 pub fn get(&self, key: &str) -> Option<String> {
61 if let Ok(mut entries) = self.entries.write() {
62 if let Some(entry) = entries.get_mut(key) {
63 if entry.is_expired() {
64 entries.remove(key);
65 return None;
66 }
67 entry.hits += 1;
68 return Some(entry.value.clone());
69 }
70 }
71 None
72 }
73
74 pub fn invalidate(&self, key: &str) {
76 if let Ok(mut entries) = self.entries.write() {
77 entries.remove(key);
78 }
79 }
80
81 pub fn clear(&self) {
83 if let Ok(mut entries) = self.entries.write() {
84 entries.clear();
85 }
86 }
87
88 pub fn stats(&self) -> CacheStats {
90 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
91 let total = entries.len();
92 let expired = entries.values().filter(|e| e.is_expired()).count();
93 let total_hits: u64 = entries.values().map(|e| e.hits).sum();
94 CacheStats {
95 entries: total,
96 expired,
97 hits: total_hits,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Default)]
104pub struct L2QueryCache {
105 entries: Arc<RwLock<HashMap<String, CacheEntry<String>>>>,
106}
107
108impl L2QueryCache {
109 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn set(&self, query: &str, result: String, ttl: Duration) {
114 if let Ok(mut entries) = self.entries.write() {
115 entries.insert(query.to_string(), CacheEntry::new(result, ttl));
116 }
117 }
118
119 pub fn get(&self, query: &str) -> Option<String> {
120 if let Ok(mut entries) = self.entries.write() {
121 if let Some(entry) = entries.get_mut(query) {
122 if entry.is_expired() {
123 entries.remove(query);
124 return None;
125 }
126 entry.hits += 1;
127 return Some(entry.value.clone());
128 }
129 }
130 None
131 }
132
133 pub fn clear(&self) {
134 if let Ok(mut entries) = self.entries.write() {
135 entries.clear();
136 }
137 }
138
139 pub fn stats(&self) -> CacheStats {
140 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
141 CacheStats {
142 entries: entries.len(),
143 expired: entries.values().filter(|e| e.is_expired()).count(),
144 hits: entries.values().map(|e| e.hits).sum(),
145 }
146 }
147}
148
149#[derive(Debug, Clone, Default)]
151pub struct L3LlmCache {
152 entries: Arc<RwLock<HashMap<String, CacheEntry<LlmCacheEntry>>>>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct LlmCacheEntry {
158 pub response: String,
159 pub model: String,
160 pub cost_saved: f64,
161}
162
163impl L3LlmCache {
164 pub fn new() -> Self {
165 Self::default()
166 }
167
168 pub fn set(&self, prompt_hash: &str, entry: LlmCacheEntry, ttl: Duration) {
170 if let Ok(mut entries) = self.entries.write() {
171 entries.insert(prompt_hash.to_string(), CacheEntry::new(entry, ttl));
172 }
173 }
174
175 pub fn get(&self, prompt_hash: &str) -> Option<LlmCacheEntry> {
177 if let Ok(mut entries) = self.entries.write() {
178 if let Some(entry) = entries.get_mut(prompt_hash) {
179 if entry.is_expired() {
180 entries.remove(prompt_hash);
181 return None;
182 }
183 entry.hits += 1;
184 return Some(entry.value.clone());
185 }
186 }
187 None
188 }
189
190 pub fn total_cost_saved(&self) -> f64 {
192 self.entries
193 .read()
194 .map(|entries| {
195 entries
196 .values()
197 .filter(|e| e.hits > 0)
198 .map(|e| e.value.cost_saved * e.hits as f64)
199 .sum()
200 })
201 .unwrap_or(0.0)
202 }
203
204 pub fn clear(&self) {
205 if let Ok(mut entries) = self.entries.write() {
206 entries.clear();
207 }
208 }
209
210 pub fn stats(&self) -> CacheStats {
211 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
212 CacheStats {
213 entries: entries.len(),
214 expired: entries.values().filter(|e| e.is_expired()).count(),
215 hits: entries.values().map(|e| e.hits).sum(),
216 }
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CacheStats {
223 pub entries: usize,
224 pub expired: usize,
225 pub hits: u64,
226}
227
228#[derive(Debug, Clone)]
230pub struct CacheSystem {
231 pub l1: L1StructCache,
232 pub l2: L2QueryCache,
233 pub l3: L3LlmCache,
234}
235
236impl CacheSystem {
237 pub fn new() -> Self {
238 Self {
239 l1: L1StructCache::new(),
240 l2: L2QueryCache::new(),
241 l3: L3LlmCache::new(),
242 }
243 }
244
245 pub fn all_stats(&self) -> (CacheStats, CacheStats, CacheStats) {
247 (self.l1.stats(), self.l2.stats(), self.l3.stats())
248 }
249
250 pub fn clear_all(&self) {
252 self.l1.clear();
253 self.l2.clear();
254 self.l3.clear();
255 }
256}
257
258impl Default for CacheSystem {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264pub fn prompt_hash(prompt: &str) -> String {
266 use sha2::{Digest, Sha256};
267 let mut hasher = Sha256::new();
268 hasher.update(prompt.as_bytes());
269 format!("{:x}", hasher.finalize())
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_l1_set_get() {
278 let cache = L1StructCache::new();
279 cache.set("key1", "value1".into(), Duration::from_secs(60));
280 assert_eq!(cache.get("key1"), Some("value1".into()));
281 assert_eq!(cache.get("missing"), None);
282 }
283
284 #[test]
285 fn test_l1_expiration() {
286 let cache = L1StructCache::new();
287 cache.set("key1", "value1".into(), Duration::from_secs(0));
288 std::thread::sleep(Duration::from_millis(10));
290 assert_eq!(cache.get("key1"), None);
291 }
292
293 #[test]
294 fn test_l1_invalidate() {
295 let cache = L1StructCache::new();
296 cache.set("key1", "value1".into(), Duration::from_secs(60));
297 cache.invalidate("key1");
298 assert_eq!(cache.get("key1"), None);
299 }
300
301 #[test]
302 fn test_l1_stats() {
303 let cache = L1StructCache::new();
304 cache.set("k1", "v1".into(), Duration::from_secs(60));
305 cache.set("k2", "v2".into(), Duration::from_secs(60));
306 cache.get("k1");
307 cache.get("k1");
308 let stats = cache.stats();
309 assert_eq!(stats.entries, 2);
310 assert_eq!(stats.hits, 2);
311 }
312
313 #[test]
314 fn test_l3_cost_saved() {
315 let cache = L3LlmCache::new();
316 cache.set(
317 "hash1",
318 LlmCacheEntry {
319 response: "cached response".into(),
320 model: "sonnet".into(),
321 cost_saved: 0.05,
322 },
323 Duration::from_secs(3600),
324 );
325 cache.get("hash1");
327 cache.get("hash1");
328 let saved = cache.total_cost_saved();
329 assert!((saved - 0.10).abs() < 0.001);
330 }
331
332 #[test]
333 fn test_prompt_hash() {
334 let h1 = prompt_hash("hello");
335 let h2 = prompt_hash("hello");
336 let h3 = prompt_hash("world");
337 assert_eq!(h1, h2);
338 assert_ne!(h1, h3);
339 assert_eq!(h1.len(), 64); }
341
342 #[test]
343 fn test_cache_system() {
344 let system = CacheSystem::new();
345 system.l1.set("k", "v".into(), Duration::from_secs(60));
346 let (l1, l2, l3) = system.all_stats();
347 assert_eq!(l1.entries, 1);
348 assert_eq!(l2.entries, 0);
349 assert_eq!(l3.entries, 0);
350
351 system.clear_all();
352 let (l1, _, _) = system.all_stats();
353 assert_eq!(l1.entries, 0);
354 }
355}