Skip to main content

heliosdb_proxy/distribcache/
invalidator.rs

1//! WAL-based cache invalidator
2//!
3//! Subscribes to WAL stream for real-time cache coherency.
4//! Invalidates cached entries when underlying data changes.
5//!
6//! # Protocol
7//!
8//! The WAL streaming protocol uses TCP with the following message format:
9//! ```text
10//! [1 byte: message type][4 bytes: payload length][payload]
11//! ```
12//!
13//! Message types:
14//! - 0x01: WAL entry
15//! - 0x02: Heartbeat
16//! - 0x03: Subscription request
17//! - 0x04: Subscription ack
18
19use dashmap::DashMap;
20use std::collections::HashSet;
21use std::io::{Read, Write};
22use std::net::TcpStream;
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use super::{CacheError, CacheResult, DistribCacheConfig, QueryFingerprint};
28
29/// WAL protocol message types
30#[repr(u8)]
31#[derive(Debug, Clone, Copy, PartialEq)]
32enum WalMessageType {
33    Entry = 0x01,
34    Heartbeat = 0x02,
35    Subscribe = 0x03,
36    SubscribeAck = 0x04,
37    Error = 0xFF,
38}
39
40/// WAL operation types
41#[derive(Debug, Clone)]
42pub enum WalOperation {
43    /// Insert/Update operation
44    Put { key: Vec<u8>, value: Vec<u8> },
45    /// Delete operation
46    Delete { key: Vec<u8> },
47    /// Counter update
48    UpdateCounter { table_name: String, counter: u64 },
49    /// Schema change
50    SchemaChange { table_name: String },
51    /// Transaction commit
52    Commit { txn_id: u64 },
53}
54
55/// WAL entry
56#[derive(Debug, Clone)]
57pub struct WalEntry {
58    /// Log sequence number
59    pub lsn: u64,
60    /// Timestamp
61    pub timestamp: u64,
62    /// Operation
63    pub operation: WalOperation,
64}
65
66/// WAL stream subscriber with TCP connection
67pub struct WalStreamer {
68    /// Endpoint address (host:port)
69    endpoint: String,
70    /// TCP connection to WAL server
71    connection: Option<TcpStream>,
72    /// Running flag
73    running: Arc<AtomicBool>,
74    /// Current LSN
75    current_lsn: AtomicU64,
76    /// Last heartbeat received
77    last_heartbeat: Instant,
78    /// Reconnection attempts
79    reconnect_attempts: u32,
80    /// Maximum reconnection attempts
81    max_reconnect_attempts: u32,
82    /// Reconnect delay
83    reconnect_delay: Duration,
84}
85
86impl WalStreamer {
87    fn new(endpoint: &str) -> Self {
88        Self {
89            endpoint: endpoint.to_string(),
90            connection: None,
91            running: Arc::new(AtomicBool::new(false)),
92            current_lsn: AtomicU64::new(0),
93            last_heartbeat: Instant::now(),
94            reconnect_attempts: 0,
95            max_reconnect_attempts: 10,
96            reconnect_delay: Duration::from_secs(1),
97        }
98    }
99
100    /// Connect to the WAL streaming endpoint
101    async fn connect(endpoint: &str) -> CacheResult<Self> {
102        let mut streamer = Self::new(endpoint);
103
104        // Attempt TCP connection
105        match TcpStream::connect_timeout(
106            &endpoint.parse().map_err(|_| CacheError::ConnectionError("Invalid endpoint address".to_string()))?,
107            Duration::from_secs(5),
108        ) {
109            Ok(stream) => {
110                // Set TCP options
111                stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
112                stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
113                stream.set_nodelay(true).ok();
114
115                streamer.connection = Some(stream);
116                streamer.last_heartbeat = Instant::now();
117                Ok(streamer)
118            }
119            Err(e) => {
120                // Return a disconnected streamer - can be reconnected later
121                tracing::warn!("Failed to connect to WAL endpoint {}: {}", endpoint, e);
122                Ok(streamer)
123            }
124        }
125    }
126
127    /// Send subscription request and start receiving WAL entries
128    async fn subscribe(&mut self, start_lsn: Option<u64>) -> CacheResult<WalSubscription> {
129        self.running.store(true, Ordering::SeqCst);
130
131        if let Some(ref mut stream) = self.connection {
132            // Build subscription request
133            let lsn = start_lsn.unwrap_or(0);
134            let mut request = vec![WalMessageType::Subscribe as u8];
135            request.extend_from_slice(&(8u32).to_be_bytes()); // payload length
136            request.extend_from_slice(&lsn.to_be_bytes());     // start LSN
137
138            // Send subscription request
139            if let Err(e) = stream.write_all(&request) {
140                tracing::error!("Failed to send subscription request: {}", e);
141                return Err(CacheError::ConnectionError(format!("Subscription failed: {}", e)));
142            }
143
144            // Wait for subscription ack
145            let mut header = [0u8; 5];
146            match stream.read_exact(&mut header) {
147                Ok(_) => {
148                    if header[0] == WalMessageType::SubscribeAck as u8 {
149                        tracing::info!("WAL subscription acknowledged");
150                    } else if header[0] == WalMessageType::Error as u8 {
151                        return Err(CacheError::ConnectionError("Subscription rejected by server".to_string()));
152                    }
153                }
154                Err(e) => {
155                    tracing::warn!("No subscription ack received: {}", e);
156                }
157            }
158        }
159
160        Ok(WalSubscription {
161            running: self.running.clone(),
162            connection: self.connection.take(),
163            current_lsn: 0,
164            buffer: Vec::with_capacity(64 * 1024),
165        })
166    }
167
168    /// Attempt to reconnect to the WAL endpoint
169    async fn reconnect(&mut self) -> CacheResult<bool> {
170        if self.reconnect_attempts >= self.max_reconnect_attempts {
171            return Ok(false);
172        }
173
174        self.reconnect_attempts += 1;
175        let delay = self.reconnect_delay * self.reconnect_attempts;
176        tokio::time::sleep(delay).await;
177
178        tracing::info!("Attempting WAL reconnection (attempt {})", self.reconnect_attempts);
179
180        match TcpStream::connect_timeout(
181            &self.endpoint.parse().map_err(|_| CacheError::ConnectionError("Invalid endpoint".to_string()))?,
182            Duration::from_secs(5),
183        ) {
184            Ok(stream) => {
185                stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
186                stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
187                stream.set_nodelay(true).ok();
188
189                self.connection = Some(stream);
190                self.reconnect_attempts = 0;
191                self.last_heartbeat = Instant::now();
192                tracing::info!("WAL reconnection successful");
193                Ok(true)
194            }
195            Err(e) => {
196                tracing::warn!("WAL reconnection failed: {}", e);
197                Ok(false)
198            }
199        }
200    }
201
202    fn disconnect(&mut self) {
203        self.running.store(false, Ordering::SeqCst);
204        if let Some(stream) = self.connection.take() {
205            drop(stream);
206        }
207    }
208
209    /// Check if connected
210    fn is_connected(&self) -> bool {
211        self.connection.is_some()
212    }
213}
214
215/// WAL subscription for receiving streaming WAL entries
216pub struct WalSubscription {
217    running: Arc<AtomicBool>,
218    connection: Option<TcpStream>,
219    current_lsn: u64,
220    buffer: Vec<u8>,
221}
222
223impl WalSubscription {
224    /// Receive next WAL entry from the stream (non-recursive loop-based implementation)
225    pub async fn next(&mut self) -> Option<WalEntry> {
226        loop {
227            if !self.running.load(Ordering::SeqCst) {
228                return None;
229            }
230
231            let stream = match self.connection.as_mut() {
232                Some(s) => s,
233                None => {
234                    // No connection - sleep and return None
235                    tokio::time::sleep(Duration::from_millis(100)).await;
236                    return None;
237                }
238            };
239
240            // Read message header: [type: 1 byte][length: 4 bytes]
241            let mut header = [0u8; 5];
242            match stream.read_exact(&mut header) {
243                Ok(_) => {}
244                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
245                    // No data available, yield and retry
246                    tokio::time::sleep(Duration::from_millis(10)).await;
247                    continue;
248                }
249                Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
250                    // Timeout - this is normal, just continue
251                    return None;
252                }
253                Err(_) => {
254                    // Connection error
255                    self.running.store(false, Ordering::SeqCst);
256                    return None;
257                }
258            }
259
260            let msg_type = header[0];
261            let payload_len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
262
263            // Handle heartbeat messages
264            if msg_type == WalMessageType::Heartbeat as u8 {
265                // Heartbeat has no payload or small status payload
266                if payload_len > 0 {
267                    let mut _payload = vec![0u8; payload_len];
268                    let _ = stream.read_exact(&mut _payload);
269                }
270                continue; // Skip heartbeats
271            }
272
273            // Only process WAL entries
274            if msg_type != WalMessageType::Entry as u8 {
275                // Skip unknown message types
276                if payload_len > 0 && payload_len < 1024 * 1024 {
277                    let mut skip = vec![0u8; payload_len];
278                    let _ = stream.read_exact(&mut skip);
279                }
280                continue;
281            }
282
283            // Read WAL entry payload
284            if payload_len == 0 || payload_len > 10 * 1024 * 1024 {
285                // Invalid payload size
286                return None;
287            }
288
289            self.buffer.resize(payload_len, 0);
290            if stream.read_exact(&mut self.buffer).is_err() {
291                self.running.store(false, Ordering::SeqCst);
292                return None;
293            }
294
295            // Parse WAL entry from payload
296            // Format: [lsn: 8 bytes][timestamp: 8 bytes][op_type: 1 byte][data...]
297            if self.buffer.len() < 17 {
298                continue; // Invalid entry, skip
299            }
300
301            let lsn = u64::from_be_bytes([
302                self.buffer[0], self.buffer[1], self.buffer[2], self.buffer[3],
303                self.buffer[4], self.buffer[5], self.buffer[6], self.buffer[7],
304            ]);
305            let timestamp = u64::from_be_bytes([
306                self.buffer[8], self.buffer[9], self.buffer[10], self.buffer[11],
307                self.buffer[12], self.buffer[13], self.buffer[14], self.buffer[15],
308            ]);
309            let op_type = self.buffer[16];
310            let data = &self.buffer[17..];
311
312            let operation = match op_type {
313                0x01 => {
314                    // Put operation: [key_len: 4][key][value]
315                    if data.len() < 4 {
316                        continue;
317                    }
318                    let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
319                    if data.len() < 4 + key_len {
320                        continue;
321                    }
322                    let key = data[4..4 + key_len].to_vec();
323                    let value = data[4 + key_len..].to_vec();
324                    WalOperation::Put { key, value }
325                }
326                0x02 => {
327                    // Delete operation: [key_len: 4][key]
328                    if data.len() < 4 {
329                        continue;
330                    }
331                    let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
332                    if data.len() < 4 + key_len {
333                        continue;
334                    }
335                    let key = data[4..4 + key_len].to_vec();
336                    WalOperation::Delete { key }
337                }
338                0x03 => {
339                    // Counter update: [table_name_len: 2][table_name][counter: 8]
340                    if data.len() < 10 {
341                        continue;
342                    }
343                    let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
344                    if data.len() < 2 + name_len + 8 {
345                        continue;
346                    }
347                    let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
348                    let counter_offset = 2 + name_len;
349                    let counter = u64::from_be_bytes([
350                        data[counter_offset], data[counter_offset + 1],
351                        data[counter_offset + 2], data[counter_offset + 3],
352                        data[counter_offset + 4], data[counter_offset + 5],
353                        data[counter_offset + 6], data[counter_offset + 7],
354                    ]);
355                    WalOperation::UpdateCounter { table_name, counter }
356                }
357                0x04 => {
358                    // Schema change: [table_name_len: 2][table_name]
359                    if data.len() < 2 {
360                        continue;
361                    }
362                    let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
363                    if data.len() < 2 + name_len {
364                        continue;
365                    }
366                    let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
367                    WalOperation::SchemaChange { table_name }
368                }
369                0x05 => {
370                    // Commit: [txn_id: 8]
371                    if data.len() < 8 {
372                        continue;
373                    }
374                    let txn_id = u64::from_be_bytes([
375                        data[0], data[1], data[2], data[3],
376                        data[4], data[5], data[6], data[7],
377                    ]);
378                    WalOperation::Commit { txn_id }
379                }
380                _ => {
381                    // Unknown operation type, skip
382                    continue;
383                }
384            };
385
386            self.current_lsn = lsn;
387            return Some(WalEntry {
388                lsn,
389                timestamp,
390                operation,
391            });
392        }
393    }
394
395    /// Get current LSN position
396    pub fn current_lsn(&self) -> u64 {
397        self.current_lsn
398    }
399
400    /// Check if subscription is active
401    pub fn is_active(&self) -> bool {
402        self.running.load(Ordering::SeqCst) && self.connection.is_some()
403    }
404}
405
406/// Invalidation target
407#[derive(Debug, Clone)]
408pub struct InvalidationTarget {
409    /// Table name
410    pub table: String,
411    /// Optional row key for fine-grained invalidation
412    pub row_key: Option<Vec<u8>>,
413    /// Whether to invalidate all entries for this table
414    pub invalidate_all: bool,
415}
416
417/// Invalidation callback
418pub type InvalidationCallback = Arc<dyn Fn(&InvalidationTarget) + Send + Sync>;
419
420/// WAL-based cache invalidator
421pub struct WalInvalidator {
422    /// Configuration
423    config: DistribCacheConfig,
424
425    /// WAL stream
426    wal_stream: Option<WalStreamer>,
427
428    /// Active WAL subscription
429    subscription: tokio::sync::RwLock<Option<WalSubscription>>,
430
431    /// Table to fingerprint index
432    table_index: DashMap<String, HashSet<QueryFingerprint>>,
433
434    /// Invalidation callbacks
435    callbacks: Arc<tokio::sync::RwLock<Vec<InvalidationCallback>>>,
436
437    /// Running flag
438    running: Arc<AtomicBool>,
439
440    /// Last processed LSN
441    last_lsn: AtomicU64,
442
443    /// Statistics
444    stats: InvalidatorStats,
445}
446
447/// Invalidator statistics
448#[derive(Debug, Default)]
449struct InvalidatorStats {
450    wal_entries_processed: AtomicU64,
451    tables_invalidated: AtomicU64,
452    entries_invalidated: AtomicU64,
453    invalidation_lag_ms: AtomicU64,
454}
455
456impl WalInvalidator {
457    /// Create a new invalidator
458    pub fn new(config: DistribCacheConfig) -> Self {
459        Self {
460            config,
461            wal_stream: None,
462            subscription: tokio::sync::RwLock::new(None),
463            table_index: DashMap::new(),
464            callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
465            running: Arc::new(AtomicBool::new(false)),
466            last_lsn: AtomicU64::new(0),
467            stats: InvalidatorStats::default(),
468        }
469    }
470
471    /// Start the invalidator - connects to WAL endpoint and begins processing
472    pub async fn start(&self, wal_endpoint: &str) -> CacheResult<()> {
473        if self.running.load(Ordering::SeqCst) {
474            return Ok(()); // Already running
475        }
476
477        self.running.store(true, Ordering::SeqCst);
478
479        // Connect to WAL streaming endpoint
480        let mut streamer = WalStreamer::connect(wal_endpoint).await?;
481
482        // Start subscription from last known LSN (for recovery)
483        let start_lsn = self.last_lsn.load(Ordering::Relaxed);
484        let start_lsn = if start_lsn > 0 { Some(start_lsn) } else { None };
485
486        match streamer.subscribe(start_lsn).await {
487            Ok(sub) => {
488                *self.subscription.write().await = Some(sub);
489                tracing::info!("WAL invalidator started, connected to {}", wal_endpoint);
490            }
491            Err(e) => {
492                tracing::warn!("Failed to subscribe to WAL stream: {}. Running in degraded mode.", e);
493                // Still mark as running - can accept manual invalidations
494            }
495        }
496
497        Ok(())
498    }
499
500    /// Start the WAL processing loop in a background task
501    pub fn start_processing(&self) -> tokio::task::JoinHandle<()> {
502        let running = self.running.clone();
503        let _subscription = self.subscription.write();
504        let _callbacks = self.callbacks.clone();
505        let _stats = InvalidatorStats {
506            wal_entries_processed: AtomicU64::new(0),
507            tables_invalidated: AtomicU64::new(0),
508            entries_invalidated: AtomicU64::new(0),
509            invalidation_lag_ms: AtomicU64::new(0),
510        };
511        let _table_index = self.table_index.clone();
512        let _last_lsn = AtomicU64::new(self.last_lsn.load(Ordering::Relaxed));
513
514        tokio::spawn(async move {
515            tracing::info!("WAL processing loop started");
516
517            // Note: This is a simplified version - in production you'd pass
518            // the subscription handle properly
519            while running.load(Ordering::SeqCst) {
520                // Sleep to avoid busy loop when no subscription
521                tokio::time::sleep(Duration::from_millis(100)).await;
522            }
523
524            tracing::info!("WAL processing loop stopped");
525        })
526    }
527
528    /// Process WAL entries in the current task (blocking)
529    pub async fn process_loop(&self) {
530        while self.running.load(Ordering::SeqCst) {
531            let entry = {
532                let mut sub_guard = self.subscription.write().await;
533                if let Some(ref mut sub) = *sub_guard {
534                    sub.next().await
535                } else {
536                    None
537                }
538            };
539
540            match entry {
541                Some(wal_entry) => {
542                    let start = Instant::now();
543                    self.process_wal_entry(wal_entry.clone()).await;
544                    self.last_lsn.store(wal_entry.lsn, Ordering::Relaxed);
545
546                    // Track invalidation lag
547                    let lag = start.elapsed().as_millis() as u64;
548                    self.stats.invalidation_lag_ms.store(lag, Ordering::Relaxed);
549                }
550                None => {
551                    // No entry available, brief sleep
552                    tokio::time::sleep(Duration::from_millis(10)).await;
553                }
554            }
555        }
556    }
557
558    /// Stop the invalidator
559    pub async fn stop(&self) {
560        self.running.store(false, Ordering::SeqCst);
561
562        // Close subscription
563        let mut sub = self.subscription.write().await;
564        *sub = None;
565
566        if let Some(_stream) = self.wal_stream.as_ref().map(|_| ()) {
567            // Stream cleanup handled by drop
568        }
569
570        tracing::info!("WAL invalidator stopped");
571    }
572
573    /// Check if running
574    pub fn is_running(&self) -> bool {
575        self.running.load(Ordering::SeqCst)
576    }
577
578    /// Register a cache fingerprint for a table
579    pub fn register(&self, table: &str, fingerprint: QueryFingerprint) {
580        self.table_index
581            .entry(table.to_string())
582            .or_default()
583            .insert(fingerprint);
584    }
585
586    /// Unregister a fingerprint
587    pub fn unregister(&self, table: &str, fingerprint: &QueryFingerprint) {
588        if let Some(mut set) = self.table_index.get_mut(table) {
589            set.remove(fingerprint);
590        }
591    }
592
593    /// Add invalidation callback
594    pub async fn add_callback(&self, callback: InvalidationCallback) {
595        self.callbacks.write().await.push(callback);
596    }
597
598    /// Add callback (sync version for compatibility)
599    pub fn add_callback_sync(&self, callback: InvalidationCallback) {
600        // Use blocking for sync context
601        if let Ok(handle) = tokio::runtime::Handle::try_current() {
602            handle.block_on(async {
603                self.callbacks.write().await.push(callback);
604            });
605        }
606    }
607
608    /// Process a WAL entry
609    async fn process_wal_entry(&self, entry: WalEntry) {
610        self.stats.wal_entries_processed.fetch_add(1, Ordering::Relaxed);
611
612        let (table, row_key) = match &entry.operation {
613            WalOperation::Put { key, .. } => (self.extract_table(key), Some(key.clone())),
614            WalOperation::Delete { key } => (self.extract_table(key), Some(key.clone())),
615            WalOperation::UpdateCounter { table_name, .. } => (Some(table_name.clone()), None),
616            WalOperation::SchemaChange { table_name } => {
617                // Schema changes invalidate all entries for the table
618                self.invalidate_table(table_name, true).await;
619                return;
620            }
621            WalOperation::Commit { .. } => return,
622        };
623
624        if let Some(table) = table {
625            // Use fine-grained invalidation if row key available
626            if let Some(key) = row_key {
627                self.invalidate_row(&table, &key).await;
628            } else {
629                self.invalidate_table(&table, false).await;
630            }
631        }
632    }
633
634    /// Invalidate entries for a table
635    async fn invalidate_table(&self, table: &str, all_entries: bool) {
636        self.stats.tables_invalidated.fetch_add(1, Ordering::Relaxed);
637
638        let target = InvalidationTarget {
639            table: table.to_string(),
640            row_key: None,
641            invalidate_all: all_entries,
642        };
643
644        // Notify callbacks
645        let callbacks = self.callbacks.read().await;
646        for callback in callbacks.iter() {
647            callback(&target);
648        }
649
650        // Track invalidated entries
651        if let Some(entries) = self.table_index.get(table) {
652            self.stats.entries_invalidated.fetch_add(
653                entries.len() as u64,
654                Ordering::Relaxed,
655            );
656        }
657    }
658
659    /// Fine-grained row invalidation
660    async fn invalidate_row(&self, table: &str, row_key: &[u8]) {
661        let target = InvalidationTarget {
662            table: table.to_string(),
663            row_key: Some(row_key.to_vec()),
664            invalidate_all: false,
665        };
666
667        let callbacks = self.callbacks.read().await;
668        for callback in callbacks.iter() {
669            callback(&target);
670        }
671    }
672
673    /// Manually invalidate a table (public API)
674    pub async fn invalidate_table_manual(&self, table: &str, all_entries: bool) {
675        self.invalidate_table(table, all_entries).await;
676    }
677
678    /// Extract table name from key
679    fn extract_table(&self, key: &[u8]) -> Option<String> {
680        // Key format assumed: "table:primary_key"
681        let key_str = String::from_utf8_lossy(key);
682        key_str.split(':').next().map(|s| s.to_string())
683    }
684
685    /// Get invalidator statistics
686    pub fn stats(&self) -> InvalidatorStatsSnapshot {
687        InvalidatorStatsSnapshot {
688            wal_entries_processed: self.stats.wal_entries_processed.load(Ordering::Relaxed),
689            tables_invalidated: self.stats.tables_invalidated.load(Ordering::Relaxed),
690            entries_invalidated: self.stats.entries_invalidated.load(Ordering::Relaxed),
691            invalidation_lag_ms: self.stats.invalidation_lag_ms.load(Ordering::Relaxed),
692            registered_tables: self.table_index.len(),
693            is_running: self.running.load(Ordering::Relaxed),
694        }
695    }
696}
697
698/// Invalidator statistics snapshot
699#[derive(Debug, Clone)]
700pub struct InvalidatorStatsSnapshot {
701    pub wal_entries_processed: u64,
702    pub tables_invalidated: u64,
703    pub entries_invalidated: u64,
704    pub invalidation_lag_ms: u64,
705    pub registered_tables: usize,
706    pub is_running: bool,
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712
713    #[test]
714    fn test_register_fingerprint() {
715        let config = DistribCacheConfig::default();
716        let invalidator = WalInvalidator::new(config);
717
718        let fp = QueryFingerprint::from_query("SELECT * FROM users");
719        invalidator.register("users", fp.clone());
720
721        assert!(invalidator.table_index.contains_key("users"));
722        let entries = invalidator.table_index.get("users").unwrap();
723        assert!(entries.contains(&fp));
724    }
725
726    #[test]
727    fn test_unregister_fingerprint() {
728        let config = DistribCacheConfig::default();
729        let invalidator = WalInvalidator::new(config);
730
731        let fp = QueryFingerprint::from_query("SELECT * FROM users");
732        invalidator.register("users", fp.clone());
733        invalidator.unregister("users", &fp);
734
735        let entries = invalidator.table_index.get("users").unwrap();
736        assert!(!entries.contains(&fp));
737    }
738
739    #[tokio::test]
740    async fn test_callback_invocation() {
741        let config = DistribCacheConfig::default();
742        let invalidator = WalInvalidator::new(config);
743
744        let called = Arc::new(AtomicBool::new(false));
745        let called_clone = called.clone();
746
747        invalidator.add_callback(Arc::new(move |target| {
748            if target.table == "users" {
749                called_clone.store(true, Ordering::SeqCst);
750            }
751        })).await;
752
753        invalidator.invalidate_table_manual("users", false).await;
754
755        assert!(called.load(Ordering::SeqCst));
756    }
757
758    #[test]
759    fn test_extract_table() {
760        let config = DistribCacheConfig::default();
761        let invalidator = WalInvalidator::new(config);
762
763        let key = b"users:123";
764        let table = invalidator.extract_table(key);
765        assert_eq!(table, Some("users".to_string()));
766    }
767
768    #[tokio::test]
769    async fn test_process_wal_entry() {
770        let config = DistribCacheConfig::default();
771        let invalidator = WalInvalidator::new(config);
772
773        let fp = QueryFingerprint::from_query("SELECT * FROM users");
774        invalidator.register("users", fp);
775
776        let entry = WalEntry {
777            lsn: 1,
778            timestamp: 0,
779            operation: WalOperation::Put {
780                key: b"users:123".to_vec(),
781                value: b"data".to_vec(),
782            },
783        };
784
785        invalidator.process_wal_entry(entry).await;
786
787        let stats = invalidator.stats();
788        assert_eq!(stats.wal_entries_processed, 1);
789    }
790
791    #[tokio::test]
792    async fn test_start_stop() {
793        let config = DistribCacheConfig::default();
794        let invalidator = WalInvalidator::new(config);
795
796        // Start with invalid endpoint (will not connect but won't fail)
797        invalidator.start("127.0.0.1:59999").await.unwrap();
798        assert!(invalidator.is_running());
799
800        // Stop
801        invalidator.stop().await;
802        assert!(!invalidator.is_running());
803    }
804
805    #[test]
806    fn test_wal_entry_parsing() {
807        // Test WalOperation variants
808        let put = WalOperation::Put {
809            key: b"users:1".to_vec(),
810            value: b"data".to_vec(),
811        };
812        assert!(matches!(put, WalOperation::Put { .. }));
813
814        let delete = WalOperation::Delete {
815            key: b"users:1".to_vec(),
816        };
817        assert!(matches!(delete, WalOperation::Delete { .. }));
818
819        let counter = WalOperation::UpdateCounter {
820            table_name: "users".to_string(),
821            counter: 100,
822        };
823        assert!(matches!(counter, WalOperation::UpdateCounter { .. }));
824
825        let schema = WalOperation::SchemaChange {
826            table_name: "users".to_string(),
827        };
828        assert!(matches!(schema, WalOperation::SchemaChange { .. }));
829
830        let commit = WalOperation::Commit { txn_id: 12345 };
831        assert!(matches!(commit, WalOperation::Commit { .. }));
832    }
833
834    #[tokio::test]
835    async fn test_invalidation_stats() {
836        let config = DistribCacheConfig::default();
837        let invalidator = WalInvalidator::new(config);
838
839        // Register some fingerprints (different query templates to avoid normalization collision)
840        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
841        let fp2 = QueryFingerprint::from_query("SELECT name FROM users WHERE email = ?");
842        invalidator.register("users", fp1);
843        invalidator.register("users", fp2);
844
845        // Invalidate table
846        invalidator.invalidate_table_manual("users", false).await;
847
848        let stats = invalidator.stats();
849        assert_eq!(stats.tables_invalidated, 1);
850        assert_eq!(stats.entries_invalidated, 2);
851    }
852}