Skip to main content

heliosdb_proxy/cache/
mod.rs

1//! Query Caching Module
2//!
3//! Provides multi-tier query caching for HeliosProxy:
4//!
5//! - **L1 Hot Cache**: Per-connection, exact match, LRU eviction
6//! - **L2 Warm Cache**: Shared, normalized queries, configurable storage
7//! - **L3 Semantic Cache**: Vector similarity for AI workloads
8//!
9//! # Architecture
10//!
11//! ```text
12//!                     ┌─────────────────────────────────────────────────┐
13//!                     │                QUERY CACHE LAYER                 │
14//!                     │                                                  │
15//!   Query ───────────►│ ┌──────────────────────────────────────────────┐│
16//!                     ││ L1: Hot Cache (in-memory, <1ms)               ││
17//!                     │└──────────────────────────────────────────────┘│
18//!                     │         │ miss                                  │
19//!                     │         ▼                                       │
20//!                     │ ┌──────────────────────────────────────────────┐│
21//!                     ││ L2: Warm Cache (shared memory, <5ms)          ││
22//!                     │└──────────────────────────────────────────────┘│
23//!                     │         │ miss                                  │
24//!                     │         ▼                                       │
25//!                     │ ┌──────────────────────────────────────────────┐│
26//!                     ││ L3: Semantic Cache (vector similarity, <20ms) ││
27//!                     │└──────────────────────────────────────────────┘│
28//!                     │         │ miss                                  │
29//!                     │         ▼                                       │
30//!                     │       BACKEND                                   │
31//!                     └─────────────────────────────────────────────────┘
32//! ```
33//!
34//! # Usage
35//!
36//! ```rust,ignore
37//! use heliosdb_lite::proxy::cache::{QueryCache, CacheConfig};
38//!
39//! let config = CacheConfig::default();
40//! let cache = QueryCache::new(config);
41//!
42//! // Check cache before executing query
43//! if let Some(result) = cache.get(&query, &context).await {
44//!     return result;
45//! }
46//!
47//! // Execute query and cache result
48//! let result = execute_query(&query).await?;
49//! cache.put(&query, &context, result.clone()).await;
50//! ```
51
52pub mod config;
53pub mod hints;
54pub mod invalidation;
55pub mod l1_hot;
56pub mod l2_warm;
57pub mod l3_semantic;
58pub mod metrics;
59pub mod normalizer;
60pub mod result;
61
62// Re-exports
63pub use config::{CacheConfig, L1Config, L2Config, L3Config, StorageBackend};
64pub use hints::{parse_cache_hints, CacheHint};
65pub use invalidation::{InvalidationManager, InvalidationMode};
66pub use l1_hot::L1HotCache;
67pub use l2_warm::L2WarmCache;
68pub use l3_semantic::L3SemanticCache;
69pub use metrics::{CacheMetrics, CacheStatsLevelSnapshot, CacheStatsSnapshot};
70pub use normalizer::{NormalizedQuery, QueryNormalizer};
71pub use result::{CacheKey, CachedResult};
72
73use bytes::Bytes;
74use dashmap::DashMap;
75use std::sync::Arc;
76use std::time::{Duration, Instant};
77
78/// Query cache context (for cache key generation)
79#[derive(Debug, Clone, Hash, Eq, PartialEq)]
80pub struct CacheContext {
81    /// Database name
82    pub database: String,
83    /// Username (for RLS)
84    pub user: Option<String>,
85    /// Branch name (for HeliosDB branching)
86    pub branch: Option<String>,
87    /// Connection ID (for L1 cache)
88    pub connection_id: Option<u64>,
89}
90
91impl Default for CacheContext {
92    fn default() -> Self {
93        Self {
94            database: "default".to_string(),
95            user: None,
96            branch: None,
97            connection_id: None,
98        }
99    }
100}
101
102/// Cache lookup result
103#[derive(Debug)]
104pub enum CacheLookup {
105    /// Cache hit with result
106    Hit {
107        result: CachedResult,
108        level: CacheLevel,
109    },
110    /// Cache miss
111    Miss,
112}
113
114/// Cache level indicator
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum CacheLevel {
117    L1Hot,
118    L2Warm,
119    L3Semantic,
120}
121
122impl std::fmt::Display for CacheLevel {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            CacheLevel::L1Hot => write!(f, "L1"),
126            CacheLevel::L2Warm => write!(f, "L2"),
127            CacheLevel::L3Semantic => write!(f, "L3"),
128        }
129    }
130}
131
132/// Main query cache implementation
133pub struct QueryCache {
134    /// Configuration
135    config: CacheConfig,
136
137    /// L1: Per-connection hot cache (exact match)
138    l1_caches: DashMap<u64, Arc<L1HotCache>>,
139
140    /// L2: Shared normalized cache
141    l2_cache: Option<Arc<L2WarmCache>>,
142
143    /// L3: Semantic similarity cache
144    l3_cache: Option<Arc<L3SemanticCache>>,
145
146    /// Query normalizer
147    normalizer: Arc<QueryNormalizer>,
148
149    /// Cache invalidation manager
150    invalidator: Arc<InvalidationManager>,
151
152    /// Metrics collector
153    metrics: Arc<CacheMetrics>,
154
155    /// Request coalescing for cache stampede prevention
156    #[allow(dead_code)]
157    pending_requests: DashMap<CacheKey, Arc<tokio::sync::Notify>>,
158}
159
160impl QueryCache {
161    /// Create a new query cache with the given configuration
162    pub fn new(config: CacheConfig) -> Self {
163        let l2_cache = if config.l2.enabled {
164            Some(Arc::new(L2WarmCache::new(config.l2.clone())))
165        } else {
166            None
167        };
168
169        let l3_cache = if config.l3.enabled {
170            Some(Arc::new(L3SemanticCache::new(config.l3.clone())))
171        } else {
172            None
173        };
174
175        let invalidator = Arc::new(InvalidationManager::new(config.invalidation.clone()));
176
177        Self {
178            config: config.clone(),
179            l1_caches: DashMap::new(),
180            l2_cache,
181            l3_cache,
182            normalizer: Arc::new(QueryNormalizer::new()),
183            invalidator,
184            metrics: Arc::new(CacheMetrics::new()),
185            pending_requests: DashMap::new(),
186        }
187    }
188
189    /// Get or create L1 cache for a connection
190    pub fn get_l1_cache(&self, connection_id: u64) -> Arc<L1HotCache> {
191        self.l1_caches
192            .entry(connection_id)
193            .or_insert_with(|| Arc::new(L1HotCache::new(self.config.l1.clone())))
194            .clone()
195    }
196
197    /// Remove L1 cache for a connection (on disconnect)
198    pub fn remove_l1_cache(&self, connection_id: u64) {
199        self.l1_caches.remove(&connection_id);
200    }
201
202    /// Look up a query in the cache hierarchy
203    pub async fn get(&self, query: &str, context: &CacheContext) -> CacheLookup {
204        // Parse cache hints
205        let hints = parse_cache_hints(query);
206
207        // Skip cache if hint says so
208        if hints.skip {
209            self.metrics.record_skip();
210            return CacheLookup::Miss;
211        }
212
213        let start = Instant::now();
214
215        // L1: Check hot cache (exact match)
216        if self.config.l1.enabled {
217            if let Some(conn_id) = context.connection_id {
218                let l1 = self.get_l1_cache(conn_id);
219                if let Some(result) = l1.get(query) {
220                    self.metrics.record_hit(CacheLevel::L1Hot, start.elapsed());
221                    return CacheLookup::Hit {
222                        result,
223                        level: CacheLevel::L1Hot,
224                    };
225                }
226            }
227        }
228
229        // Normalize query for L2/L3 lookup
230        let normalized = self.normalizer.normalize(query);
231        let cache_key = CacheKey::new(&normalized, context);
232
233        // L2: Check warm cache (normalized match)
234        if let Some(ref l2) = self.l2_cache {
235            if let Some(result) = l2.get(&cache_key).await {
236                self.metrics.record_hit(CacheLevel::L2Warm, start.elapsed());
237
238                // Promote to L1
239                if self.config.l1.enabled {
240                    if let Some(conn_id) = context.connection_id {
241                        let l1 = self.get_l1_cache(conn_id);
242                        l1.put(query.to_string(), result.clone());
243                    }
244                }
245
246                return CacheLookup::Hit {
247                    result,
248                    level: CacheLevel::L2Warm,
249                };
250            }
251        }
252
253        // L3: Check semantic cache (similarity match)
254        if hints.semantic_cache {
255            if let Some(ref l3) = self.l3_cache {
256                if let Some(result) = l3.get(query, context).await {
257                    self.metrics
258                        .record_hit(CacheLevel::L3Semantic, start.elapsed());
259                    return CacheLookup::Hit {
260                        result,
261                        level: CacheLevel::L3Semantic,
262                    };
263                }
264            }
265        }
266
267        self.metrics.record_miss(start.elapsed());
268        CacheLookup::Miss
269    }
270
271    /// Store a query result in the cache
272    pub async fn put(
273        &self,
274        query: &str,
275        context: &CacheContext,
276        data: Bytes,
277        row_count: usize,
278        execution_time: Duration,
279    ) {
280        // Parse cache hints
281        let hints = parse_cache_hints(query);
282
283        // Skip if hint says so
284        if hints.skip {
285            return;
286        }
287
288        // Normalize query
289        let normalized = self.normalizer.normalize(query);
290
291        // Determine TTL
292        let ttl = hints
293            .ttl
294            .unwrap_or_else(|| self.get_table_ttl(&normalized.tables));
295
296        // Check size limit
297        if data.len() > self.config.max_result_size {
298            self.metrics.record_size_exceeded();
299            return;
300        }
301
302        // Create cached result
303        let result = CachedResult {
304            data,
305            row_count,
306            cached_at: Instant::now(),
307            ttl,
308            tables: normalized.tables.clone(),
309            execution_time,
310        };
311
312        // Store in L1 (exact match)
313        if self.config.l1.enabled {
314            if let Some(conn_id) = context.connection_id {
315                let l1 = self.get_l1_cache(conn_id);
316                l1.put(query.to_string(), result.clone());
317            }
318        }
319
320        // Store in L2 (normalized)
321        if let Some(ref l2) = self.l2_cache {
322            let cache_key = CacheKey::new(&normalized, context);
323            l2.put(cache_key.clone(), result.clone()).await;
324
325            // Register for invalidation
326            for table in &normalized.tables {
327                self.invalidator.register(&cache_key, table);
328            }
329        }
330
331        // Store in L3 (semantic) if hint enabled
332        if hints.semantic_cache {
333            if let Some(ref l3) = self.l3_cache {
334                l3.put(query, context, result).await;
335            }
336        }
337
338        self.metrics.record_put();
339    }
340
341    /// Invalidate any cached results that reference a table written by `sql`.
342    /// Normalizes the (write) query to extract its tables, then drops their
343    /// cached entries.
344    pub async fn invalidate_query(&self, sql: &str) {
345        let normalized = self.normalizer.normalize(sql);
346        if !normalized.tables.is_empty() {
347            self.invalidate_tables(&normalized.tables).await;
348        }
349    }
350
351    /// Invalidate cache entries for specific tables
352    pub async fn invalidate_tables(&self, tables: &[String]) {
353        for table in tables {
354            let keys = self.invalidator.get_keys_for_table(table);
355
356            // Invalidate L2
357            if let Some(ref l2) = self.l2_cache {
358                for key in &keys {
359                    l2.remove(key).await;
360                }
361            }
362
363            self.invalidator.invalidate_table(table);
364        }
365
366        // L1 caches are invalidated on next access (TTL-based)
367        // L3 semantic cache has its own TTL handling
368
369        self.metrics.record_invalidation(tables.len());
370    }
371
372    /// Clear all caches
373    pub async fn clear(&self, levels: &[CacheLevel]) {
374        for level in levels {
375            match level {
376                CacheLevel::L1Hot => {
377                    self.l1_caches.clear();
378                }
379                CacheLevel::L2Warm => {
380                    if let Some(ref l2) = self.l2_cache {
381                        l2.clear().await;
382                    }
383                }
384                CacheLevel::L3Semantic => {
385                    if let Some(ref l3) = self.l3_cache {
386                        l3.clear().await;
387                    }
388                }
389            }
390        }
391
392        self.metrics.record_clear();
393    }
394
395    /// Get cache statistics
396    pub fn stats(&self) -> CacheStatsSnapshot {
397        self.metrics.snapshot()
398    }
399
400    /// Get configuration
401    pub fn config(&self) -> &CacheConfig {
402        &self.config
403    }
404
405    /// Get the invalidation manager (for WAL subscription)
406    pub fn invalidator(&self) -> Arc<InvalidationManager> {
407        self.invalidator.clone()
408    }
409
410    /// Get table-specific TTL or default
411    fn get_table_ttl(&self, tables: &[String]) -> Duration {
412        // Find shortest TTL among tables
413        let mut min_ttl = self.config.default_ttl;
414
415        for table in tables {
416            if let Some(table_config) = self.config.table_configs.get(table) {
417                if table_config.ttl < min_ttl {
418                    min_ttl = table_config.ttl;
419                }
420            }
421        }
422
423        min_ttl
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_cache_context_default() {
433        let ctx = CacheContext::default();
434        assert_eq!(ctx.database, "default");
435        assert!(ctx.user.is_none());
436        assert!(ctx.branch.is_none());
437        assert!(ctx.connection_id.is_none());
438    }
439
440    #[test]
441    fn test_cache_level_display() {
442        assert_eq!(format!("{}", CacheLevel::L1Hot), "L1");
443        assert_eq!(format!("{}", CacheLevel::L2Warm), "L2");
444        assert_eq!(format!("{}", CacheLevel::L3Semantic), "L3");
445    }
446
447    #[tokio::test]
448    async fn test_query_cache_creation() {
449        let config = CacheConfig::default();
450        let cache = QueryCache::new(config);
451
452        assert!(cache.config.l1.enabled);
453        assert!(cache.config.l2.enabled);
454    }
455
456    #[tokio::test]
457    async fn test_l1_cache_per_connection() {
458        let config = CacheConfig::default();
459        let cache = QueryCache::new(config);
460
461        let l1_a = cache.get_l1_cache(1);
462        let l1_b = cache.get_l1_cache(2);
463        let l1_a2 = cache.get_l1_cache(1);
464
465        // Same connection should get same cache
466        assert!(Arc::ptr_eq(&l1_a, &l1_a2));
467        // Different connections should get different caches
468        assert!(!Arc::ptr_eq(&l1_a, &l1_b));
469    }
470
471    #[tokio::test]
472    async fn test_cache_miss() {
473        let config = CacheConfig::default();
474        let cache = QueryCache::new(config);
475        let context = CacheContext::default();
476
477        let result = cache.get("SELECT * FROM users", &context).await;
478        assert!(matches!(result, CacheLookup::Miss));
479    }
480}