1use std::collections::HashMap;
7use std::fs::{File, OpenOptions};
8use std::io::{Read, Write};
9use std::path::PathBuf;
10use std::sync::RwLock;
11use std::time::Instant;
12
13use bytes::Bytes;
14use dashmap::DashMap;
15
16use super::config::{L2Config, StorageBackend};
17use super::result::{CacheKey, CachedResult, L2Entry};
18
19#[derive(Debug)]
26pub struct L2WarmCache {
27 config: L2Config,
29
30 memory_entries: DashMap<u64, L2Entry>,
32
33 mmap_storage: Option<RwLock<MmapStorage>>,
35
36 memory_usage: std::sync::atomic::AtomicUsize,
38}
39
40#[derive(Debug)]
42struct MmapStorage {
43 path: PathBuf,
45
46 file: Option<File>,
48
49 index: HashMap<u64, MmapEntry>,
51
52 file_size: usize,
54}
55
56#[derive(Debug, Clone)]
58struct MmapEntry {
59 offset: usize,
61
62 size: usize,
64
65 expires_at: u64,
67}
68
69impl L2WarmCache {
70 pub fn new(config: L2Config) -> Self {
72 let mmap_storage = if config.storage == StorageBackend::Mmap {
73 config
74 .mmap_path
75 .as_ref()
76 .map(|path| RwLock::new(MmapStorage::new(path.clone())))
77 } else {
78 None
79 };
80
81 Self {
82 config,
83 memory_entries: DashMap::new(),
84 mmap_storage,
85 memory_usage: std::sync::atomic::AtomicUsize::new(0),
86 }
87 }
88
89 pub async fn get(&self, key: &CacheKey) -> Option<CachedResult> {
91 if !self.config.enabled {
92 return None;
93 }
94
95 let hash = key.hash_value();
96
97 if let Some(mut entry) = self.memory_entries.get_mut(&hash) {
99 if entry.is_expired() {
100 drop(entry);
101 self.memory_entries.remove(&hash);
102 return None;
103 }
104
105 entry.touch();
106 return Some(entry.result.clone());
107 }
108
109 if let Some(ref mmap) = self.mmap_storage {
111 if let Ok(storage) = mmap.read() {
112 if let Some(result) = storage.get(hash) {
113 self.promote_to_memory(key, result.clone());
115 return Some(result);
116 }
117 }
118 }
119
120 None
121 }
122
123 pub async fn put(&self, key: CacheKey, result: CachedResult) {
125 if !self.config.enabled {
126 return;
127 }
128
129 let entry_size = result.size() + std::mem::size_of::<L2Entry>();
130
131 let max_bytes = self.config.size_mb * 1024 * 1024;
133 let current_usage = self.memory_usage.load(std::sync::atomic::Ordering::Relaxed);
134
135 if current_usage + entry_size > max_bytes {
136 self.evict_to_fit(entry_size).await;
137 }
138
139 let hash = key.hash_value();
140 let fingerprint = format!("{:016x}", hash);
141 let entry = L2Entry::new(key, fingerprint, result);
142 let entry_memory = entry.memory_size;
143
144 self.memory_entries.insert(hash, entry);
145 self.memory_usage
146 .fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
147 }
148
149 pub async fn remove(&self, key: &CacheKey) {
151 let hash = key.hash_value();
152
153 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
154 self.memory_usage
155 .fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
156 }
157
158 if let Some(ref mmap) = self.mmap_storage {
160 if let Ok(mut storage) = mmap.write() {
161 storage.remove(hash);
162 }
163 }
164 }
165
166 pub async fn clear(&self) {
168 self.memory_entries.clear();
169 self.memory_usage
170 .store(0, std::sync::atomic::Ordering::Relaxed);
171
172 if let Some(ref mmap) = self.mmap_storage {
173 if let Ok(mut storage) = mmap.write() {
174 storage.clear();
175 }
176 }
177 }
178
179 pub fn len(&self) -> usize {
181 self.memory_entries.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.memory_entries.is_empty()
187 }
188
189 pub fn memory_usage(&self) -> usize {
191 self.memory_usage.load(std::sync::atomic::Ordering::Relaxed)
192 }
193
194 pub fn stats(&self) -> L2CacheStats {
196 let total_access: u64 = self.memory_entries.iter().map(|e| e.access_count).sum();
197
198 L2CacheStats {
199 entry_count: self.memory_entries.len(),
200 memory_usage_bytes: self.memory_usage(),
201 max_memory_bytes: self.config.size_mb * 1024 * 1024,
202 total_accesses: total_access,
203 storage_backend: self.config.storage.clone(),
204 }
205 }
206
207 async fn evict_to_fit(&self, required_bytes: usize) {
209 let max_bytes = self.config.size_mb * 1024 * 1024;
210 let target = max_bytes.saturating_sub(required_bytes);
211
212 let expired: Vec<u64> = self
214 .memory_entries
215 .iter()
216 .filter(|e| e.is_expired())
217 .map(|e| *e.key())
218 .collect();
219
220 for hash in expired {
221 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
222 self.memory_usage
223 .fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
224 }
225 }
226
227 while self.memory_usage.load(std::sync::atomic::Ordering::Relaxed) > target {
229 let lru_hash = self
231 .memory_entries
232 .iter()
233 .min_by_key(|e| e.last_access)
234 .map(|e| *e.key());
235
236 if let Some(hash) = lru_hash {
237 if self.mmap_storage.is_some() {
239 if let Some(entry) = self.memory_entries.get(&hash) {
240 self.demote_to_mmap(&entry);
241 }
242 }
243
244 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
245 self.memory_usage
246 .fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
247 }
248 } else {
249 break;
250 }
251 }
252 }
253
254 fn promote_to_memory(&self, key: &CacheKey, result: CachedResult) {
256 let hash = key.hash_value();
257 let fingerprint = format!("{:016x}", hash);
258 let entry = L2Entry::new(key.clone(), fingerprint, result);
259 let entry_memory = entry.memory_size;
260
261 self.memory_entries.insert(hash, entry);
262 self.memory_usage
263 .fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
264 }
265
266 fn demote_to_mmap(&self, entry: &dashmap::mapref::one::Ref<u64, L2Entry>) {
268 if let Some(ref mmap) = self.mmap_storage {
269 if let Ok(mut storage) = mmap.write() {
270 storage.put(*entry.key(), &entry.result);
271 }
272 }
273 }
274
275 pub fn flush_to_disk(&self) -> Result<usize, std::io::Error> {
277 let Some(ref mmap) = self.mmap_storage else {
278 return Ok(0);
279 };
280
281 let mut storage = mmap
282 .write()
283 .map_err(|_| std::io::Error::other("Lock poisoned"))?;
284
285 let mut count = 0;
286 for entry in self.memory_entries.iter() {
287 if !entry.is_expired() {
288 storage.put(*entry.key(), &entry.result);
289 count += 1;
290 }
291 }
292
293 storage.sync()?;
294 Ok(count)
295 }
296
297 pub fn load_from_disk(&self) -> Result<usize, std::io::Error> {
299 let Some(ref mmap) = self.mmap_storage else {
300 return Ok(0);
301 };
302
303 let storage = mmap
304 .read()
305 .map_err(|_| std::io::Error::other("Lock poisoned"))?;
306
307 Ok(storage.entry_count())
308 }
309}
310
311impl MmapStorage {
312 fn new(path: PathBuf) -> Self {
313 Self {
314 path,
315 file: None,
316 index: HashMap::new(),
317 file_size: 0,
318 }
319 }
320
321 fn get(&self, hash: u64) -> Option<CachedResult> {
322 let entry = self.index.get(&hash)?;
323
324 let now = std::time::SystemTime::now()
326 .duration_since(std::time::UNIX_EPOCH)
327 .ok()?
328 .as_secs();
329
330 if now > entry.expires_at {
331 return None;
332 }
333
334 let mut file = File::open(&self.path).ok()?;
336 let mut buffer = vec![0u8; entry.size];
337
338 use std::io::Seek;
339 file.seek(std::io::SeekFrom::Start(entry.offset as u64))
340 .ok()?;
341 file.read_exact(&mut buffer).ok()?;
342
343 deserialize_result(&buffer)
345 }
346
347 fn put(&mut self, hash: u64, result: &CachedResult) {
348 let data = serialize_result(result);
349
350 let file = match &mut self.file {
352 Some(f) => f,
353 None => {
354 self.file = OpenOptions::new()
355 .create(true)
356 .truncate(true)
357 .read(true)
358 .write(true)
359 .open(&self.path)
360 .ok();
361 match &mut self.file {
362 Some(f) => f,
363 None => return,
364 }
365 }
366 };
367
368 use std::io::Seek;
370 if file.seek(std::io::SeekFrom::End(0)).is_err() {
371 return;
372 }
373
374 let offset = self.file_size;
375 if file.write_all(&data).is_ok() {
376 let expires_at = std::time::SystemTime::now()
377 .duration_since(std::time::UNIX_EPOCH)
378 .map(|d| d.as_secs() + result.ttl.as_secs())
379 .unwrap_or(0);
380
381 self.index.insert(
382 hash,
383 MmapEntry {
384 offset,
385 size: data.len(),
386 expires_at,
387 },
388 );
389 self.file_size += data.len();
390 }
391 }
392
393 fn remove(&mut self, hash: u64) {
394 self.index.remove(&hash);
395 }
397
398 fn clear(&mut self) {
399 self.index.clear();
400 self.file_size = 0;
401
402 if let Some(ref mut file) = self.file {
404 let _ = file.set_len(0);
405 }
406 }
407
408 fn sync(&mut self) -> Result<(), std::io::Error> {
409 if let Some(ref file) = self.file {
410 file.sync_all()?;
411 }
412 Ok(())
413 }
414
415 fn entry_count(&self) -> usize {
416 self.index.len()
417 }
418}
419
420fn serialize_result(result: &CachedResult) -> Vec<u8> {
422 let mut buffer = Vec::new();
423
424 buffer.extend_from_slice(&result.ttl.as_secs().to_le_bytes());
426
427 buffer.extend_from_slice(&(result.row_count as u64).to_le_bytes());
429
430 buffer.extend_from_slice(&(result.data.len() as u64).to_le_bytes());
432 buffer.extend_from_slice(&result.data);
433
434 buffer
435}
436
437fn deserialize_result(buffer: &[u8]) -> Option<CachedResult> {
439 if buffer.len() < 24 {
440 return None;
441 }
442
443 let ttl_secs = u64::from_le_bytes(buffer[0..8].try_into().ok()?);
444 let row_count = u64::from_le_bytes(buffer[8..16].try_into().ok()?) as usize;
445 let data_len = u64::from_le_bytes(buffer[16..24].try_into().ok()?) as usize;
446
447 if buffer.len() < 24 + data_len {
448 return None;
449 }
450
451 let data = Bytes::copy_from_slice(&buffer[24..24 + data_len]);
452
453 Some(CachedResult {
454 data,
455 row_count,
456 cached_at: Instant::now(),
457 ttl: std::time::Duration::from_secs(ttl_secs),
458 tables: Vec::new(), execution_time: std::time::Duration::from_millis(0),
460 })
461}
462
463#[derive(Debug, Clone)]
465pub struct L2CacheStats {
466 pub entry_count: usize,
468
469 pub memory_usage_bytes: usize,
471
472 pub max_memory_bytes: usize,
474
475 pub total_accesses: u64,
477
478 pub storage_backend: StorageBackend,
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::cache::normalizer::NormalizedQuery;
486 use crate::cache::CacheContext;
487 use std::time::Duration;
488
489 fn create_result(data: &str) -> CachedResult {
490 CachedResult::new(
491 Bytes::from(data.to_string()),
492 1,
493 Duration::from_secs(60),
494 vec!["test".to_string()],
495 Duration::from_millis(5),
496 )
497 }
498
499 fn create_key(query_hash: u64) -> CacheKey {
500 CacheKey::from_parts(query_hash, "test".to_string(), None, None)
501 }
502
503 #[tokio::test]
504 async fn test_basic_get_put() {
505 let config = L2Config::default();
506 let cache = L2WarmCache::new(config);
507
508 let key = create_key(12345);
509 let result = create_result("test data");
510
511 assert!(cache.get(&key).await.is_none());
513
514 cache.put(key.clone(), result.clone()).await;
516 let cached = cache.get(&key).await;
517 assert!(cached.is_some());
518 assert_eq!(cached.unwrap().data, result.data);
519 }
520
521 #[tokio::test]
522 async fn test_different_keys() {
523 let config = L2Config::default();
524 let cache = L2WarmCache::new(config);
525
526 let key1 = create_key(11111);
527 let key2 = create_key(22222);
528 let result = create_result("data");
529
530 cache.put(key1.clone(), result.clone()).await;
531
532 assert!(cache.get(&key1).await.is_some());
533 assert!(cache.get(&key2).await.is_none());
534 }
535
536 #[tokio::test]
537 async fn test_expiration() {
538 let config = L2Config {
539 ttl: Duration::from_millis(10),
540 ..Default::default()
541 };
542 let cache = L2WarmCache::new(config);
543
544 let key = create_key(12345);
545 let mut result = create_result("data");
546 result.ttl = Duration::from_millis(10);
547
548 cache.put(key.clone(), result).await;
549 assert!(cache.get(&key).await.is_some());
550
551 std::thread::sleep(Duration::from_millis(15));
552 assert!(cache.get(&key).await.is_none());
553 }
554
555 #[tokio::test]
556 async fn test_remove() {
557 let config = L2Config::default();
558 let cache = L2WarmCache::new(config);
559
560 let key = create_key(12345);
561 let result = create_result("data");
562
563 cache.put(key.clone(), result).await;
564 assert!(cache.get(&key).await.is_some());
565
566 cache.remove(&key).await;
567 assert!(cache.get(&key).await.is_none());
568 }
569
570 #[tokio::test]
571 async fn test_clear() {
572 let config = L2Config::default();
573 let cache = L2WarmCache::new(config);
574
575 cache.put(create_key(111), create_result("1")).await;
576 cache.put(create_key(222), create_result("2")).await;
577
578 assert_eq!(cache.len(), 2);
579
580 cache.clear().await;
581
582 assert!(cache.is_empty());
583 }
584
585 #[tokio::test]
586 async fn test_memory_eviction() {
587 let config = L2Config {
588 size_mb: 1, ..Default::default()
590 };
591 let cache = L2WarmCache::new(config);
592
593 let large_data = "x".repeat(100 * 1024); for i in 0..15 {
596 cache.put(create_key(i), create_result(&large_data)).await;
597 }
598
599 assert!(cache.memory_usage() <= 1024 * 1024 + 100 * 1024);
601 }
602
603 #[tokio::test]
604 async fn test_stats() {
605 let config = L2Config::default();
606 let cache = L2WarmCache::new(config);
607
608 cache.put(create_key(111), create_result("1")).await;
609 cache.put(create_key(222), create_result("2")).await;
610
611 cache.get(&create_key(111)).await;
612 cache.get(&create_key(111)).await;
613
614 let stats = cache.stats();
615 assert_eq!(stats.entry_count, 2);
616 assert!(stats.memory_usage_bytes > 0);
617 assert_eq!(stats.storage_backend, StorageBackend::Memory);
618 }
619
620 #[tokio::test]
621 async fn test_disabled_cache() {
622 let config = L2Config {
623 enabled: false,
624 ..Default::default()
625 };
626 let cache = L2WarmCache::new(config);
627
628 let key = create_key(12345);
629 cache.put(key.clone(), create_result("data")).await;
630
631 assert!(cache.get(&key).await.is_none());
632 }
633
634 #[test]
635 fn test_serialize_deserialize() {
636 let result = create_result("test data for serialization");
637 let serialized = serialize_result(&result);
638 let deserialized = deserialize_result(&serialized).unwrap();
639
640 assert_eq!(deserialized.data, result.data);
641 assert_eq!(deserialized.row_count, result.row_count);
642 assert_eq!(deserialized.ttl.as_secs(), result.ttl.as_secs());
643 }
644}