use super::config::{StandbyConfig, WalStreamingConfig};
use super::{ReplicationError, Result};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::task::JoinHandle;
use uuid::Uuid;
pub type Lsn = u64;
#[derive(Debug, Clone)]
pub struct WalSegment {
pub segment_id: u64,
pub start_lsn: Lsn,
pub end_lsn: Lsn,
pub size: usize,
pub checksum: u32,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone)]
pub struct WalEntry {
pub lsn: Lsn,
pub tx_id: Option<u64>,
pub entry_type: WalEntryType,
pub data: Vec<u8>,
pub checksum: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WalEntryType {
Insert,
Update,
Delete,
TxBegin,
TxCommit,
TxRollback,
Checkpoint,
SchemaChange,
BranchOp,
}
#[derive(Debug, Clone)]
pub struct StandbyState {
pub node_id: Uuid,
pub ack_lsn: Lsn,
pub connected: bool,
pub last_heartbeat: chrono::DateTime<chrono::Utc>,
pub lag_bytes: u64,
}
pub struct WalReplicator {
config: WalStreamingConfig,
standbys: Vec<StandbyConfig>,
current_lsn: Arc<RwLock<Lsn>>,
standby_states: Arc<RwLock<HashMap<Uuid, StandbyState>>>,
wal_broadcast: broadcast::Sender<WalEntry>,
shutdown_tx: mpsc::Sender<()>,
shutdown_rx: Arc<RwLock<Option<mpsc::Receiver<()>>>>,
running: Arc<AtomicBool>,
task_handles: Arc<RwLock<Vec<JoinHandle<()>>>>,
}
impl WalReplicator {
pub fn new(config: WalStreamingConfig, standbys: Vec<StandbyConfig>) -> Self {
let (wal_broadcast, _) = broadcast::channel(config.batch_size);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
Self {
config,
standbys,
current_lsn: Arc::new(RwLock::new(0)),
standby_states: Arc::new(RwLock::new(HashMap::new())),
wal_broadcast,
shutdown_tx,
shutdown_rx: Arc::new(RwLock::new(Some(shutdown_rx))),
running: Arc::new(AtomicBool::new(false)),
task_handles: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn start(&self) -> Result<()> {
if self.running.swap(true, Ordering::SeqCst) {
return Err(ReplicationError::WalStreaming(
"WAL Replicator already running".to_string(),
));
}
tracing::info!(
"Starting WAL Replicator with {} standbys",
self.standbys.len()
);
{
let mut states = self.standby_states.write().await;
for standby in &self.standbys {
let node_id = Uuid::new_v4(); states.insert(
node_id,
StandbyState {
node_id,
ack_lsn: 0,
connected: false,
last_heartbeat: chrono::Utc::now(),
lag_bytes: 0,
},
);
tracing::info!(
"Registered standby at {}:{} (id: {})",
standby.host,
standby.port,
node_id
);
}
}
let running = self.running.clone();
let standby_states = self.standby_states.clone();
let heartbeat_interval = self.config.heartbeat_interval;
let heartbeat_timeout = self.config.heartbeat_timeout;
let heartbeat_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(heartbeat_interval);
while running.load(Ordering::SeqCst) {
interval.tick().await;
let now = chrono::Utc::now();
let mut states = standby_states.write().await;
for (id, state) in states.iter_mut() {
if state.connected {
let elapsed = now - state.last_heartbeat;
if elapsed > chrono::Duration::from_std(heartbeat_timeout).unwrap_or_default()
{
tracing::warn!("Standby {} heartbeat timeout, marking disconnected", id);
state.connected = false;
}
}
}
}
tracing::debug!("Heartbeat monitor task stopped");
});
{
let mut handles = self.task_handles.write().await;
handles.push(heartbeat_handle);
}
tracing::info!("WAL Replicator started successfully");
Ok(())
}
pub async fn stop(&self) -> Result<()> {
if !self.running.swap(false, Ordering::SeqCst) {
return Ok(()); }
tracing::info!("Stopping WAL Replicator...");
let _ = self.shutdown_tx.send(()).await;
{
let mut handles = self.task_handles.write().await;
for handle in handles.drain(..) {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
}
}
{
let mut states = self.standby_states.write().await;
for state in states.values_mut() {
state.connected = false;
}
}
tracing::info!("WAL Replicator stopped successfully");
Ok(())
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn append(&self, entry: WalEntry) -> Result<Lsn> {
let mut current = self.current_lsn.write().await;
*current = entry.lsn;
{
let mut states = self.standby_states.write().await;
for state in states.values_mut() {
state.lag_bytes = entry.lsn.saturating_sub(state.ack_lsn);
}
}
self.wal_broadcast
.send(entry.clone())
.map_err(|e| ReplicationError::WalStreaming(e.to_string()))?;
Ok(entry.lsn)
}
pub async fn current_lsn(&self) -> Lsn {
*self.current_lsn.read().await
}
pub async fn standby_states(&self) -> HashMap<Uuid, StandbyState> {
self.standby_states.read().await.clone()
}
pub async fn get_lag(&self, standby_id: &Uuid) -> Option<u64> {
let states = self.standby_states.read().await;
states.get(standby_id).map(|s| s.lag_bytes)
}
pub fn subscribe(&self) -> broadcast::Receiver<WalEntry> {
self.wal_broadcast.subscribe()
}
pub async fn acknowledge(&self, standby_id: Uuid, ack_lsn: Lsn) -> Result<()> {
let mut states = self.standby_states.write().await;
if let Some(state) = states.get_mut(&standby_id) {
state.ack_lsn = ack_lsn;
state.last_heartbeat = chrono::Utc::now();
let current = *self.current_lsn.read().await;
state.lag_bytes = current.saturating_sub(ack_lsn);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_wal_replicator_creation() {
let config = WalStreamingConfig::default();
let replicator = WalReplicator::new(config, vec![]);
assert_eq!(replicator.current_lsn().await, 0);
}
#[tokio::test]
async fn test_wal_entry_broadcast() {
let config = WalStreamingConfig::default();
let replicator = WalReplicator::new(config, vec![]);
let mut rx = replicator.subscribe();
let entry = WalEntry {
lsn: 1,
tx_id: None,
entry_type: WalEntryType::Insert,
data: vec![1, 2, 3],
checksum: 0,
};
replicator.append(entry.clone()).await.expect("append failed");
let received = rx.recv().await.expect("recv failed");
assert_eq!(received.lsn, 1);
}
#[tokio::test]
async fn test_start_stop() {
use super::super::config::SyncMode;
let config = WalStreamingConfig::default();
let standbys = vec![
StandbyConfig {
node_id: Uuid::new_v4(),
host: "standby1.example.com".to_string(),
port: 5433,
sync_mode: SyncMode::Async,
priority: 1,
},
];
let replicator = WalReplicator::new(config, standbys);
assert!(!replicator.is_running());
replicator.start().await.expect("start failed");
assert!(replicator.is_running());
let states = replicator.standby_states().await;
assert_eq!(states.len(), 1);
let result = replicator.start().await;
assert!(result.is_err());
replicator.stop().await.expect("stop failed");
assert!(!replicator.is_running());
replicator.stop().await.expect("stop failed");
}
}