Skip to main content

heliosdb_proxy/cache/
invalidation.rs

1//! Cache Invalidation
2//!
3//! Manages cache invalidation through multiple strategies:
4//! - WAL-based: Subscribe to WAL events for real-time invalidation
5//! - TTL-based: Time-based expiration fallback
6//! - Manual: Explicit invalidation via API
7
8use std::collections::HashSet;
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use tokio::sync::broadcast;
14
15use super::config::InvalidationConfig;
16use super::result::CacheKey;
17
18/// Cache invalidation mode
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum InvalidationMode {
21    /// WAL-based invalidation (real-time)
22    #[default]
23    Wal,
24
25    /// TTL-based invalidation only
26    TtlOnly,
27
28    /// Manual invalidation only
29    ManualOnly,
30
31    /// Combined WAL + TTL fallback
32    WalWithTtlFallback,
33}
34
35/// Cache invalidation manager
36///
37/// Tracks table -> cache key mappings and handles invalidation events.
38#[derive(Debug)]
39pub struct InvalidationManager {
40    /// Configuration
41    config: InvalidationConfig,
42
43    /// Table -> cache keys mapping
44    table_keys: DashMap<String, HashSet<CacheKey>>,
45
46    /// Cache key -> tables mapping (reverse index)
47    key_tables: DashMap<CacheKey, HashSet<String>>,
48
49    /// Last invalidation time per table
50    last_invalidation: DashMap<String, Instant>,
51
52    /// Invalidation event sender
53    event_tx: broadcast::Sender<InvalidationEvent>,
54
55    /// WAL subscription status
56    wal_connected: std::sync::atomic::AtomicBool,
57
58    /// Pending invalidations (batched)
59    pending_invalidations: RwLock<HashSet<String>>,
60
61    /// Batch timer
62    last_batch_flush: RwLock<Instant>,
63}
64
65/// Invalidation event
66#[derive(Debug, Clone)]
67pub enum InvalidationEvent {
68    /// Invalidate specific tables
69    Tables(Vec<String>),
70
71    /// Invalidate specific cache keys
72    Keys(Vec<CacheKey>),
73
74    /// Invalidate all caches
75    All,
76
77    /// WAL event received
78    WalEvent {
79        table: String,
80        operation: WalOperation,
81        lsn: u64,
82    },
83}
84
85/// WAL operation type
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WalOperation {
88    Insert,
89    Update,
90    Delete,
91    Truncate,
92}
93
94impl InvalidationManager {
95    /// Create a new invalidation manager
96    pub fn new(config: InvalidationConfig) -> Self {
97        let (event_tx, _) = broadcast::channel(1024);
98
99        Self {
100            config,
101            table_keys: DashMap::new(),
102            key_tables: DashMap::new(),
103            last_invalidation: DashMap::new(),
104            event_tx,
105            wal_connected: std::sync::atomic::AtomicBool::new(false),
106            pending_invalidations: RwLock::new(HashSet::new()),
107            last_batch_flush: RwLock::new(Instant::now()),
108        }
109    }
110
111    /// Register a cache key for table-based invalidation
112    pub fn register(&self, key: &CacheKey, table: &str) {
113        // Add to table -> keys mapping
114        self.table_keys
115            .entry(table.to_string())
116            .or_insert_with(HashSet::new)
117            .insert(key.clone());
118
119        // Add to key -> tables mapping
120        self.key_tables
121            .entry(key.clone())
122            .or_insert_with(HashSet::new)
123            .insert(table.to_string());
124    }
125
126    /// Unregister a cache key
127    pub fn unregister(&self, key: &CacheKey) {
128        if let Some((_, tables)) = self.key_tables.remove(key) {
129            for table in tables {
130                if let Some(mut keys) = self.table_keys.get_mut(&table) {
131                    keys.remove(key);
132                }
133            }
134        }
135    }
136
137    /// Get all cache keys associated with a table
138    pub fn get_keys_for_table(&self, table: &str) -> Vec<CacheKey> {
139        self.table_keys
140            .get(table)
141            .map(|keys| keys.iter().cloned().collect())
142            .unwrap_or_default()
143    }
144
145    /// Get all tables associated with a cache key
146    pub fn get_tables_for_key(&self, key: &CacheKey) -> Vec<String> {
147        self.key_tables
148            .get(key)
149            .map(|tables| tables.iter().cloned().collect())
150            .unwrap_or_default()
151    }
152
153    /// Invalidate all cache entries for a table
154    pub fn invalidate_table(&self, table: &str) {
155        // Record invalidation time
156        self.last_invalidation.insert(table.to_string(), Instant::now());
157
158        // Send event
159        let _ = self.event_tx.send(InvalidationEvent::Tables(vec![table.to_string()]));
160
161        // Clear table -> keys mapping
162        if let Some((_, keys)) = self.table_keys.remove(table) {
163            for key in keys {
164                if let Some(mut tables) = self.key_tables.get_mut(&key) {
165                    tables.remove(table);
166                }
167            }
168        }
169    }
170
171    /// Invalidate multiple tables
172    pub fn invalidate_tables(&self, tables: &[String]) {
173        for table in tables {
174            self.invalidate_table(table);
175        }
176    }
177
178    /// Queue a table for batched invalidation
179    pub fn queue_invalidation(&self, table: &str) {
180        if let Ok(mut pending) = self.pending_invalidations.write() {
181            pending.insert(table.to_string());
182        }
183
184        // Check if we should flush the batch
185        self.maybe_flush_batch();
186    }
187
188    /// Flush pending invalidations
189    pub fn flush_pending(&self) {
190        let tables: Vec<String> = {
191            let mut pending = match self.pending_invalidations.write() {
192                Ok(p) => p,
193                Err(_) => return,
194            };
195
196            let tables: Vec<String> = pending.drain().collect();
197            tables
198        };
199
200        if !tables.is_empty() {
201            self.invalidate_tables(&tables);
202        }
203
204        if let Ok(mut last) = self.last_batch_flush.write() {
205            *last = Instant::now();
206        }
207    }
208
209    /// Check if batch should be flushed
210    fn maybe_flush_batch(&self) {
211        let should_flush = {
212            let last = match self.last_batch_flush.read() {
213                Ok(l) => *l,
214                Err(_) => return,
215            };
216
217            let pending_count = self.pending_invalidations
218                .read()
219                .map(|p| p.len())
220                .unwrap_or(0);
221
222            // Flush if batch is large or time threshold exceeded
223            pending_count >= 100 || last.elapsed() > Duration::from_millis(50)
224        };
225
226        if should_flush {
227            self.flush_pending();
228        }
229    }
230
231    /// Handle a WAL event
232    pub fn on_wal_event(&self, table: &str, operation: WalOperation, lsn: u64) {
233        // Send WAL event
234        let _ = self.event_tx.send(InvalidationEvent::WalEvent {
235            table: table.to_string(),
236            operation,
237            lsn,
238        });
239
240        // Queue for batched invalidation
241        self.queue_invalidation(table);
242    }
243
244    /// Subscribe to invalidation events
245    pub fn subscribe(&self) -> broadcast::Receiver<InvalidationEvent> {
246        self.event_tx.subscribe()
247    }
248
249    /// Check if WAL is connected
250    pub fn is_wal_connected(&self) -> bool {
251        self.wal_connected.load(std::sync::atomic::Ordering::Relaxed)
252    }
253
254    /// Set WAL connection status
255    pub fn set_wal_connected(&self, connected: bool) {
256        self.wal_connected.store(connected, std::sync::atomic::Ordering::Relaxed);
257    }
258
259    /// Get invalidation mode
260    pub fn mode(&self) -> InvalidationMode {
261        self.config.mode
262    }
263
264    /// Get last invalidation time for a table
265    pub fn last_invalidation_time(&self, table: &str) -> Option<Instant> {
266        self.last_invalidation.get(table).map(|t| *t)
267    }
268
269    /// Get statistics
270    pub fn stats(&self) -> InvalidationStats {
271        let total_keys: usize = self.table_keys
272            .iter()
273            .map(|e| e.value().len())
274            .sum();
275
276        InvalidationStats {
277            tracked_tables: self.table_keys.len(),
278            tracked_keys: total_keys,
279            pending_invalidations: self.pending_invalidations
280                .read()
281                .map(|p| p.len())
282                .unwrap_or(0),
283            wal_connected: self.is_wal_connected(),
284            mode: self.config.mode,
285        }
286    }
287
288    /// Clear all tracking data
289    pub fn clear(&self) {
290        self.table_keys.clear();
291        self.key_tables.clear();
292        self.last_invalidation.clear();
293
294        if let Ok(mut pending) = self.pending_invalidations.write() {
295            pending.clear();
296        }
297    }
298}
299
300/// Invalidation statistics
301#[derive(Debug, Clone)]
302pub struct InvalidationStats {
303    /// Number of tracked tables
304    pub tracked_tables: usize,
305
306    /// Number of tracked cache keys
307    pub tracked_keys: usize,
308
309    /// Number of pending invalidations
310    pub pending_invalidations: usize,
311
312    /// Whether WAL is connected
313    pub wal_connected: bool,
314
315    /// Current invalidation mode
316    pub mode: InvalidationMode,
317}
318
319/// WAL event parser
320pub struct WalEventParser;
321
322impl WalEventParser {
323    /// Parse a WAL message into an invalidation event
324    pub fn parse(message: &[u8]) -> Option<(String, WalOperation, u64)> {
325        // Simple format: "OP:TABLE:LSN"
326        let text = std::str::from_utf8(message).ok()?;
327        let parts: Vec<&str> = text.split(':').collect();
328
329        if parts.len() < 3 {
330            return None;
331        }
332
333        let operation = match parts[0].to_uppercase().as_str() {
334            "I" | "INSERT" => WalOperation::Insert,
335            "U" | "UPDATE" => WalOperation::Update,
336            "D" | "DELETE" => WalOperation::Delete,
337            "T" | "TRUNCATE" => WalOperation::Truncate,
338            _ => return None,
339        };
340
341        let table = parts[1].to_string();
342        let lsn = parts[2].parse().ok()?;
343
344        Some((table, operation, lsn))
345    }
346
347    /// Extract affected tables from SQL
348    pub fn extract_affected_tables(sql: &str) -> Vec<String> {
349        let sql_upper = sql.to_uppercase();
350        let mut tables = Vec::new();
351
352        // Simple extraction (more sophisticated parsing would use sqlparser)
353        let patterns = [
354            (r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
355            (r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
356            (r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
357            (r"TRUNCATE\s+(?:TABLE\s+)?([a-zA-Z_][a-zA-Z0-9_]*)", 1),
358        ];
359
360        for (pattern, group) in patterns {
361            if let Ok(re) = regex::Regex::new(pattern) {
362                for cap in re.captures_iter(&sql_upper) {
363                    if let Some(m) = cap.get(group) {
364                        tables.push(m.as_str().to_lowercase());
365                    }
366                }
367            }
368        }
369
370        tables
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    fn create_key(hash: u64) -> CacheKey {
379        CacheKey::from_parts(hash, "test".to_string(), None, None)
380    }
381
382    #[test]
383    fn test_register_and_lookup() {
384        let config = InvalidationConfig::default();
385        let manager = InvalidationManager::new(config);
386
387        let key1 = create_key(111);
388        let key2 = create_key(222);
389
390        manager.register(&key1, "users");
391        manager.register(&key2, "users");
392        manager.register(&key1, "sessions");
393
394        // Check table -> keys
395        let user_keys = manager.get_keys_for_table("users");
396        assert_eq!(user_keys.len(), 2);
397        assert!(user_keys.contains(&key1));
398        assert!(user_keys.contains(&key2));
399
400        // Check key -> tables
401        let key1_tables = manager.get_tables_for_key(&key1);
402        assert_eq!(key1_tables.len(), 2);
403        assert!(key1_tables.contains(&"users".to_string()));
404        assert!(key1_tables.contains(&"sessions".to_string()));
405    }
406
407    #[test]
408    fn test_unregister() {
409        let config = InvalidationConfig::default();
410        let manager = InvalidationManager::new(config);
411
412        let key = create_key(111);
413        manager.register(&key, "users");
414        manager.register(&key, "sessions");
415
416        manager.unregister(&key);
417
418        assert!(manager.get_keys_for_table("users").is_empty());
419        assert!(manager.get_keys_for_table("sessions").is_empty());
420        assert!(manager.get_tables_for_key(&key).is_empty());
421    }
422
423    #[test]
424    fn test_invalidate_table() {
425        let config = InvalidationConfig::default();
426        let manager = InvalidationManager::new(config);
427
428        let key1 = create_key(111);
429        let key2 = create_key(222);
430
431        manager.register(&key1, "users");
432        manager.register(&key2, "orders");
433
434        manager.invalidate_table("users");
435
436        assert!(manager.get_keys_for_table("users").is_empty());
437        assert!(!manager.get_keys_for_table("orders").is_empty());
438        assert!(manager.last_invalidation_time("users").is_some());
439    }
440
441    #[test]
442    fn test_queue_and_flush() {
443        let config = InvalidationConfig::default();
444        let manager = InvalidationManager::new(config);
445
446        let key = create_key(111);
447        manager.register(&key, "users");
448
449        manager.queue_invalidation("users");
450
451        {
452            let pending = manager.pending_invalidations.read().unwrap();
453            assert!(pending.contains("users"));
454        }
455
456        manager.flush_pending();
457
458        {
459            let pending = manager.pending_invalidations.read().unwrap();
460            assert!(pending.is_empty());
461        }
462
463        assert!(manager.get_keys_for_table("users").is_empty());
464    }
465
466    #[test]
467    fn test_wal_event() {
468        let config = InvalidationConfig::default();
469        let manager = InvalidationManager::new(config);
470
471        let key = create_key(111);
472        manager.register(&key, "users");
473
474        let mut receiver = manager.subscribe();
475
476        manager.on_wal_event("users", WalOperation::Update, 12345);
477
478        // Flush to process the event
479        manager.flush_pending();
480
481        // Should have received the WAL event
482        let event = receiver.try_recv();
483        assert!(event.is_ok());
484    }
485
486    #[test]
487    fn test_stats() {
488        let config = InvalidationConfig::default();
489        let manager = InvalidationManager::new(config);
490
491        let key1 = create_key(111);
492        let key2 = create_key(222);
493
494        manager.register(&key1, "users");
495        manager.register(&key2, "orders");
496
497        let stats = manager.stats();
498        assert_eq!(stats.tracked_tables, 2);
499        assert_eq!(stats.tracked_keys, 2);
500    }
501
502    #[test]
503    fn test_wal_event_parser() {
504        // Test message parsing
505        let (table, op, lsn) = WalEventParser::parse(b"INSERT:users:12345").unwrap();
506        assert_eq!(table, "users");
507        assert_eq!(op, WalOperation::Insert);
508        assert_eq!(lsn, 12345);
509
510        let (table, op, _) = WalEventParser::parse(b"U:orders:67890").unwrap();
511        assert_eq!(table, "orders");
512        assert_eq!(op, WalOperation::Update);
513    }
514
515    #[test]
516    fn test_extract_affected_tables() {
517        let tests = vec![
518            ("INSERT INTO users VALUES (1)", vec!["users"]),
519            ("UPDATE orders SET status = 'done'", vec!["orders"]),
520            ("DELETE FROM sessions WHERE expired", vec!["sessions"]),
521            ("TRUNCATE TABLE logs", vec!["logs"]),
522            ("TRUNCATE products", vec!["products"]),
523        ];
524
525        for (sql, expected) in tests {
526            let tables = WalEventParser::extract_affected_tables(sql);
527            assert_eq!(tables, expected, "Failed for SQL: {}", sql);
528        }
529    }
530
531    #[test]
532    fn test_clear() {
533        let config = InvalidationConfig::default();
534        let manager = InvalidationManager::new(config);
535
536        manager.register(&create_key(111), "users");
537        manager.queue_invalidation("users");
538
539        manager.clear();
540
541        assert_eq!(manager.stats().tracked_tables, 0);
542        assert_eq!(manager.stats().tracked_keys, 0);
543        assert_eq!(manager.stats().pending_invalidations, 0);
544    }
545}