1use std::sync::Mutex;
10use std::time::{Duration, Instant};
11
12struct CacheEntry<V> {
14 value: V,
15 created_at: Instant,
16 hits: u64,
17}
18
19struct LruTier<V> {
21 entries: Vec<(String, CacheEntry<V>)>,
22 capacity: usize,
23 ttl: Duration,
24}
25
26impl<V: Clone> LruTier<V> {
27 fn new(capacity: usize, ttl: Duration) -> Self {
28 Self {
29 entries: Vec::with_capacity(capacity),
30 capacity,
31 ttl,
32 }
33 }
34
35 fn get(&mut self, key: &str) -> Option<V> {
36 let now = Instant::now();
37 let pos = self.entries.iter().position(|(k, _)| k == key)?;
39 let entry = &mut self.entries[pos].1;
40 if now.duration_since(entry.created_at) > self.ttl {
41 self.entries.remove(pos);
42 return None;
43 }
44 entry.hits += 1;
45 let value = entry.value.clone();
46 let item = self.entries.remove(pos);
48 self.entries.push(item);
49 Some(value)
50 }
51
52 fn insert(&mut self, key: String, value: V) {
53 self.entries.retain(|(k, _)| k != &key);
55 if self.entries.len() >= self.capacity {
57 self.entries.remove(0);
58 }
59 self.entries.push((
60 key,
61 CacheEntry {
62 value,
63 created_at: Instant::now(),
64 hits: 0,
65 },
66 ));
67 }
68
69 fn clear(&mut self) {
70 self.entries.clear();
71 }
72
73 fn len(&self) -> usize {
74 self.entries.len()
75 }
76
77 fn total_hits(&self) -> u64 {
78 self.entries.iter().map(|(_, e)| e.hits).sum()
79 }
80}
81
82pub struct QueryCache {
84 tier1: Mutex<LruTier<String>>,
86 tier2: Mutex<LruTier<String>>,
88}
89
90pub struct CacheStats {
92 pub tier1_entries: usize,
93 pub tier1_hits: u64,
94 pub tier2_entries: usize,
95 pub tier2_hits: u64,
96}
97
98impl QueryCache {
99 pub fn new(tier1_capacity: usize, tier2_capacity: usize, ttl_secs: u64) -> Self {
101 let ttl = Duration::from_secs(ttl_secs);
102 Self {
103 tier1: Mutex::new(LruTier::new(tier1_capacity, ttl)),
104 tier2: Mutex::new(LruTier::new(tier2_capacity, ttl)),
105 }
106 }
107
108 pub fn get_query(&self, key: &str) -> Option<String> {
112 self.tier1.lock().ok()?.get(key)
113 }
114
115 pub fn insert_query(&self, key: String, json_response: String) {
117 if let Ok(mut t) = self.tier1.lock() {
118 t.insert(key, json_response);
119 }
120 }
121
122 pub fn get_block_text(&self, block_idx: usize) -> Option<String> {
126 let key = block_idx.to_string();
127 self.tier2.lock().ok()?.get(&key)
128 }
129
130 pub fn insert_block_text(&self, block_idx: usize, text: String) {
132 if let Ok(mut t) = self.tier2.lock() {
133 t.insert(block_idx.to_string(), text);
134 }
135 }
136
137 pub fn invalidate_all(&self) {
141 if let Ok(mut t) = self.tier1.lock() {
142 t.clear();
143 }
144 if let Ok(mut t) = self.tier2.lock() {
145 t.clear();
146 }
147 }
148
149 pub fn stats(&self) -> CacheStats {
151 let (t1_entries, t1_hits) = self
152 .tier1
153 .lock()
154 .map(|t| (t.len(), t.total_hits()))
155 .unwrap_or((0, 0));
156 let (t2_entries, t2_hits) = self
157 .tier2
158 .lock()
159 .map(|t| (t.len(), t.total_hits()))
160 .unwrap_or((0, 0));
161 CacheStats {
162 tier1_entries: t1_entries,
163 tier1_hits: t1_hits,
164 tier2_entries: t2_entries,
165 tier2_hits: t2_hits,
166 }
167 }
168
169 pub fn make_key(endpoint: &str, query: &str, k: usize) -> String {
171 format!("{}:{}:{}", endpoint, query.to_lowercase().trim(), k)
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_basic_insert_get() {
181 let cache = QueryCache::new(4, 4, 300);
182 cache.insert_query("recall:test:10".into(), "{\"results\":[]}".into());
183 assert_eq!(
184 cache.get_query("recall:test:10"),
185 Some("{\"results\":[]}".into())
186 );
187 assert_eq!(cache.get_query("recall:missing:10"), None);
188 }
189
190 #[test]
191 fn test_lru_eviction() {
192 let cache = QueryCache::new(2, 2, 300);
193 cache.insert_query("a".into(), "1".into());
194 cache.insert_query("b".into(), "2".into());
195 cache.insert_query("c".into(), "3".into()); assert_eq!(cache.get_query("a"), None);
197 assert_eq!(cache.get_query("b"), Some("2".into()));
198 assert_eq!(cache.get_query("c"), Some("3".into()));
199 }
200
201 #[test]
202 fn test_invalidate_all() {
203 let cache = QueryCache::new(4, 4, 300);
204 cache.insert_query("x".into(), "1".into());
205 cache.insert_block_text(42, "hello".into());
206 cache.invalidate_all();
207 assert_eq!(cache.get_query("x"), None);
208 assert_eq!(cache.get_block_text(42), None);
209 }
210
211 #[test]
212 fn test_tier2_block_text() {
213 let cache = QueryCache::new(4, 4, 300);
214 cache.insert_block_text(0, "block zero".into());
215 cache.insert_block_text(99, "block 99".into());
216 assert_eq!(cache.get_block_text(0), Some("block zero".into()));
217 assert_eq!(cache.get_block_text(99), Some("block 99".into()));
218 assert_eq!(cache.get_block_text(1), None);
219 }
220
221 #[test]
222 fn test_stats() {
223 let cache = QueryCache::new(4, 4, 300);
224 cache.insert_query("a".into(), "1".into());
225 cache.insert_query("b".into(), "2".into());
226 let _ = cache.get_query("a"); let _ = cache.get_query("a"); let stats = cache.stats();
229 assert_eq!(stats.tier1_entries, 2);
230 assert_eq!(stats.tier1_hits, 2);
231 }
232
233 #[test]
234 fn test_ttl_expiry() {
235 let cache = QueryCache::new(4, 4, 0); cache.insert_query("x".into(), "1".into());
237 std::thread::sleep(std::time::Duration::from_millis(10));
238 assert_eq!(cache.get_query("x"), None); }
240}