1use std::collections::HashSet;
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use tokio::sync::broadcast;
14
15use super::config::InvalidationConfig;
16use super::result::CacheKey;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum InvalidationMode {
21 #[default]
23 Wal,
24
25 TtlOnly,
27
28 ManualOnly,
30
31 WalWithTtlFallback,
33}
34
35#[derive(Debug)]
39pub struct InvalidationManager {
40 config: InvalidationConfig,
42
43 table_keys: DashMap<String, HashSet<CacheKey>>,
45
46 key_tables: DashMap<CacheKey, HashSet<String>>,
48
49 last_invalidation: DashMap<String, Instant>,
51
52 event_tx: broadcast::Sender<InvalidationEvent>,
54
55 wal_connected: std::sync::atomic::AtomicBool,
57
58 pending_invalidations: RwLock<HashSet<String>>,
60
61 last_batch_flush: RwLock<Instant>,
63}
64
65#[derive(Debug, Clone)]
67pub enum InvalidationEvent {
68 Tables(Vec<String>),
70
71 Keys(Vec<CacheKey>),
73
74 All,
76
77 WalEvent {
79 table: String,
80 operation: WalOperation,
81 lsn: u64,
82 },
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WalOperation {
88 Insert,
89 Update,
90 Delete,
91 Truncate,
92}
93
94impl InvalidationManager {
95 pub fn new(config: InvalidationConfig) -> Self {
97 let (event_tx, _) = broadcast::channel(1024);
98
99 Self {
100 config,
101 table_keys: DashMap::new(),
102 key_tables: DashMap::new(),
103 last_invalidation: DashMap::new(),
104 event_tx,
105 wal_connected: std::sync::atomic::AtomicBool::new(false),
106 pending_invalidations: RwLock::new(HashSet::new()),
107 last_batch_flush: RwLock::new(Instant::now()),
108 }
109 }
110
111 pub fn register(&self, key: &CacheKey, table: &str) {
113 self.table_keys
115 .entry(table.to_string())
116 .or_insert_with(HashSet::new)
117 .insert(key.clone());
118
119 self.key_tables
121 .entry(key.clone())
122 .or_insert_with(HashSet::new)
123 .insert(table.to_string());
124 }
125
126 pub fn unregister(&self, key: &CacheKey) {
128 if let Some((_, tables)) = self.key_tables.remove(key) {
129 for table in tables {
130 if let Some(mut keys) = self.table_keys.get_mut(&table) {
131 keys.remove(key);
132 }
133 }
134 }
135 }
136
137 pub fn get_keys_for_table(&self, table: &str) -> Vec<CacheKey> {
139 self.table_keys
140 .get(table)
141 .map(|keys| keys.iter().cloned().collect())
142 .unwrap_or_default()
143 }
144
145 pub fn get_tables_for_key(&self, key: &CacheKey) -> Vec<String> {
147 self.key_tables
148 .get(key)
149 .map(|tables| tables.iter().cloned().collect())
150 .unwrap_or_default()
151 }
152
153 pub fn invalidate_table(&self, table: &str) {
155 self.last_invalidation.insert(table.to_string(), Instant::now());
157
158 let _ = self.event_tx.send(InvalidationEvent::Tables(vec![table.to_string()]));
160
161 if let Some((_, keys)) = self.table_keys.remove(table) {
163 for key in keys {
164 if let Some(mut tables) = self.key_tables.get_mut(&key) {
165 tables.remove(table);
166 }
167 }
168 }
169 }
170
171 pub fn invalidate_tables(&self, tables: &[String]) {
173 for table in tables {
174 self.invalidate_table(table);
175 }
176 }
177
178 pub fn queue_invalidation(&self, table: &str) {
180 if let Ok(mut pending) = self.pending_invalidations.write() {
181 pending.insert(table.to_string());
182 }
183
184 self.maybe_flush_batch();
186 }
187
188 pub fn flush_pending(&self) {
190 let tables: Vec<String> = {
191 let mut pending = match self.pending_invalidations.write() {
192 Ok(p) => p,
193 Err(_) => return,
194 };
195
196 let tables: Vec<String> = pending.drain().collect();
197 tables
198 };
199
200 if !tables.is_empty() {
201 self.invalidate_tables(&tables);
202 }
203
204 if let Ok(mut last) = self.last_batch_flush.write() {
205 *last = Instant::now();
206 }
207 }
208
209 fn maybe_flush_batch(&self) {
211 let should_flush = {
212 let last = match self.last_batch_flush.read() {
213 Ok(l) => *l,
214 Err(_) => return,
215 };
216
217 let pending_count = self.pending_invalidations
218 .read()
219 .map(|p| p.len())
220 .unwrap_or(0);
221
222 pending_count >= 100 || last.elapsed() > Duration::from_millis(50)
224 };
225
226 if should_flush {
227 self.flush_pending();
228 }
229 }
230
231 pub fn on_wal_event(&self, table: &str, operation: WalOperation, lsn: u64) {
233 let _ = self.event_tx.send(InvalidationEvent::WalEvent {
235 table: table.to_string(),
236 operation,
237 lsn,
238 });
239
240 self.queue_invalidation(table);
242 }
243
244 pub fn subscribe(&self) -> broadcast::Receiver<InvalidationEvent> {
246 self.event_tx.subscribe()
247 }
248
249 pub fn is_wal_connected(&self) -> bool {
251 self.wal_connected.load(std::sync::atomic::Ordering::Relaxed)
252 }
253
254 pub fn set_wal_connected(&self, connected: bool) {
256 self.wal_connected.store(connected, std::sync::atomic::Ordering::Relaxed);
257 }
258
259 pub fn mode(&self) -> InvalidationMode {
261 self.config.mode
262 }
263
264 pub fn last_invalidation_time(&self, table: &str) -> Option<Instant> {
266 self.last_invalidation.get(table).map(|t| *t)
267 }
268
269 pub fn stats(&self) -> InvalidationStats {
271 let total_keys: usize = self.table_keys
272 .iter()
273 .map(|e| e.value().len())
274 .sum();
275
276 InvalidationStats {
277 tracked_tables: self.table_keys.len(),
278 tracked_keys: total_keys,
279 pending_invalidations: self.pending_invalidations
280 .read()
281 .map(|p| p.len())
282 .unwrap_or(0),
283 wal_connected: self.is_wal_connected(),
284 mode: self.config.mode,
285 }
286 }
287
288 pub fn clear(&self) {
290 self.table_keys.clear();
291 self.key_tables.clear();
292 self.last_invalidation.clear();
293
294 if let Ok(mut pending) = self.pending_invalidations.write() {
295 pending.clear();
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct InvalidationStats {
303 pub tracked_tables: usize,
305
306 pub tracked_keys: usize,
308
309 pub pending_invalidations: usize,
311
312 pub wal_connected: bool,
314
315 pub mode: InvalidationMode,
317}
318
319pub struct WalEventParser;
321
322impl WalEventParser {
323 pub fn parse(message: &[u8]) -> Option<(String, WalOperation, u64)> {
325 let text = std::str::from_utf8(message).ok()?;
327 let parts: Vec<&str> = text.split(':').collect();
328
329 if parts.len() < 3 {
330 return None;
331 }
332
333 let operation = match parts[0].to_uppercase().as_str() {
334 "I" | "INSERT" => WalOperation::Insert,
335 "U" | "UPDATE" => WalOperation::Update,
336 "D" | "DELETE" => WalOperation::Delete,
337 "T" | "TRUNCATE" => WalOperation::Truncate,
338 _ => return None,
339 };
340
341 let table = parts[1].to_string();
342 let lsn = parts[2].parse().ok()?;
343
344 Some((table, operation, lsn))
345 }
346
347 pub fn extract_affected_tables(sql: &str) -> Vec<String> {
349 let sql_upper = sql.to_uppercase();
350 let mut tables = Vec::new();
351
352 let patterns = [
354 (r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
355 (r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
356 (r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
357 (r"TRUNCATE\s+(?:TABLE\s+)?([a-zA-Z_][a-zA-Z0-9_]*)", 1),
358 ];
359
360 for (pattern, group) in patterns {
361 if let Ok(re) = regex::Regex::new(pattern) {
362 for cap in re.captures_iter(&sql_upper) {
363 if let Some(m) = cap.get(group) {
364 tables.push(m.as_str().to_lowercase());
365 }
366 }
367 }
368 }
369
370 tables
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 fn create_key(hash: u64) -> CacheKey {
379 CacheKey::from_parts(hash, "test".to_string(), None, None)
380 }
381
382 #[test]
383 fn test_register_and_lookup() {
384 let config = InvalidationConfig::default();
385 let manager = InvalidationManager::new(config);
386
387 let key1 = create_key(111);
388 let key2 = create_key(222);
389
390 manager.register(&key1, "users");
391 manager.register(&key2, "users");
392 manager.register(&key1, "sessions");
393
394 let user_keys = manager.get_keys_for_table("users");
396 assert_eq!(user_keys.len(), 2);
397 assert!(user_keys.contains(&key1));
398 assert!(user_keys.contains(&key2));
399
400 let key1_tables = manager.get_tables_for_key(&key1);
402 assert_eq!(key1_tables.len(), 2);
403 assert!(key1_tables.contains(&"users".to_string()));
404 assert!(key1_tables.contains(&"sessions".to_string()));
405 }
406
407 #[test]
408 fn test_unregister() {
409 let config = InvalidationConfig::default();
410 let manager = InvalidationManager::new(config);
411
412 let key = create_key(111);
413 manager.register(&key, "users");
414 manager.register(&key, "sessions");
415
416 manager.unregister(&key);
417
418 assert!(manager.get_keys_for_table("users").is_empty());
419 assert!(manager.get_keys_for_table("sessions").is_empty());
420 assert!(manager.get_tables_for_key(&key).is_empty());
421 }
422
423 #[test]
424 fn test_invalidate_table() {
425 let config = InvalidationConfig::default();
426 let manager = InvalidationManager::new(config);
427
428 let key1 = create_key(111);
429 let key2 = create_key(222);
430
431 manager.register(&key1, "users");
432 manager.register(&key2, "orders");
433
434 manager.invalidate_table("users");
435
436 assert!(manager.get_keys_for_table("users").is_empty());
437 assert!(!manager.get_keys_for_table("orders").is_empty());
438 assert!(manager.last_invalidation_time("users").is_some());
439 }
440
441 #[test]
442 fn test_queue_and_flush() {
443 let config = InvalidationConfig::default();
444 let manager = InvalidationManager::new(config);
445
446 let key = create_key(111);
447 manager.register(&key, "users");
448
449 manager.queue_invalidation("users");
450
451 {
452 let pending = manager.pending_invalidations.read().unwrap();
453 assert!(pending.contains("users"));
454 }
455
456 manager.flush_pending();
457
458 {
459 let pending = manager.pending_invalidations.read().unwrap();
460 assert!(pending.is_empty());
461 }
462
463 assert!(manager.get_keys_for_table("users").is_empty());
464 }
465
466 #[test]
467 fn test_wal_event() {
468 let config = InvalidationConfig::default();
469 let manager = InvalidationManager::new(config);
470
471 let key = create_key(111);
472 manager.register(&key, "users");
473
474 let mut receiver = manager.subscribe();
475
476 manager.on_wal_event("users", WalOperation::Update, 12345);
477
478 manager.flush_pending();
480
481 let event = receiver.try_recv();
483 assert!(event.is_ok());
484 }
485
486 #[test]
487 fn test_stats() {
488 let config = InvalidationConfig::default();
489 let manager = InvalidationManager::new(config);
490
491 let key1 = create_key(111);
492 let key2 = create_key(222);
493
494 manager.register(&key1, "users");
495 manager.register(&key2, "orders");
496
497 let stats = manager.stats();
498 assert_eq!(stats.tracked_tables, 2);
499 assert_eq!(stats.tracked_keys, 2);
500 }
501
502 #[test]
503 fn test_wal_event_parser() {
504 let (table, op, lsn) = WalEventParser::parse(b"INSERT:users:12345").unwrap();
506 assert_eq!(table, "users");
507 assert_eq!(op, WalOperation::Insert);
508 assert_eq!(lsn, 12345);
509
510 let (table, op, _) = WalEventParser::parse(b"U:orders:67890").unwrap();
511 assert_eq!(table, "orders");
512 assert_eq!(op, WalOperation::Update);
513 }
514
515 #[test]
516 fn test_extract_affected_tables() {
517 let tests = vec![
518 ("INSERT INTO users VALUES (1)", vec!["users"]),
519 ("UPDATE orders SET status = 'done'", vec!["orders"]),
520 ("DELETE FROM sessions WHERE expired", vec!["sessions"]),
521 ("TRUNCATE TABLE logs", vec!["logs"]),
522 ("TRUNCATE products", vec!["products"]),
523 ];
524
525 for (sql, expected) in tests {
526 let tables = WalEventParser::extract_affected_tables(sql);
527 assert_eq!(tables, expected, "Failed for SQL: {}", sql);
528 }
529 }
530
531 #[test]
532 fn test_clear() {
533 let config = InvalidationConfig::default();
534 let manager = InvalidationManager::new(config);
535
536 manager.register(&create_key(111), "users");
537 manager.queue_invalidation("users");
538
539 manager.clear();
540
541 assert_eq!(manager.stats().tracked_tables, 0);
542 assert_eq!(manager.stats().tracked_keys, 0);
543 assert_eq!(manager.stats().pending_invalidations, 0);
544 }
545}