Skip to main content

heliosdb_proxy/cache/
mod.rs

1//! Query Caching Module
2//!
3//! Provides multi-tier query caching for HeliosProxy:
4//!
5//! - **L1 Hot Cache**: Per-connection, exact match, LRU eviction
6//! - **L2 Warm Cache**: Shared, normalized queries, configurable storage
7//! - **L3 Semantic Cache**: Vector similarity for AI workloads
8//!
9//! # Architecture
10//!
11//! ```text
12//!                     ┌─────────────────────────────────────────────────┐
13//!                     │                QUERY CACHE LAYER                 │
14//!                     │                                                  │
15//!   Query ───────────►│ ┌──────────────────────────────────────────────┐│
16//!                     ││ L1: Hot Cache (in-memory, <1ms)               ││
17//!                     │└──────────────────────────────────────────────┘│
18//!                     │         │ miss                                  │
19//!                     │         ▼                                       │
20//!                     │ ┌──────────────────────────────────────────────┐│
21//!                     ││ L2: Warm Cache (shared memory, <5ms)          ││
22//!                     │└──────────────────────────────────────────────┘│
23//!                     │         │ miss                                  │
24//!                     │         ▼                                       │
25//!                     │ ┌──────────────────────────────────────────────┐│
26//!                     ││ L3: Semantic Cache (vector similarity, <20ms) ││
27//!                     │└──────────────────────────────────────────────┘│
28//!                     │         │ miss                                  │
29//!                     │         ▼                                       │
30//!                     │       BACKEND                                   │
31//!                     └─────────────────────────────────────────────────┘
32//! ```
33//!
34//! # Usage
35//!
36//! ```rust,ignore
37//! use heliosdb_lite::proxy::cache::{QueryCache, CacheConfig};
38//!
39//! let config = CacheConfig::default();
40//! let cache = QueryCache::new(config);
41//!
42//! // Check cache before executing query
43//! if let Some(result) = cache.get(&query, &context).await {
44//!     return result;
45//! }
46//!
47//! // Execute query and cache result
48//! let result = execute_query(&query).await?;
49//! cache.put(&query, &context, result.clone()).await;
50//! ```
51
52pub mod config;
53pub mod l1_hot;
54pub mod l2_warm;
55pub mod l3_semantic;
56pub mod normalizer;
57pub mod invalidation;
58pub mod metrics;
59pub mod hints;
60pub mod result;
61
62// Re-exports
63pub use config::{CacheConfig, L1Config, L2Config, L3Config, StorageBackend};
64pub use l1_hot::L1HotCache;
65pub use l2_warm::L2WarmCache;
66pub use l3_semantic::L3SemanticCache;
67pub use normalizer::{QueryNormalizer, NormalizedQuery};
68pub use invalidation::{InvalidationManager, InvalidationMode};
69pub use metrics::{CacheMetrics, CacheStatsSnapshot, CacheStatsLevelSnapshot};
70pub use hints::{CacheHint, parse_cache_hints};
71pub use result::{CachedResult, CacheKey};
72
73use bytes::Bytes;
74use dashmap::DashMap;
75use std::sync::Arc;
76use std::time::{Duration, Instant};
77
78/// Query cache context (for cache key generation)
79#[derive(Debug, Clone, Hash, Eq, PartialEq)]
80pub struct CacheContext {
81    /// Database name
82    pub database: String,
83    /// Username (for RLS)
84    pub user: Option<String>,
85    /// Branch name (for HeliosDB branching)
86    pub branch: Option<String>,
87    /// Connection ID (for L1 cache)
88    pub connection_id: Option<u64>,
89}
90
91impl Default for CacheContext {
92    fn default() -> Self {
93        Self {
94            database: "default".to_string(),
95            user: None,
96            branch: None,
97            connection_id: None,
98        }
99    }
100}
101
102/// Cache lookup result
103#[derive(Debug)]
104pub enum CacheLookup {
105    /// Cache hit with result
106    Hit {
107        result: CachedResult,
108        level: CacheLevel,
109    },
110    /// Cache miss
111    Miss,
112}
113
114/// Cache level indicator
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum CacheLevel {
117    L1Hot,
118    L2Warm,
119    L3Semantic,
120}
121
122impl std::fmt::Display for CacheLevel {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            CacheLevel::L1Hot => write!(f, "L1"),
126            CacheLevel::L2Warm => write!(f, "L2"),
127            CacheLevel::L3Semantic => write!(f, "L3"),
128        }
129    }
130}
131
132/// Main query cache implementation
133pub struct QueryCache {
134    /// Configuration
135    config: CacheConfig,
136
137    /// L1: Per-connection hot cache (exact match)
138    l1_caches: DashMap<u64, Arc<L1HotCache>>,
139
140    /// L2: Shared normalized cache
141    l2_cache: Option<Arc<L2WarmCache>>,
142
143    /// L3: Semantic similarity cache
144    l3_cache: Option<Arc<L3SemanticCache>>,
145
146    /// Query normalizer
147    normalizer: Arc<QueryNormalizer>,
148
149    /// Cache invalidation manager
150    invalidator: Arc<InvalidationManager>,
151
152    /// Metrics collector
153    metrics: Arc<CacheMetrics>,
154
155    /// Request coalescing for cache stampede prevention
156    pending_requests: DashMap<CacheKey, Arc<tokio::sync::Notify>>,
157}
158
159impl QueryCache {
160    /// Create a new query cache with the given configuration
161    pub fn new(config: CacheConfig) -> Self {
162        let l2_cache = if config.l2.enabled {
163            Some(Arc::new(L2WarmCache::new(config.l2.clone())))
164        } else {
165            None
166        };
167
168        let l3_cache = if config.l3.enabled {
169            Some(Arc::new(L3SemanticCache::new(config.l3.clone())))
170        } else {
171            None
172        };
173
174        let invalidator = Arc::new(InvalidationManager::new(config.invalidation.clone()));
175
176        Self {
177            config: config.clone(),
178            l1_caches: DashMap::new(),
179            l2_cache,
180            l3_cache,
181            normalizer: Arc::new(QueryNormalizer::new()),
182            invalidator,
183            metrics: Arc::new(CacheMetrics::new()),
184            pending_requests: DashMap::new(),
185        }
186    }
187
188    /// Get or create L1 cache for a connection
189    pub fn get_l1_cache(&self, connection_id: u64) -> Arc<L1HotCache> {
190        self.l1_caches
191            .entry(connection_id)
192            .or_insert_with(|| Arc::new(L1HotCache::new(self.config.l1.clone())))
193            .clone()
194    }
195
196    /// Remove L1 cache for a connection (on disconnect)
197    pub fn remove_l1_cache(&self, connection_id: u64) {
198        self.l1_caches.remove(&connection_id);
199    }
200
201    /// Look up a query in the cache hierarchy
202    pub async fn get(&self, query: &str, context: &CacheContext) -> CacheLookup {
203        // Parse cache hints
204        let hints = parse_cache_hints(query);
205
206        // Skip cache if hint says so
207        if hints.skip {
208            self.metrics.record_skip();
209            return CacheLookup::Miss;
210        }
211
212        let start = Instant::now();
213
214        // L1: Check hot cache (exact match)
215        if self.config.l1.enabled {
216            if let Some(conn_id) = context.connection_id {
217                let l1 = self.get_l1_cache(conn_id);
218                if let Some(result) = l1.get(query) {
219                    self.metrics.record_hit(CacheLevel::L1Hot, start.elapsed());
220                    return CacheLookup::Hit {
221                        result,
222                        level: CacheLevel::L1Hot,
223                    };
224                }
225            }
226        }
227
228        // Normalize query for L2/L3 lookup
229        let normalized = self.normalizer.normalize(query);
230        let cache_key = CacheKey::new(&normalized, context);
231
232        // L2: Check warm cache (normalized match)
233        if let Some(ref l2) = self.l2_cache {
234            if let Some(result) = l2.get(&cache_key).await {
235                self.metrics.record_hit(CacheLevel::L2Warm, start.elapsed());
236
237                // Promote to L1
238                if self.config.l1.enabled {
239                    if let Some(conn_id) = context.connection_id {
240                        let l1 = self.get_l1_cache(conn_id);
241                        l1.put(query.to_string(), result.clone());
242                    }
243                }
244
245                return CacheLookup::Hit {
246                    result,
247                    level: CacheLevel::L2Warm,
248                };
249            }
250        }
251
252        // L3: Check semantic cache (similarity match)
253        if hints.semantic_cache {
254            if let Some(ref l3) = self.l3_cache {
255                if let Some(result) = l3.get(query, context).await {
256                    self.metrics.record_hit(CacheLevel::L3Semantic, start.elapsed());
257                    return CacheLookup::Hit {
258                        result,
259                        level: CacheLevel::L3Semantic,
260                    };
261                }
262            }
263        }
264
265        self.metrics.record_miss(start.elapsed());
266        CacheLookup::Miss
267    }
268
269    /// Store a query result in the cache
270    pub async fn put(
271        &self,
272        query: &str,
273        context: &CacheContext,
274        data: Bytes,
275        row_count: usize,
276        execution_time: Duration,
277    ) {
278        // Parse cache hints
279        let hints = parse_cache_hints(query);
280
281        // Skip if hint says so
282        if hints.skip {
283            return;
284        }
285
286        // Normalize query
287        let normalized = self.normalizer.normalize(query);
288
289        // Determine TTL
290        let ttl = hints.ttl.unwrap_or_else(|| {
291            self.get_table_ttl(&normalized.tables)
292        });
293
294        // Check size limit
295        if data.len() > self.config.max_result_size {
296            self.metrics.record_size_exceeded();
297            return;
298        }
299
300        // Create cached result
301        let result = CachedResult {
302            data,
303            row_count,
304            cached_at: Instant::now(),
305            ttl,
306            tables: normalized.tables.clone(),
307            execution_time,
308        };
309
310        // Store in L1 (exact match)
311        if self.config.l1.enabled {
312            if let Some(conn_id) = context.connection_id {
313                let l1 = self.get_l1_cache(conn_id);
314                l1.put(query.to_string(), result.clone());
315            }
316        }
317
318        // Store in L2 (normalized)
319        if let Some(ref l2) = self.l2_cache {
320            let cache_key = CacheKey::new(&normalized, context);
321            l2.put(cache_key.clone(), result.clone()).await;
322
323            // Register for invalidation
324            for table in &normalized.tables {
325                self.invalidator.register(&cache_key, table);
326            }
327        }
328
329        // Store in L3 (semantic) if hint enabled
330        if hints.semantic_cache {
331            if let Some(ref l3) = self.l3_cache {
332                l3.put(query, context, result).await;
333            }
334        }
335
336        self.metrics.record_put();
337    }
338
339    /// Invalidate cache entries for specific tables
340    pub async fn invalidate_tables(&self, tables: &[String]) {
341        for table in tables {
342            let keys = self.invalidator.get_keys_for_table(table);
343
344            // Invalidate L2
345            if let Some(ref l2) = self.l2_cache {
346                for key in &keys {
347                    l2.remove(key).await;
348                }
349            }
350
351            self.invalidator.invalidate_table(table);
352        }
353
354        // L1 caches are invalidated on next access (TTL-based)
355        // L3 semantic cache has its own TTL handling
356
357        self.metrics.record_invalidation(tables.len());
358    }
359
360    /// Clear all caches
361    pub async fn clear(&self, levels: &[CacheLevel]) {
362        for level in levels {
363            match level {
364                CacheLevel::L1Hot => {
365                    self.l1_caches.clear();
366                }
367                CacheLevel::L2Warm => {
368                    if let Some(ref l2) = self.l2_cache {
369                        l2.clear().await;
370                    }
371                }
372                CacheLevel::L3Semantic => {
373                    if let Some(ref l3) = self.l3_cache {
374                        l3.clear().await;
375                    }
376                }
377            }
378        }
379
380        self.metrics.record_clear();
381    }
382
383    /// Get cache statistics
384    pub fn stats(&self) -> CacheStatsSnapshot {
385        self.metrics.snapshot()
386    }
387
388    /// Get configuration
389    pub fn config(&self) -> &CacheConfig {
390        &self.config
391    }
392
393    /// Get the invalidation manager (for WAL subscription)
394    pub fn invalidator(&self) -> Arc<InvalidationManager> {
395        self.invalidator.clone()
396    }
397
398    /// Get table-specific TTL or default
399    fn get_table_ttl(&self, tables: &[String]) -> Duration {
400        // Find shortest TTL among tables
401        let mut min_ttl = self.config.default_ttl;
402
403        for table in tables {
404            if let Some(table_config) = self.config.table_configs.get(table) {
405                if table_config.ttl < min_ttl {
406                    min_ttl = table_config.ttl;
407                }
408            }
409        }
410
411        min_ttl
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_cache_context_default() {
421        let ctx = CacheContext::default();
422        assert_eq!(ctx.database, "default");
423        assert!(ctx.user.is_none());
424        assert!(ctx.branch.is_none());
425        assert!(ctx.connection_id.is_none());
426    }
427
428    #[test]
429    fn test_cache_level_display() {
430        assert_eq!(format!("{}", CacheLevel::L1Hot), "L1");
431        assert_eq!(format!("{}", CacheLevel::L2Warm), "L2");
432        assert_eq!(format!("{}", CacheLevel::L3Semantic), "L3");
433    }
434
435    #[tokio::test]
436    async fn test_query_cache_creation() {
437        let config = CacheConfig::default();
438        let cache = QueryCache::new(config);
439
440        assert!(cache.config.l1.enabled);
441        assert!(cache.config.l2.enabled);
442    }
443
444    #[tokio::test]
445    async fn test_l1_cache_per_connection() {
446        let config = CacheConfig::default();
447        let cache = QueryCache::new(config);
448
449        let l1_a = cache.get_l1_cache(1);
450        let l1_b = cache.get_l1_cache(2);
451        let l1_a2 = cache.get_l1_cache(1);
452
453        // Same connection should get same cache
454        assert!(Arc::ptr_eq(&l1_a, &l1_a2));
455        // Different connections should get different caches
456        assert!(!Arc::ptr_eq(&l1_a, &l1_b));
457    }
458
459    #[tokio::test]
460    async fn test_cache_miss() {
461        let config = CacheConfig::default();
462        let cache = QueryCache::new(config);
463        let context = CacheContext::default();
464
465        let result = cache.get("SELECT * FROM users", &context).await;
466        assert!(matches!(result, CacheLookup::Miss));
467    }
468}