1use dashmap::DashMap;
20use std::collections::HashSet;
21use std::io::{Read, Write};
22use std::net::TcpStream;
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use super::{CacheError, CacheResult, DistribCacheConfig, QueryFingerprint};
28
29#[repr(u8)]
31#[derive(Debug, Clone, Copy, PartialEq)]
32enum WalMessageType {
33 Entry = 0x01,
34 Heartbeat = 0x02,
35 Subscribe = 0x03,
36 SubscribeAck = 0x04,
37 Error = 0xFF,
38}
39
40#[derive(Debug, Clone)]
42pub enum WalOperation {
43 Put { key: Vec<u8>, value: Vec<u8> },
45 Delete { key: Vec<u8> },
47 UpdateCounter { table_name: String, counter: u64 },
49 SchemaChange { table_name: String },
51 Commit { txn_id: u64 },
53}
54
55#[derive(Debug, Clone)]
57pub struct WalEntry {
58 pub lsn: u64,
60 pub timestamp: u64,
62 pub operation: WalOperation,
64}
65
66pub struct WalStreamer {
68 endpoint: String,
70 connection: Option<TcpStream>,
72 running: Arc<AtomicBool>,
74 current_lsn: AtomicU64,
76 last_heartbeat: Instant,
78 reconnect_attempts: u32,
80 max_reconnect_attempts: u32,
82 reconnect_delay: Duration,
84}
85
86impl WalStreamer {
87 fn new(endpoint: &str) -> Self {
88 Self {
89 endpoint: endpoint.to_string(),
90 connection: None,
91 running: Arc::new(AtomicBool::new(false)),
92 current_lsn: AtomicU64::new(0),
93 last_heartbeat: Instant::now(),
94 reconnect_attempts: 0,
95 max_reconnect_attempts: 10,
96 reconnect_delay: Duration::from_secs(1),
97 }
98 }
99
100 async fn connect(endpoint: &str) -> CacheResult<Self> {
102 let mut streamer = Self::new(endpoint);
103
104 match TcpStream::connect_timeout(
106 &endpoint.parse().map_err(|_| CacheError::ConnectionError("Invalid endpoint address".to_string()))?,
107 Duration::from_secs(5),
108 ) {
109 Ok(stream) => {
110 stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
112 stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
113 stream.set_nodelay(true).ok();
114
115 streamer.connection = Some(stream);
116 streamer.last_heartbeat = Instant::now();
117 Ok(streamer)
118 }
119 Err(e) => {
120 tracing::warn!("Failed to connect to WAL endpoint {}: {}", endpoint, e);
122 Ok(streamer)
123 }
124 }
125 }
126
127 async fn subscribe(&mut self, start_lsn: Option<u64>) -> CacheResult<WalSubscription> {
129 self.running.store(true, Ordering::SeqCst);
130
131 if let Some(ref mut stream) = self.connection {
132 let lsn = start_lsn.unwrap_or(0);
134 let mut request = vec![WalMessageType::Subscribe as u8];
135 request.extend_from_slice(&(8u32).to_be_bytes()); request.extend_from_slice(&lsn.to_be_bytes()); if let Err(e) = stream.write_all(&request) {
140 tracing::error!("Failed to send subscription request: {}", e);
141 return Err(CacheError::ConnectionError(format!("Subscription failed: {}", e)));
142 }
143
144 let mut header = [0u8; 5];
146 match stream.read_exact(&mut header) {
147 Ok(_) => {
148 if header[0] == WalMessageType::SubscribeAck as u8 {
149 tracing::info!("WAL subscription acknowledged");
150 } else if header[0] == WalMessageType::Error as u8 {
151 return Err(CacheError::ConnectionError("Subscription rejected by server".to_string()));
152 }
153 }
154 Err(e) => {
155 tracing::warn!("No subscription ack received: {}", e);
156 }
157 }
158 }
159
160 Ok(WalSubscription {
161 running: self.running.clone(),
162 connection: self.connection.take(),
163 current_lsn: 0,
164 buffer: Vec::with_capacity(64 * 1024),
165 })
166 }
167
168 async fn reconnect(&mut self) -> CacheResult<bool> {
170 if self.reconnect_attempts >= self.max_reconnect_attempts {
171 return Ok(false);
172 }
173
174 self.reconnect_attempts += 1;
175 let delay = self.reconnect_delay * self.reconnect_attempts;
176 tokio::time::sleep(delay).await;
177
178 tracing::info!("Attempting WAL reconnection (attempt {})", self.reconnect_attempts);
179
180 match TcpStream::connect_timeout(
181 &self.endpoint.parse().map_err(|_| CacheError::ConnectionError("Invalid endpoint".to_string()))?,
182 Duration::from_secs(5),
183 ) {
184 Ok(stream) => {
185 stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
186 stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
187 stream.set_nodelay(true).ok();
188
189 self.connection = Some(stream);
190 self.reconnect_attempts = 0;
191 self.last_heartbeat = Instant::now();
192 tracing::info!("WAL reconnection successful");
193 Ok(true)
194 }
195 Err(e) => {
196 tracing::warn!("WAL reconnection failed: {}", e);
197 Ok(false)
198 }
199 }
200 }
201
202 fn disconnect(&mut self) {
203 self.running.store(false, Ordering::SeqCst);
204 if let Some(stream) = self.connection.take() {
205 drop(stream);
206 }
207 }
208
209 fn is_connected(&self) -> bool {
211 self.connection.is_some()
212 }
213}
214
215pub struct WalSubscription {
217 running: Arc<AtomicBool>,
218 connection: Option<TcpStream>,
219 current_lsn: u64,
220 buffer: Vec<u8>,
221}
222
223impl WalSubscription {
224 pub async fn next(&mut self) -> Option<WalEntry> {
226 loop {
227 if !self.running.load(Ordering::SeqCst) {
228 return None;
229 }
230
231 let stream = match self.connection.as_mut() {
232 Some(s) => s,
233 None => {
234 tokio::time::sleep(Duration::from_millis(100)).await;
236 return None;
237 }
238 };
239
240 let mut header = [0u8; 5];
242 match stream.read_exact(&mut header) {
243 Ok(_) => {}
244 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
245 tokio::time::sleep(Duration::from_millis(10)).await;
247 continue;
248 }
249 Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
250 return None;
252 }
253 Err(_) => {
254 self.running.store(false, Ordering::SeqCst);
256 return None;
257 }
258 }
259
260 let msg_type = header[0];
261 let payload_len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
262
263 if msg_type == WalMessageType::Heartbeat as u8 {
265 if payload_len > 0 {
267 let mut _payload = vec![0u8; payload_len];
268 let _ = stream.read_exact(&mut _payload);
269 }
270 continue; }
272
273 if msg_type != WalMessageType::Entry as u8 {
275 if payload_len > 0 && payload_len < 1024 * 1024 {
277 let mut skip = vec![0u8; payload_len];
278 let _ = stream.read_exact(&mut skip);
279 }
280 continue;
281 }
282
283 if payload_len == 0 || payload_len > 10 * 1024 * 1024 {
285 return None;
287 }
288
289 self.buffer.resize(payload_len, 0);
290 if stream.read_exact(&mut self.buffer).is_err() {
291 self.running.store(false, Ordering::SeqCst);
292 return None;
293 }
294
295 if self.buffer.len() < 17 {
298 continue; }
300
301 let lsn = u64::from_be_bytes([
302 self.buffer[0], self.buffer[1], self.buffer[2], self.buffer[3],
303 self.buffer[4], self.buffer[5], self.buffer[6], self.buffer[7],
304 ]);
305 let timestamp = u64::from_be_bytes([
306 self.buffer[8], self.buffer[9], self.buffer[10], self.buffer[11],
307 self.buffer[12], self.buffer[13], self.buffer[14], self.buffer[15],
308 ]);
309 let op_type = self.buffer[16];
310 let data = &self.buffer[17..];
311
312 let operation = match op_type {
313 0x01 => {
314 if data.len() < 4 {
316 continue;
317 }
318 let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
319 if data.len() < 4 + key_len {
320 continue;
321 }
322 let key = data[4..4 + key_len].to_vec();
323 let value = data[4 + key_len..].to_vec();
324 WalOperation::Put { key, value }
325 }
326 0x02 => {
327 if data.len() < 4 {
329 continue;
330 }
331 let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
332 if data.len() < 4 + key_len {
333 continue;
334 }
335 let key = data[4..4 + key_len].to_vec();
336 WalOperation::Delete { key }
337 }
338 0x03 => {
339 if data.len() < 10 {
341 continue;
342 }
343 let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
344 if data.len() < 2 + name_len + 8 {
345 continue;
346 }
347 let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
348 let counter_offset = 2 + name_len;
349 let counter = u64::from_be_bytes([
350 data[counter_offset], data[counter_offset + 1],
351 data[counter_offset + 2], data[counter_offset + 3],
352 data[counter_offset + 4], data[counter_offset + 5],
353 data[counter_offset + 6], data[counter_offset + 7],
354 ]);
355 WalOperation::UpdateCounter { table_name, counter }
356 }
357 0x04 => {
358 if data.len() < 2 {
360 continue;
361 }
362 let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
363 if data.len() < 2 + name_len {
364 continue;
365 }
366 let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
367 WalOperation::SchemaChange { table_name }
368 }
369 0x05 => {
370 if data.len() < 8 {
372 continue;
373 }
374 let txn_id = u64::from_be_bytes([
375 data[0], data[1], data[2], data[3],
376 data[4], data[5], data[6], data[7],
377 ]);
378 WalOperation::Commit { txn_id }
379 }
380 _ => {
381 continue;
383 }
384 };
385
386 self.current_lsn = lsn;
387 return Some(WalEntry {
388 lsn,
389 timestamp,
390 operation,
391 });
392 }
393 }
394
395 pub fn current_lsn(&self) -> u64 {
397 self.current_lsn
398 }
399
400 pub fn is_active(&self) -> bool {
402 self.running.load(Ordering::SeqCst) && self.connection.is_some()
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct InvalidationTarget {
409 pub table: String,
411 pub row_key: Option<Vec<u8>>,
413 pub invalidate_all: bool,
415}
416
417pub type InvalidationCallback = Arc<dyn Fn(&InvalidationTarget) + Send + Sync>;
419
420pub struct WalInvalidator {
422 config: DistribCacheConfig,
424
425 wal_stream: Option<WalStreamer>,
427
428 subscription: tokio::sync::RwLock<Option<WalSubscription>>,
430
431 table_index: DashMap<String, HashSet<QueryFingerprint>>,
433
434 callbacks: Arc<tokio::sync::RwLock<Vec<InvalidationCallback>>>,
436
437 running: Arc<AtomicBool>,
439
440 last_lsn: AtomicU64,
442
443 stats: InvalidatorStats,
445}
446
447#[derive(Debug, Default)]
449struct InvalidatorStats {
450 wal_entries_processed: AtomicU64,
451 tables_invalidated: AtomicU64,
452 entries_invalidated: AtomicU64,
453 invalidation_lag_ms: AtomicU64,
454}
455
456impl WalInvalidator {
457 pub fn new(config: DistribCacheConfig) -> Self {
459 Self {
460 config,
461 wal_stream: None,
462 subscription: tokio::sync::RwLock::new(None),
463 table_index: DashMap::new(),
464 callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
465 running: Arc::new(AtomicBool::new(false)),
466 last_lsn: AtomicU64::new(0),
467 stats: InvalidatorStats::default(),
468 }
469 }
470
471 pub async fn start(&self, wal_endpoint: &str) -> CacheResult<()> {
473 if self.running.load(Ordering::SeqCst) {
474 return Ok(()); }
476
477 self.running.store(true, Ordering::SeqCst);
478
479 let mut streamer = WalStreamer::connect(wal_endpoint).await?;
481
482 let start_lsn = self.last_lsn.load(Ordering::Relaxed);
484 let start_lsn = if start_lsn > 0 { Some(start_lsn) } else { None };
485
486 match streamer.subscribe(start_lsn).await {
487 Ok(sub) => {
488 *self.subscription.write().await = Some(sub);
489 tracing::info!("WAL invalidator started, connected to {}", wal_endpoint);
490 }
491 Err(e) => {
492 tracing::warn!("Failed to subscribe to WAL stream: {}. Running in degraded mode.", e);
493 }
495 }
496
497 Ok(())
498 }
499
500 pub fn start_processing(&self) -> tokio::task::JoinHandle<()> {
502 let running = self.running.clone();
503 let _subscription = self.subscription.write();
504 let _callbacks = self.callbacks.clone();
505 let _stats = InvalidatorStats {
506 wal_entries_processed: AtomicU64::new(0),
507 tables_invalidated: AtomicU64::new(0),
508 entries_invalidated: AtomicU64::new(0),
509 invalidation_lag_ms: AtomicU64::new(0),
510 };
511 let _table_index = self.table_index.clone();
512 let _last_lsn = AtomicU64::new(self.last_lsn.load(Ordering::Relaxed));
513
514 tokio::spawn(async move {
515 tracing::info!("WAL processing loop started");
516
517 while running.load(Ordering::SeqCst) {
520 tokio::time::sleep(Duration::from_millis(100)).await;
522 }
523
524 tracing::info!("WAL processing loop stopped");
525 })
526 }
527
528 pub async fn process_loop(&self) {
530 while self.running.load(Ordering::SeqCst) {
531 let entry = {
532 let mut sub_guard = self.subscription.write().await;
533 if let Some(ref mut sub) = *sub_guard {
534 sub.next().await
535 } else {
536 None
537 }
538 };
539
540 match entry {
541 Some(wal_entry) => {
542 let start = Instant::now();
543 self.process_wal_entry(wal_entry.clone()).await;
544 self.last_lsn.store(wal_entry.lsn, Ordering::Relaxed);
545
546 let lag = start.elapsed().as_millis() as u64;
548 self.stats.invalidation_lag_ms.store(lag, Ordering::Relaxed);
549 }
550 None => {
551 tokio::time::sleep(Duration::from_millis(10)).await;
553 }
554 }
555 }
556 }
557
558 pub async fn stop(&self) {
560 self.running.store(false, Ordering::SeqCst);
561
562 let mut sub = self.subscription.write().await;
564 *sub = None;
565
566 if let Some(_stream) = self.wal_stream.as_ref().map(|_| ()) {
567 }
569
570 tracing::info!("WAL invalidator stopped");
571 }
572
573 pub fn is_running(&self) -> bool {
575 self.running.load(Ordering::SeqCst)
576 }
577
578 pub fn register(&self, table: &str, fingerprint: QueryFingerprint) {
580 self.table_index
581 .entry(table.to_string())
582 .or_default()
583 .insert(fingerprint);
584 }
585
586 pub fn unregister(&self, table: &str, fingerprint: &QueryFingerprint) {
588 if let Some(mut set) = self.table_index.get_mut(table) {
589 set.remove(fingerprint);
590 }
591 }
592
593 pub async fn add_callback(&self, callback: InvalidationCallback) {
595 self.callbacks.write().await.push(callback);
596 }
597
598 pub fn add_callback_sync(&self, callback: InvalidationCallback) {
600 if let Ok(handle) = tokio::runtime::Handle::try_current() {
602 handle.block_on(async {
603 self.callbacks.write().await.push(callback);
604 });
605 }
606 }
607
608 async fn process_wal_entry(&self, entry: WalEntry) {
610 self.stats.wal_entries_processed.fetch_add(1, Ordering::Relaxed);
611
612 let (table, row_key) = match &entry.operation {
613 WalOperation::Put { key, .. } => (self.extract_table(key), Some(key.clone())),
614 WalOperation::Delete { key } => (self.extract_table(key), Some(key.clone())),
615 WalOperation::UpdateCounter { table_name, .. } => (Some(table_name.clone()), None),
616 WalOperation::SchemaChange { table_name } => {
617 self.invalidate_table(table_name, true).await;
619 return;
620 }
621 WalOperation::Commit { .. } => return,
622 };
623
624 if let Some(table) = table {
625 if let Some(key) = row_key {
627 self.invalidate_row(&table, &key).await;
628 } else {
629 self.invalidate_table(&table, false).await;
630 }
631 }
632 }
633
634 async fn invalidate_table(&self, table: &str, all_entries: bool) {
636 self.stats.tables_invalidated.fetch_add(1, Ordering::Relaxed);
637
638 let target = InvalidationTarget {
639 table: table.to_string(),
640 row_key: None,
641 invalidate_all: all_entries,
642 };
643
644 let callbacks = self.callbacks.read().await;
646 for callback in callbacks.iter() {
647 callback(&target);
648 }
649
650 if let Some(entries) = self.table_index.get(table) {
652 self.stats.entries_invalidated.fetch_add(
653 entries.len() as u64,
654 Ordering::Relaxed,
655 );
656 }
657 }
658
659 async fn invalidate_row(&self, table: &str, row_key: &[u8]) {
661 let target = InvalidationTarget {
662 table: table.to_string(),
663 row_key: Some(row_key.to_vec()),
664 invalidate_all: false,
665 };
666
667 let callbacks = self.callbacks.read().await;
668 for callback in callbacks.iter() {
669 callback(&target);
670 }
671 }
672
673 pub async fn invalidate_table_manual(&self, table: &str, all_entries: bool) {
675 self.invalidate_table(table, all_entries).await;
676 }
677
678 fn extract_table(&self, key: &[u8]) -> Option<String> {
680 let key_str = String::from_utf8_lossy(key);
682 key_str.split(':').next().map(|s| s.to_string())
683 }
684
685 pub fn stats(&self) -> InvalidatorStatsSnapshot {
687 InvalidatorStatsSnapshot {
688 wal_entries_processed: self.stats.wal_entries_processed.load(Ordering::Relaxed),
689 tables_invalidated: self.stats.tables_invalidated.load(Ordering::Relaxed),
690 entries_invalidated: self.stats.entries_invalidated.load(Ordering::Relaxed),
691 invalidation_lag_ms: self.stats.invalidation_lag_ms.load(Ordering::Relaxed),
692 registered_tables: self.table_index.len(),
693 is_running: self.running.load(Ordering::Relaxed),
694 }
695 }
696}
697
698#[derive(Debug, Clone)]
700pub struct InvalidatorStatsSnapshot {
701 pub wal_entries_processed: u64,
702 pub tables_invalidated: u64,
703 pub entries_invalidated: u64,
704 pub invalidation_lag_ms: u64,
705 pub registered_tables: usize,
706 pub is_running: bool,
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712
713 #[test]
714 fn test_register_fingerprint() {
715 let config = DistribCacheConfig::default();
716 let invalidator = WalInvalidator::new(config);
717
718 let fp = QueryFingerprint::from_query("SELECT * FROM users");
719 invalidator.register("users", fp.clone());
720
721 assert!(invalidator.table_index.contains_key("users"));
722 let entries = invalidator.table_index.get("users").unwrap();
723 assert!(entries.contains(&fp));
724 }
725
726 #[test]
727 fn test_unregister_fingerprint() {
728 let config = DistribCacheConfig::default();
729 let invalidator = WalInvalidator::new(config);
730
731 let fp = QueryFingerprint::from_query("SELECT * FROM users");
732 invalidator.register("users", fp.clone());
733 invalidator.unregister("users", &fp);
734
735 let entries = invalidator.table_index.get("users").unwrap();
736 assert!(!entries.contains(&fp));
737 }
738
739 #[tokio::test]
740 async fn test_callback_invocation() {
741 let config = DistribCacheConfig::default();
742 let invalidator = WalInvalidator::new(config);
743
744 let called = Arc::new(AtomicBool::new(false));
745 let called_clone = called.clone();
746
747 invalidator.add_callback(Arc::new(move |target| {
748 if target.table == "users" {
749 called_clone.store(true, Ordering::SeqCst);
750 }
751 })).await;
752
753 invalidator.invalidate_table_manual("users", false).await;
754
755 assert!(called.load(Ordering::SeqCst));
756 }
757
758 #[test]
759 fn test_extract_table() {
760 let config = DistribCacheConfig::default();
761 let invalidator = WalInvalidator::new(config);
762
763 let key = b"users:123";
764 let table = invalidator.extract_table(key);
765 assert_eq!(table, Some("users".to_string()));
766 }
767
768 #[tokio::test]
769 async fn test_process_wal_entry() {
770 let config = DistribCacheConfig::default();
771 let invalidator = WalInvalidator::new(config);
772
773 let fp = QueryFingerprint::from_query("SELECT * FROM users");
774 invalidator.register("users", fp);
775
776 let entry = WalEntry {
777 lsn: 1,
778 timestamp: 0,
779 operation: WalOperation::Put {
780 key: b"users:123".to_vec(),
781 value: b"data".to_vec(),
782 },
783 };
784
785 invalidator.process_wal_entry(entry).await;
786
787 let stats = invalidator.stats();
788 assert_eq!(stats.wal_entries_processed, 1);
789 }
790
791 #[tokio::test]
792 async fn test_start_stop() {
793 let config = DistribCacheConfig::default();
794 let invalidator = WalInvalidator::new(config);
795
796 invalidator.start("127.0.0.1:59999").await.unwrap();
798 assert!(invalidator.is_running());
799
800 invalidator.stop().await;
802 assert!(!invalidator.is_running());
803 }
804
805 #[test]
806 fn test_wal_entry_parsing() {
807 let put = WalOperation::Put {
809 key: b"users:1".to_vec(),
810 value: b"data".to_vec(),
811 };
812 assert!(matches!(put, WalOperation::Put { .. }));
813
814 let delete = WalOperation::Delete {
815 key: b"users:1".to_vec(),
816 };
817 assert!(matches!(delete, WalOperation::Delete { .. }));
818
819 let counter = WalOperation::UpdateCounter {
820 table_name: "users".to_string(),
821 counter: 100,
822 };
823 assert!(matches!(counter, WalOperation::UpdateCounter { .. }));
824
825 let schema = WalOperation::SchemaChange {
826 table_name: "users".to_string(),
827 };
828 assert!(matches!(schema, WalOperation::SchemaChange { .. }));
829
830 let commit = WalOperation::Commit { txn_id: 12345 };
831 assert!(matches!(commit, WalOperation::Commit { .. }));
832 }
833
834 #[tokio::test]
835 async fn test_invalidation_stats() {
836 let config = DistribCacheConfig::default();
837 let invalidator = WalInvalidator::new(config);
838
839 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
841 let fp2 = QueryFingerprint::from_query("SELECT name FROM users WHERE email = ?");
842 invalidator.register("users", fp1);
843 invalidator.register("users", fp2);
844
845 invalidator.invalidate_table_manual("users", false).await;
847
848 let stats = invalidator.stats();
849 assert_eq!(stats.tables_invalidated, 1);
850 assert_eq!(stats.entries_invalidated, 2);
851 }
852}