Skip to main content

heliosdb_proxy/cache/
l2_warm.rs

1//! L2 Warm Cache
2//!
3//! Shared cache with normalized queries and configurable storage backend.
4//! Supports both in-memory and memory-mapped file storage.
5
6use std::collections::HashMap;
7use std::fs::{File, OpenOptions};
8use std::io::{Read, Write};
9use std::path::PathBuf;
10use std::sync::RwLock;
11use std::time::Instant;
12
13use bytes::Bytes;
14use dashmap::DashMap;
15
16use super::config::{L2Config, StorageBackend};
17use super::result::{CachedResult, CacheKey, L2Entry};
18
19/// L2 warm cache (shared across connections)
20///
21/// This cache stores normalized query results shared across all connections.
22/// It supports two storage backends:
23/// - Memory: Fast, volatile storage
24/// - Mmap: Memory-mapped file that survives restarts
25#[derive(Debug)]
26pub struct L2WarmCache {
27    /// Cache configuration
28    config: L2Config,
29
30    /// Cache entries (in-memory storage)
31    memory_entries: DashMap<u64, L2Entry>,
32
33    /// Memory-mapped storage (if enabled)
34    mmap_storage: Option<RwLock<MmapStorage>>,
35
36    /// Current memory usage in bytes
37    memory_usage: std::sync::atomic::AtomicUsize,
38}
39
40/// Memory-mapped storage for persistent caching
41#[derive(Debug)]
42struct MmapStorage {
43    /// File path
44    path: PathBuf,
45
46    /// File handle
47    file: Option<File>,
48
49    /// Cached entries index (hash -> offset)
50    index: HashMap<u64, MmapEntry>,
51
52    /// Total file size
53    file_size: usize,
54}
55
56/// Entry metadata for mmap storage
57#[derive(Debug, Clone)]
58struct MmapEntry {
59    /// Offset in the file
60    offset: usize,
61
62    /// Size of the entry
63    size: usize,
64
65    /// TTL expiration timestamp (seconds since epoch)
66    expires_at: u64,
67}
68
69impl L2WarmCache {
70    /// Create a new L2 warm cache
71    pub fn new(config: L2Config) -> Self {
72        let mmap_storage = if config.storage == StorageBackend::Mmap {
73            config.mmap_path.as_ref().map(|path| {
74                RwLock::new(MmapStorage::new(path.clone()))
75            })
76        } else {
77            None
78        };
79
80        Self {
81            config,
82            memory_entries: DashMap::new(),
83            mmap_storage,
84            memory_usage: std::sync::atomic::AtomicUsize::new(0),
85        }
86    }
87
88    /// Look up a cache key
89    pub async fn get(&self, key: &CacheKey) -> Option<CachedResult> {
90        if !self.config.enabled {
91            return None;
92        }
93
94        let hash = key.hash_value();
95
96        // Try memory storage first
97        if let Some(mut entry) = self.memory_entries.get_mut(&hash) {
98            if entry.is_expired() {
99                drop(entry);
100                self.memory_entries.remove(&hash);
101                return None;
102            }
103
104            entry.touch();
105            return Some(entry.result.clone());
106        }
107
108        // Try mmap storage
109        if let Some(ref mmap) = self.mmap_storage {
110            if let Ok(storage) = mmap.read() {
111                if let Some(result) = storage.get(hash) {
112                    // Promote to memory cache
113                    self.promote_to_memory(key, result.clone());
114                    return Some(result);
115                }
116            }
117        }
118
119        None
120    }
121
122    /// Store a result in the cache
123    pub async fn put(&self, key: CacheKey, result: CachedResult) {
124        if !self.config.enabled {
125            return;
126        }
127
128        let entry_size = result.size() + std::mem::size_of::<L2Entry>();
129
130        // Check size limit
131        let max_bytes = self.config.size_mb * 1024 * 1024;
132        let current_usage = self.memory_usage.load(std::sync::atomic::Ordering::Relaxed);
133
134        if current_usage + entry_size > max_bytes {
135            self.evict_to_fit(entry_size).await;
136        }
137
138        let hash = key.hash_value();
139        let fingerprint = format!("{:016x}", hash);
140        let entry = L2Entry::new(key, fingerprint, result);
141        let entry_memory = entry.memory_size;
142
143        self.memory_entries.insert(hash, entry);
144        self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
145    }
146
147    /// Remove an entry from the cache
148    pub async fn remove(&self, key: &CacheKey) {
149        let hash = key.hash_value();
150
151        if let Some((_, entry)) = self.memory_entries.remove(&hash) {
152            self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
153        }
154
155        // Also remove from mmap if present
156        if let Some(ref mmap) = self.mmap_storage {
157            if let Ok(mut storage) = mmap.write() {
158                storage.remove(hash);
159            }
160        }
161    }
162
163    /// Clear all entries
164    pub async fn clear(&self) {
165        self.memory_entries.clear();
166        self.memory_usage.store(0, std::sync::atomic::Ordering::Relaxed);
167
168        if let Some(ref mmap) = self.mmap_storage {
169            if let Ok(mut storage) = mmap.write() {
170                storage.clear();
171            }
172        }
173    }
174
175    /// Get current entry count
176    pub fn len(&self) -> usize {
177        self.memory_entries.len()
178    }
179
180    /// Check if cache is empty
181    pub fn is_empty(&self) -> bool {
182        self.memory_entries.is_empty()
183    }
184
185    /// Get current memory usage in bytes
186    pub fn memory_usage(&self) -> usize {
187        self.memory_usage.load(std::sync::atomic::Ordering::Relaxed)
188    }
189
190    /// Get cache statistics
191    pub fn stats(&self) -> L2CacheStats {
192        let total_access: u64 = self.memory_entries
193            .iter()
194            .map(|e| e.access_count)
195            .sum();
196
197        L2CacheStats {
198            entry_count: self.memory_entries.len(),
199            memory_usage_bytes: self.memory_usage(),
200            max_memory_bytes: self.config.size_mb * 1024 * 1024,
201            total_accesses: total_access,
202            storage_backend: self.config.storage.clone(),
203        }
204    }
205
206    /// Evict entries to fit new data
207    async fn evict_to_fit(&self, required_bytes: usize) {
208        let max_bytes = self.config.size_mb * 1024 * 1024;
209        let target = max_bytes.saturating_sub(required_bytes);
210
211        // First, evict expired entries
212        let expired: Vec<u64> = self.memory_entries
213            .iter()
214            .filter(|e| e.is_expired())
215            .map(|e| *e.key())
216            .collect();
217
218        for hash in expired {
219            if let Some((_, entry)) = self.memory_entries.remove(&hash) {
220                self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
221            }
222        }
223
224        // If still over limit, evict LRU entries
225        while self.memory_usage.load(std::sync::atomic::Ordering::Relaxed) > target {
226            // Find LRU entry
227            let lru_hash = self.memory_entries
228                .iter()
229                .min_by_key(|e| e.last_access)
230                .map(|e| *e.key());
231
232            if let Some(hash) = lru_hash {
233                // Optionally move to mmap before evicting
234                if self.mmap_storage.is_some() {
235                    if let Some(entry) = self.memory_entries.get(&hash) {
236                        self.demote_to_mmap(&entry);
237                    }
238                }
239
240                if let Some((_, entry)) = self.memory_entries.remove(&hash) {
241                    self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
242                }
243            } else {
244                break;
245            }
246        }
247    }
248
249    /// Promote an entry from mmap to memory
250    fn promote_to_memory(&self, key: &CacheKey, result: CachedResult) {
251        let hash = key.hash_value();
252        let fingerprint = format!("{:016x}", hash);
253        let entry = L2Entry::new(key.clone(), fingerprint, result);
254        let entry_memory = entry.memory_size;
255
256        self.memory_entries.insert(hash, entry);
257        self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
258    }
259
260    /// Demote an entry to mmap storage
261    fn demote_to_mmap(&self, entry: &dashmap::mapref::one::Ref<u64, L2Entry>) {
262        if let Some(ref mmap) = self.mmap_storage {
263            if let Ok(mut storage) = mmap.write() {
264                storage.put(*entry.key(), &entry.result);
265            }
266        }
267    }
268
269    /// Flush memory entries to mmap (for graceful shutdown)
270    pub fn flush_to_disk(&self) -> Result<usize, std::io::Error> {
271        let Some(ref mmap) = self.mmap_storage else {
272            return Ok(0);
273        };
274
275        let mut storage = mmap.write()
276            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
277
278        let mut count = 0;
279        for entry in self.memory_entries.iter() {
280            if !entry.is_expired() {
281                storage.put(*entry.key(), &entry.result);
282                count += 1;
283            }
284        }
285
286        storage.sync()?;
287        Ok(count)
288    }
289
290    /// Load entries from mmap on startup
291    pub fn load_from_disk(&self) -> Result<usize, std::io::Error> {
292        let Some(ref mmap) = self.mmap_storage else {
293            return Ok(0);
294        };
295
296        let storage = mmap.read()
297            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
298
299        Ok(storage.entry_count())
300    }
301}
302
303impl MmapStorage {
304    fn new(path: PathBuf) -> Self {
305        Self {
306            path,
307            file: None,
308            index: HashMap::new(),
309            file_size: 0,
310        }
311    }
312
313    fn get(&self, hash: u64) -> Option<CachedResult> {
314        let entry = self.index.get(&hash)?;
315
316        // Check expiration
317        let now = std::time::SystemTime::now()
318            .duration_since(std::time::UNIX_EPOCH)
319            .ok()?
320            .as_secs();
321
322        if now > entry.expires_at {
323            return None;
324        }
325
326        // Read from file
327        let mut file = File::open(&self.path).ok()?;
328        let mut buffer = vec![0u8; entry.size];
329
330        use std::io::Seek;
331        file.seek(std::io::SeekFrom::Start(entry.offset as u64)).ok()?;
332        file.read_exact(&mut buffer).ok()?;
333
334        // Deserialize (simple format: ttl_secs:row_count:data)
335        deserialize_result(&buffer)
336    }
337
338    fn put(&mut self, hash: u64, result: &CachedResult) {
339        let data = serialize_result(result);
340
341        // Open or create file
342        let file = match &mut self.file {
343            Some(f) => f,
344            None => {
345                self.file = OpenOptions::new()
346                    .create(true)
347                    .read(true)
348                    .write(true)
349                    .open(&self.path)
350                    .ok();
351                match &mut self.file {
352                    Some(f) => f,
353                    None => return,
354                }
355            }
356        };
357
358        // Append to file
359        use std::io::Seek;
360        if file.seek(std::io::SeekFrom::End(0)).is_err() {
361            return;
362        }
363
364        let offset = self.file_size;
365        if file.write_all(&data).is_ok() {
366            let expires_at = std::time::SystemTime::now()
367                .duration_since(std::time::UNIX_EPOCH)
368                .map(|d| d.as_secs() + result.ttl.as_secs())
369                .unwrap_or(0);
370
371            self.index.insert(hash, MmapEntry {
372                offset,
373                size: data.len(),
374                expires_at,
375            });
376            self.file_size += data.len();
377        }
378    }
379
380    fn remove(&mut self, hash: u64) {
381        self.index.remove(&hash);
382        // Note: This doesn't reclaim space, just marks as removed
383    }
384
385    fn clear(&mut self) {
386        self.index.clear();
387        self.file_size = 0;
388
389        // Truncate file
390        if let Some(ref mut file) = self.file {
391            let _ = file.set_len(0);
392        }
393    }
394
395    fn sync(&mut self) -> Result<(), std::io::Error> {
396        if let Some(ref file) = self.file {
397            file.sync_all()?;
398        }
399        Ok(())
400    }
401
402    fn entry_count(&self) -> usize {
403        self.index.len()
404    }
405}
406
407/// Serialize a cached result for mmap storage
408fn serialize_result(result: &CachedResult) -> Vec<u8> {
409    let mut buffer = Vec::new();
410
411    // Write TTL (8 bytes)
412    buffer.extend_from_slice(&result.ttl.as_secs().to_le_bytes());
413
414    // Write row count (8 bytes)
415    buffer.extend_from_slice(&(result.row_count as u64).to_le_bytes());
416
417    // Write data length (8 bytes) + data
418    buffer.extend_from_slice(&(result.data.len() as u64).to_le_bytes());
419    buffer.extend_from_slice(&result.data);
420
421    buffer
422}
423
424/// Deserialize a cached result from mmap storage
425fn deserialize_result(buffer: &[u8]) -> Option<CachedResult> {
426    if buffer.len() < 24 {
427        return None;
428    }
429
430    let ttl_secs = u64::from_le_bytes(buffer[0..8].try_into().ok()?);
431    let row_count = u64::from_le_bytes(buffer[8..16].try_into().ok()?) as usize;
432    let data_len = u64::from_le_bytes(buffer[16..24].try_into().ok()?) as usize;
433
434    if buffer.len() < 24 + data_len {
435        return None;
436    }
437
438    let data = Bytes::copy_from_slice(&buffer[24..24 + data_len]);
439
440    Some(CachedResult {
441        data,
442        row_count,
443        cached_at: Instant::now(),
444        ttl: std::time::Duration::from_secs(ttl_secs),
445        tables: Vec::new(), // Tables are not persisted
446        execution_time: std::time::Duration::from_millis(0),
447    })
448}
449
450/// L2 cache statistics
451#[derive(Debug, Clone)]
452pub struct L2CacheStats {
453    /// Number of entries in cache
454    pub entry_count: usize,
455
456    /// Current memory usage in bytes
457    pub memory_usage_bytes: usize,
458
459    /// Maximum memory in bytes
460    pub max_memory_bytes: usize,
461
462    /// Total number of accesses
463    pub total_accesses: u64,
464
465    /// Storage backend type
466    pub storage_backend: StorageBackend,
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use std::time::Duration;
473    use crate::cache::CacheContext;
474    use crate::cache::normalizer::NormalizedQuery;
475
476    fn create_result(data: &str) -> CachedResult {
477        CachedResult::new(
478            Bytes::from(data.to_string()),
479            1,
480            Duration::from_secs(60),
481            vec!["test".to_string()],
482            Duration::from_millis(5),
483        )
484    }
485
486    fn create_key(query_hash: u64) -> CacheKey {
487        CacheKey::from_parts(
488            query_hash,
489            "test".to_string(),
490            None,
491            None,
492        )
493    }
494
495    #[tokio::test]
496    async fn test_basic_get_put() {
497        let config = L2Config::default();
498        let cache = L2WarmCache::new(config);
499
500        let key = create_key(12345);
501        let result = create_result("test data");
502
503        // Initially empty
504        assert!(cache.get(&key).await.is_none());
505
506        // Put and get
507        cache.put(key.clone(), result.clone()).await;
508        let cached = cache.get(&key).await;
509        assert!(cached.is_some());
510        assert_eq!(cached.unwrap().data, result.data);
511    }
512
513    #[tokio::test]
514    async fn test_different_keys() {
515        let config = L2Config::default();
516        let cache = L2WarmCache::new(config);
517
518        let key1 = create_key(11111);
519        let key2 = create_key(22222);
520        let result = create_result("data");
521
522        cache.put(key1.clone(), result.clone()).await;
523
524        assert!(cache.get(&key1).await.is_some());
525        assert!(cache.get(&key2).await.is_none());
526    }
527
528    #[tokio::test]
529    async fn test_expiration() {
530        let config = L2Config {
531            ttl: Duration::from_millis(10),
532            ..Default::default()
533        };
534        let cache = L2WarmCache::new(config);
535
536        let key = create_key(12345);
537        let mut result = create_result("data");
538        result.ttl = Duration::from_millis(10);
539
540        cache.put(key.clone(), result).await;
541        assert!(cache.get(&key).await.is_some());
542
543        std::thread::sleep(Duration::from_millis(15));
544        assert!(cache.get(&key).await.is_none());
545    }
546
547    #[tokio::test]
548    async fn test_remove() {
549        let config = L2Config::default();
550        let cache = L2WarmCache::new(config);
551
552        let key = create_key(12345);
553        let result = create_result("data");
554
555        cache.put(key.clone(), result).await;
556        assert!(cache.get(&key).await.is_some());
557
558        cache.remove(&key).await;
559        assert!(cache.get(&key).await.is_none());
560    }
561
562    #[tokio::test]
563    async fn test_clear() {
564        let config = L2Config::default();
565        let cache = L2WarmCache::new(config);
566
567        cache.put(create_key(111), create_result("1")).await;
568        cache.put(create_key(222), create_result("2")).await;
569
570        assert_eq!(cache.len(), 2);
571
572        cache.clear().await;
573
574        assert!(cache.is_empty());
575    }
576
577    #[tokio::test]
578    async fn test_memory_eviction() {
579        let config = L2Config {
580            size_mb: 1, // 1 MB limit
581            ..Default::default()
582        };
583        let cache = L2WarmCache::new(config);
584
585        // Add entries until eviction kicks in
586        let large_data = "x".repeat(100 * 1024); // 100 KB per entry
587        for i in 0..15 {
588            cache.put(create_key(i), create_result(&large_data)).await;
589        }
590
591        // Should have evicted some entries
592        assert!(cache.memory_usage() <= 1024 * 1024 + 100 * 1024);
593    }
594
595    #[tokio::test]
596    async fn test_stats() {
597        let config = L2Config::default();
598        let cache = L2WarmCache::new(config);
599
600        cache.put(create_key(111), create_result("1")).await;
601        cache.put(create_key(222), create_result("2")).await;
602
603        cache.get(&create_key(111)).await;
604        cache.get(&create_key(111)).await;
605
606        let stats = cache.stats();
607        assert_eq!(stats.entry_count, 2);
608        assert!(stats.memory_usage_bytes > 0);
609        assert_eq!(stats.storage_backend, StorageBackend::Memory);
610    }
611
612    #[tokio::test]
613    async fn test_disabled_cache() {
614        let config = L2Config {
615            enabled: false,
616            ..Default::default()
617        };
618        let cache = L2WarmCache::new(config);
619
620        let key = create_key(12345);
621        cache.put(key.clone(), create_result("data")).await;
622
623        assert!(cache.get(&key).await.is_none());
624    }
625
626    #[test]
627    fn test_serialize_deserialize() {
628        let result = create_result("test data for serialization");
629        let serialized = serialize_result(&result);
630        let deserialized = deserialize_result(&serialized).unwrap();
631
632        assert_eq!(deserialized.data, result.data);
633        assert_eq!(deserialized.row_count, result.row_count);
634        assert_eq!(deserialized.ttl.as_secs(), result.ttl.as_secs());
635    }
636}