1use dashmap::DashMap;
20use std::collections::HashSet;
21use std::io::{Read, Write};
22use std::net::TcpStream;
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26
27use super::{CacheError, CacheResult, DistribCacheConfig, QueryFingerprint};
28
29#[repr(u8)]
31#[derive(Debug, Clone, Copy, PartialEq)]
32enum WalMessageType {
33 Entry = 0x01,
34 Heartbeat = 0x02,
35 Subscribe = 0x03,
36 SubscribeAck = 0x04,
37 Error = 0xFF,
38}
39
40#[derive(Debug, Clone)]
42pub enum WalOperation {
43 Put { key: Vec<u8>, value: Vec<u8> },
45 Delete { key: Vec<u8> },
47 UpdateCounter { table_name: String, counter: u64 },
49 SchemaChange { table_name: String },
51 Commit { txn_id: u64 },
53}
54
55#[derive(Debug, Clone)]
57pub struct WalEntry {
58 pub lsn: u64,
60 pub timestamp: u64,
62 pub operation: WalOperation,
64}
65
66pub struct WalStreamer {
68 #[allow(dead_code)]
70 endpoint: String,
71 connection: Option<TcpStream>,
73 running: Arc<AtomicBool>,
75 #[allow(dead_code)]
77 current_lsn: AtomicU64,
78 last_heartbeat: Instant,
80 #[allow(dead_code)]
82 reconnect_attempts: u32,
83 #[allow(dead_code)]
85 max_reconnect_attempts: u32,
86 #[allow(dead_code)]
88 reconnect_delay: Duration,
89}
90
91impl WalStreamer {
92 fn new(endpoint: &str) -> Self {
93 Self {
94 endpoint: endpoint.to_string(),
95 connection: None,
96 running: Arc::new(AtomicBool::new(false)),
97 current_lsn: AtomicU64::new(0),
98 last_heartbeat: Instant::now(),
99 reconnect_attempts: 0,
100 max_reconnect_attempts: 10,
101 reconnect_delay: Duration::from_secs(1),
102 }
103 }
104
105 async fn connect(endpoint: &str) -> CacheResult<Self> {
107 let mut streamer = Self::new(endpoint);
108
109 match TcpStream::connect_timeout(
111 &endpoint
112 .parse()
113 .map_err(|_| CacheError::ConnectionError("Invalid endpoint address".to_string()))?,
114 Duration::from_secs(5),
115 ) {
116 Ok(stream) => {
117 stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
119 stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
120 stream.set_nodelay(true).ok();
121
122 streamer.connection = Some(stream);
123 streamer.last_heartbeat = Instant::now();
124 Ok(streamer)
125 }
126 Err(e) => {
127 tracing::warn!("Failed to connect to WAL endpoint {}: {}", endpoint, e);
129 Ok(streamer)
130 }
131 }
132 }
133
134 async fn subscribe(&mut self, start_lsn: Option<u64>) -> CacheResult<WalSubscription> {
136 self.running.store(true, Ordering::SeqCst);
137
138 if let Some(ref mut stream) = self.connection {
139 let lsn = start_lsn.unwrap_or(0);
141 let mut request = vec![WalMessageType::Subscribe as u8];
142 request.extend_from_slice(&(8u32).to_be_bytes()); request.extend_from_slice(&lsn.to_be_bytes()); if let Err(e) = stream.write_all(&request) {
147 tracing::error!("Failed to send subscription request: {}", e);
148 return Err(CacheError::ConnectionError(format!(
149 "Subscription failed: {}",
150 e
151 )));
152 }
153
154 let mut header = [0u8; 5];
156 match stream.read_exact(&mut header) {
157 Ok(_) => {
158 if header[0] == WalMessageType::SubscribeAck as u8 {
159 tracing::info!("WAL subscription acknowledged");
160 } else if header[0] == WalMessageType::Error as u8 {
161 return Err(CacheError::ConnectionError(
162 "Subscription rejected by server".to_string(),
163 ));
164 }
165 }
166 Err(e) => {
167 tracing::warn!("No subscription ack received: {}", e);
168 }
169 }
170 }
171
172 Ok(WalSubscription {
173 running: self.running.clone(),
174 connection: self.connection.take(),
175 current_lsn: 0,
176 buffer: Vec::with_capacity(64 * 1024),
177 })
178 }
179
180 #[allow(dead_code)]
182 async fn reconnect(&mut self) -> CacheResult<bool> {
183 if self.reconnect_attempts >= self.max_reconnect_attempts {
184 return Ok(false);
185 }
186
187 self.reconnect_attempts += 1;
188 let delay = self.reconnect_delay * self.reconnect_attempts;
189 tokio::time::sleep(delay).await;
190
191 tracing::info!(
192 "Attempting WAL reconnection (attempt {})",
193 self.reconnect_attempts
194 );
195
196 match TcpStream::connect_timeout(
197 &self
198 .endpoint
199 .parse()
200 .map_err(|_| CacheError::ConnectionError("Invalid endpoint".to_string()))?,
201 Duration::from_secs(5),
202 ) {
203 Ok(stream) => {
204 stream.set_read_timeout(Some(Duration::from_secs(30))).ok();
205 stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
206 stream.set_nodelay(true).ok();
207
208 self.connection = Some(stream);
209 self.reconnect_attempts = 0;
210 self.last_heartbeat = Instant::now();
211 tracing::info!("WAL reconnection successful");
212 Ok(true)
213 }
214 Err(e) => {
215 tracing::warn!("WAL reconnection failed: {}", e);
216 Ok(false)
217 }
218 }
219 }
220
221 #[allow(dead_code)]
222 fn disconnect(&mut self) {
223 self.running.store(false, Ordering::SeqCst);
224 if let Some(stream) = self.connection.take() {
225 drop(stream);
226 }
227 }
228
229 #[allow(dead_code)]
231 fn is_connected(&self) -> bool {
232 self.connection.is_some()
233 }
234}
235
236pub struct WalSubscription {
238 running: Arc<AtomicBool>,
239 connection: Option<TcpStream>,
240 current_lsn: u64,
241 buffer: Vec<u8>,
242}
243
244impl WalSubscription {
245 pub async fn next(&mut self) -> Option<WalEntry> {
247 loop {
248 if !self.running.load(Ordering::SeqCst) {
249 return None;
250 }
251
252 let stream = match self.connection.as_mut() {
253 Some(s) => s,
254 None => {
255 tokio::time::sleep(Duration::from_millis(100)).await;
257 return None;
258 }
259 };
260
261 let mut header = [0u8; 5];
263 match stream.read_exact(&mut header) {
264 Ok(_) => {}
265 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
266 tokio::time::sleep(Duration::from_millis(10)).await;
268 continue;
269 }
270 Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
271 return None;
273 }
274 Err(_) => {
275 self.running.store(false, Ordering::SeqCst);
277 return None;
278 }
279 }
280
281 let msg_type = header[0];
282 let payload_len =
283 u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
284
285 if msg_type == WalMessageType::Heartbeat as u8 {
287 if payload_len > 0 {
289 let mut _payload = vec![0u8; payload_len];
290 let _ = stream.read_exact(&mut _payload);
291 }
292 continue; }
294
295 if msg_type != WalMessageType::Entry as u8 {
297 if payload_len > 0 && payload_len < 1024 * 1024 {
299 let mut skip = vec![0u8; payload_len];
300 let _ = stream.read_exact(&mut skip);
301 }
302 continue;
303 }
304
305 if payload_len == 0 || payload_len > 10 * 1024 * 1024 {
307 return None;
309 }
310
311 self.buffer.resize(payload_len, 0);
312 if stream.read_exact(&mut self.buffer).is_err() {
313 self.running.store(false, Ordering::SeqCst);
314 return None;
315 }
316
317 if self.buffer.len() < 17 {
320 continue; }
322
323 let lsn = u64::from_be_bytes([
324 self.buffer[0],
325 self.buffer[1],
326 self.buffer[2],
327 self.buffer[3],
328 self.buffer[4],
329 self.buffer[5],
330 self.buffer[6],
331 self.buffer[7],
332 ]);
333 let timestamp = u64::from_be_bytes([
334 self.buffer[8],
335 self.buffer[9],
336 self.buffer[10],
337 self.buffer[11],
338 self.buffer[12],
339 self.buffer[13],
340 self.buffer[14],
341 self.buffer[15],
342 ]);
343 let op_type = self.buffer[16];
344 let data = &self.buffer[17..];
345
346 let operation = match op_type {
347 0x01 => {
348 if data.len() < 4 {
350 continue;
351 }
352 let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
353 if data.len() < 4 + key_len {
354 continue;
355 }
356 let key = data[4..4 + key_len].to_vec();
357 let value = data[4 + key_len..].to_vec();
358 WalOperation::Put { key, value }
359 }
360 0x02 => {
361 if data.len() < 4 {
363 continue;
364 }
365 let key_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
366 if data.len() < 4 + key_len {
367 continue;
368 }
369 let key = data[4..4 + key_len].to_vec();
370 WalOperation::Delete { key }
371 }
372 0x03 => {
373 if data.len() < 10 {
375 continue;
376 }
377 let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
378 if data.len() < 2 + name_len + 8 {
379 continue;
380 }
381 let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
382 let counter_offset = 2 + name_len;
383 let counter = u64::from_be_bytes([
384 data[counter_offset],
385 data[counter_offset + 1],
386 data[counter_offset + 2],
387 data[counter_offset + 3],
388 data[counter_offset + 4],
389 data[counter_offset + 5],
390 data[counter_offset + 6],
391 data[counter_offset + 7],
392 ]);
393 WalOperation::UpdateCounter {
394 table_name,
395 counter,
396 }
397 }
398 0x04 => {
399 if data.len() < 2 {
401 continue;
402 }
403 let name_len = u16::from_be_bytes([data[0], data[1]]) as usize;
404 if data.len() < 2 + name_len {
405 continue;
406 }
407 let table_name = String::from_utf8_lossy(&data[2..2 + name_len]).to_string();
408 WalOperation::SchemaChange { table_name }
409 }
410 0x05 => {
411 if data.len() < 8 {
413 continue;
414 }
415 let txn_id = u64::from_be_bytes([
416 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
417 ]);
418 WalOperation::Commit { txn_id }
419 }
420 _ => {
421 continue;
423 }
424 };
425
426 self.current_lsn = lsn;
427 return Some(WalEntry {
428 lsn,
429 timestamp,
430 operation,
431 });
432 }
433 }
434
435 pub fn current_lsn(&self) -> u64 {
437 self.current_lsn
438 }
439
440 pub fn is_active(&self) -> bool {
442 self.running.load(Ordering::SeqCst) && self.connection.is_some()
443 }
444}
445
446#[derive(Debug, Clone)]
448pub struct InvalidationTarget {
449 pub table: String,
451 pub row_key: Option<Vec<u8>>,
453 pub invalidate_all: bool,
455}
456
457pub type InvalidationCallback = Arc<dyn Fn(&InvalidationTarget) + Send + Sync>;
459
460pub struct WalInvalidator {
462 #[allow(dead_code)]
464 config: DistribCacheConfig,
465
466 wal_stream: Option<WalStreamer>,
468
469 subscription: tokio::sync::RwLock<Option<WalSubscription>>,
471
472 table_index: DashMap<String, HashSet<QueryFingerprint>>,
474
475 callbacks: Arc<tokio::sync::RwLock<Vec<InvalidationCallback>>>,
477
478 running: Arc<AtomicBool>,
480
481 last_lsn: AtomicU64,
483
484 stats: InvalidatorStats,
486}
487
488#[derive(Debug, Default)]
490struct InvalidatorStats {
491 wal_entries_processed: AtomicU64,
492 tables_invalidated: AtomicU64,
493 entries_invalidated: AtomicU64,
494 invalidation_lag_ms: AtomicU64,
495}
496
497impl WalInvalidator {
498 pub fn new(config: DistribCacheConfig) -> Self {
500 Self {
501 config,
502 wal_stream: None,
503 subscription: tokio::sync::RwLock::new(None),
504 table_index: DashMap::new(),
505 callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
506 running: Arc::new(AtomicBool::new(false)),
507 last_lsn: AtomicU64::new(0),
508 stats: InvalidatorStats::default(),
509 }
510 }
511
512 pub async fn start(&self, wal_endpoint: &str) -> CacheResult<()> {
514 if self.running.load(Ordering::SeqCst) {
515 return Ok(()); }
517
518 self.running.store(true, Ordering::SeqCst);
519
520 let mut streamer = WalStreamer::connect(wal_endpoint).await?;
522
523 let start_lsn = self.last_lsn.load(Ordering::Relaxed);
525 let start_lsn = if start_lsn > 0 { Some(start_lsn) } else { None };
526
527 match streamer.subscribe(start_lsn).await {
528 Ok(sub) => {
529 *self.subscription.write().await = Some(sub);
530 tracing::info!("WAL invalidator started, connected to {}", wal_endpoint);
531 }
532 Err(e) => {
533 tracing::warn!(
534 "Failed to subscribe to WAL stream: {}. Running in degraded mode.",
535 e
536 );
537 }
539 }
540
541 Ok(())
542 }
543
544 pub fn start_processing(&self) -> tokio::task::JoinHandle<()> {
546 let running = self.running.clone();
547 let _subscription = self.subscription.write();
548 let _callbacks = self.callbacks.clone();
549 let _stats = InvalidatorStats {
550 wal_entries_processed: AtomicU64::new(0),
551 tables_invalidated: AtomicU64::new(0),
552 entries_invalidated: AtomicU64::new(0),
553 invalidation_lag_ms: AtomicU64::new(0),
554 };
555 let _table_index = self.table_index.clone();
556 let _last_lsn = AtomicU64::new(self.last_lsn.load(Ordering::Relaxed));
557
558 tokio::spawn(async move {
559 tracing::info!("WAL processing loop started");
560
561 while running.load(Ordering::SeqCst) {
564 tokio::time::sleep(Duration::from_millis(100)).await;
566 }
567
568 tracing::info!("WAL processing loop stopped");
569 })
570 }
571
572 pub async fn process_loop(&self) {
574 while self.running.load(Ordering::SeqCst) {
575 let entry = {
576 let mut sub_guard = self.subscription.write().await;
577 if let Some(ref mut sub) = *sub_guard {
578 sub.next().await
579 } else {
580 None
581 }
582 };
583
584 match entry {
585 Some(wal_entry) => {
586 let start = Instant::now();
587 self.process_wal_entry(wal_entry.clone()).await;
588 self.last_lsn.store(wal_entry.lsn, Ordering::Relaxed);
589
590 let lag = start.elapsed().as_millis() as u64;
592 self.stats.invalidation_lag_ms.store(lag, Ordering::Relaxed);
593 }
594 None => {
595 tokio::time::sleep(Duration::from_millis(10)).await;
597 }
598 }
599 }
600 }
601
602 pub async fn stop(&self) {
604 self.running.store(false, Ordering::SeqCst);
605
606 let mut sub = self.subscription.write().await;
608 *sub = None;
609
610 if let Some(_stream) = self.wal_stream.as_ref().map(|_| ()) {
611 }
613
614 tracing::info!("WAL invalidator stopped");
615 }
616
617 pub fn is_running(&self) -> bool {
619 self.running.load(Ordering::SeqCst)
620 }
621
622 pub fn register(&self, table: &str, fingerprint: QueryFingerprint) {
624 self.table_index
625 .entry(table.to_string())
626 .or_default()
627 .insert(fingerprint);
628 }
629
630 pub fn unregister(&self, table: &str, fingerprint: &QueryFingerprint) {
632 if let Some(mut set) = self.table_index.get_mut(table) {
633 set.remove(fingerprint);
634 }
635 }
636
637 pub async fn add_callback(&self, callback: InvalidationCallback) {
639 self.callbacks.write().await.push(callback);
640 }
641
642 pub fn add_callback_sync(&self, callback: InvalidationCallback) {
644 if let Ok(handle) = tokio::runtime::Handle::try_current() {
646 handle.block_on(async {
647 self.callbacks.write().await.push(callback);
648 });
649 }
650 }
651
652 async fn process_wal_entry(&self, entry: WalEntry) {
654 self.stats
655 .wal_entries_processed
656 .fetch_add(1, Ordering::Relaxed);
657
658 let (table, row_key) = match &entry.operation {
659 WalOperation::Put { key, .. } => (self.extract_table(key), Some(key.clone())),
660 WalOperation::Delete { key } => (self.extract_table(key), Some(key.clone())),
661 WalOperation::UpdateCounter { table_name, .. } => (Some(table_name.clone()), None),
662 WalOperation::SchemaChange { table_name } => {
663 self.invalidate_table(table_name, true).await;
665 return;
666 }
667 WalOperation::Commit { .. } => return,
668 };
669
670 if let Some(table) = table {
671 if let Some(key) = row_key {
673 self.invalidate_row(&table, &key).await;
674 } else {
675 self.invalidate_table(&table, false).await;
676 }
677 }
678 }
679
680 async fn invalidate_table(&self, table: &str, all_entries: bool) {
682 self.stats
683 .tables_invalidated
684 .fetch_add(1, Ordering::Relaxed);
685
686 let target = InvalidationTarget {
687 table: table.to_string(),
688 row_key: None,
689 invalidate_all: all_entries,
690 };
691
692 let callbacks = self.callbacks.read().await;
694 for callback in callbacks.iter() {
695 callback(&target);
696 }
697
698 if let Some(entries) = self.table_index.get(table) {
700 self.stats
701 .entries_invalidated
702 .fetch_add(entries.len() as u64, Ordering::Relaxed);
703 }
704 }
705
706 async fn invalidate_row(&self, table: &str, row_key: &[u8]) {
708 let target = InvalidationTarget {
709 table: table.to_string(),
710 row_key: Some(row_key.to_vec()),
711 invalidate_all: false,
712 };
713
714 let callbacks = self.callbacks.read().await;
715 for callback in callbacks.iter() {
716 callback(&target);
717 }
718 }
719
720 pub async fn invalidate_table_manual(&self, table: &str, all_entries: bool) {
722 self.invalidate_table(table, all_entries).await;
723 }
724
725 fn extract_table(&self, key: &[u8]) -> Option<String> {
727 let key_str = String::from_utf8_lossy(key);
729 key_str.split(':').next().map(|s| s.to_string())
730 }
731
732 pub fn stats(&self) -> InvalidatorStatsSnapshot {
734 InvalidatorStatsSnapshot {
735 wal_entries_processed: self.stats.wal_entries_processed.load(Ordering::Relaxed),
736 tables_invalidated: self.stats.tables_invalidated.load(Ordering::Relaxed),
737 entries_invalidated: self.stats.entries_invalidated.load(Ordering::Relaxed),
738 invalidation_lag_ms: self.stats.invalidation_lag_ms.load(Ordering::Relaxed),
739 registered_tables: self.table_index.len(),
740 is_running: self.running.load(Ordering::Relaxed),
741 }
742 }
743}
744
745#[derive(Debug, Clone)]
747pub struct InvalidatorStatsSnapshot {
748 pub wal_entries_processed: u64,
749 pub tables_invalidated: u64,
750 pub entries_invalidated: u64,
751 pub invalidation_lag_ms: u64,
752 pub registered_tables: usize,
753 pub is_running: bool,
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 #[test]
761 fn test_register_fingerprint() {
762 let config = DistribCacheConfig::default();
763 let invalidator = WalInvalidator::new(config);
764
765 let fp = QueryFingerprint::from_query("SELECT * FROM users");
766 invalidator.register("users", fp.clone());
767
768 assert!(invalidator.table_index.contains_key("users"));
769 let entries = invalidator.table_index.get("users").unwrap();
770 assert!(entries.contains(&fp));
771 }
772
773 #[test]
774 fn test_unregister_fingerprint() {
775 let config = DistribCacheConfig::default();
776 let invalidator = WalInvalidator::new(config);
777
778 let fp = QueryFingerprint::from_query("SELECT * FROM users");
779 invalidator.register("users", fp.clone());
780 invalidator.unregister("users", &fp);
781
782 let entries = invalidator.table_index.get("users").unwrap();
783 assert!(!entries.contains(&fp));
784 }
785
786 #[tokio::test]
787 async fn test_callback_invocation() {
788 let config = DistribCacheConfig::default();
789 let invalidator = WalInvalidator::new(config);
790
791 let called = Arc::new(AtomicBool::new(false));
792 let called_clone = called.clone();
793
794 invalidator
795 .add_callback(Arc::new(move |target| {
796 if target.table == "users" {
797 called_clone.store(true, Ordering::SeqCst);
798 }
799 }))
800 .await;
801
802 invalidator.invalidate_table_manual("users", false).await;
803
804 assert!(called.load(Ordering::SeqCst));
805 }
806
807 #[test]
808 fn test_extract_table() {
809 let config = DistribCacheConfig::default();
810 let invalidator = WalInvalidator::new(config);
811
812 let key = b"users:123";
813 let table = invalidator.extract_table(key);
814 assert_eq!(table, Some("users".to_string()));
815 }
816
817 #[tokio::test]
818 async fn test_process_wal_entry() {
819 let config = DistribCacheConfig::default();
820 let invalidator = WalInvalidator::new(config);
821
822 let fp = QueryFingerprint::from_query("SELECT * FROM users");
823 invalidator.register("users", fp);
824
825 let entry = WalEntry {
826 lsn: 1,
827 timestamp: 0,
828 operation: WalOperation::Put {
829 key: b"users:123".to_vec(),
830 value: b"data".to_vec(),
831 },
832 };
833
834 invalidator.process_wal_entry(entry).await;
835
836 let stats = invalidator.stats();
837 assert_eq!(stats.wal_entries_processed, 1);
838 }
839
840 #[tokio::test]
841 async fn test_start_stop() {
842 let config = DistribCacheConfig::default();
843 let invalidator = WalInvalidator::new(config);
844
845 invalidator.start("127.0.0.1:59999").await.unwrap();
847 assert!(invalidator.is_running());
848
849 invalidator.stop().await;
851 assert!(!invalidator.is_running());
852 }
853
854 #[test]
855 fn test_wal_entry_parsing() {
856 let put = WalOperation::Put {
858 key: b"users:1".to_vec(),
859 value: b"data".to_vec(),
860 };
861 assert!(matches!(put, WalOperation::Put { .. }));
862
863 let delete = WalOperation::Delete {
864 key: b"users:1".to_vec(),
865 };
866 assert!(matches!(delete, WalOperation::Delete { .. }));
867
868 let counter = WalOperation::UpdateCounter {
869 table_name: "users".to_string(),
870 counter: 100,
871 };
872 assert!(matches!(counter, WalOperation::UpdateCounter { .. }));
873
874 let schema = WalOperation::SchemaChange {
875 table_name: "users".to_string(),
876 };
877 assert!(matches!(schema, WalOperation::SchemaChange { .. }));
878
879 let commit = WalOperation::Commit { txn_id: 12345 };
880 assert!(matches!(commit, WalOperation::Commit { .. }));
881 }
882
883 #[tokio::test]
884 async fn test_invalidation_stats() {
885 let config = DistribCacheConfig::default();
886 let invalidator = WalInvalidator::new(config);
887
888 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
890 let fp2 = QueryFingerprint::from_query("SELECT name FROM users WHERE email = ?");
891 invalidator.register("users", fp1);
892 invalidator.register("users", fp2);
893
894 invalidator.invalidate_table_manual("users", false).await;
896
897 let stats = invalidator.stats();
898 assert_eq!(stats.tables_invalidated, 1);
899 assert_eq!(stats.entries_invalidated, 2);
900 }
901}