Skip to main content

heliosdb_proxy/cache/
invalidation.rs

1//! Cache Invalidation
2//!
3//! Manages cache invalidation through multiple strategies:
4//! - WAL-based: Subscribe to WAL events for real-time invalidation
5//! - TTL-based: Time-based expiration fallback
6//! - Manual: Explicit invalidation via API
7
8use std::collections::HashSet;
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use tokio::sync::broadcast;
14
15use super::config::InvalidationConfig;
16use super::result::CacheKey;
17
18/// Cache invalidation mode
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum InvalidationMode {
21    /// WAL-based invalidation (real-time)
22    #[default]
23    Wal,
24
25    /// TTL-based invalidation only
26    TtlOnly,
27
28    /// Manual invalidation only
29    ManualOnly,
30
31    /// Combined WAL + TTL fallback
32    WalWithTtlFallback,
33}
34
35/// Cache invalidation manager
36///
37/// Tracks table -> cache key mappings and handles invalidation events.
38#[derive(Debug)]
39pub struct InvalidationManager {
40    /// Configuration
41    config: InvalidationConfig,
42
43    /// Table -> cache keys mapping
44    table_keys: DashMap<String, HashSet<CacheKey>>,
45
46    /// Cache key -> tables mapping (reverse index)
47    key_tables: DashMap<CacheKey, HashSet<String>>,
48
49    /// Last invalidation time per table
50    last_invalidation: DashMap<String, Instant>,
51
52    /// Invalidation event sender
53    event_tx: broadcast::Sender<InvalidationEvent>,
54
55    /// WAL subscription status
56    wal_connected: std::sync::atomic::AtomicBool,
57
58    /// Pending invalidations (batched)
59    pending_invalidations: RwLock<HashSet<String>>,
60
61    /// Batch timer
62    last_batch_flush: RwLock<Instant>,
63}
64
65/// Invalidation event
66#[derive(Debug, Clone)]
67pub enum InvalidationEvent {
68    /// Invalidate specific tables
69    Tables(Vec<String>),
70
71    /// Invalidate specific cache keys
72    Keys(Vec<CacheKey>),
73
74    /// Invalidate all caches
75    All,
76
77    /// WAL event received
78    WalEvent {
79        table: String,
80        operation: WalOperation,
81        lsn: u64,
82    },
83}
84
85/// WAL operation type
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WalOperation {
88    Insert,
89    Update,
90    Delete,
91    Truncate,
92}
93
94impl InvalidationManager {
95    /// Create a new invalidation manager
96    pub fn new(config: InvalidationConfig) -> Self {
97        let (event_tx, _) = broadcast::channel(1024);
98
99        Self {
100            config,
101            table_keys: DashMap::new(),
102            key_tables: DashMap::new(),
103            last_invalidation: DashMap::new(),
104            event_tx,
105            wal_connected: std::sync::atomic::AtomicBool::new(false),
106            pending_invalidations: RwLock::new(HashSet::new()),
107            last_batch_flush: RwLock::new(Instant::now()),
108        }
109    }
110
111    /// Register a cache key for table-based invalidation
112    pub fn register(&self, key: &CacheKey, table: &str) {
113        // Add to table -> keys mapping
114        self.table_keys
115            .entry(table.to_string())
116            .or_default()
117            .insert(key.clone());
118
119        // Add to key -> tables mapping
120        self.key_tables
121            .entry(key.clone())
122            .or_default()
123            .insert(table.to_string());
124    }
125
126    /// Unregister a cache key
127    pub fn unregister(&self, key: &CacheKey) {
128        if let Some((_, tables)) = self.key_tables.remove(key) {
129            for table in tables {
130                if let Some(mut keys) = self.table_keys.get_mut(&table) {
131                    keys.remove(key);
132                }
133            }
134        }
135    }
136
137    /// Get all cache keys associated with a table
138    pub fn get_keys_for_table(&self, table: &str) -> Vec<CacheKey> {
139        self.table_keys
140            .get(table)
141            .map(|keys| keys.iter().cloned().collect())
142            .unwrap_or_default()
143    }
144
145    /// Get all tables associated with a cache key
146    pub fn get_tables_for_key(&self, key: &CacheKey) -> Vec<String> {
147        self.key_tables
148            .get(key)
149            .map(|tables| tables.iter().cloned().collect())
150            .unwrap_or_default()
151    }
152
153    /// Invalidate all cache entries for a table
154    pub fn invalidate_table(&self, table: &str) {
155        // Record invalidation time
156        self.last_invalidation
157            .insert(table.to_string(), Instant::now());
158
159        // Send event
160        let _ = self
161            .event_tx
162            .send(InvalidationEvent::Tables(vec![table.to_string()]));
163
164        // Clear table -> keys mapping
165        if let Some((_, keys)) = self.table_keys.remove(table) {
166            for key in keys {
167                if let Some(mut tables) = self.key_tables.get_mut(&key) {
168                    tables.remove(table);
169                }
170            }
171        }
172    }
173
174    /// Invalidate multiple tables
175    pub fn invalidate_tables(&self, tables: &[String]) {
176        for table in tables {
177            self.invalidate_table(table);
178        }
179    }
180
181    /// Queue a table for batched invalidation
182    pub fn queue_invalidation(&self, table: &str) {
183        if let Ok(mut pending) = self.pending_invalidations.write() {
184            pending.insert(table.to_string());
185        }
186
187        // Check if we should flush the batch
188        self.maybe_flush_batch();
189    }
190
191    /// Flush pending invalidations
192    pub fn flush_pending(&self) {
193        let tables: Vec<String> = {
194            let mut pending = match self.pending_invalidations.write() {
195                Ok(p) => p,
196                Err(_) => return,
197            };
198
199            let tables: Vec<String> = pending.drain().collect();
200            tables
201        };
202
203        if !tables.is_empty() {
204            self.invalidate_tables(&tables);
205        }
206
207        if let Ok(mut last) = self.last_batch_flush.write() {
208            *last = Instant::now();
209        }
210    }
211
212    /// Check if batch should be flushed
213    fn maybe_flush_batch(&self) {
214        let should_flush = {
215            let last = match self.last_batch_flush.read() {
216                Ok(l) => *l,
217                Err(_) => return,
218            };
219
220            let pending_count = self
221                .pending_invalidations
222                .read()
223                .map(|p| p.len())
224                .unwrap_or(0);
225
226            // Flush if batch is large or time threshold exceeded
227            pending_count >= 100 || last.elapsed() > Duration::from_millis(50)
228        };
229
230        if should_flush {
231            self.flush_pending();
232        }
233    }
234
235    /// Handle a WAL event
236    pub fn on_wal_event(&self, table: &str, operation: WalOperation, lsn: u64) {
237        // Send WAL event
238        let _ = self.event_tx.send(InvalidationEvent::WalEvent {
239            table: table.to_string(),
240            operation,
241            lsn,
242        });
243
244        // Queue for batched invalidation
245        self.queue_invalidation(table);
246    }
247
248    /// Subscribe to invalidation events
249    pub fn subscribe(&self) -> broadcast::Receiver<InvalidationEvent> {
250        self.event_tx.subscribe()
251    }
252
253    /// Check if WAL is connected
254    pub fn is_wal_connected(&self) -> bool {
255        self.wal_connected
256            .load(std::sync::atomic::Ordering::Relaxed)
257    }
258
259    /// Set WAL connection status
260    pub fn set_wal_connected(&self, connected: bool) {
261        self.wal_connected
262            .store(connected, std::sync::atomic::Ordering::Relaxed);
263    }
264
265    /// Get invalidation mode
266    pub fn mode(&self) -> InvalidationMode {
267        self.config.mode
268    }
269
270    /// Get last invalidation time for a table
271    pub fn last_invalidation_time(&self, table: &str) -> Option<Instant> {
272        self.last_invalidation.get(table).map(|t| *t)
273    }
274
275    /// Get statistics
276    pub fn stats(&self) -> InvalidationStats {
277        let total_keys: usize = self.table_keys.iter().map(|e| e.value().len()).sum();
278
279        InvalidationStats {
280            tracked_tables: self.table_keys.len(),
281            tracked_keys: total_keys,
282            pending_invalidations: self
283                .pending_invalidations
284                .read()
285                .map(|p| p.len())
286                .unwrap_or(0),
287            wal_connected: self.is_wal_connected(),
288            mode: self.config.mode,
289        }
290    }
291
292    /// Clear all tracking data
293    pub fn clear(&self) {
294        self.table_keys.clear();
295        self.key_tables.clear();
296        self.last_invalidation.clear();
297
298        if let Ok(mut pending) = self.pending_invalidations.write() {
299            pending.clear();
300        }
301    }
302}
303
304/// Invalidation statistics
305#[derive(Debug, Clone)]
306pub struct InvalidationStats {
307    /// Number of tracked tables
308    pub tracked_tables: usize,
309
310    /// Number of tracked cache keys
311    pub tracked_keys: usize,
312
313    /// Number of pending invalidations
314    pub pending_invalidations: usize,
315
316    /// Whether WAL is connected
317    pub wal_connected: bool,
318
319    /// Current invalidation mode
320    pub mode: InvalidationMode,
321}
322
323/// WAL event parser
324pub struct WalEventParser;
325
326impl WalEventParser {
327    /// Parse a WAL message into an invalidation event
328    pub fn parse(message: &[u8]) -> Option<(String, WalOperation, u64)> {
329        // Simple format: "OP:TABLE:LSN"
330        let text = std::str::from_utf8(message).ok()?;
331        let parts: Vec<&str> = text.split(':').collect();
332
333        if parts.len() < 3 {
334            return None;
335        }
336
337        let operation = match parts[0].to_uppercase().as_str() {
338            "I" | "INSERT" => WalOperation::Insert,
339            "U" | "UPDATE" => WalOperation::Update,
340            "D" | "DELETE" => WalOperation::Delete,
341            "T" | "TRUNCATE" => WalOperation::Truncate,
342            _ => return None,
343        };
344
345        let table = parts[1].to_string();
346        let lsn = parts[2].parse().ok()?;
347
348        Some((table, operation, lsn))
349    }
350
351    /// Extract affected tables from SQL
352    pub fn extract_affected_tables(sql: &str) -> Vec<String> {
353        let sql_upper = sql.to_uppercase();
354        let mut tables = Vec::new();
355
356        // Simple extraction (more sophisticated parsing would use sqlparser)
357        let patterns = [
358            (r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
359            (r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
360            (r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
361            (r"TRUNCATE\s+(?:TABLE\s+)?([a-zA-Z_][a-zA-Z0-9_]*)", 1),
362        ];
363
364        for (pattern, group) in patterns {
365            if let Ok(re) = regex::Regex::new(pattern) {
366                for cap in re.captures_iter(&sql_upper) {
367                    if let Some(m) = cap.get(group) {
368                        tables.push(m.as_str().to_lowercase());
369                    }
370                }
371            }
372        }
373
374        tables
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    fn create_key(hash: u64) -> CacheKey {
383        CacheKey::from_parts(hash, "test".to_string(), None, None)
384    }
385
386    #[test]
387    fn test_register_and_lookup() {
388        let config = InvalidationConfig::default();
389        let manager = InvalidationManager::new(config);
390
391        let key1 = create_key(111);
392        let key2 = create_key(222);
393
394        manager.register(&key1, "users");
395        manager.register(&key2, "users");
396        manager.register(&key1, "sessions");
397
398        // Check table -> keys
399        let user_keys = manager.get_keys_for_table("users");
400        assert_eq!(user_keys.len(), 2);
401        assert!(user_keys.contains(&key1));
402        assert!(user_keys.contains(&key2));
403
404        // Check key -> tables
405        let key1_tables = manager.get_tables_for_key(&key1);
406        assert_eq!(key1_tables.len(), 2);
407        assert!(key1_tables.contains(&"users".to_string()));
408        assert!(key1_tables.contains(&"sessions".to_string()));
409    }
410
411    #[test]
412    fn test_unregister() {
413        let config = InvalidationConfig::default();
414        let manager = InvalidationManager::new(config);
415
416        let key = create_key(111);
417        manager.register(&key, "users");
418        manager.register(&key, "sessions");
419
420        manager.unregister(&key);
421
422        assert!(manager.get_keys_for_table("users").is_empty());
423        assert!(manager.get_keys_for_table("sessions").is_empty());
424        assert!(manager.get_tables_for_key(&key).is_empty());
425    }
426
427    #[test]
428    fn test_invalidate_table() {
429        let config = InvalidationConfig::default();
430        let manager = InvalidationManager::new(config);
431
432        let key1 = create_key(111);
433        let key2 = create_key(222);
434
435        manager.register(&key1, "users");
436        manager.register(&key2, "orders");
437
438        manager.invalidate_table("users");
439
440        assert!(manager.get_keys_for_table("users").is_empty());
441        assert!(!manager.get_keys_for_table("orders").is_empty());
442        assert!(manager.last_invalidation_time("users").is_some());
443    }
444
445    #[test]
446    fn test_queue_and_flush() {
447        let config = InvalidationConfig::default();
448        let manager = InvalidationManager::new(config);
449
450        let key = create_key(111);
451        manager.register(&key, "users");
452
453        manager.queue_invalidation("users");
454
455        {
456            let pending = manager.pending_invalidations.read().unwrap();
457            assert!(pending.contains("users"));
458        }
459
460        manager.flush_pending();
461
462        {
463            let pending = manager.pending_invalidations.read().unwrap();
464            assert!(pending.is_empty());
465        }
466
467        assert!(manager.get_keys_for_table("users").is_empty());
468    }
469
470    #[test]
471    fn test_wal_event() {
472        let config = InvalidationConfig::default();
473        let manager = InvalidationManager::new(config);
474
475        let key = create_key(111);
476        manager.register(&key, "users");
477
478        let mut receiver = manager.subscribe();
479
480        manager.on_wal_event("users", WalOperation::Update, 12345);
481
482        // Flush to process the event
483        manager.flush_pending();
484
485        // Should have received the WAL event
486        let event = receiver.try_recv();
487        assert!(event.is_ok());
488    }
489
490    #[test]
491    fn test_stats() {
492        let config = InvalidationConfig::default();
493        let manager = InvalidationManager::new(config);
494
495        let key1 = create_key(111);
496        let key2 = create_key(222);
497
498        manager.register(&key1, "users");
499        manager.register(&key2, "orders");
500
501        let stats = manager.stats();
502        assert_eq!(stats.tracked_tables, 2);
503        assert_eq!(stats.tracked_keys, 2);
504    }
505
506    #[test]
507    fn test_wal_event_parser() {
508        // Test message parsing
509        let (table, op, lsn) = WalEventParser::parse(b"INSERT:users:12345").unwrap();
510        assert_eq!(table, "users");
511        assert_eq!(op, WalOperation::Insert);
512        assert_eq!(lsn, 12345);
513
514        let (table, op, _) = WalEventParser::parse(b"U:orders:67890").unwrap();
515        assert_eq!(table, "orders");
516        assert_eq!(op, WalOperation::Update);
517    }
518
519    #[test]
520    fn test_extract_affected_tables() {
521        let tests = vec![
522            ("INSERT INTO users VALUES (1)", vec!["users"]),
523            ("UPDATE orders SET status = 'done'", vec!["orders"]),
524            ("DELETE FROM sessions WHERE expired", vec!["sessions"]),
525            ("TRUNCATE TABLE logs", vec!["logs"]),
526            ("TRUNCATE products", vec!["products"]),
527        ];
528
529        for (sql, expected) in tests {
530            let tables = WalEventParser::extract_affected_tables(sql);
531            assert_eq!(tables, expected, "Failed for SQL: {}", sql);
532        }
533    }
534
535    #[test]
536    fn test_clear() {
537        let config = InvalidationConfig::default();
538        let manager = InvalidationManager::new(config);
539
540        manager.register(&create_key(111), "users");
541        manager.queue_invalidation("users");
542
543        manager.clear();
544
545        assert_eq!(manager.stats().tracked_tables, 0);
546        assert_eq!(manager.stats().tracked_keys, 0);
547        assert_eq!(manager.stats().pending_invalidations, 0);
548    }
549}