1use std::collections::HashMap;
12use std::time::Instant;
13
14use parking_lot::RwLock;
15
16use super::config::L1Config;
17use super::result::{CachedResult, L1Entry};
18
19#[derive(Debug)]
24pub struct L1HotCache {
25 config: L1Config,
27
28 entries: RwLock<HashMap<String, L1Entry>>,
30
31 lru_order: RwLock<Vec<(String, Instant)>>,
33}
34
35impl L1HotCache {
36 pub fn new(config: L1Config) -> Self {
38 let size = config.size;
39 Self {
40 config,
41 entries: RwLock::new(HashMap::with_capacity(size)),
42 lru_order: RwLock::new(Vec::with_capacity(size)),
43 }
44 }
45
46 pub fn get(&self, query: &str) -> Option<CachedResult> {
53 if !self.config.enabled {
54 return None;
55 }
56
57 let (result, expired) = {
59 let entries = self.entries.read();
60 match entries.get(query) {
61 None => return None,
62 Some(entry) if entry.is_expired() => (None, true),
63 Some(entry) => {
64 entry.touch();
65 (Some(entry.result.clone()), false)
66 }
67 }
68 };
69
70 if expired {
71 let mut entries = self.entries.write();
73 entries.remove(query);
74 drop(entries);
75 self.remove_from_lru(query);
76 return None;
77 }
78
79 self.update_lru(query);
82 result
83 }
84
85 pub fn put(&self, query: String, result: CachedResult) {
87 if !self.config.enabled {
88 return;
89 }
90
91 let mut entries = self.entries.write();
92
93 if entries.len() >= self.config.size && !entries.contains_key(&query) {
95 self.evict_lru(&mut entries);
96 }
97
98 let mut adjusted_result = result;
100 if adjusted_result.ttl > self.config.ttl {
101 adjusted_result.ttl = self.config.ttl;
102 }
103
104 let entry = L1Entry::new(query.clone(), adjusted_result);
106 entries.insert(query.clone(), entry);
107 drop(entries);
108 self.update_lru(&query);
109 }
110
111 pub fn remove(&self, query: &str) {
113 self.entries.write().remove(query);
114 self.remove_from_lru(query);
115 }
116
117 pub fn clear(&self) {
119 self.entries.write().clear();
120 self.lru_order.write().clear();
121 }
122
123 pub fn len(&self) -> usize {
125 self.entries.read().len()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.len() == 0
131 }
132
133 pub fn capacity(&self) -> usize {
135 self.config.size
136 }
137
138 pub fn stats(&self) -> L1CacheStats {
140 let entries = self.entries.read();
141 let total_size: usize = entries.values().map(|e| e.result.size()).sum();
142 let total_access: u64 = entries.values().map(|e| e.access_count()).sum();
143
144 L1CacheStats {
145 entry_count: entries.len(),
146 capacity: self.config.size,
147 total_size_bytes: total_size,
148 total_accesses: total_access,
149 }
150 }
151
152 pub fn evict_expired(&self) {
154 let mut entries = self.entries.write();
155 let expired: Vec<String> = entries
156 .iter()
157 .filter(|(_, entry)| entry.is_expired())
158 .map(|(key, _)| key.clone())
159 .collect();
160
161 for key in &expired {
162 entries.remove(key);
163 }
164 drop(entries);
165
166 for key in &expired {
167 self.remove_from_lru(key);
168 }
169 }
170
171 fn update_lru(&self, query: &str) {
173 let mut lru = self.lru_order.write();
174 lru.retain(|(q, _)| q != query);
175 lru.push((query.to_string(), Instant::now()));
176 }
177
178 fn remove_from_lru(&self, query: &str) {
180 self.lru_order.write().retain(|(q, _)| q != query);
181 }
182
183 fn evict_lru(&self, entries: &mut HashMap<String, L1Entry>) {
185 let mut lru = self.lru_order.write();
186
187 let expired: Vec<String> = lru
189 .iter()
190 .filter(|(q, _)| entries.get(q).map(|e| e.is_expired()).unwrap_or(true))
191 .map(|(q, _)| q.clone())
192 .collect();
193
194 for key in expired {
195 entries.remove(&key);
196 lru.retain(|(q, _)| q != &key);
197 }
198
199 if entries.len() >= self.config.size {
201 if let Some((key, _)) = lru.first().cloned() {
202 entries.remove(&key);
203 lru.remove(0);
204 }
205 }
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct L1CacheStats {
212 pub entry_count: usize,
214
215 pub capacity: usize,
217
218 pub total_size_bytes: usize,
220
221 pub total_accesses: u64,
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use bytes::Bytes;
229 use std::time::Duration;
230
231 fn create_result(data: &str) -> CachedResult {
232 CachedResult::new(
233 Bytes::from(data.to_string()),
234 1,
235 Duration::from_secs(60),
236 vec!["test".to_string()],
237 Duration::from_millis(5),
238 )
239 }
240
241 #[test]
242 fn test_basic_get_put() {
243 let config = L1Config {
244 enabled: true,
245 size: 100,
246 ttl: Duration::from_secs(60),
247 };
248 let cache = L1HotCache::new(config);
249
250 let query = "SELECT * FROM users WHERE id = 1";
251 let result = create_result("user data");
252
253 assert!(cache.get(query).is_none());
255
256 cache.put(query.to_string(), result.clone());
258 let cached = cache.get(query);
259 assert!(cached.is_some());
260 assert_eq!(cached.unwrap().data, result.data);
261 }
262
263 #[test]
264 fn test_exact_match() {
265 let config = L1Config {
266 enabled: true,
267 size: 100,
268 ttl: Duration::from_secs(60),
269 };
270 let cache = L1HotCache::new(config);
271
272 let query1 = "SELECT * FROM users WHERE id = 1";
273 let query2 = "SELECT * FROM users WHERE id = 2";
274 let result = create_result("user data");
275
276 cache.put(query1.to_string(), result);
277
278 assert!(cache.get(query1).is_some());
280
281 assert!(cache.get(query2).is_none());
283 }
284
285 #[test]
286 fn test_expiration() {
287 let config = L1Config {
288 enabled: true,
289 size: 100,
290 ttl: Duration::from_millis(10),
291 };
292 let cache = L1HotCache::new(config);
293
294 let query = "SELECT 1";
295 let result = create_result("1");
296
297 cache.put(query.to_string(), result);
298 assert!(cache.get(query).is_some());
299
300 std::thread::sleep(Duration::from_millis(15));
302 assert!(cache.get(query).is_none());
303 }
304
305 #[test]
306 fn test_lru_eviction() {
307 let config = L1Config {
308 enabled: true,
309 size: 3,
310 ttl: Duration::from_secs(60),
311 };
312 let cache = L1HotCache::new(config);
313
314 cache.put("query1".to_string(), create_result("1"));
316 cache.put("query2".to_string(), create_result("2"));
317 cache.put("query3".to_string(), create_result("3"));
318
319 cache.get("query1");
321
322 cache.put("query4".to_string(), create_result("4"));
324
325 assert!(cache.get("query1").is_some()); assert!(cache.get("query2").is_none()); assert!(cache.get("query3").is_some()); assert!(cache.get("query4").is_some()); }
330
331 #[test]
332 fn test_clear() {
333 let config = L1Config {
334 enabled: true,
335 size: 100,
336 ttl: Duration::from_secs(60),
337 };
338 let cache = L1HotCache::new(config);
339
340 cache.put("query1".to_string(), create_result("1"));
341 cache.put("query2".to_string(), create_result("2"));
342
343 assert_eq!(cache.len(), 2);
344
345 cache.clear();
346
347 assert_eq!(cache.len(), 0);
348 assert!(cache.is_empty());
349 }
350
351 #[test]
352 fn test_remove() {
353 let config = L1Config {
354 enabled: true,
355 size: 100,
356 ttl: Duration::from_secs(60),
357 };
358 let cache = L1HotCache::new(config);
359
360 cache.put("query1".to_string(), create_result("1"));
361 cache.put("query2".to_string(), create_result("2"));
362
363 cache.remove("query1");
364
365 assert!(cache.get("query1").is_none());
366 assert!(cache.get("query2").is_some());
367 }
368
369 #[test]
370 fn test_disabled_cache() {
371 let config = L1Config {
372 enabled: false,
373 size: 100,
374 ttl: Duration::from_secs(60),
375 };
376 let cache = L1HotCache::new(config);
377
378 cache.put("query".to_string(), create_result("data"));
379 assert!(cache.get("query").is_none());
380 }
381
382 #[test]
383 fn test_stats() {
384 let config = L1Config {
385 enabled: true,
386 size: 100,
387 ttl: Duration::from_secs(60),
388 };
389 let cache = L1HotCache::new(config);
390
391 cache.put("query1".to_string(), create_result("1"));
392 cache.put("query2".to_string(), create_result("2"));
393
394 cache.get("query1");
396 cache.get("query1");
397 cache.get("query2");
398
399 let stats = cache.stats();
400 assert_eq!(stats.entry_count, 2);
401 assert_eq!(stats.capacity, 100);
402 assert!(stats.total_size_bytes > 0);
403 assert_eq!(stats.total_accesses, 5); }
405
406 #[test]
407 fn test_evict_expired() {
408 let config = L1Config {
409 enabled: true,
410 size: 100,
411 ttl: Duration::from_millis(10),
412 };
413 let cache = L1HotCache::new(config);
414
415 cache.put("query1".to_string(), create_result("1"));
416 cache.put("query2".to_string(), create_result("2"));
417
418 std::thread::sleep(Duration::from_millis(15));
419
420 cache.evict_expired();
421
422 assert!(cache.is_empty());
423 }
424
425 #[test]
426 fn test_update_existing() {
427 let config = L1Config {
428 enabled: true,
429 size: 100,
430 ttl: Duration::from_secs(60),
431 };
432 let cache = L1HotCache::new(config);
433
434 cache.put("query".to_string(), create_result("old"));
435 cache.put("query".to_string(), create_result("new"));
436
437 let cached = cache.get("query").unwrap();
438 assert_eq!(cached.data, Bytes::from("new"));
439 }
440
441 #[test]
447 fn test_concurrent_hits_read_lock_only() {
448 use std::sync::Arc;
449 use std::thread;
450
451 let cache = Arc::new(L1HotCache::new(L1Config {
452 enabled: true,
453 size: 100,
454 ttl: Duration::from_secs(60),
455 }));
456 cache.put("hot-query".to_string(), create_result("hot data"));
457
458 const THREADS: usize = 16;
459 const ITERS_PER_THREAD: usize = 500;
460
461 let mut handles = Vec::with_capacity(THREADS);
462 for _ in 0..THREADS {
463 let cache = Arc::clone(&cache);
464 handles.push(thread::spawn(move || {
465 for _ in 0..ITERS_PER_THREAD {
466 let r = cache.get("hot-query").expect("hit expected");
467 assert_eq!(r.data, Bytes::from("hot data"));
468 }
469 }));
470 }
471 for h in handles {
472 h.join().unwrap();
473 }
474
475 let stats = cache.stats();
476 assert_eq!(
479 stats.total_accesses,
480 1 + (THREADS * ITERS_PER_THREAD) as u64
481 );
482 }
483}