1use std::collections::HashMap;
7use std::fs::{File, OpenOptions};
8use std::io::{Read, Write};
9use std::path::PathBuf;
10use std::sync::RwLock;
11use std::time::Instant;
12
13use bytes::Bytes;
14use dashmap::DashMap;
15
16use super::config::{L2Config, StorageBackend};
17use super::result::{CachedResult, CacheKey, L2Entry};
18
19#[derive(Debug)]
26pub struct L2WarmCache {
27 config: L2Config,
29
30 memory_entries: DashMap<u64, L2Entry>,
32
33 mmap_storage: Option<RwLock<MmapStorage>>,
35
36 memory_usage: std::sync::atomic::AtomicUsize,
38}
39
40#[derive(Debug)]
42struct MmapStorage {
43 path: PathBuf,
45
46 file: Option<File>,
48
49 index: HashMap<u64, MmapEntry>,
51
52 file_size: usize,
54}
55
56#[derive(Debug, Clone)]
58struct MmapEntry {
59 offset: usize,
61
62 size: usize,
64
65 expires_at: u64,
67}
68
69impl L2WarmCache {
70 pub fn new(config: L2Config) -> Self {
72 let mmap_storage = if config.storage == StorageBackend::Mmap {
73 config.mmap_path.as_ref().map(|path| {
74 RwLock::new(MmapStorage::new(path.clone()))
75 })
76 } else {
77 None
78 };
79
80 Self {
81 config,
82 memory_entries: DashMap::new(),
83 mmap_storage,
84 memory_usage: std::sync::atomic::AtomicUsize::new(0),
85 }
86 }
87
88 pub async fn get(&self, key: &CacheKey) -> Option<CachedResult> {
90 if !self.config.enabled {
91 return None;
92 }
93
94 let hash = key.hash_value();
95
96 if let Some(mut entry) = self.memory_entries.get_mut(&hash) {
98 if entry.is_expired() {
99 drop(entry);
100 self.memory_entries.remove(&hash);
101 return None;
102 }
103
104 entry.touch();
105 return Some(entry.result.clone());
106 }
107
108 if let Some(ref mmap) = self.mmap_storage {
110 if let Ok(storage) = mmap.read() {
111 if let Some(result) = storage.get(hash) {
112 self.promote_to_memory(key, result.clone());
114 return Some(result);
115 }
116 }
117 }
118
119 None
120 }
121
122 pub async fn put(&self, key: CacheKey, result: CachedResult) {
124 if !self.config.enabled {
125 return;
126 }
127
128 let entry_size = result.size() + std::mem::size_of::<L2Entry>();
129
130 let max_bytes = self.config.size_mb * 1024 * 1024;
132 let current_usage = self.memory_usage.load(std::sync::atomic::Ordering::Relaxed);
133
134 if current_usage + entry_size > max_bytes {
135 self.evict_to_fit(entry_size).await;
136 }
137
138 let hash = key.hash_value();
139 let fingerprint = format!("{:016x}", hash);
140 let entry = L2Entry::new(key, fingerprint, result);
141 let entry_memory = entry.memory_size;
142
143 self.memory_entries.insert(hash, entry);
144 self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
145 }
146
147 pub async fn remove(&self, key: &CacheKey) {
149 let hash = key.hash_value();
150
151 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
152 self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
153 }
154
155 if let Some(ref mmap) = self.mmap_storage {
157 if let Ok(mut storage) = mmap.write() {
158 storage.remove(hash);
159 }
160 }
161 }
162
163 pub async fn clear(&self) {
165 self.memory_entries.clear();
166 self.memory_usage.store(0, std::sync::atomic::Ordering::Relaxed);
167
168 if let Some(ref mmap) = self.mmap_storage {
169 if let Ok(mut storage) = mmap.write() {
170 storage.clear();
171 }
172 }
173 }
174
175 pub fn len(&self) -> usize {
177 self.memory_entries.len()
178 }
179
180 pub fn is_empty(&self) -> bool {
182 self.memory_entries.is_empty()
183 }
184
185 pub fn memory_usage(&self) -> usize {
187 self.memory_usage.load(std::sync::atomic::Ordering::Relaxed)
188 }
189
190 pub fn stats(&self) -> L2CacheStats {
192 let total_access: u64 = self.memory_entries
193 .iter()
194 .map(|e| e.access_count)
195 .sum();
196
197 L2CacheStats {
198 entry_count: self.memory_entries.len(),
199 memory_usage_bytes: self.memory_usage(),
200 max_memory_bytes: self.config.size_mb * 1024 * 1024,
201 total_accesses: total_access,
202 storage_backend: self.config.storage.clone(),
203 }
204 }
205
206 async fn evict_to_fit(&self, required_bytes: usize) {
208 let max_bytes = self.config.size_mb * 1024 * 1024;
209 let target = max_bytes.saturating_sub(required_bytes);
210
211 let expired: Vec<u64> = self.memory_entries
213 .iter()
214 .filter(|e| e.is_expired())
215 .map(|e| *e.key())
216 .collect();
217
218 for hash in expired {
219 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
220 self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
221 }
222 }
223
224 while self.memory_usage.load(std::sync::atomic::Ordering::Relaxed) > target {
226 let lru_hash = self.memory_entries
228 .iter()
229 .min_by_key(|e| e.last_access)
230 .map(|e| *e.key());
231
232 if let Some(hash) = lru_hash {
233 if self.mmap_storage.is_some() {
235 if let Some(entry) = self.memory_entries.get(&hash) {
236 self.demote_to_mmap(&entry);
237 }
238 }
239
240 if let Some((_, entry)) = self.memory_entries.remove(&hash) {
241 self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
242 }
243 } else {
244 break;
245 }
246 }
247 }
248
249 fn promote_to_memory(&self, key: &CacheKey, result: CachedResult) {
251 let hash = key.hash_value();
252 let fingerprint = format!("{:016x}", hash);
253 let entry = L2Entry::new(key.clone(), fingerprint, result);
254 let entry_memory = entry.memory_size;
255
256 self.memory_entries.insert(hash, entry);
257 self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
258 }
259
260 fn demote_to_mmap(&self, entry: &dashmap::mapref::one::Ref<u64, L2Entry>) {
262 if let Some(ref mmap) = self.mmap_storage {
263 if let Ok(mut storage) = mmap.write() {
264 storage.put(*entry.key(), &entry.result);
265 }
266 }
267 }
268
269 pub fn flush_to_disk(&self) -> Result<usize, std::io::Error> {
271 let Some(ref mmap) = self.mmap_storage else {
272 return Ok(0);
273 };
274
275 let mut storage = mmap.write()
276 .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
277
278 let mut count = 0;
279 for entry in self.memory_entries.iter() {
280 if !entry.is_expired() {
281 storage.put(*entry.key(), &entry.result);
282 count += 1;
283 }
284 }
285
286 storage.sync()?;
287 Ok(count)
288 }
289
290 pub fn load_from_disk(&self) -> Result<usize, std::io::Error> {
292 let Some(ref mmap) = self.mmap_storage else {
293 return Ok(0);
294 };
295
296 let storage = mmap.read()
297 .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
298
299 Ok(storage.entry_count())
300 }
301}
302
303impl MmapStorage {
304 fn new(path: PathBuf) -> Self {
305 Self {
306 path,
307 file: None,
308 index: HashMap::new(),
309 file_size: 0,
310 }
311 }
312
313 fn get(&self, hash: u64) -> Option<CachedResult> {
314 let entry = self.index.get(&hash)?;
315
316 let now = std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .ok()?
320 .as_secs();
321
322 if now > entry.expires_at {
323 return None;
324 }
325
326 let mut file = File::open(&self.path).ok()?;
328 let mut buffer = vec![0u8; entry.size];
329
330 use std::io::Seek;
331 file.seek(std::io::SeekFrom::Start(entry.offset as u64)).ok()?;
332 file.read_exact(&mut buffer).ok()?;
333
334 deserialize_result(&buffer)
336 }
337
338 fn put(&mut self, hash: u64, result: &CachedResult) {
339 let data = serialize_result(result);
340
341 let file = match &mut self.file {
343 Some(f) => f,
344 None => {
345 self.file = OpenOptions::new()
346 .create(true)
347 .read(true)
348 .write(true)
349 .open(&self.path)
350 .ok();
351 match &mut self.file {
352 Some(f) => f,
353 None => return,
354 }
355 }
356 };
357
358 use std::io::Seek;
360 if file.seek(std::io::SeekFrom::End(0)).is_err() {
361 return;
362 }
363
364 let offset = self.file_size;
365 if file.write_all(&data).is_ok() {
366 let expires_at = std::time::SystemTime::now()
367 .duration_since(std::time::UNIX_EPOCH)
368 .map(|d| d.as_secs() + result.ttl.as_secs())
369 .unwrap_or(0);
370
371 self.index.insert(hash, MmapEntry {
372 offset,
373 size: data.len(),
374 expires_at,
375 });
376 self.file_size += data.len();
377 }
378 }
379
380 fn remove(&mut self, hash: u64) {
381 self.index.remove(&hash);
382 }
384
385 fn clear(&mut self) {
386 self.index.clear();
387 self.file_size = 0;
388
389 if let Some(ref mut file) = self.file {
391 let _ = file.set_len(0);
392 }
393 }
394
395 fn sync(&mut self) -> Result<(), std::io::Error> {
396 if let Some(ref file) = self.file {
397 file.sync_all()?;
398 }
399 Ok(())
400 }
401
402 fn entry_count(&self) -> usize {
403 self.index.len()
404 }
405}
406
407fn serialize_result(result: &CachedResult) -> Vec<u8> {
409 let mut buffer = Vec::new();
410
411 buffer.extend_from_slice(&result.ttl.as_secs().to_le_bytes());
413
414 buffer.extend_from_slice(&(result.row_count as u64).to_le_bytes());
416
417 buffer.extend_from_slice(&(result.data.len() as u64).to_le_bytes());
419 buffer.extend_from_slice(&result.data);
420
421 buffer
422}
423
424fn deserialize_result(buffer: &[u8]) -> Option<CachedResult> {
426 if buffer.len() < 24 {
427 return None;
428 }
429
430 let ttl_secs = u64::from_le_bytes(buffer[0..8].try_into().ok()?);
431 let row_count = u64::from_le_bytes(buffer[8..16].try_into().ok()?) as usize;
432 let data_len = u64::from_le_bytes(buffer[16..24].try_into().ok()?) as usize;
433
434 if buffer.len() < 24 + data_len {
435 return None;
436 }
437
438 let data = Bytes::copy_from_slice(&buffer[24..24 + data_len]);
439
440 Some(CachedResult {
441 data,
442 row_count,
443 cached_at: Instant::now(),
444 ttl: std::time::Duration::from_secs(ttl_secs),
445 tables: Vec::new(), execution_time: std::time::Duration::from_millis(0),
447 })
448}
449
450#[derive(Debug, Clone)]
452pub struct L2CacheStats {
453 pub entry_count: usize,
455
456 pub memory_usage_bytes: usize,
458
459 pub max_memory_bytes: usize,
461
462 pub total_accesses: u64,
464
465 pub storage_backend: StorageBackend,
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use std::time::Duration;
473 use crate::cache::CacheContext;
474 use crate::cache::normalizer::NormalizedQuery;
475
476 fn create_result(data: &str) -> CachedResult {
477 CachedResult::new(
478 Bytes::from(data.to_string()),
479 1,
480 Duration::from_secs(60),
481 vec!["test".to_string()],
482 Duration::from_millis(5),
483 )
484 }
485
486 fn create_key(query_hash: u64) -> CacheKey {
487 CacheKey::from_parts(
488 query_hash,
489 "test".to_string(),
490 None,
491 None,
492 )
493 }
494
495 #[tokio::test]
496 async fn test_basic_get_put() {
497 let config = L2Config::default();
498 let cache = L2WarmCache::new(config);
499
500 let key = create_key(12345);
501 let result = create_result("test data");
502
503 assert!(cache.get(&key).await.is_none());
505
506 cache.put(key.clone(), result.clone()).await;
508 let cached = cache.get(&key).await;
509 assert!(cached.is_some());
510 assert_eq!(cached.unwrap().data, result.data);
511 }
512
513 #[tokio::test]
514 async fn test_different_keys() {
515 let config = L2Config::default();
516 let cache = L2WarmCache::new(config);
517
518 let key1 = create_key(11111);
519 let key2 = create_key(22222);
520 let result = create_result("data");
521
522 cache.put(key1.clone(), result.clone()).await;
523
524 assert!(cache.get(&key1).await.is_some());
525 assert!(cache.get(&key2).await.is_none());
526 }
527
528 #[tokio::test]
529 async fn test_expiration() {
530 let config = L2Config {
531 ttl: Duration::from_millis(10),
532 ..Default::default()
533 };
534 let cache = L2WarmCache::new(config);
535
536 let key = create_key(12345);
537 let mut result = create_result("data");
538 result.ttl = Duration::from_millis(10);
539
540 cache.put(key.clone(), result).await;
541 assert!(cache.get(&key).await.is_some());
542
543 std::thread::sleep(Duration::from_millis(15));
544 assert!(cache.get(&key).await.is_none());
545 }
546
547 #[tokio::test]
548 async fn test_remove() {
549 let config = L2Config::default();
550 let cache = L2WarmCache::new(config);
551
552 let key = create_key(12345);
553 let result = create_result("data");
554
555 cache.put(key.clone(), result).await;
556 assert!(cache.get(&key).await.is_some());
557
558 cache.remove(&key).await;
559 assert!(cache.get(&key).await.is_none());
560 }
561
562 #[tokio::test]
563 async fn test_clear() {
564 let config = L2Config::default();
565 let cache = L2WarmCache::new(config);
566
567 cache.put(create_key(111), create_result("1")).await;
568 cache.put(create_key(222), create_result("2")).await;
569
570 assert_eq!(cache.len(), 2);
571
572 cache.clear().await;
573
574 assert!(cache.is_empty());
575 }
576
577 #[tokio::test]
578 async fn test_memory_eviction() {
579 let config = L2Config {
580 size_mb: 1, ..Default::default()
582 };
583 let cache = L2WarmCache::new(config);
584
585 let large_data = "x".repeat(100 * 1024); for i in 0..15 {
588 cache.put(create_key(i), create_result(&large_data)).await;
589 }
590
591 assert!(cache.memory_usage() <= 1024 * 1024 + 100 * 1024);
593 }
594
595 #[tokio::test]
596 async fn test_stats() {
597 let config = L2Config::default();
598 let cache = L2WarmCache::new(config);
599
600 cache.put(create_key(111), create_result("1")).await;
601 cache.put(create_key(222), create_result("2")).await;
602
603 cache.get(&create_key(111)).await;
604 cache.get(&create_key(111)).await;
605
606 let stats = cache.stats();
607 assert_eq!(stats.entry_count, 2);
608 assert!(stats.memory_usage_bytes > 0);
609 assert_eq!(stats.storage_backend, StorageBackend::Memory);
610 }
611
612 #[tokio::test]
613 async fn test_disabled_cache() {
614 let config = L2Config {
615 enabled: false,
616 ..Default::default()
617 };
618 let cache = L2WarmCache::new(config);
619
620 let key = create_key(12345);
621 cache.put(key.clone(), create_result("data")).await;
622
623 assert!(cache.get(&key).await.is_none());
624 }
625
626 #[test]
627 fn test_serialize_deserialize() {
628 let result = create_result("test data for serialization");
629 let serialized = serialize_result(&result);
630 let deserialized = deserialize_result(&serialized).unwrap();
631
632 assert_eq!(deserialized.data, result.data);
633 assert_eq!(deserialized.row_count, result.row_count);
634 assert_eq!(deserialized.ttl.as_secs(), result.ttl.as_secs());
635 }
636}