1use std::collections::HashSet;
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use tokio::sync::broadcast;
14
15use super::config::InvalidationConfig;
16use super::result::CacheKey;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum InvalidationMode {
21 #[default]
23 Wal,
24
25 TtlOnly,
27
28 ManualOnly,
30
31 WalWithTtlFallback,
33}
34
35#[derive(Debug)]
39pub struct InvalidationManager {
40 config: InvalidationConfig,
42
43 table_keys: DashMap<String, HashSet<CacheKey>>,
45
46 key_tables: DashMap<CacheKey, HashSet<String>>,
48
49 last_invalidation: DashMap<String, Instant>,
51
52 event_tx: broadcast::Sender<InvalidationEvent>,
54
55 wal_connected: std::sync::atomic::AtomicBool,
57
58 pending_invalidations: RwLock<HashSet<String>>,
60
61 last_batch_flush: RwLock<Instant>,
63}
64
65#[derive(Debug, Clone)]
67pub enum InvalidationEvent {
68 Tables(Vec<String>),
70
71 Keys(Vec<CacheKey>),
73
74 All,
76
77 WalEvent {
79 table: String,
80 operation: WalOperation,
81 lsn: u64,
82 },
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WalOperation {
88 Insert,
89 Update,
90 Delete,
91 Truncate,
92}
93
94impl InvalidationManager {
95 pub fn new(config: InvalidationConfig) -> Self {
97 let (event_tx, _) = broadcast::channel(1024);
98
99 Self {
100 config,
101 table_keys: DashMap::new(),
102 key_tables: DashMap::new(),
103 last_invalidation: DashMap::new(),
104 event_tx,
105 wal_connected: std::sync::atomic::AtomicBool::new(false),
106 pending_invalidations: RwLock::new(HashSet::new()),
107 last_batch_flush: RwLock::new(Instant::now()),
108 }
109 }
110
111 pub fn register(&self, key: &CacheKey, table: &str) {
113 self.table_keys
115 .entry(table.to_string())
116 .or_default()
117 .insert(key.clone());
118
119 self.key_tables
121 .entry(key.clone())
122 .or_default()
123 .insert(table.to_string());
124 }
125
126 pub fn unregister(&self, key: &CacheKey) {
128 if let Some((_, tables)) = self.key_tables.remove(key) {
129 for table in tables {
130 if let Some(mut keys) = self.table_keys.get_mut(&table) {
131 keys.remove(key);
132 }
133 }
134 }
135 }
136
137 pub fn get_keys_for_table(&self, table: &str) -> Vec<CacheKey> {
139 self.table_keys
140 .get(table)
141 .map(|keys| keys.iter().cloned().collect())
142 .unwrap_or_default()
143 }
144
145 pub fn get_tables_for_key(&self, key: &CacheKey) -> Vec<String> {
147 self.key_tables
148 .get(key)
149 .map(|tables| tables.iter().cloned().collect())
150 .unwrap_or_default()
151 }
152
153 pub fn invalidate_table(&self, table: &str) {
155 self.last_invalidation
157 .insert(table.to_string(), Instant::now());
158
159 let _ = self
161 .event_tx
162 .send(InvalidationEvent::Tables(vec![table.to_string()]));
163
164 if let Some((_, keys)) = self.table_keys.remove(table) {
166 for key in keys {
167 if let Some(mut tables) = self.key_tables.get_mut(&key) {
168 tables.remove(table);
169 }
170 }
171 }
172 }
173
174 pub fn invalidate_tables(&self, tables: &[String]) {
176 for table in tables {
177 self.invalidate_table(table);
178 }
179 }
180
181 pub fn queue_invalidation(&self, table: &str) {
183 if let Ok(mut pending) = self.pending_invalidations.write() {
184 pending.insert(table.to_string());
185 }
186
187 self.maybe_flush_batch();
189 }
190
191 pub fn flush_pending(&self) {
193 let tables: Vec<String> = {
194 let mut pending = match self.pending_invalidations.write() {
195 Ok(p) => p,
196 Err(_) => return,
197 };
198
199 let tables: Vec<String> = pending.drain().collect();
200 tables
201 };
202
203 if !tables.is_empty() {
204 self.invalidate_tables(&tables);
205 }
206
207 if let Ok(mut last) = self.last_batch_flush.write() {
208 *last = Instant::now();
209 }
210 }
211
212 fn maybe_flush_batch(&self) {
214 let should_flush = {
215 let last = match self.last_batch_flush.read() {
216 Ok(l) => *l,
217 Err(_) => return,
218 };
219
220 let pending_count = self
221 .pending_invalidations
222 .read()
223 .map(|p| p.len())
224 .unwrap_or(0);
225
226 pending_count >= 100 || last.elapsed() > Duration::from_millis(50)
228 };
229
230 if should_flush {
231 self.flush_pending();
232 }
233 }
234
235 pub fn on_wal_event(&self, table: &str, operation: WalOperation, lsn: u64) {
237 let _ = self.event_tx.send(InvalidationEvent::WalEvent {
239 table: table.to_string(),
240 operation,
241 lsn,
242 });
243
244 self.queue_invalidation(table);
246 }
247
248 pub fn subscribe(&self) -> broadcast::Receiver<InvalidationEvent> {
250 self.event_tx.subscribe()
251 }
252
253 pub fn is_wal_connected(&self) -> bool {
255 self.wal_connected
256 .load(std::sync::atomic::Ordering::Relaxed)
257 }
258
259 pub fn set_wal_connected(&self, connected: bool) {
261 self.wal_connected
262 .store(connected, std::sync::atomic::Ordering::Relaxed);
263 }
264
265 pub fn mode(&self) -> InvalidationMode {
267 self.config.mode
268 }
269
270 pub fn last_invalidation_time(&self, table: &str) -> Option<Instant> {
272 self.last_invalidation.get(table).map(|t| *t)
273 }
274
275 pub fn stats(&self) -> InvalidationStats {
277 let total_keys: usize = self.table_keys.iter().map(|e| e.value().len()).sum();
278
279 InvalidationStats {
280 tracked_tables: self.table_keys.len(),
281 tracked_keys: total_keys,
282 pending_invalidations: self
283 .pending_invalidations
284 .read()
285 .map(|p| p.len())
286 .unwrap_or(0),
287 wal_connected: self.is_wal_connected(),
288 mode: self.config.mode,
289 }
290 }
291
292 pub fn clear(&self) {
294 self.table_keys.clear();
295 self.key_tables.clear();
296 self.last_invalidation.clear();
297
298 if let Ok(mut pending) = self.pending_invalidations.write() {
299 pending.clear();
300 }
301 }
302}
303
304#[derive(Debug, Clone)]
306pub struct InvalidationStats {
307 pub tracked_tables: usize,
309
310 pub tracked_keys: usize,
312
313 pub pending_invalidations: usize,
315
316 pub wal_connected: bool,
318
319 pub mode: InvalidationMode,
321}
322
323pub struct WalEventParser;
325
326impl WalEventParser {
327 pub fn parse(message: &[u8]) -> Option<(String, WalOperation, u64)> {
329 let text = std::str::from_utf8(message).ok()?;
331 let parts: Vec<&str> = text.split(':').collect();
332
333 if parts.len() < 3 {
334 return None;
335 }
336
337 let operation = match parts[0].to_uppercase().as_str() {
338 "I" | "INSERT" => WalOperation::Insert,
339 "U" | "UPDATE" => WalOperation::Update,
340 "D" | "DELETE" => WalOperation::Delete,
341 "T" | "TRUNCATE" => WalOperation::Truncate,
342 _ => return None,
343 };
344
345 let table = parts[1].to_string();
346 let lsn = parts[2].parse().ok()?;
347
348 Some((table, operation, lsn))
349 }
350
351 pub fn extract_affected_tables(sql: &str) -> Vec<String> {
353 let sql_upper = sql.to_uppercase();
354 let mut tables = Vec::new();
355
356 let patterns = [
358 (r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
359 (r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
360 (r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", 1),
361 (r"TRUNCATE\s+(?:TABLE\s+)?([a-zA-Z_][a-zA-Z0-9_]*)", 1),
362 ];
363
364 for (pattern, group) in patterns {
365 if let Ok(re) = regex::Regex::new(pattern) {
366 for cap in re.captures_iter(&sql_upper) {
367 if let Some(m) = cap.get(group) {
368 tables.push(m.as_str().to_lowercase());
369 }
370 }
371 }
372 }
373
374 tables
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 fn create_key(hash: u64) -> CacheKey {
383 CacheKey::from_parts(hash, "test".to_string(), None, None)
384 }
385
386 #[test]
387 fn test_register_and_lookup() {
388 let config = InvalidationConfig::default();
389 let manager = InvalidationManager::new(config);
390
391 let key1 = create_key(111);
392 let key2 = create_key(222);
393
394 manager.register(&key1, "users");
395 manager.register(&key2, "users");
396 manager.register(&key1, "sessions");
397
398 let user_keys = manager.get_keys_for_table("users");
400 assert_eq!(user_keys.len(), 2);
401 assert!(user_keys.contains(&key1));
402 assert!(user_keys.contains(&key2));
403
404 let key1_tables = manager.get_tables_for_key(&key1);
406 assert_eq!(key1_tables.len(), 2);
407 assert!(key1_tables.contains(&"users".to_string()));
408 assert!(key1_tables.contains(&"sessions".to_string()));
409 }
410
411 #[test]
412 fn test_unregister() {
413 let config = InvalidationConfig::default();
414 let manager = InvalidationManager::new(config);
415
416 let key = create_key(111);
417 manager.register(&key, "users");
418 manager.register(&key, "sessions");
419
420 manager.unregister(&key);
421
422 assert!(manager.get_keys_for_table("users").is_empty());
423 assert!(manager.get_keys_for_table("sessions").is_empty());
424 assert!(manager.get_tables_for_key(&key).is_empty());
425 }
426
427 #[test]
428 fn test_invalidate_table() {
429 let config = InvalidationConfig::default();
430 let manager = InvalidationManager::new(config);
431
432 let key1 = create_key(111);
433 let key2 = create_key(222);
434
435 manager.register(&key1, "users");
436 manager.register(&key2, "orders");
437
438 manager.invalidate_table("users");
439
440 assert!(manager.get_keys_for_table("users").is_empty());
441 assert!(!manager.get_keys_for_table("orders").is_empty());
442 assert!(manager.last_invalidation_time("users").is_some());
443 }
444
445 #[test]
446 fn test_queue_and_flush() {
447 let config = InvalidationConfig::default();
448 let manager = InvalidationManager::new(config);
449
450 let key = create_key(111);
451 manager.register(&key, "users");
452
453 manager.queue_invalidation("users");
454
455 {
456 let pending = manager.pending_invalidations.read().unwrap();
457 assert!(pending.contains("users"));
458 }
459
460 manager.flush_pending();
461
462 {
463 let pending = manager.pending_invalidations.read().unwrap();
464 assert!(pending.is_empty());
465 }
466
467 assert!(manager.get_keys_for_table("users").is_empty());
468 }
469
470 #[test]
471 fn test_wal_event() {
472 let config = InvalidationConfig::default();
473 let manager = InvalidationManager::new(config);
474
475 let key = create_key(111);
476 manager.register(&key, "users");
477
478 let mut receiver = manager.subscribe();
479
480 manager.on_wal_event("users", WalOperation::Update, 12345);
481
482 manager.flush_pending();
484
485 let event = receiver.try_recv();
487 assert!(event.is_ok());
488 }
489
490 #[test]
491 fn test_stats() {
492 let config = InvalidationConfig::default();
493 let manager = InvalidationManager::new(config);
494
495 let key1 = create_key(111);
496 let key2 = create_key(222);
497
498 manager.register(&key1, "users");
499 manager.register(&key2, "orders");
500
501 let stats = manager.stats();
502 assert_eq!(stats.tracked_tables, 2);
503 assert_eq!(stats.tracked_keys, 2);
504 }
505
506 #[test]
507 fn test_wal_event_parser() {
508 let (table, op, lsn) = WalEventParser::parse(b"INSERT:users:12345").unwrap();
510 assert_eq!(table, "users");
511 assert_eq!(op, WalOperation::Insert);
512 assert_eq!(lsn, 12345);
513
514 let (table, op, _) = WalEventParser::parse(b"U:orders:67890").unwrap();
515 assert_eq!(table, "orders");
516 assert_eq!(op, WalOperation::Update);
517 }
518
519 #[test]
520 fn test_extract_affected_tables() {
521 let tests = vec![
522 ("INSERT INTO users VALUES (1)", vec!["users"]),
523 ("UPDATE orders SET status = 'done'", vec!["orders"]),
524 ("DELETE FROM sessions WHERE expired", vec!["sessions"]),
525 ("TRUNCATE TABLE logs", vec!["logs"]),
526 ("TRUNCATE products", vec!["products"]),
527 ];
528
529 for (sql, expected) in tests {
530 let tables = WalEventParser::extract_affected_tables(sql);
531 assert_eq!(tables, expected, "Failed for SQL: {}", sql);
532 }
533 }
534
535 #[test]
536 fn test_clear() {
537 let config = InvalidationConfig::default();
538 let manager = InvalidationManager::new(config);
539
540 manager.register(&create_key(111), "users");
541 manager.queue_invalidation("users");
542
543 manager.clear();
544
545 assert_eq!(manager.stats().tracked_tables, 0);
546 assert_eq!(manager.stats().tracked_keys, 0);
547 assert_eq!(manager.stats().pending_invalidations, 0);
548 }
549}