use crate::{
infrastructure::{
observability::metrics::MetricsRegistry,
persistence::wal::WALEntry,
replication::protocol::{FollowerMessage, LeaderMessage},
},
store::EventStore,
};
use dashmap::DashMap;
use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{TcpListener, TcpStream},
sync::{Notify, broadcast},
};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReplicationMode {
Async,
SemiSync,
Sync,
}
impl ReplicationMode {
pub fn from_str_value(s: &str) -> Self {
match s.to_lowercase().as_str() {
"semi-sync" | "semi_sync" | "semisync" => ReplicationMode::SemiSync,
"sync" => ReplicationMode::Sync,
_ => ReplicationMode::Async,
}
}
}
impl std::fmt::Display for ReplicationMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReplicationMode::Async => write!(f, "async"),
ReplicationMode::SemiSync => write!(f, "semi-sync"),
ReplicationMode::Sync => write!(f, "sync"),
}
}
}
const SNAPSHOT_CHUNK_SIZE: usize = 512 * 1024;
struct FollowerState {
acked_offset: u64,
connected_at: Instant,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ReplicationStatus {
pub followers: usize,
pub min_lag_ms: u64,
pub max_lag_ms: u64,
pub replication_mode: ReplicationMode,
}
pub struct WalShipper {
entry_tx: broadcast::Sender<WALEntry>,
followers: Arc<DashMap<Uuid, FollowerState>>,
leader_offset: Arc<std::sync::atomic::AtomicU64>,
store: Option<Arc<EventStore>>,
metrics: Option<Arc<MetricsRegistry>>,
replication_mode: ReplicationMode,
ack_timeout: Duration,
ack_notify: Arc<Notify>,
}
impl WalShipper {
pub fn new() -> (Self, broadcast::Sender<WALEntry>) {
let (entry_tx, _) = broadcast::channel(4096);
let tx_clone = entry_tx.clone();
(
Self {
entry_tx,
followers: Arc::new(DashMap::new()),
leader_offset: Arc::new(std::sync::atomic::AtomicU64::new(0)),
store: None,
metrics: None,
replication_mode: ReplicationMode::Async,
ack_timeout: Duration::from_millis(5000),
ack_notify: Arc::new(Notify::new()),
},
tx_clone,
)
}
pub fn set_replication_mode(&mut self, mode: ReplicationMode, ack_timeout: Duration) {
self.replication_mode = mode;
self.ack_timeout = ack_timeout;
}
pub fn replication_mode(&self) -> ReplicationMode {
self.replication_mode
}
pub fn current_leader_offset(&self) -> u64 {
self.leader_offset
.load(std::sync::atomic::Ordering::Relaxed)
}
#[cfg_attr(feature = "hotpath", hotpath::measure)]
pub async fn wait_for_ack(&self, target_offset: u64) -> bool {
match self.replication_mode {
ReplicationMode::Async => true,
ReplicationMode::SemiSync => self.wait_for_ack_inner(target_offset, false).await,
ReplicationMode::Sync => self.wait_for_ack_inner(target_offset, true).await,
}
}
async fn wait_for_ack_inner(&self, target_offset: u64, all_followers: bool) -> bool {
let start = Instant::now();
let timeout = self.ack_timeout;
loop {
let follower_count = self.followers.len();
if follower_count == 0 {
return false;
}
if all_followers {
let all_acked = self
.followers
.iter()
.all(|entry| entry.value().acked_offset >= target_offset);
if all_acked {
return true;
}
} else {
let any_acked = self
.followers
.iter()
.any(|entry| entry.value().acked_offset >= target_offset);
if any_acked {
return true;
}
}
let elapsed = start.elapsed();
if elapsed >= timeout {
return false;
}
let remaining = timeout - elapsed;
if tokio::time::timeout(remaining, self.ack_notify.notified())
.await
.is_err()
{
return false;
}
}
}
pub fn set_metrics(&mut self, metrics: Arc<MetricsRegistry>) {
self.metrics = Some(metrics);
}
pub fn set_store(&mut self, store: Arc<EventStore>) {
self.store = Some(store);
}
pub fn status(&self) -> ReplicationStatus {
let leader_offset = self
.leader_offset
.load(std::sync::atomic::Ordering::Relaxed);
let mut min_lag_ms = u64::MAX;
let mut max_lag_ms = 0u64;
for entry in self.followers.iter() {
let follower = entry.value();
let lag = leader_offset.saturating_sub(follower.acked_offset);
min_lag_ms = min_lag_ms.min(lag);
max_lag_ms = max_lag_ms.max(lag);
}
let follower_count = self.followers.len();
if follower_count == 0 {
min_lag_ms = 0;
}
ReplicationStatus {
followers: follower_count,
min_lag_ms,
max_lag_ms,
replication_mode: self.replication_mode,
}
}
#[cfg_attr(feature = "hotpath", hotpath::measure)]
pub async fn serve(self: Arc<Self>, port: u16) -> anyhow::Result<()> {
let addr = format!("0.0.0.0:{port}");
let listener = TcpListener::bind(&addr).await?;
tracing::info!(
"Replication server listening on {} (followers can connect)",
addr
);
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
tracing::info!("Follower connected from {}", peer_addr);
let shipper = Arc::clone(&self);
tokio::spawn(async move {
if let Err(e) = shipper.handle_follower(stream).await {
tracing::warn!("Follower {} disconnected: {}", peer_addr, e);
}
});
}
Err(e) => {
tracing::error!("Failed to accept follower connection: {}", e);
}
}
}
}
fn needs_snapshot_catchup(&self, last_offset: u64) -> bool {
if last_offset == 0 {
if let Some(ref store) = self.store
&& let Some(wal) = store.wal()
{
return wal.current_sequence() > 0;
}
return false;
}
if let Some(ref store) = self.store
&& let Some(wal) = store.wal()
&& let Some(oldest) = wal.oldest_sequence()
{
return last_offset < oldest;
}
false
}
async fn send_snapshot(
&self,
writer: &mut tokio::net::tcp::OwnedWriteHalf,
peer: std::net::SocketAddr,
) -> anyhow::Result<u64> {
let store = self
.store
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No store available for snapshot catch-up"))?;
if let Err(e) = store.flush_storage() {
tracing::warn!("Failed to flush storage before snapshot: {}", e);
}
let storage = store.parquet_storage().ok_or_else(|| {
anyhow::anyhow!("No Parquet storage configured for snapshot catch-up")
})?;
let parquet_files = {
let storage_guard = storage.read();
storage_guard.list_parquet_files()?
};
if parquet_files.is_empty() {
tracing::info!("No Parquet files to send for snapshot catch-up to {}", peer);
let current_offset = self
.leader_offset
.load(std::sync::atomic::Ordering::Relaxed);
return Ok(current_offset);
}
let filenames: Vec<String> = parquet_files
.iter()
.filter_map(|p| p.file_name().map(|n| n.to_string_lossy().to_string()))
.collect();
tracing::info!(
"Sending Parquet snapshot to {} ({} files: {:?})",
peer,
filenames.len(),
filenames,
);
let start_msg = LeaderMessage::SnapshotStart {
parquet_files: filenames,
};
send_message(writer, &start_msg).await?;
for file_path in &parquet_files {
let filename = file_path
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
let file_data = tokio::fs::read(file_path).await.map_err(|e| {
anyhow::anyhow!("Failed to read Parquet file {}: {}", file_path.display(), e)
})?;
let total_size = file_data.len();
let mut offset: usize = 0;
while offset < total_size {
let end = (offset + SNAPSHOT_CHUNK_SIZE).min(total_size);
let chunk = &file_data[offset..end];
let is_last = end >= total_size;
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(chunk);
let chunk_msg = LeaderMessage::SnapshotChunk {
filename: filename.clone(),
data: encoded,
chunk_offset: offset as u64,
is_last,
};
send_message(writer, &chunk_msg).await?;
offset = end;
}
tracing::debug!(
"Sent Parquet file {} ({} bytes) to {}",
filename,
total_size,
peer,
);
}
let wal_offset_after_snapshot = self
.leader_offset
.load(std::sync::atomic::Ordering::Relaxed);
let end_msg = LeaderMessage::SnapshotEnd {
wal_offset_after_snapshot,
};
send_message(writer, &end_msg).await?;
tracing::info!(
"Snapshot transfer complete to {}, resuming WAL from offset {}",
peer,
wal_offset_after_snapshot,
);
Ok(wal_offset_after_snapshot)
}
async fn handle_follower(self: &Arc<Self>, stream: TcpStream) -> anyhow::Result<()> {
let peer = stream.peer_addr()?;
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let mut line = String::new();
reader.read_line(&mut line).await?;
let subscribe_msg: FollowerMessage = serde_json::from_str(line.trim())?;
let FollowerMessage::Subscribe { last_offset } = subscribe_msg else {
anyhow::bail!("Expected Subscribe message, got: {subscribe_msg:?}");
};
tracing::info!(
"Follower {} subscribed with last_offset={}",
peer,
last_offset
);
let follower_id = Uuid::new_v4();
self.followers.insert(
follower_id,
FollowerState {
acked_offset: last_offset,
connected_at: Instant::now(),
},
);
if let Some(ref m) = self.metrics {
m.replication_followers_connected
.set(self.followers.len() as i64);
}
let mut entry_rx = self.entry_tx.subscribe();
let resume_offset = if self.needs_snapshot_catchup(last_offset) {
tracing::info!(
"Follower {} needs snapshot catch-up (last_offset={}, behind WAL range)",
peer,
last_offset,
);
match self.send_snapshot(&mut writer, peer).await {
Ok(offset) => offset,
Err(e) => {
tracing::error!("Failed to send snapshot to {}: {}", peer, e);
self.followers.remove(&follower_id);
return Err(e);
}
}
} else {
last_offset
};
let current_offset = self
.leader_offset
.load(std::sync::atomic::Ordering::Relaxed);
let caught_up = LeaderMessage::CaughtUp { current_offset };
send_message(&mut writer, &caught_up).await?;
let followers = Arc::clone(&self.followers);
let leader_offset = Arc::clone(&self.leader_offset);
let followers_ack = Arc::clone(&followers);
let ack_metrics = self.metrics.clone();
let ack_leader_offset = Arc::clone(&leader_offset);
let ack_follower_id_str = follower_id.to_string();
let ack_notify = Arc::clone(&self.ack_notify);
let ack_task = tokio::spawn(async move {
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
if let Ok(FollowerMessage::Ack { offset }) =
serde_json::from_str(line.trim())
&& let Some(mut f) = followers_ack.get_mut(&follower_id)
{
f.acked_offset = offset;
ack_notify.notify_waiters();
if let Some(ref m) = ack_metrics {
m.replication_acks_total.inc();
let leader_off =
ack_leader_offset.load(std::sync::atomic::Ordering::Relaxed);
let lag = leader_off.saturating_sub(offset);
m.replication_follower_lag_seconds
.with_label_values(&[&ack_follower_id_str])
.set(lag as i64);
}
}
}
Err(e) => {
tracing::debug!("Error reading ACK from follower: {}", e);
break;
}
}
}
});
let ship_metrics = self.metrics.clone();
let stream_result: anyhow::Result<()> = async {
loop {
match entry_rx.recv().await {
Ok(wal_entry) => {
let offset = wal_entry.sequence;
if offset > resume_offset {
leader_offset.store(offset, std::sync::atomic::Ordering::Relaxed);
let msg = LeaderMessage::WalEntry {
offset,
data: wal_entry,
};
let json = serde_json::to_string(&msg)?;
if let Some(ref m) = ship_metrics {
m.replication_wal_shipped_total.inc();
m.replication_wal_shipped_bytes_total
.inc_by(json.len() as u64);
}
send_message_raw(&mut writer, json).await?;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(
"Follower {} lagged by {} entries, some may be missed",
peer,
n
);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!(
"Broadcast channel closed, stopping replication to {}",
peer
);
break;
}
}
}
Ok(())
}
.await;
ack_task.abort();
self.followers.remove(&follower_id);
if let Some(ref m) = self.metrics {
m.replication_followers_connected
.set(self.followers.len() as i64);
}
tracing::info!("Follower {} removed from active set", peer);
stream_result
}
}
async fn send_message(
writer: &mut tokio::net::tcp::OwnedWriteHalf,
msg: &LeaderMessage,
) -> anyhow::Result<()> {
let json = serde_json::to_string(msg)?;
send_message_raw(writer, json).await
}
async fn send_message_raw(
writer: &mut tokio::net::tcp::OwnedWriteHalf,
mut json: String,
) -> anyhow::Result<()> {
json.push('\n');
writer.write_all(json.as_bytes()).await?;
writer.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wal_shipper_creation() {
let (shipper, _tx) = WalShipper::new();
let status = shipper.status();
assert_eq!(status.followers, 0);
assert_eq!(status.min_lag_ms, 0);
assert_eq!(status.max_lag_ms, 0);
}
#[test]
fn test_replication_status_serialization() {
let status = ReplicationStatus {
followers: 2,
min_lag_ms: 12,
max_lag_ms: 45,
replication_mode: ReplicationMode::Async,
};
let json = serde_json::to_value(&status).unwrap();
assert_eq!(json["followers"], 2);
assert_eq!(json["min_lag_ms"], 12);
assert_eq!(json["max_lag_ms"], 45);
assert_eq!(json["replication_mode"], "async");
}
#[test]
fn test_replication_mode_from_str() {
assert_eq!(
ReplicationMode::from_str_value("async"),
ReplicationMode::Async
);
assert_eq!(
ReplicationMode::from_str_value("semi-sync"),
ReplicationMode::SemiSync
);
assert_eq!(
ReplicationMode::from_str_value("semi_sync"),
ReplicationMode::SemiSync
);
assert_eq!(
ReplicationMode::from_str_value("semisync"),
ReplicationMode::SemiSync
);
assert_eq!(
ReplicationMode::from_str_value("sync"),
ReplicationMode::Sync
);
assert_eq!(
ReplicationMode::from_str_value("unknown"),
ReplicationMode::Async
);
}
#[test]
fn test_replication_mode_display() {
assert_eq!(ReplicationMode::Async.to_string(), "async");
assert_eq!(ReplicationMode::SemiSync.to_string(), "semi-sync");
assert_eq!(ReplicationMode::Sync.to_string(), "sync");
}
#[test]
fn test_replication_mode_serialization() {
let json = serde_json::to_value(ReplicationMode::SemiSync).unwrap();
assert_eq!(json, "semi_sync");
let json = serde_json::to_value(ReplicationMode::Sync).unwrap();
assert_eq!(json, "sync");
let json = serde_json::to_value(ReplicationMode::Async).unwrap();
assert_eq!(json, "async");
}
#[tokio::test]
async fn test_wait_for_ack_async_mode() {
let (shipper, _tx) = WalShipper::new();
assert!(shipper.wait_for_ack(100).await);
}
#[tokio::test]
async fn test_wait_for_ack_semi_sync_no_followers() {
let (mut shipper, _tx) = WalShipper::new();
shipper.set_replication_mode(ReplicationMode::SemiSync, Duration::from_millis(100));
assert!(!shipper.wait_for_ack(1).await);
}
#[tokio::test]
async fn test_broadcast_channel_delivery() {
let (shipper, tx) = WalShipper::new();
let mut rx = shipper.entry_tx.subscribe();
let event = crate::test_utils::test_event("test-entity", "test.event");
let entry = WALEntry::new(1, event);
tx.send(entry.clone()).unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.sequence, 1);
}
#[test]
fn test_needs_snapshot_catchup_no_store() {
let (shipper, _tx) = WalShipper::new();
assert!(!shipper.needs_snapshot_catchup(0));
assert!(!shipper.needs_snapshot_catchup(100));
}
#[test]
fn test_needs_snapshot_catchup_with_empty_store() {
let (mut shipper, _tx) = WalShipper::new();
let store = Arc::new(EventStore::new());
shipper.set_store(store);
assert!(!shipper.needs_snapshot_catchup(0));
}
}