1use crate::{
27 infrastructure::{
28 observability::metrics::MetricsRegistry,
29 persistence::wal::WALEntry,
30 replication::protocol::{FollowerMessage, LeaderMessage},
31 },
32 store::EventStore,
33};
34use dashmap::DashMap;
35use std::{
36 sync::Arc,
37 time::{Duration, Instant},
38};
39use tokio::{
40 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
41 net::{TcpListener, TcpStream},
42 sync::{Notify, broadcast},
43};
44use uuid::Uuid;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
48#[serde(rename_all = "snake_case")]
49pub enum ReplicationMode {
50 Async,
52 SemiSync,
55 Sync,
57}
58
59impl ReplicationMode {
60 pub fn from_str_value(s: &str) -> Self {
62 match s.to_lowercase().as_str() {
63 "semi-sync" | "semi_sync" | "semisync" => ReplicationMode::SemiSync,
64 "sync" => ReplicationMode::Sync,
65 _ => ReplicationMode::Async,
66 }
67 }
68}
69
70impl std::fmt::Display for ReplicationMode {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 ReplicationMode::Async => write!(f, "async"),
74 ReplicationMode::SemiSync => write!(f, "semi-sync"),
75 ReplicationMode::Sync => write!(f, "sync"),
76 }
77 }
78}
79
80const SNAPSHOT_CHUNK_SIZE: usize = 512 * 1024;
82
83struct FollowerState {
85 acked_offset: u64,
87 connected_at: Instant,
89}
90
91#[derive(Debug, Clone, serde::Serialize)]
93pub struct ReplicationStatus {
94 pub followers: usize,
95 pub min_lag_ms: u64,
96 pub max_lag_ms: u64,
97 pub replication_mode: ReplicationMode,
98}
99
100pub struct WalShipper {
102 entry_tx: broadcast::Sender<WALEntry>,
104 followers: Arc<DashMap<Uuid, FollowerState>>,
106 leader_offset: Arc<std::sync::atomic::AtomicU64>,
108 store: Option<Arc<EventStore>>,
110 metrics: Option<Arc<MetricsRegistry>>,
112 replication_mode: ReplicationMode,
114 ack_timeout: Duration,
116 ack_notify: Arc<Notify>,
118}
119
120impl WalShipper {
121 pub fn new() -> (Self, broadcast::Sender<WALEntry>) {
124 let (entry_tx, _) = broadcast::channel(4096);
125 let tx_clone = entry_tx.clone();
126 (
127 Self {
128 entry_tx,
129 followers: Arc::new(DashMap::new()),
130 leader_offset: Arc::new(std::sync::atomic::AtomicU64::new(0)),
131 store: None,
132 metrics: None,
133 replication_mode: ReplicationMode::Async,
134 ack_timeout: Duration::from_millis(5000),
135 ack_notify: Arc::new(Notify::new()),
136 },
137 tx_clone,
138 )
139 }
140
141 pub fn set_replication_mode(&mut self, mode: ReplicationMode, ack_timeout: Duration) {
143 self.replication_mode = mode;
144 self.ack_timeout = ack_timeout;
145 }
146
147 pub fn replication_mode(&self) -> ReplicationMode {
149 self.replication_mode
150 }
151
152 pub fn current_leader_offset(&self) -> u64 {
154 self.leader_offset
155 .load(std::sync::atomic::Ordering::Relaxed)
156 }
157
158 pub async fn wait_for_ack(&self, target_offset: u64) -> bool {
167 match self.replication_mode {
168 ReplicationMode::Async => true,
169 ReplicationMode::SemiSync => self.wait_for_ack_inner(target_offset, false).await,
170 ReplicationMode::Sync => self.wait_for_ack_inner(target_offset, true).await,
171 }
172 }
173
174 async fn wait_for_ack_inner(&self, target_offset: u64, all_followers: bool) -> bool {
177 let start = Instant::now();
178 let timeout = self.ack_timeout;
179
180 loop {
181 let follower_count = self.followers.len();
183 if follower_count == 0 {
184 return false;
186 }
187
188 if all_followers {
189 let all_acked = self
191 .followers
192 .iter()
193 .all(|entry| entry.value().acked_offset >= target_offset);
194 if all_acked {
195 return true;
196 }
197 } else {
198 let any_acked = self
200 .followers
201 .iter()
202 .any(|entry| entry.value().acked_offset >= target_offset);
203 if any_acked {
204 return true;
205 }
206 }
207
208 let elapsed = start.elapsed();
210 if elapsed >= timeout {
211 return false;
212 }
213
214 let remaining = timeout - elapsed;
216 if tokio::time::timeout(remaining, self.ack_notify.notified())
217 .await
218 .is_err()
219 {
220 return false;
221 }
222 }
223 }
224
225 pub fn set_metrics(&mut self, metrics: Arc<MetricsRegistry>) {
227 self.metrics = Some(metrics);
228 }
229
230 pub fn set_store(&mut self, store: Arc<EventStore>) {
236 self.store = Some(store);
237 }
238
239 pub fn status(&self) -> ReplicationStatus {
241 let leader_offset = self
242 .leader_offset
243 .load(std::sync::atomic::Ordering::Relaxed);
244 let mut min_lag_ms = u64::MAX;
245 let mut max_lag_ms = 0u64;
246
247 for entry in self.followers.iter() {
248 let follower = entry.value();
249 let lag = leader_offset.saturating_sub(follower.acked_offset);
250 min_lag_ms = min_lag_ms.min(lag);
251 max_lag_ms = max_lag_ms.max(lag);
252 }
253
254 let follower_count = self.followers.len();
255 if follower_count == 0 {
256 min_lag_ms = 0;
257 }
258
259 ReplicationStatus {
260 followers: follower_count,
261 min_lag_ms,
262 max_lag_ms,
263 replication_mode: self.replication_mode,
264 }
265 }
266
267 pub async fn serve(self: Arc<Self>, port: u16) -> anyhow::Result<()> {
269 let addr = format!("0.0.0.0:{port}");
270 let listener = TcpListener::bind(&addr).await?;
271
272 tracing::info!(
273 "Replication server listening on {} (followers can connect)",
274 addr
275 );
276
277 loop {
278 match listener.accept().await {
279 Ok((stream, peer_addr)) => {
280 tracing::info!("Follower connected from {}", peer_addr);
281 let shipper = Arc::clone(&self);
282 tokio::spawn(async move {
283 if let Err(e) = shipper.handle_follower(stream).await {
284 tracing::warn!("Follower {} disconnected: {}", peer_addr, e);
285 }
286 });
287 }
288 Err(e) => {
289 tracing::error!("Failed to accept follower connection: {}", e);
290 }
291 }
292 }
293 }
294
295 fn needs_snapshot_catchup(&self, last_offset: u64) -> bool {
300 if last_offset == 0 {
302 if let Some(ref store) = self.store
303 && let Some(wal) = store.wal()
304 {
305 return wal.current_sequence() > 0;
306 }
307 return false;
308 }
309
310 if let Some(ref store) = self.store
311 && let Some(wal) = store.wal()
312 && let Some(oldest) = wal.oldest_sequence()
313 {
314 return last_offset < oldest;
315 }
316 false
317 }
318
319 async fn send_snapshot(
325 &self,
326 writer: &mut tokio::net::tcp::OwnedWriteHalf,
327 peer: std::net::SocketAddr,
328 ) -> anyhow::Result<u64> {
329 let store = self
330 .store
331 .as_ref()
332 .ok_or_else(|| anyhow::anyhow!("No store available for snapshot catch-up"))?;
333
334 if let Err(e) = store.flush_storage() {
336 tracing::warn!("Failed to flush storage before snapshot: {}", e);
337 }
338
339 let storage = store.parquet_storage().ok_or_else(|| {
340 anyhow::anyhow!("No Parquet storage configured for snapshot catch-up")
341 })?;
342
343 let parquet_files = {
345 let storage_guard = storage.read();
346 storage_guard.list_parquet_files()?
347 };
348
349 if parquet_files.is_empty() {
350 tracing::info!("No Parquet files to send for snapshot catch-up to {}", peer);
351 let current_offset = self
352 .leader_offset
353 .load(std::sync::atomic::Ordering::Relaxed);
354 return Ok(current_offset);
355 }
356
357 let filenames: Vec<String> = parquet_files
359 .iter()
360 .filter_map(|p| p.file_name().map(|n| n.to_string_lossy().to_string()))
361 .collect();
362
363 tracing::info!(
364 "Sending Parquet snapshot to {} ({} files: {:?})",
365 peer,
366 filenames.len(),
367 filenames,
368 );
369
370 let start_msg = LeaderMessage::SnapshotStart {
372 parquet_files: filenames,
373 };
374 send_message(writer, &start_msg).await?;
375
376 for file_path in &parquet_files {
378 let filename = file_path
379 .file_name()
380 .map(|n| n.to_string_lossy().to_string())
381 .unwrap_or_default();
382
383 let file_data = tokio::fs::read(file_path).await.map_err(|e| {
384 anyhow::anyhow!("Failed to read Parquet file {}: {}", file_path.display(), e)
385 })?;
386
387 let total_size = file_data.len();
388 let mut offset: usize = 0;
389
390 while offset < total_size {
391 let end = (offset + SNAPSHOT_CHUNK_SIZE).min(total_size);
392 let chunk = &file_data[offset..end];
393 let is_last = end >= total_size;
394
395 use base64::Engine;
396 let encoded = base64::engine::general_purpose::STANDARD.encode(chunk);
397
398 let chunk_msg = LeaderMessage::SnapshotChunk {
399 filename: filename.clone(),
400 data: encoded,
401 chunk_offset: offset as u64,
402 is_last,
403 };
404 send_message(writer, &chunk_msg).await?;
405
406 offset = end;
407 }
408
409 tracing::debug!(
410 "Sent Parquet file {} ({} bytes) to {}",
411 filename,
412 total_size,
413 peer,
414 );
415 }
416
417 let wal_offset_after_snapshot = self
419 .leader_offset
420 .load(std::sync::atomic::Ordering::Relaxed);
421
422 let end_msg = LeaderMessage::SnapshotEnd {
424 wal_offset_after_snapshot,
425 };
426 send_message(writer, &end_msg).await?;
427
428 tracing::info!(
429 "Snapshot transfer complete to {}, resuming WAL from offset {}",
430 peer,
431 wal_offset_after_snapshot,
432 );
433
434 Ok(wal_offset_after_snapshot)
435 }
436
437 async fn handle_follower(self: &Arc<Self>, stream: TcpStream) -> anyhow::Result<()> {
439 let peer = stream.peer_addr()?;
440 let (reader, mut writer) = stream.into_split();
441 let mut reader = BufReader::new(reader);
442
443 let mut line = String::new();
445 reader.read_line(&mut line).await?;
446
447 let subscribe_msg: FollowerMessage = serde_json::from_str(line.trim())?;
448 let last_offset = match subscribe_msg {
449 FollowerMessage::Subscribe { last_offset } => last_offset,
450 _ => {
451 anyhow::bail!("Expected Subscribe message, got: {:?}", subscribe_msg);
452 }
453 };
454
455 tracing::info!(
456 "Follower {} subscribed with last_offset={}",
457 peer,
458 last_offset
459 );
460
461 let follower_id = Uuid::new_v4();
463 self.followers.insert(
464 follower_id,
465 FollowerState {
466 acked_offset: last_offset,
467 connected_at: Instant::now(),
468 },
469 );
470
471 if let Some(ref m) = self.metrics {
473 m.replication_followers_connected
474 .set(self.followers.len() as i64);
475 }
476
477 let mut entry_rx = self.entry_tx.subscribe();
479
480 let resume_offset = if self.needs_snapshot_catchup(last_offset) {
482 tracing::info!(
484 "Follower {} needs snapshot catch-up (last_offset={}, behind WAL range)",
485 peer,
486 last_offset,
487 );
488 match self.send_snapshot(&mut writer, peer).await {
489 Ok(offset) => offset,
490 Err(e) => {
491 tracing::error!("Failed to send snapshot to {}: {}", peer, e);
492 self.followers.remove(&follower_id);
493 return Err(e);
494 }
495 }
496 } else {
497 last_offset
499 };
500
501 let current_offset = self
503 .leader_offset
504 .load(std::sync::atomic::Ordering::Relaxed);
505 let caught_up = LeaderMessage::CaughtUp { current_offset };
506 send_message(&mut writer, &caught_up).await?;
507
508 let followers = Arc::clone(&self.followers);
510 let leader_offset = Arc::clone(&self.leader_offset);
511
512 let followers_ack = Arc::clone(&followers);
514 let ack_metrics = self.metrics.clone();
515 let ack_leader_offset = Arc::clone(&leader_offset);
516 let ack_follower_id_str = follower_id.to_string();
517 let ack_notify = Arc::clone(&self.ack_notify);
518 let ack_task = tokio::spawn(async move {
519 let mut line = String::new();
520 loop {
521 line.clear();
522 match reader.read_line(&mut line).await {
523 Ok(0) => break, Ok(_) => {
525 if let Ok(FollowerMessage::Ack { offset }) =
526 serde_json::from_str(line.trim())
527 && let Some(mut f) = followers_ack.get_mut(&follower_id)
528 {
529 f.acked_offset = offset;
530 ack_notify.notify_waiters();
532 if let Some(ref m) = ack_metrics {
533 m.replication_acks_total.inc();
534 let leader_off =
535 ack_leader_offset.load(std::sync::atomic::Ordering::Relaxed);
536 let lag = leader_off.saturating_sub(offset);
537 m.replication_follower_lag_seconds
538 .with_label_values(&[&ack_follower_id_str])
539 .set(lag as i64);
540 }
541 }
542 }
543 Err(e) => {
544 tracing::debug!("Error reading ACK from follower: {}", e);
545 break;
546 }
547 }
548 }
549 });
550
551 let ship_metrics = self.metrics.clone();
553 let stream_result: anyhow::Result<()> = async {
554 loop {
555 match entry_rx.recv().await {
556 Ok(wal_entry) => {
557 let offset = wal_entry.sequence;
558 if offset > resume_offset {
560 leader_offset.store(offset, std::sync::atomic::Ordering::Relaxed);
561 let msg = LeaderMessage::WalEntry {
562 offset,
563 data: wal_entry,
564 };
565 let json = serde_json::to_string(&msg)?;
566 if let Some(ref m) = ship_metrics {
567 m.replication_wal_shipped_total.inc();
568 m.replication_wal_shipped_bytes_total
569 .inc_by(json.len() as u64);
570 }
571 send_message_raw(&mut writer, json).await?;
572 }
573 }
574 Err(broadcast::error::RecvError::Lagged(n)) => {
575 tracing::warn!(
576 "Follower {} lagged by {} entries, some may be missed",
577 peer,
578 n
579 );
580 }
581 Err(broadcast::error::RecvError::Closed) => {
582 tracing::info!(
583 "Broadcast channel closed, stopping replication to {}",
584 peer
585 );
586 break;
587 }
588 }
589 }
590 Ok(())
591 }
592 .await;
593
594 ack_task.abort();
596 self.followers.remove(&follower_id);
597 if let Some(ref m) = self.metrics {
598 m.replication_followers_connected
599 .set(self.followers.len() as i64);
600 }
601 tracing::info!("Follower {} removed from active set", peer);
602
603 stream_result
604 }
605}
606
607async fn send_message(
609 writer: &mut tokio::net::tcp::OwnedWriteHalf,
610 msg: &LeaderMessage,
611) -> anyhow::Result<()> {
612 let json = serde_json::to_string(msg)?;
613 send_message_raw(writer, json).await
614}
615
616async fn send_message_raw(
618 writer: &mut tokio::net::tcp::OwnedWriteHalf,
619 mut json: String,
620) -> anyhow::Result<()> {
621 json.push('\n');
622 writer.write_all(json.as_bytes()).await?;
623 writer.flush().await?;
624 Ok(())
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_wal_shipper_creation() {
633 let (shipper, _tx) = WalShipper::new();
634 let status = shipper.status();
635 assert_eq!(status.followers, 0);
636 assert_eq!(status.min_lag_ms, 0);
637 assert_eq!(status.max_lag_ms, 0);
638 }
639
640 #[test]
641 fn test_replication_status_serialization() {
642 let status = ReplicationStatus {
643 followers: 2,
644 min_lag_ms: 12,
645 max_lag_ms: 45,
646 replication_mode: ReplicationMode::Async,
647 };
648 let json = serde_json::to_value(&status).unwrap();
649 assert_eq!(json["followers"], 2);
650 assert_eq!(json["min_lag_ms"], 12);
651 assert_eq!(json["max_lag_ms"], 45);
652 assert_eq!(json["replication_mode"], "async");
653 }
654
655 #[test]
656 fn test_replication_mode_from_str() {
657 assert_eq!(
658 ReplicationMode::from_str_value("async"),
659 ReplicationMode::Async
660 );
661 assert_eq!(
662 ReplicationMode::from_str_value("semi-sync"),
663 ReplicationMode::SemiSync
664 );
665 assert_eq!(
666 ReplicationMode::from_str_value("semi_sync"),
667 ReplicationMode::SemiSync
668 );
669 assert_eq!(
670 ReplicationMode::from_str_value("semisync"),
671 ReplicationMode::SemiSync
672 );
673 assert_eq!(
674 ReplicationMode::from_str_value("sync"),
675 ReplicationMode::Sync
676 );
677 assert_eq!(
678 ReplicationMode::from_str_value("unknown"),
679 ReplicationMode::Async
680 );
681 }
682
683 #[test]
684 fn test_replication_mode_display() {
685 assert_eq!(ReplicationMode::Async.to_string(), "async");
686 assert_eq!(ReplicationMode::SemiSync.to_string(), "semi-sync");
687 assert_eq!(ReplicationMode::Sync.to_string(), "sync");
688 }
689
690 #[test]
691 fn test_replication_mode_serialization() {
692 let json = serde_json::to_value(ReplicationMode::SemiSync).unwrap();
693 assert_eq!(json, "semi_sync");
694 let json = serde_json::to_value(ReplicationMode::Sync).unwrap();
695 assert_eq!(json, "sync");
696 let json = serde_json::to_value(ReplicationMode::Async).unwrap();
697 assert_eq!(json, "async");
698 }
699
700 #[tokio::test]
701 async fn test_wait_for_ack_async_mode() {
702 let (shipper, _tx) = WalShipper::new();
703 assert!(shipper.wait_for_ack(100).await);
705 }
706
707 #[tokio::test]
708 async fn test_wait_for_ack_semi_sync_no_followers() {
709 let (mut shipper, _tx) = WalShipper::new();
710 shipper.set_replication_mode(ReplicationMode::SemiSync, Duration::from_millis(100));
711 assert!(!shipper.wait_for_ack(1).await);
713 }
714
715 #[tokio::test]
716 async fn test_broadcast_channel_delivery() {
717 let (shipper, tx) = WalShipper::new();
718 let mut rx = shipper.entry_tx.subscribe();
719
720 let event = crate::test_utils::test_event("test-entity", "test.event");
722 let entry = WALEntry::new(1, event);
723
724 tx.send(entry.clone()).unwrap();
726
727 let received = rx.recv().await.unwrap();
729 assert_eq!(received.sequence, 1);
730 }
731
732 #[test]
733 fn test_needs_snapshot_catchup_no_store() {
734 let (shipper, _tx) = WalShipper::new();
735 assert!(!shipper.needs_snapshot_catchup(0));
737 assert!(!shipper.needs_snapshot_catchup(100));
738 }
739
740 #[test]
741 fn test_needs_snapshot_catchup_with_empty_store() {
742 let (mut shipper, _tx) = WalShipper::new();
743 let store = Arc::new(EventStore::new());
744 shipper.set_store(store);
745
746 assert!(!shipper.needs_snapshot_catchup(0));
748 }
749}