Skip to main content

heliosdb_proxy/distribcache/
invalidator.rs

1//! WAL-based cache invalidator
2//!
3//! Subscribes to WAL stream for real-time cache coherency.
4//! Invalidates cached entries when underlying data changes.
5//!
6//! # Protocol
7//!
8//! The WAL streaming protocol uses TCP with the following message format:
9//! ```text
10//! [1 byte: message type][4 bytes: payload length][payload]
11//! ```
12//!
13//! Message types:
14//! - 0x01: WAL entry
15//! - 0x02: Heartbeat
16//! - 0x03: Subscription request
17//! - 0x04: Subscription ack
18
19use dashmap::DashMap;
20use std::collections::HashSet;
21use std::io::{Read, Write};
22use std::net::TcpStream;
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use super::{CacheError, CacheResult, DistribCacheConfig, QueryFingerprint};
28
29/// WAL protocol message types
30#[repr(u8)]
31#[derive(Debug, Clone, Copy, PartialEq)]
32enum WalMessageType {
33    Entry = 0x01,
34    Heartbeat = 0x02,
35    Subscribe = 0x03,
36    SubscribeAck = 0x04,
37    Error = 0xFF,
38}
39
40/// WAL operation types
41#[derive(Debug, Clone)]
42pub enum WalOperation {
43    /// Insert/Update operation
44    Put { key: Vec<u8>, value: Vec<u8> },
45    /// Delete operation
46    Delete { key: Vec<u8> },
47    /// Counter update
48    UpdateCounter { table_name: String, counter: u64 },
49    /// Schema change
50    SchemaChange { table_name: String },
51    /// Transaction commit
52    Commit { txn_id: u64 },
53}
54
55/// WAL entry
56#[derive(Debug, Clone)]
57pub struct WalEntry {
58    /// Log sequence number
59    pub lsn: u64,
60    /// Timestamp
61    pub timestamp: u64,
62    /// Operation
63    pub operation: WalOperation,
64}
65
66/// WAL stream subscriber with TCP connection
67pub struct WalStreamer {
68    /// Endpoint address (host:port)
69    #[allow(dead_code)]
70    endpoint: String,
71    /// TCP connection to WAL server
72    connection: Option<TcpStream>,
73    /// Running flag
74    running: Arc<AtomicBool>,
75    /// Current LSN
76    #[allow(dead_code)]
77    current_lsn: AtomicU64,
78    /// Last heartbeat received
79    last_heartbeat: Instant,
80    /// Reconnection attempts
81    #[allow(dead_code)]
82    reconnect_attempts: u32,
83    /// Maximum reconnection attempts
84    #[allow(dead_code)]
85    max_reconnect_attempts: u32,
86    /// Reconnect delay
87    #[allow(dead_code)]
88    reconnect_delay: Duration,
89}
90
91impl WalStreamer {
92    fn new(endpoint: &str) -> Self {
93        Self {
94            endpoint: endpoint.to_string(),
95            connection: None,
96            running: Arc::new(AtomicBool::new(false)),
97            current_lsn: AtomicU64::new(0),
98            last_heartbeat: Instant::now(),
99            reconnect_attempts: 0,
100            max_reconnect_attempts: 10,
101            reconnect_delay: Duration::from_secs(1),
102        }
103    }
104
105    /// Connect to the WAL streaming endpoint
106    async fn connect(endpoint: &str) -> CacheResult<Self> {
107        let mut streamer = Self::new(endpoint);
108
109        // Attempt TCP connection
110        match TcpStream::connect_timeout(
111            &endpoint
112                .parse()
113                .map_err(|_| CacheError::ConnectionError("Invalid endpoint address".to_string()))?,
114            Duration::from_secs(5),
115        ) {
116            Ok(stream) => {
117                // Set TCP options
118                stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
119                stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
120                stream.set_nodelay(true).ok();
121
122                streamer.connection = Some(stream);
123                streamer.last_heartbeat = Instant::now();
124                Ok(streamer)
125            }
126            Err(e) => {
127                // Return a disconnected streamer - can be reconnected later
128                tracing::warn!("Failed to connect to WAL endpoint {}: {}", endpoint, e);
129                Ok(streamer)
130            }
131        }
132    }
133
134    /// Send subscription request and start receiving WAL entries
135    async fn subscribe(&mut self, start_lsn: Option<u64>) -> CacheResult<WalSubscription> {
136        self.running.store(true, Ordering::SeqCst);
137
138        if let Some(ref mut stream) = self.connection {
139            // Build subscription request
140            let lsn = start_lsn.unwrap_or(0);
141            let mut request = vec![WalMessageType::Subscribe as u8];
142            request.extend_from_slice(&(8u32).to_be_bytes()); // payload length
143            request.extend_from_slice(&lsn.to_be_bytes()); // start LSN
144
145            // Send subscription request
146            if let Err(e) = stream.write_all(&request) {
147                tracing::error!("Failed to send subscription request: {}", e);
148                return Err(CacheError::ConnectionError(format!(
149                    "Subscription failed: {}",
150                    e
151                )));
152            }
153
154            // Wait for subscription ack
155            let mut header = [0u8; 5];
156            match stream.read_exact(&mut header) {
157                Ok(_) => {
158                    if header[0] == WalMessageType::SubscribeAck as u8 {
159                        tracing::info!("WAL subscription acknowledged");
160                    } else if header[0] == WalMessageType::Error as u8 {
161                        return Err(CacheError::ConnectionError(
162                            "Subscription rejected by server".to_string(),
163                        ));
164                    }
165                }
166                Err(e) => {
167                    tracing::warn!("No subscription ack received: {}", e);
168                }
169            }
170        }
171
172        Ok(WalSubscription {
173            running: self.running.clone(),
174            connection: self.connection.take(),
175            current_lsn: 0,
176            buffer: Vec::with_capacity(64 * 1024),
177        })
178    }
179
180    /// Attempt to reconnect to the WAL endpoint
181    #[allow(dead_code)]
182    async fn reconnect(&mut self) -> CacheResult<bool> {
183        if self.reconnect_attempts >= self.max_reconnect_attempts {
184            return Ok(false);
185        }
186
187        self.reconnect_attempts += 1;
188        let delay = self.reconnect_delay * self.reconnect_attempts;
189        tokio::time::sleep(delay).await;
190
191        tracing::info!(
192            "Attempting WAL reconnection (attempt {})",
193            self.reconnect_attempts
194        );
195
196        match TcpStream::connect_timeout(
197            &self
198                .endpoint
199                .parse()
200                .map_err(|_| CacheError::ConnectionError("Invalid endpoint".to_string()))?,
201            Duration::from_secs(5),
202        ) {
203            Ok(stream) => {
204                stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
205                stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
206                stream.set_nodelay(true).ok();
207
208                self.connection = Some(stream);
209                self.reconnect_attempts = 0;
210                self.last_heartbeat = Instant::now();
211                tracing::info!("WAL reconnection successful");
212                Ok(true)
213            }
214            Err(e) => {
215                tracing::warn!("WAL reconnection failed: {}", e);
216                Ok(false)
217            }
218        }
219    }
220
221    #[allow(dead_code)]
222    fn disconnect(&mut self) {
223        self.running.store(false, Ordering::SeqCst);
224        if let Some(stream) = self.connection.take() {
225            drop(stream);
226        }
227    }
228
229    /// Check if connected
230    #[allow(dead_code)]
231    fn is_connected(&self) -> bool {
232        self.connection.is_some()
233    }
234}
235
236/// WAL subscription for receiving streaming WAL entries
237pub struct WalSubscription {
238    running: Arc<AtomicBool>,
239    connection: Option<TcpStream>,
240    current_lsn: u64,
241    buffer: Vec<u8>,
242}
243
244impl WalSubscription {
245    /// Receive next WAL entry from the stream (non-recursive loop-based implementation)
246    pub async fn next(&mut self) -> Option<WalEntry> {
247        loop {
248            if !self.running.load(Ordering::SeqCst) {
249                return None;
250            }
251
252            let stream = match self.connection.as_mut() {
253                Some(s) => s,
254                None => {
255                    // No connection - sleep and return None
256                    tokio::time::sleep(Duration::from_millis(100)).await;
257                    return None;
258                }
259            };
260
261            // Read message header: [type: 1 byte][length: 4 bytes]
262            let mut header = [0u8; 5];
263            match stream.read_exact(&mut header) {
264                Ok(_) => {}
265                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
266                    // No data available, yield and retry
267                    tokio::time::sleep(Duration::from_millis(10)).await;
268                    continue;
269                }
270                Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
271                    // Timeout - this is normal, just continue
272                    return None;
273                }
274                Err(_) => {
275                    // Connection error
276                    self.running.store(false, Ordering::SeqCst);
277                    return None;
278                }
279            }
280
281            let msg_type = header[0];
282            let payload_len =
283                u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
284
285            // Handle heartbeat messages
286            if msg_type == WalMessageType::Heartbeat as u8 {
287                // Heartbeat has no payload or small status payload
288                if payload_len > 0 {
289                    let mut _payload = vec![0u8; payload_len];
290                    let _ = stream.read_exact(&mut _payload);
291                }
292                continue; // Skip heartbeats
293            }
294
295            // Only process WAL entries
296            if msg_type != WalMessageType::Entry as u8 {
297                // Skip unknown message types
298                if payload_len > 0 && payload_len < 1024 * 1024 {
299                    let mut skip = vec![0u8; payload_len];
300                    let _ = stream.read_exact(&mut skip);
301                }
302                continue;
303            }
304
305            // Read WAL entry payload
306            if payload_len == 0 || payload_len > 10 * 1024 * 1024 {
307                // Invalid payload size
308                return None;
309            }
310
311            self.buffer.resize(payload_len, 0);
312            if stream.read_exact(&mut self.buffer).is_err() {
313                self.running.store(false, Ordering::SeqCst);
314                return None;
315            }
316
317            // Parse WAL entry from payload
318            // Format: [lsn: 8 bytes][timestamp: 8 bytes][op_type: 1 byte][data...]
319            if self.buffer.len() < 17 {
320                continue; // Invalid entry, skip
321            }
322
323            let lsn = u64::from_be_bytes([
324                self.buffer[0],
325                self.buffer[1],
326                self.buffer[2],
327                self.buffer[3],
328                self.buffer[4],
329                self.buffer[5],
330                self.buffer[6],
331                self.buffer[7],
332            ]);
333            let timestamp = u64::from_be_bytes([
334                self.buffer[8],
335                self.buffer[9],
336                self.buffer[10],
337                self.buffer[11],
338                self.buffer[12],
339                self.buffer[13],
340                self.buffer[14],
341                self.buffer[15],
342            ]);
343            let op_type = self.buffer[16];
344            let data = &self.buffer[17..];
345
346            let operation = match op_type {
347                0x01 => {
348                    // Put operation: [key_len: 4][key][value]
349                    if data.len() < 4 {
350                        continue;
351                    }
352                    let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
353                    if data.len() < 4 + key_len {
354                        continue;
355                    }
356                    let key = data[4..4 + key_len].to_vec();
357                    let value = data[4 + key_len..].to_vec();
358                    WalOperation::Put { key, value }
359                }
360                0x02 => {
361                    // Delete operation: [key_len: 4][key]
362                    if data.len() < 4 {
363                        continue;
364                    }
365                    let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
366                    if data.len() < 4 + key_len {
367                        continue;
368                    }
369                    let key = data[4..4 + key_len].to_vec();
370                    WalOperation::Delete { key }
371                }
372                0x03 => {
373                    // Counter update: [table_name_len: 2][table_name][counter: 8]
374                    if data.len() < 10 {
375                        continue;
376                    }
377                    let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
378                    if data.len() < 2 + name_len + 8 {
379                        continue;
380                    }
381                    let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
382                    let counter_offset = 2 + name_len;
383                    let counter = u64::from_be_bytes([
384                        data[counter_offset],
385                        data[counter_offset + 1],
386                        data[counter_offset + 2],
387                        data[counter_offset + 3],
388                        data[counter_offset + 4],
389                        data[counter_offset + 5],
390                        data[counter_offset + 6],
391                        data[counter_offset + 7],
392                    ]);
393                    WalOperation::UpdateCounter {
394                        table_name,
395                        counter,
396                    }
397                }
398                0x04 => {
399                    // Schema change: [table_name_len: 2][table_name]
400                    if data.len() < 2 {
401                        continue;
402                    }
403                    let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
404                    if data.len() < 2 + name_len {
405                        continue;
406                    }
407                    let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
408                    WalOperation::SchemaChange { table_name }
409                }
410                0x05 => {
411                    // Commit: [txn_id: 8]
412                    if data.len() < 8 {
413                        continue;
414                    }
415                    let txn_id = u64::from_be_bytes([
416                        data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
417                    ]);
418                    WalOperation::Commit { txn_id }
419                }
420                _ => {
421                    // Unknown operation type, skip
422                    continue;
423                }
424            };
425
426            self.current_lsn = lsn;
427            return Some(WalEntry {
428                lsn,
429                timestamp,
430                operation,
431            });
432        }
433    }
434
435    /// Get current LSN position
436    pub fn current_lsn(&self) -> u64 {
437        self.current_lsn
438    }
439
440    /// Check if subscription is active
441    pub fn is_active(&self) -> bool {
442        self.running.load(Ordering::SeqCst) && self.connection.is_some()
443    }
444}
445
446/// Invalidation target
447#[derive(Debug, Clone)]
448pub struct InvalidationTarget {
449    /// Table name
450    pub table: String,
451    /// Optional row key for fine-grained invalidation
452    pub row_key: Option<Vec<u8>>,
453    /// Whether to invalidate all entries for this table
454    pub invalidate_all: bool,
455}
456
457/// Invalidation callback
458pub type InvalidationCallback = Arc<dyn Fn(&InvalidationTarget) + Send + Sync>;
459
460/// WAL-based cache invalidator
461pub struct WalInvalidator {
462    /// Configuration
463    #[allow(dead_code)]
464    config: DistribCacheConfig,
465
466    /// WAL stream
467    wal_stream: Option<WalStreamer>,
468
469    /// Active WAL subscription
470    subscription: tokio::sync::RwLock<Option<WalSubscription>>,
471
472    /// Table to fingerprint index
473    table_index: DashMap<String, HashSet<QueryFingerprint>>,
474
475    /// Invalidation callbacks
476    callbacks: Arc<tokio::sync::RwLock<Vec<InvalidationCallback>>>,
477
478    /// Running flag
479    running: Arc<AtomicBool>,
480
481    /// Last processed LSN
482    last_lsn: AtomicU64,
483
484    /// Statistics
485    stats: InvalidatorStats,
486}
487
488/// Invalidator statistics
489#[derive(Debug, Default)]
490struct InvalidatorStats {
491    wal_entries_processed: AtomicU64,
492    tables_invalidated: AtomicU64,
493    entries_invalidated: AtomicU64,
494    invalidation_lag_ms: AtomicU64,
495}
496
497impl WalInvalidator {
498    /// Create a new invalidator
499    pub fn new(config: DistribCacheConfig) -> Self {
500        Self {
501            config,
502            wal_stream: None,
503            subscription: tokio::sync::RwLock::new(None),
504            table_index: DashMap::new(),
505            callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
506            running: Arc::new(AtomicBool::new(false)),
507            last_lsn: AtomicU64::new(0),
508            stats: InvalidatorStats::default(),
509        }
510    }
511
512    /// Start the invalidator - connects to WAL endpoint and begins processing
513    pub async fn start(&self, wal_endpoint: &str) -> CacheResult<()> {
514        if self.running.load(Ordering::SeqCst) {
515            return Ok(()); // Already running
516        }
517
518        self.running.store(true, Ordering::SeqCst);
519
520        // Connect to WAL streaming endpoint
521        let mut streamer = WalStreamer::connect(wal_endpoint).await?;
522
523        // Start subscription from last known LSN (for recovery)
524        let start_lsn = self.last_lsn.load(Ordering::Relaxed);
525        let start_lsn = if start_lsn > 0 { Some(start_lsn) } else { None };
526
527        match streamer.subscribe(start_lsn).await {
528            Ok(sub) => {
529                *self.subscription.write().await = Some(sub);
530                tracing::info!("WAL invalidator started, connected to {}", wal_endpoint);
531            }
532            Err(e) => {
533                tracing::warn!(
534                    "Failed to subscribe to WAL stream: {}. Running in degraded mode.",
535                    e
536                );
537                // Still mark as running - can accept manual invalidations
538            }
539        }
540
541        Ok(())
542    }
543
544    /// Start the WAL processing loop in a background task
545    pub fn start_processing(&self) -> tokio::task::JoinHandle<()> {
546        let running = self.running.clone();
547        let _subscription = self.subscription.write();
548        let _callbacks = self.callbacks.clone();
549        let _stats = InvalidatorStats {
550            wal_entries_processed: AtomicU64::new(0),
551            tables_invalidated: AtomicU64::new(0),
552            entries_invalidated: AtomicU64::new(0),
553            invalidation_lag_ms: AtomicU64::new(0),
554        };
555        let _table_index = self.table_index.clone();
556        let _last_lsn = AtomicU64::new(self.last_lsn.load(Ordering::Relaxed));
557
558        tokio::spawn(async move {
559            tracing::info!("WAL processing loop started");
560
561            // Note: This is a simplified version - in production you'd pass
562            // the subscription handle properly
563            while running.load(Ordering::SeqCst) {
564                // Sleep to avoid busy loop when no subscription
565                tokio::time::sleep(Duration::from_millis(100)).await;
566            }
567
568            tracing::info!("WAL processing loop stopped");
569        })
570    }
571
572    /// Process WAL entries in the current task (blocking)
573    pub async fn process_loop(&self) {
574        while self.running.load(Ordering::SeqCst) {
575            let entry = {
576                let mut sub_guard = self.subscription.write().await;
577                if let Some(ref mut sub) = *sub_guard {
578                    sub.next().await
579                } else {
580                    None
581                }
582            };
583
584            match entry {
585                Some(wal_entry) => {
586                    let start = Instant::now();
587                    self.process_wal_entry(wal_entry.clone()).await;
588                    self.last_lsn.store(wal_entry.lsn, Ordering::Relaxed);
589
590                    // Track invalidation lag
591                    let lag = start.elapsed().as_millis() as u64;
592                    self.stats.invalidation_lag_ms.store(lag, Ordering::Relaxed);
593                }
594                None => {
595                    // No entry available, brief sleep
596                    tokio::time::sleep(Duration::from_millis(10)).await;
597                }
598            }
599        }
600    }
601
602    /// Stop the invalidator
603    pub async fn stop(&self) {
604        self.running.store(false, Ordering::SeqCst);
605
606        // Close subscription
607        let mut sub = self.subscription.write().await;
608        *sub = None;
609
610        if let Some(_stream) = self.wal_stream.as_ref().map(|_| ()) {
611            // Stream cleanup handled by drop
612        }
613
614        tracing::info!("WAL invalidator stopped");
615    }
616
617    /// Check if running
618    pub fn is_running(&self) -> bool {
619        self.running.load(Ordering::SeqCst)
620    }
621
622    /// Register a cache fingerprint for a table
623    pub fn register(&self, table: &str, fingerprint: QueryFingerprint) {
624        self.table_index
625            .entry(table.to_string())
626            .or_default()
627            .insert(fingerprint);
628    }
629
630    /// Unregister a fingerprint
631    pub fn unregister(&self, table: &str, fingerprint: &QueryFingerprint) {
632        if let Some(mut set) = self.table_index.get_mut(table) {
633            set.remove(fingerprint);
634        }
635    }
636
637    /// Add invalidation callback
638    pub async fn add_callback(&self, callback: InvalidationCallback) {
639        self.callbacks.write().await.push(callback);
640    }
641
642    /// Add callback (sync version for compatibility)
643    pub fn add_callback_sync(&self, callback: InvalidationCallback) {
644        // Use blocking for sync context
645        if let Ok(handle) = tokio::runtime::Handle::try_current() {
646            handle.block_on(async {
647                self.callbacks.write().await.push(callback);
648            });
649        }
650    }
651
652    /// Process a WAL entry
653    async fn process_wal_entry(&self, entry: WalEntry) {
654        self.stats
655            .wal_entries_processed
656            .fetch_add(1, Ordering::Relaxed);
657
658        let (table, row_key) = match &entry.operation {
659            WalOperation::Put { key, .. } => (self.extract_table(key), Some(key.clone())),
660            WalOperation::Delete { key } => (self.extract_table(key), Some(key.clone())),
661            WalOperation::UpdateCounter { table_name, .. } => (Some(table_name.clone()), None),
662            WalOperation::SchemaChange { table_name } => {
663                // Schema changes invalidate all entries for the table
664                self.invalidate_table(table_name, true).await;
665                return;
666            }
667            WalOperation::Commit { .. } => return,
668        };
669
670        if let Some(table) = table {
671            // Use fine-grained invalidation if row key available
672            if let Some(key) = row_key {
673                self.invalidate_row(&table, &key).await;
674            } else {
675                self.invalidate_table(&table, false).await;
676            }
677        }
678    }
679
680    /// Invalidate entries for a table
681    async fn invalidate_table(&self, table: &str, all_entries: bool) {
682        self.stats
683            .tables_invalidated
684            .fetch_add(1, Ordering::Relaxed);
685
686        let target = InvalidationTarget {
687            table: table.to_string(),
688            row_key: None,
689            invalidate_all: all_entries,
690        };
691
692        // Notify callbacks
693        let callbacks = self.callbacks.read().await;
694        for callback in callbacks.iter() {
695            callback(&target);
696        }
697
698        // Track invalidated entries
699        if let Some(entries) = self.table_index.get(table) {
700            self.stats
701                .entries_invalidated
702                .fetch_add(entries.len() as u64, Ordering::Relaxed);
703        }
704    }
705
706    /// Fine-grained row invalidation
707    async fn invalidate_row(&self, table: &str, row_key: &[u8]) {
708        let target = InvalidationTarget {
709            table: table.to_string(),
710            row_key: Some(row_key.to_vec()),
711            invalidate_all: false,
712        };
713
714        let callbacks = self.callbacks.read().await;
715        for callback in callbacks.iter() {
716            callback(&target);
717        }
718    }
719
720    /// Manually invalidate a table (public API)
721    pub async fn invalidate_table_manual(&self, table: &str, all_entries: bool) {
722        self.invalidate_table(table, all_entries).await;
723    }
724
725    /// Extract table name from key
726    fn extract_table(&self, key: &[u8]) -> Option<String> {
727        // Key format assumed: "table:primary_key"
728        let key_str = String::from_utf8_lossy(key);
729        key_str.split(':').next().map(|s| s.to_string())
730    }
731
732    /// Get invalidator statistics
733    pub fn stats(&self) -> InvalidatorStatsSnapshot {
734        InvalidatorStatsSnapshot {
735            wal_entries_processed: self.stats.wal_entries_processed.load(Ordering::Relaxed),
736            tables_invalidated: self.stats.tables_invalidated.load(Ordering::Relaxed),
737            entries_invalidated: self.stats.entries_invalidated.load(Ordering::Relaxed),
738            invalidation_lag_ms: self.stats.invalidation_lag_ms.load(Ordering::Relaxed),
739            registered_tables: self.table_index.len(),
740            is_running: self.running.load(Ordering::Relaxed),
741        }
742    }
743}
744
745/// Invalidator statistics snapshot
746#[derive(Debug, Clone)]
747pub struct InvalidatorStatsSnapshot {
748    pub wal_entries_processed: u64,
749    pub tables_invalidated: u64,
750    pub entries_invalidated: u64,
751    pub invalidation_lag_ms: u64,
752    pub registered_tables: usize,
753    pub is_running: bool,
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[test]
761    fn test_register_fingerprint() {
762        let config = DistribCacheConfig::default();
763        let invalidator = WalInvalidator::new(config);
764
765        let fp = QueryFingerprint::from_query("SELECT * FROM users");
766        invalidator.register("users", fp.clone());
767
768        assert!(invalidator.table_index.contains_key("users"));
769        let entries = invalidator.table_index.get("users").unwrap();
770        assert!(entries.contains(&fp));
771    }
772
773    #[test]
774    fn test_unregister_fingerprint() {
775        let config = DistribCacheConfig::default();
776        let invalidator = WalInvalidator::new(config);
777
778        let fp = QueryFingerprint::from_query("SELECT * FROM users");
779        invalidator.register("users", fp.clone());
780        invalidator.unregister("users", &fp);
781
782        let entries = invalidator.table_index.get("users").unwrap();
783        assert!(!entries.contains(&fp));
784    }
785
786    #[tokio::test]
787    async fn test_callback_invocation() {
788        let config = DistribCacheConfig::default();
789        let invalidator = WalInvalidator::new(config);
790
791        let called = Arc::new(AtomicBool::new(false));
792        let called_clone = called.clone();
793
794        invalidator
795            .add_callback(Arc::new(move |target| {
796                if target.table == "users" {
797                    called_clone.store(true, Ordering::SeqCst);
798                }
799            }))
800            .await;
801
802        invalidator.invalidate_table_manual("users", false).await;
803
804        assert!(called.load(Ordering::SeqCst));
805    }
806
807    #[test]
808    fn test_extract_table() {
809        let config = DistribCacheConfig::default();
810        let invalidator = WalInvalidator::new(config);
811
812        let key = b"users:123";
813        let table = invalidator.extract_table(key);
814        assert_eq!(table, Some("users".to_string()));
815    }
816
817    #[tokio::test]
818    async fn test_process_wal_entry() {
819        let config = DistribCacheConfig::default();
820        let invalidator = WalInvalidator::new(config);
821
822        let fp = QueryFingerprint::from_query("SELECT * FROM users");
823        invalidator.register("users", fp);
824
825        let entry = WalEntry {
826            lsn: 1,
827            timestamp: 0,
828            operation: WalOperation::Put {
829                key: b"users:123".to_vec(),
830                value: b"data".to_vec(),
831            },
832        };
833
834        invalidator.process_wal_entry(entry).await;
835
836        let stats = invalidator.stats();
837        assert_eq!(stats.wal_entries_processed, 1);
838    }
839
840    #[tokio::test]
841    async fn test_start_stop() {
842        let config = DistribCacheConfig::default();
843        let invalidator = WalInvalidator::new(config);
844
845        // Start with invalid endpoint (will not connect but won't fail)
846        invalidator.start("127.0.0.1:59999").await.unwrap();
847        assert!(invalidator.is_running());
848
849        // Stop
850        invalidator.stop().await;
851        assert!(!invalidator.is_running());
852    }
853
854    #[test]
855    fn test_wal_entry_parsing() {
856        // Test WalOperation variants
857        let put = WalOperation::Put {
858            key: b"users:1".to_vec(),
859            value: b"data".to_vec(),
860        };
861        assert!(matches!(put, WalOperation::Put { .. }));
862
863        let delete = WalOperation::Delete {
864            key: b"users:1".to_vec(),
865        };
866        assert!(matches!(delete, WalOperation::Delete { .. }));
867
868        let counter = WalOperation::UpdateCounter {
869            table_name: "users".to_string(),
870            counter: 100,
871        };
872        assert!(matches!(counter, WalOperation::UpdateCounter { .. }));
873
874        let schema = WalOperation::SchemaChange {
875            table_name: "users".to_string(),
876        };
877        assert!(matches!(schema, WalOperation::SchemaChange { .. }));
878
879        let commit = WalOperation::Commit { txn_id: 12345 };
880        assert!(matches!(commit, WalOperation::Commit { .. }));
881    }
882
883    #[tokio::test]
884    async fn test_invalidation_stats() {
885        let config = DistribCacheConfig::default();
886        let invalidator = WalInvalidator::new(config);
887
888        // Register some fingerprints (different query templates to avoid normalization collision)
889        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
890        let fp2 = QueryFingerprint::from_query("SELECT name FROM users WHERE email = ?");
891        invalidator.register("users", fp1);
892        invalidator.register("users", fp2);
893
894        // Invalidate table
895        invalidator.invalidate_table_manual("users", false).await;
896
897        let stats = invalidator.stats();
898        assert_eq!(stats.tables_invalidated, 1);
899        assert_eq!(stats.entries_invalidated, 2);
900    }
901}