1use crate::{
27 infrastructure::{
28 observability::metrics::MetricsRegistry,
29 persistence::wal::WALEntry,
30 replication::protocol::{FollowerMessage, LeaderMessage},
31 },
32 store::EventStore,
33};
34use dashmap::DashMap;
35use std::{
36 sync::Arc,
37 time::{Duration, Instant},
38};
39use tokio::{
40 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
41 net::{TcpListener, TcpStream},
42 sync::{Notify, broadcast},
43};
44use uuid::Uuid;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
48#[serde(rename_all = "snake_case")]
49pub enum ReplicationMode {
50 Async,
52 SemiSync,
55 Sync,
57}
58
59impl ReplicationMode {
60 pub fn from_str_value(s: &str) -> Self {
62 match s.to_lowercase().as_str() {
63 "semi-sync" | "semi_sync" | "semisync" => ReplicationMode::SemiSync,
64 "sync" => ReplicationMode::Sync,
65 _ => ReplicationMode::Async,
66 }
67 }
68}
69
70impl std::fmt::Display for ReplicationMode {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 ReplicationMode::Async => write!(f, "async"),
74 ReplicationMode::SemiSync => write!(f, "semi-sync"),
75 ReplicationMode::Sync => write!(f, "sync"),
76 }
77 }
78}
79
80const SNAPSHOT_CHUNK_SIZE: usize = 512 * 1024;
82
83struct FollowerState {
85 acked_offset: u64,
87 connected_at: Instant,
89}
90
91#[derive(Debug, Clone, serde::Serialize)]
93pub struct ReplicationStatus {
94 pub followers: usize,
95 pub min_lag_ms: u64,
96 pub max_lag_ms: u64,
97 pub replication_mode: ReplicationMode,
98}
99
100pub struct WalShipper {
102 entry_tx: broadcast::Sender<WALEntry>,
104 followers: Arc<DashMap<Uuid, FollowerState>>,
106 leader_offset: Arc<std::sync::atomic::AtomicU64>,
108 store: Option<Arc<EventStore>>,
110 metrics: Option<Arc<MetricsRegistry>>,
112 replication_mode: ReplicationMode,
114 ack_timeout: Duration,
116 ack_notify: Arc<Notify>,
118}
119
120impl WalShipper {
121 pub fn new() -> (Self, broadcast::Sender<WALEntry>) {
124 let (entry_tx, _) = broadcast::channel(4096);
125 let tx_clone = entry_tx.clone();
126 (
127 Self {
128 entry_tx,
129 followers: Arc::new(DashMap::new()),
130 leader_offset: Arc::new(std::sync::atomic::AtomicU64::new(0)),
131 store: None,
132 metrics: None,
133 replication_mode: ReplicationMode::Async,
134 ack_timeout: Duration::from_millis(5000),
135 ack_notify: Arc::new(Notify::new()),
136 },
137 tx_clone,
138 )
139 }
140
141 pub fn set_replication_mode(&mut self, mode: ReplicationMode, ack_timeout: Duration) {
143 self.replication_mode = mode;
144 self.ack_timeout = ack_timeout;
145 }
146
147 pub fn replication_mode(&self) -> ReplicationMode {
149 self.replication_mode
150 }
151
152 pub fn current_leader_offset(&self) -> u64 {
154 self.leader_offset
155 .load(std::sync::atomic::Ordering::Relaxed)
156 }
157
158 #[cfg_attr(feature = "hotpath", hotpath::measure)]
167 pub async fn wait_for_ack(&self, target_offset: u64) -> bool {
168 match self.replication_mode {
169 ReplicationMode::Async => true,
170 ReplicationMode::SemiSync => self.wait_for_ack_inner(target_offset, false).await,
171 ReplicationMode::Sync => self.wait_for_ack_inner(target_offset, true).await,
172 }
173 }
174
175 async fn wait_for_ack_inner(&self, target_offset: u64, all_followers: bool) -> bool {
178 let start = Instant::now();
179 let timeout = self.ack_timeout;
180
181 loop {
182 let follower_count = self.followers.len();
184 if follower_count == 0 {
185 return false;
187 }
188
189 if all_followers {
190 let all_acked = self
192 .followers
193 .iter()
194 .all(|entry| entry.value().acked_offset >= target_offset);
195 if all_acked {
196 return true;
197 }
198 } else {
199 let any_acked = self
201 .followers
202 .iter()
203 .any(|entry| entry.value().acked_offset >= target_offset);
204 if any_acked {
205 return true;
206 }
207 }
208
209 let elapsed = start.elapsed();
211 if elapsed >= timeout {
212 return false;
213 }
214
215 let remaining = timeout - elapsed;
217 if tokio::time::timeout(remaining, self.ack_notify.notified())
218 .await
219 .is_err()
220 {
221 return false;
222 }
223 }
224 }
225
226 pub fn set_metrics(&mut self, metrics: Arc<MetricsRegistry>) {
228 self.metrics = Some(metrics);
229 }
230
231 pub fn set_store(&mut self, store: Arc<EventStore>) {
237 self.store = Some(store);
238 }
239
240 pub fn status(&self) -> ReplicationStatus {
242 let leader_offset = self
243 .leader_offset
244 .load(std::sync::atomic::Ordering::Relaxed);
245 let mut min_lag_ms = u64::MAX;
246 let mut max_lag_ms = 0u64;
247
248 for entry in self.followers.iter() {
249 let follower = entry.value();
250 let lag = leader_offset.saturating_sub(follower.acked_offset);
251 min_lag_ms = min_lag_ms.min(lag);
252 max_lag_ms = max_lag_ms.max(lag);
253 }
254
255 let follower_count = self.followers.len();
256 if follower_count == 0 {
257 min_lag_ms = 0;
258 }
259
260 ReplicationStatus {
261 followers: follower_count,
262 min_lag_ms,
263 max_lag_ms,
264 replication_mode: self.replication_mode,
265 }
266 }
267
268 #[cfg_attr(feature = "hotpath", hotpath::measure)]
270 pub async fn serve(self: Arc<Self>, port: u16) -> anyhow::Result<()> {
271 let addr = format!("0.0.0.0:{port}");
272 let listener = TcpListener::bind(&addr).await?;
273
274 tracing::info!(
275 "Replication server listening on {} (followers can connect)",
276 addr
277 );
278
279 loop {
280 match listener.accept().await {
281 Ok((stream, peer_addr)) => {
282 tracing::info!("Follower connected from {}", peer_addr);
283 let shipper = Arc::clone(&self);
284 tokio::spawn(async move {
285 if let Err(e) = shipper.handle_follower(stream).await {
286 tracing::warn!("Follower {} disconnected: {}", peer_addr, e);
287 }
288 });
289 }
290 Err(e) => {
291 tracing::error!("Failed to accept follower connection: {}", e);
292 }
293 }
294 }
295 }
296
297 fn needs_snapshot_catchup(&self, last_offset: u64) -> bool {
302 if last_offset == 0 {
304 if let Some(ref store) = self.store
305 && let Some(wal) = store.wal()
306 {
307 return wal.current_sequence() > 0;
308 }
309 return false;
310 }
311
312 if let Some(ref store) = self.store
313 && let Some(wal) = store.wal()
314 && let Some(oldest) = wal.oldest_sequence()
315 {
316 return last_offset < oldest;
317 }
318 false
319 }
320
321 async fn send_snapshot(
327 &self,
328 writer: &mut tokio::net::tcp::OwnedWriteHalf,
329 peer: std::net::SocketAddr,
330 ) -> anyhow::Result<u64> {
331 let store = self
332 .store
333 .as_ref()
334 .ok_or_else(|| anyhow::anyhow!("No store available for snapshot catch-up"))?;
335
336 if let Err(e) = store.flush_storage() {
338 tracing::warn!("Failed to flush storage before snapshot: {}", e);
339 }
340
341 let storage = store.parquet_storage().ok_or_else(|| {
342 anyhow::anyhow!("No Parquet storage configured for snapshot catch-up")
343 })?;
344
345 let parquet_files = {
347 let storage_guard = storage.read();
348 storage_guard.list_parquet_files()?
349 };
350
351 if parquet_files.is_empty() {
352 tracing::info!("No Parquet files to send for snapshot catch-up to {}", peer);
353 let current_offset = self
354 .leader_offset
355 .load(std::sync::atomic::Ordering::Relaxed);
356 return Ok(current_offset);
357 }
358
359 let filenames: Vec<String> = parquet_files
361 .iter()
362 .filter_map(|p| p.file_name().map(|n| n.to_string_lossy().to_string()))
363 .collect();
364
365 tracing::info!(
366 "Sending Parquet snapshot to {} ({} files: {:?})",
367 peer,
368 filenames.len(),
369 filenames,
370 );
371
372 let start_msg = LeaderMessage::SnapshotStart {
374 parquet_files: filenames,
375 };
376 send_message(writer, &start_msg).await?;
377
378 for file_path in &parquet_files {
380 let filename = file_path
381 .file_name()
382 .map(|n| n.to_string_lossy().to_string())
383 .unwrap_or_default();
384
385 let file_data = tokio::fs::read(file_path).await.map_err(|e| {
386 anyhow::anyhow!("Failed to read Parquet file {}: {}", file_path.display(), e)
387 })?;
388
389 let total_size = file_data.len();
390 let mut offset: usize = 0;
391
392 while offset < total_size {
393 let end = (offset + SNAPSHOT_CHUNK_SIZE).min(total_size);
394 let chunk = &file_data[offset..end];
395 let is_last = end >= total_size;
396
397 use base64::Engine;
398 let encoded = base64::engine::general_purpose::STANDARD.encode(chunk);
399
400 let chunk_msg = LeaderMessage::SnapshotChunk {
401 filename: filename.clone(),
402 data: encoded,
403 chunk_offset: offset as u64,
404 is_last,
405 };
406 send_message(writer, &chunk_msg).await?;
407
408 offset = end;
409 }
410
411 tracing::debug!(
412 "Sent Parquet file {} ({} bytes) to {}",
413 filename,
414 total_size,
415 peer,
416 );
417 }
418
419 let wal_offset_after_snapshot = self
421 .leader_offset
422 .load(std::sync::atomic::Ordering::Relaxed);
423
424 let end_msg = LeaderMessage::SnapshotEnd {
426 wal_offset_after_snapshot,
427 };
428 send_message(writer, &end_msg).await?;
429
430 tracing::info!(
431 "Snapshot transfer complete to {}, resuming WAL from offset {}",
432 peer,
433 wal_offset_after_snapshot,
434 );
435
436 Ok(wal_offset_after_snapshot)
437 }
438
439 async fn handle_follower(self: &Arc<Self>, stream: TcpStream) -> anyhow::Result<()> {
441 let peer = stream.peer_addr()?;
442 let (reader, mut writer) = stream.into_split();
443 let mut reader = BufReader::new(reader);
444
445 let mut line = String::new();
447 reader.read_line(&mut line).await?;
448
449 let subscribe_msg: FollowerMessage = serde_json::from_str(line.trim())?;
450 let FollowerMessage::Subscribe { last_offset } = subscribe_msg else {
451 anyhow::bail!("Expected Subscribe message, got: {subscribe_msg:?}");
452 };
453
454 tracing::info!(
455 "Follower {} subscribed with last_offset={}",
456 peer,
457 last_offset
458 );
459
460 let follower_id = Uuid::new_v4();
462 self.followers.insert(
463 follower_id,
464 FollowerState {
465 acked_offset: last_offset,
466 connected_at: Instant::now(),
467 },
468 );
469
470 if let Some(ref m) = self.metrics {
472 m.replication_followers_connected
473 .set(self.followers.len() as i64);
474 }
475
476 let mut entry_rx = self.entry_tx.subscribe();
478
479 let resume_offset = if self.needs_snapshot_catchup(last_offset) {
481 tracing::info!(
483 "Follower {} needs snapshot catch-up (last_offset={}, behind WAL range)",
484 peer,
485 last_offset,
486 );
487 match self.send_snapshot(&mut writer, peer).await {
488 Ok(offset) => offset,
489 Err(e) => {
490 tracing::error!("Failed to send snapshot to {}: {}", peer, e);
491 self.followers.remove(&follower_id);
492 return Err(e);
493 }
494 }
495 } else {
496 last_offset
498 };
499
500 let current_offset = self
502 .leader_offset
503 .load(std::sync::atomic::Ordering::Relaxed);
504 let caught_up = LeaderMessage::CaughtUp { current_offset };
505 send_message(&mut writer, &caught_up).await?;
506
507 let followers = Arc::clone(&self.followers);
509 let leader_offset = Arc::clone(&self.leader_offset);
510
511 let followers_ack = Arc::clone(&followers);
513 let ack_metrics = self.metrics.clone();
514 let ack_leader_offset = Arc::clone(&leader_offset);
515 let ack_follower_id_str = follower_id.to_string();
516 let ack_notify = Arc::clone(&self.ack_notify);
517 let ack_task = tokio::spawn(async move {
518 let mut line = String::new();
519 loop {
520 line.clear();
521 match reader.read_line(&mut line).await {
522 Ok(0) => break, Ok(_) => {
524 if let Ok(FollowerMessage::Ack { offset }) =
525 serde_json::from_str(line.trim())
526 && let Some(mut f) = followers_ack.get_mut(&follower_id)
527 {
528 f.acked_offset = offset;
529 ack_notify.notify_waiters();
531 if let Some(ref m) = ack_metrics {
532 m.replication_acks_total.inc();
533 let leader_off =
534 ack_leader_offset.load(std::sync::atomic::Ordering::Relaxed);
535 let lag = leader_off.saturating_sub(offset);
536 m.replication_follower_lag_seconds
537 .with_label_values(&[&ack_follower_id_str])
538 .set(lag as i64);
539 }
540 }
541 }
542 Err(e) => {
543 tracing::debug!("Error reading ACK from follower: {}", e);
544 break;
545 }
546 }
547 }
548 });
549
550 let ship_metrics = self.metrics.clone();
552 let stream_result: anyhow::Result<()> = async {
553 loop {
554 match entry_rx.recv().await {
555 Ok(wal_entry) => {
556 let offset = wal_entry.sequence;
557 if offset > resume_offset {
559 leader_offset.store(offset, std::sync::atomic::Ordering::Relaxed);
560 let msg = LeaderMessage::WalEntry {
561 offset,
562 data: wal_entry,
563 };
564 let json = serde_json::to_string(&msg)?;
565 if let Some(ref m) = ship_metrics {
566 m.replication_wal_shipped_total.inc();
567 m.replication_wal_shipped_bytes_total
568 .inc_by(json.len() as u64);
569 }
570 send_message_raw(&mut writer, json).await?;
571 }
572 }
573 Err(broadcast::error::RecvError::Lagged(n)) => {
574 tracing::warn!(
575 "Follower {} lagged by {} entries, some may be missed",
576 peer,
577 n
578 );
579 }
580 Err(broadcast::error::RecvError::Closed) => {
581 tracing::info!(
582 "Broadcast channel closed, stopping replication to {}",
583 peer
584 );
585 break;
586 }
587 }
588 }
589 Ok(())
590 }
591 .await;
592
593 ack_task.abort();
595 self.followers.remove(&follower_id);
596 if let Some(ref m) = self.metrics {
597 m.replication_followers_connected
598 .set(self.followers.len() as i64);
599 }
600 tracing::info!("Follower {} removed from active set", peer);
601
602 stream_result
603 }
604}
605
606async fn send_message(
608 writer: &mut tokio::net::tcp::OwnedWriteHalf,
609 msg: &LeaderMessage,
610) -> anyhow::Result<()> {
611 let json = serde_json::to_string(msg)?;
612 send_message_raw(writer, json).await
613}
614
615async fn send_message_raw(
617 writer: &mut tokio::net::tcp::OwnedWriteHalf,
618 mut json: String,
619) -> anyhow::Result<()> {
620 json.push('\n');
621 writer.write_all(json.as_bytes()).await?;
622 writer.flush().await?;
623 Ok(())
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_wal_shipper_creation() {
632 let (shipper, _tx) = WalShipper::new();
633 let status = shipper.status();
634 assert_eq!(status.followers, 0);
635 assert_eq!(status.min_lag_ms, 0);
636 assert_eq!(status.max_lag_ms, 0);
637 }
638
639 #[test]
640 fn test_replication_status_serialization() {
641 let status = ReplicationStatus {
642 followers: 2,
643 min_lag_ms: 12,
644 max_lag_ms: 45,
645 replication_mode: ReplicationMode::Async,
646 };
647 let json = serde_json::to_value(&status).unwrap();
648 assert_eq!(json["followers"], 2);
649 assert_eq!(json["min_lag_ms"], 12);
650 assert_eq!(json["max_lag_ms"], 45);
651 assert_eq!(json["replication_mode"], "async");
652 }
653
654 #[test]
655 fn test_replication_mode_from_str() {
656 assert_eq!(
657 ReplicationMode::from_str_value("async"),
658 ReplicationMode::Async
659 );
660 assert_eq!(
661 ReplicationMode::from_str_value("semi-sync"),
662 ReplicationMode::SemiSync
663 );
664 assert_eq!(
665 ReplicationMode::from_str_value("semi_sync"),
666 ReplicationMode::SemiSync
667 );
668 assert_eq!(
669 ReplicationMode::from_str_value("semisync"),
670 ReplicationMode::SemiSync
671 );
672 assert_eq!(
673 ReplicationMode::from_str_value("sync"),
674 ReplicationMode::Sync
675 );
676 assert_eq!(
677 ReplicationMode::from_str_value("unknown"),
678 ReplicationMode::Async
679 );
680 }
681
682 #[test]
683 fn test_replication_mode_display() {
684 assert_eq!(ReplicationMode::Async.to_string(), "async");
685 assert_eq!(ReplicationMode::SemiSync.to_string(), "semi-sync");
686 assert_eq!(ReplicationMode::Sync.to_string(), "sync");
687 }
688
689 #[test]
690 fn test_replication_mode_serialization() {
691 let json = serde_json::to_value(ReplicationMode::SemiSync).unwrap();
692 assert_eq!(json, "semi_sync");
693 let json = serde_json::to_value(ReplicationMode::Sync).unwrap();
694 assert_eq!(json, "sync");
695 let json = serde_json::to_value(ReplicationMode::Async).unwrap();
696 assert_eq!(json, "async");
697 }
698
699 #[tokio::test]
700 async fn test_wait_for_ack_async_mode() {
701 let (shipper, _tx) = WalShipper::new();
702 assert!(shipper.wait_for_ack(100).await);
704 }
705
706 #[tokio::test]
707 async fn test_wait_for_ack_semi_sync_no_followers() {
708 let (mut shipper, _tx) = WalShipper::new();
709 shipper.set_replication_mode(ReplicationMode::SemiSync, Duration::from_millis(100));
710 assert!(!shipper.wait_for_ack(1).await);
712 }
713
714 #[tokio::test]
715 async fn test_broadcast_channel_delivery() {
716 let (shipper, tx) = WalShipper::new();
717 let mut rx = shipper.entry_tx.subscribe();
718
719 let event = crate::test_utils::test_event("test-entity", "test.event");
721 let entry = WALEntry::new(1, event);
722
723 tx.send(entry.clone()).unwrap();
725
726 let received = rx.recv().await.unwrap();
728 assert_eq!(received.sequence, 1);
729 }
730
731 #[test]
732 fn test_needs_snapshot_catchup_no_store() {
733 let (shipper, _tx) = WalShipper::new();
734 assert!(!shipper.needs_snapshot_catchup(0));
736 assert!(!shipper.needs_snapshot_catchup(100));
737 }
738
739 #[test]
740 fn test_needs_snapshot_catchup_with_empty_store() {
741 let (mut shipper, _tx) = WalShipper::new();
742 let store = Arc::new(EventStore::new());
743 shipper.set_store(store);
744
745 assert!(!shipper.needs_snapshot_catchup(0));
747 }
748}