use anyhow::Context as _;
use crate::{
infrastructure::{
observability::metrics::MetricsRegistry,
persistence::{
storage::ParquetStorage,
wal::{WALConfig, WALEntry, WriteAheadLog},
},
replication::protocol::{FollowerMessage, LeaderMessage},
},
store::EventStore,
};
use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::Duration,
};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::TcpStream,
};
#[derive(Debug, Clone, serde::Serialize)]
pub struct FollowerReplicationStatus {
pub connected: bool,
pub leader: String,
pub replication_lag_ms: u64,
pub last_replayed_offset: u64,
pub leader_offset: u64,
pub total_replayed: u64,
pub corrupted_skipped: u64,
pub reconnect_count: u64,
pub snapshots_received: u64,
}
pub struct WalReceiver {
leader_addr: Arc<tokio::sync::RwLock<String>>,
local_wal: Arc<WriteAheadLog>,
store: Arc<EventStore>,
snapshot_dir: std::path::PathBuf,
connected: Arc<AtomicBool>,
last_replayed_offset: Arc<AtomicU64>,
leader_offset: Arc<AtomicU64>,
total_replayed: Arc<AtomicU64>,
corrupted_skipped: Arc<AtomicU64>,
reconnect_count: Arc<AtomicU64>,
snapshots_received: Arc<AtomicU64>,
metrics: Option<Arc<MetricsRegistry>>,
shutdown: Arc<AtomicBool>,
wake: Arc<tokio::sync::Notify>,
}
impl WalReceiver {
pub fn new(
leader_addr: String,
wal_dir: impl Into<std::path::PathBuf>,
store: Arc<EventStore>,
) -> anyhow::Result<Self> {
let wal_dir = wal_dir.into();
let wal_config = WALConfig {
max_file_size: 64 * 1024 * 1024,
sync_on_write: true,
max_wal_files: 10,
compress: false,
..WALConfig::default()
};
let local_wal = Arc::new(WriteAheadLog::new(&wal_dir, wal_config)?);
let last_offset = local_wal.current_sequence();
let snapshot_dir = wal_dir
.parent()
.unwrap_or(&wal_dir)
.join("follower-snapshots");
Ok(Self {
leader_addr: Arc::new(tokio::sync::RwLock::new(leader_addr)),
local_wal,
store,
snapshot_dir,
connected: Arc::new(AtomicBool::new(false)),
last_replayed_offset: Arc::new(AtomicU64::new(last_offset)),
leader_offset: Arc::new(AtomicU64::new(0)),
total_replayed: Arc::new(AtomicU64::new(0)),
corrupted_skipped: Arc::new(AtomicU64::new(0)),
reconnect_count: Arc::new(AtomicU64::new(0)),
snapshots_received: Arc::new(AtomicU64::new(0)),
metrics: None,
shutdown: Arc::new(AtomicBool::new(false)),
wake: Arc::new(tokio::sync::Notify::new()),
})
}
pub fn set_metrics(&mut self, metrics: Arc<MetricsRegistry>) {
self.metrics = Some(metrics);
}
pub fn status(&self) -> FollowerReplicationStatus {
let last_replayed = self.last_replayed_offset.load(Ordering::Relaxed);
let leader_off = self.leader_offset.load(Ordering::Relaxed);
let lag = leader_off.saturating_sub(last_replayed);
let leader = self
.leader_addr
.try_read()
.map_or_else(|_| "unknown".to_string(), |g| g.clone());
FollowerReplicationStatus {
connected: self.connected.load(Ordering::Relaxed),
leader,
replication_lag_ms: lag,
last_replayed_offset: last_replayed,
leader_offset: leader_off,
total_replayed: self.total_replayed.load(Ordering::Relaxed),
corrupted_skipped: self.corrupted_skipped.load(Ordering::Relaxed),
reconnect_count: self.reconnect_count.load(Ordering::Relaxed),
snapshots_received: self.snapshots_received.load(Ordering::Relaxed),
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
self.wake.notify_waiters();
}
pub fn repoint(&self, new_leader: &str) {
if let Ok(mut guard) = self.leader_addr.try_write() {
*guard = new_leader.to_string();
} else {
tracing::warn!(
"REPOINT: Could not acquire write lock on leader_addr, will retry on next reconnect"
);
}
self.wake.notify_waiters();
}
#[cfg_attr(feature = "hotpath", hotpath::measure)]
pub async fn run(self: Arc<Self>) {
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(30);
loop {
if self.shutdown.load(Ordering::Relaxed) {
tracing::info!("WAL receiver shutdown requested — stopping");
break;
}
let leader_addr = self.leader_addr.read().await.clone();
tracing::info!(
"Connecting to leader at {} (last_offset={})",
leader_addr,
self.last_replayed_offset.load(Ordering::Relaxed),
);
match self.connect_and_stream().await {
Ok(()) => {
tracing::info!("Leader connection closed normally");
}
Err(e) => {
tracing::warn!("Leader connection error: {}", e);
}
}
if self.shutdown.load(Ordering::Relaxed) {
tracing::info!("WAL receiver shutdown requested — stopping");
break;
}
self.connected.store(false, Ordering::Relaxed);
self.reconnect_count.fetch_add(1, Ordering::Relaxed);
if let Some(ref m) = self.metrics {
m.replication_connected.set(0);
m.replication_reconnects_total.inc();
}
tracing::info!(
"Reconnecting to leader in {:?} (attempt {})",
backoff,
self.reconnect_count.load(Ordering::Relaxed),
);
tokio::select! {
() = tokio::time::sleep(backoff) => {}
() = self.wake.notified() => {
tracing::info!("WAL receiver woken early (repoint or shutdown)");
backoff = Duration::from_secs(1);
}
}
backoff = (backoff * 2).min(max_backoff);
}
}
async fn connect_and_stream(&self) -> anyhow::Result<()> {
let leader_addr = self.leader_addr.read().await.clone();
let stream = TcpStream::connect(&leader_addr)
.await
.context(format!("TCP connect to leader at {leader_addr}"))?;
let peer = stream.peer_addr()?;
tracing::info!("Connected to leader at {}", peer);
self.connected.store(true, Ordering::Relaxed);
if let Some(ref m) = self.metrics {
m.replication_connected.set(1);
}
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let last_offset = self.last_replayed_offset.load(Ordering::Relaxed);
let subscribe = FollowerMessage::Subscribe { last_offset };
let mut json = serde_json::to_string(&subscribe)?;
json.push('\n');
writer
.write_all(json.as_bytes())
.await
.context("sending Subscribe message to leader")?;
writer
.flush()
.await
.context("flushing Subscribe message to leader")?;
tracing::info!("Subscribed to leader with last_offset={}", last_offset);
let mut line = String::new();
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.await
.context("reading WAL message from leader")?;
if bytes_read == 0 {
anyhow::bail!("Leader closed the connection");
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let msg: LeaderMessage =
serde_json::from_str(trimmed).context("parsing WAL LeaderMessage JSON")?;
match msg {
LeaderMessage::CaughtUp { current_offset } => {
tracing::info!("Caught up with leader at offset {}", current_offset,);
self.leader_offset.store(current_offset, Ordering::Relaxed);
if let Some(ref m) = self.metrics {
let last_replayed = self.last_replayed_offset.load(Ordering::Relaxed);
let lag = current_offset.saturating_sub(last_replayed);
m.replication_lag_seconds.set(lag as i64);
}
}
LeaderMessage::WalEntry { offset, data } => {
self.handle_wal_entry(offset, data, &mut writer).await?;
}
LeaderMessage::SnapshotStart { parquet_files } => {
self.handle_snapshot(&parquet_files, &mut reader, &mut writer)
.await?;
}
LeaderMessage::SnapshotChunk { .. } | LeaderMessage::SnapshotEnd { .. } => {
tracing::warn!(
"Received unexpected snapshot message outside of snapshot transfer"
);
}
}
}
}
async fn handle_snapshot(
&self,
expected_files: &[String],
reader: &mut BufReader<tokio::net::tcp::OwnedReadHalf>,
writer: &mut tokio::net::tcp::OwnedWriteHalf,
) -> anyhow::Result<()> {
tracing::info!(
"Receiving Parquet snapshot ({} files: {:?})",
expected_files.len(),
expected_files,
);
tokio::fs::create_dir_all(&self.snapshot_dir).await?;
let mut file_buffers: HashMap<String, Vec<u8>> = HashMap::new();
for filename in expected_files {
file_buffers.insert(filename.clone(), Vec::new());
}
let mut line = String::new();
let wal_offset_after_snapshot;
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.await
.context("reading snapshot message from leader")?;
if bytes_read == 0 {
anyhow::bail!("Leader closed connection during snapshot transfer");
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let msg: LeaderMessage =
serde_json::from_str(trimmed).context("parsing snapshot LeaderMessage JSON")?;
match msg {
LeaderMessage::SnapshotChunk {
filename,
data,
chunk_offset: _,
is_last,
} => {
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD.decode(&data)?;
let buffer = file_buffers.entry(filename.clone()).or_default();
buffer.extend_from_slice(&decoded);
if is_last {
let file_path = self.snapshot_dir.join(&filename);
tokio::fs::write(&file_path, &buffer).await?;
tracing::info!(
"Received Parquet file {} ({} bytes)",
filename,
buffer.len(),
);
}
}
LeaderMessage::SnapshotEnd {
wal_offset_after_snapshot: offset,
} => {
wal_offset_after_snapshot = offset;
tracing::info!(
"Snapshot transfer complete, WAL resume offset={}",
wal_offset_after_snapshot,
);
break;
}
LeaderMessage::WalEntry { .. } | LeaderMessage::CaughtUp { .. } => {
tracing::warn!("Received unexpected WAL message during snapshot transfer");
}
LeaderMessage::SnapshotStart { .. } => {
tracing::warn!("Received unexpected SnapshotStart during snapshot transfer");
}
}
}
let snapshot_dir = self.snapshot_dir.clone();
let store = Arc::clone(&self.store);
let temp_storage = ParquetStorage::new(&snapshot_dir)?;
let events = temp_storage.load_all_events()?;
tracing::info!(
"Loading {} events from snapshot into EventStore",
events.len(),
);
let mut replayed = 0u64;
for event in events {
if let Err(e) = store.ingest_replicated(&event) {
tracing::error!("Failed to replay snapshot event: {}", e);
} else {
replayed += 1;
}
}
self.last_replayed_offset
.store(wal_offset_after_snapshot, Ordering::Relaxed);
self.total_replayed.fetch_add(replayed, Ordering::Relaxed);
self.snapshots_received.fetch_add(1, Ordering::Relaxed);
for filename in expected_files {
let file_path = self.snapshot_dir.join(filename);
if let Err(e) = tokio::fs::remove_file(&file_path).await {
tracing::debug!("Failed to clean up snapshot file {}: {}", filename, e);
}
}
tracing::info!(
"Snapshot catch-up complete: {} events loaded, resuming WAL from offset {}",
replayed,
wal_offset_after_snapshot,
);
self.send_ack(wal_offset_after_snapshot, writer).await?;
Ok(())
}
async fn handle_wal_entry(
&self,
offset: u64,
entry: WALEntry,
writer: &mut tokio::net::tcp::OwnedWriteHalf,
) -> anyhow::Result<()> {
self.leader_offset.store(offset, Ordering::Relaxed);
if let Some(ref m) = self.metrics {
m.replication_wal_received_total.inc();
}
if !entry.verify() {
tracing::error!(
"CRC32 validation failed for WAL entry at offset {} — skipping",
offset,
);
self.corrupted_skipped.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
let current = self.last_replayed_offset.load(Ordering::Relaxed);
if offset <= current {
tracing::debug!("Skipping already-replayed offset {}", offset);
self.send_ack(offset, writer).await?;
return Ok(());
}
let event = entry.event.clone();
if let Err(e) = self.local_wal.append(event.clone()) {
tracing::error!("Failed to write to local WAL at offset {}: {}", offset, e);
}
if let Err(e) = self.store.ingest_replicated(&event) {
tracing::error!(
"Failed to replay event at offset {} into store: {}",
offset,
e
);
return Ok(());
}
self.last_replayed_offset.store(offset, Ordering::Relaxed);
self.total_replayed.fetch_add(1, Ordering::Relaxed);
if let Some(ref m) = self.metrics {
m.replication_wal_replayed_total.inc();
let lag = self
.leader_offset
.load(Ordering::Relaxed)
.saturating_sub(offset);
m.replication_lag_seconds.set(lag as i64);
}
self.send_ack(offset, writer).await?;
tracing::trace!("Replayed WAL entry at offset {}", offset);
Ok(())
}
async fn send_ack(
&self,
offset: u64,
writer: &mut tokio::net::tcp::OwnedWriteHalf,
) -> anyhow::Result<()> {
let ack = FollowerMessage::Ack { offset };
let mut json = serde_json::to_string(&ack)?;
json.push('\n');
writer
.write_all(json.as_bytes())
.await
.context("sending ACK to leader")?;
writer.flush().await.context("flushing ACK to leader")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_follower_replication_status_serialization() {
let status = FollowerReplicationStatus {
connected: true,
leader: "core-leader:3910".to_string(),
replication_lag_ms: 42,
last_replayed_offset: 100,
leader_offset: 142,
total_replayed: 100,
corrupted_skipped: 0,
reconnect_count: 1,
snapshots_received: 0,
};
let json = serde_json::to_value(&status).unwrap();
assert_eq!(json["connected"], true);
assert_eq!(json["leader"], "core-leader:3910");
assert_eq!(json["replication_lag_ms"], 42);
assert_eq!(json["last_replayed_offset"], 100);
assert_eq!(json["leader_offset"], 142);
assert_eq!(json["total_replayed"], 100);
assert_eq!(json["corrupted_skipped"], 0);
assert_eq!(json["reconnect_count"], 1);
assert_eq!(json["snapshots_received"], 0);
}
#[test]
fn test_follower_replication_status_defaults() {
let status = FollowerReplicationStatus {
connected: false,
leader: "localhost:3910".to_string(),
replication_lag_ms: 0,
last_replayed_offset: 0,
leader_offset: 0,
total_replayed: 0,
corrupted_skipped: 0,
reconnect_count: 0,
snapshots_received: 0,
};
let json = serde_json::to_value(&status).unwrap();
assert_eq!(json["connected"], false);
assert_eq!(json["replication_lag_ms"], 0);
assert_eq!(json["snapshots_received"], 0);
}
#[test]
fn test_wal_receiver_creation() {
let store = Arc::new(EventStore::new());
let temp_dir = tempfile::TempDir::new().unwrap();
let receiver = WalReceiver::new(
"localhost:3910".to_string(),
temp_dir.path().join("follower-wal"),
store,
);
assert!(receiver.is_ok());
let receiver = receiver.unwrap();
let status = receiver.status();
assert!(!status.connected);
assert_eq!(status.leader, "localhost:3910");
assert_eq!(status.last_replayed_offset, 0);
assert_eq!(status.total_replayed, 0);
assert_eq!(status.snapshots_received, 0);
}
#[test]
fn test_wal_receiver_recovers_offset_from_local_wal() {
let store = Arc::new(EventStore::new());
let temp_dir = tempfile::TempDir::new().unwrap();
let wal_dir = temp_dir.path().join("follower-wal");
{
let wal = WriteAheadLog::new(&wal_dir, WALConfig::default()).unwrap();
let event = crate::test_utils::test_event("test-entity", "test.replicated");
wal.append(event).unwrap();
let event2 = crate::test_utils::test_event("test-entity", "test.replicated");
wal.append(event2).unwrap();
}
let receiver = WalReceiver::new("localhost:3910".to_string(), &wal_dir, store).unwrap();
let status = receiver.status();
assert_eq!(status.last_replayed_offset, 0);
}
#[test]
fn test_snapshot_dir_created_correctly() {
let store = Arc::new(EventStore::new());
let temp_dir = tempfile::TempDir::new().unwrap();
let wal_dir = temp_dir.path().join("follower-wal");
let receiver = WalReceiver::new("localhost:3910".to_string(), &wal_dir, store).unwrap();
assert_eq!(
receiver.snapshot_dir,
temp_dir.path().join("follower-snapshots"),
);
}
}