use super::{NodeEndpoint, NodeId, ProxyError, Result};
#[cfg(test)]
use super::NodeRole;
use crate::backend::{BackendClient, BackendConfig};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
#[cfg(feature = "ha-tr")]
use super::failover_replay::{FailoverReplay, ReplayConfig, ReplayResult};
#[cfg(feature = "ha-tr")]
use super::transaction_journal::TransactionJournal;
#[derive(Debug, Clone)]
pub struct FailoverConfig {
pub detection_time: Duration,
pub failover_timeout: Duration,
pub auto_failover: bool,
pub prefer_sync_standby: bool,
pub max_lag_bytes: u64,
pub retry_failed: bool,
pub max_retries: u32,
}
impl Default for FailoverConfig {
fn default() -> Self {
Self {
detection_time: Duration::from_secs(10),
failover_timeout: Duration::from_secs(60),
auto_failover: true,
prefer_sync_standby: true,
max_lag_bytes: 16 * 1024 * 1024, retry_failed: true,
max_retries: 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FailoverMode {
Automatic,
Manual,
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FailoverState {
Normal,
PrimaryFailed,
InProgress,
WaitingForSync,
Completed,
Failed,
}
#[derive(Debug, Clone)]
pub enum FailoverEvent {
PrimaryFailed { node_id: NodeId },
FailoverStarted { from: NodeId, to: NodeId },
WaitingForSync { standby: NodeId, lag_bytes: u64 },
StandbyPromoted { new_primary: NodeId },
FailoverCompleted { duration_ms: u64 },
FailoverFailed { reason: String },
OldPrimaryRecovered { node_id: NodeId },
}
#[derive(Debug, Clone)]
pub struct FailoverCandidate {
pub node_id: NodeId,
pub endpoint: NodeEndpoint,
pub is_sync: bool,
pub lag_bytes: u64,
pub priority: u32,
pub last_heartbeat: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone)]
pub struct FailoverHistoryEntry {
pub id: uuid::Uuid,
pub started_at: chrono::DateTime<chrono::Utc>,
pub ended_at: Option<chrono::DateTime<chrono::Utc>>,
pub old_primary: NodeId,
pub new_primary: Option<NodeId>,
pub success: bool,
pub error: Option<String>,
}
pub struct FailoverController {
config: FailoverConfig,
state: Arc<RwLock<FailoverState>>,
current_primary: Arc<RwLock<Option<NodeId>>>,
candidates: Arc<RwLock<HashMap<NodeId, FailoverCandidate>>>,
event_tx: mpsc::Sender<FailoverEvent>,
event_rx: Option<mpsc::Receiver<FailoverEvent>>,
failover_count: AtomicU64,
history: Arc<RwLock<Vec<FailoverHistoryEntry>>>,
backend_template: Option<BackendConfig>,
}
impl FailoverController {
pub fn new(config: FailoverConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(100);
Self {
config,
state: Arc::new(RwLock::new(FailoverState::Normal)),
current_primary: Arc::new(RwLock::new(None)),
candidates: Arc::new(RwLock::new(HashMap::new())),
event_tx,
event_rx: Some(event_rx),
failover_count: AtomicU64::new(0),
history: Arc::new(RwLock::new(Vec::new())),
backend_template: None,
}
}
pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
self.backend_template = Some(template);
self
}
fn backend_config_for(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
self.backend_template.as_ref().map(|t| {
let mut c = t.clone();
c.host = endpoint.host.clone();
c.port = endpoint.port;
c
})
}
pub async fn set_primary(&self, node_id: NodeId) {
*self.current_primary.write().await = Some(node_id);
tracing::info!("Primary set to {:?}", node_id);
}
pub async fn get_primary(&self) -> Option<NodeId> {
*self.current_primary.read().await
}
pub async fn register_candidate(&self, candidate: FailoverCandidate) {
let node_id = candidate.node_id;
self.candidates.write().await.insert(node_id, candidate);
tracing::debug!("Registered failover candidate {:?}", node_id);
}
pub async fn remove_candidate(&self, node_id: &NodeId) {
self.candidates.write().await.remove(node_id);
}
pub async fn update_candidate_lag(&self, node_id: &NodeId, lag_bytes: u64) {
if let Some(candidate) = self.candidates.write().await.get_mut(node_id) {
candidate.lag_bytes = lag_bytes;
candidate.last_heartbeat = Some(chrono::Utc::now());
}
}
pub async fn state(&self) -> FailoverState {
*self.state.read().await
}
pub async fn on_primary_failed(&self, node_id: NodeId) -> Result<()> {
let current_primary = self.current_primary.read().await;
if *current_primary != Some(node_id) {
return Ok(()); }
drop(current_primary);
*self.state.write().await = FailoverState::PrimaryFailed;
let _ = self
.event_tx
.send(FailoverEvent::PrimaryFailed { node_id })
.await;
tracing::warn!("Primary node {:?} failed", node_id);
if self.config.auto_failover {
self.initiate_failover().await?;
}
Ok(())
}
pub async fn initiate_failover(&self) -> Result<()> {
let old_primary = self
.current_primary
.read()
.await
.ok_or_else(|| ProxyError::FailoverFailed("No primary to failover from".to_string()))?;
let candidate = self.select_best_candidate().await?;
let new_primary = candidate.node_id;
*self.state.write().await = FailoverState::InProgress;
let _ = self
.event_tx
.send(FailoverEvent::FailoverStarted {
from: old_primary,
to: new_primary,
})
.await;
let start = chrono::Utc::now();
let history_entry = FailoverHistoryEntry {
id: uuid::Uuid::new_v4(),
started_at: start,
ended_at: None,
old_primary,
new_primary: Some(new_primary),
success: false,
error: None,
};
self.history.write().await.push(history_entry);
if candidate.lag_bytes > self.config.max_lag_bytes {
*self.state.write().await = FailoverState::WaitingForSync;
let _ = self
.event_tx
.send(FailoverEvent::WaitingForSync {
standby: new_primary,
lag_bytes: candidate.lag_bytes,
})
.await;
let sync_result = self.wait_for_sync(new_primary).await;
if let Err(e) = sync_result {
self.fail_failover(&e.to_string()).await;
return Err(e);
}
}
self.promote_standby(new_primary).await?;
*self.current_primary.write().await = Some(new_primary);
*self.state.write().await = FailoverState::Completed;
self.failover_count.fetch_add(1, Ordering::SeqCst);
let duration = chrono::Utc::now()
.signed_duration_since(start)
.num_milliseconds() as u64;
if let Some(entry) = self.history.write().await.last_mut() {
entry.ended_at = Some(chrono::Utc::now());
entry.success = true;
}
let _ = self
.event_tx
.send(FailoverEvent::StandbyPromoted {
new_primary,
})
.await;
let _ = self
.event_tx
.send(FailoverEvent::FailoverCompleted { duration_ms: duration })
.await;
tracing::info!(
"Failover completed: {:?} -> {:?} in {}ms",
old_primary,
new_primary,
duration
);
tokio::spawn({
let state = self.state.clone();
async move {
tokio::time::sleep(Duration::from_secs(1)).await;
*state.write().await = FailoverState::Normal;
}
});
Ok(())
}
async fn select_best_candidate(&self) -> Result<FailoverCandidate> {
let candidates = self.candidates.read().await;
if candidates.is_empty() {
return Err(ProxyError::FailoverFailed(
"No failover candidates available".to_string(),
));
}
let mut sorted: Vec<_> = candidates.values().cloned().collect();
sorted.sort_by(|a, b| {
if self.config.prefer_sync_standby {
if a.is_sync != b.is_sync {
return b.is_sync.cmp(&a.is_sync);
}
}
if a.lag_bytes != b.lag_bytes {
return a.lag_bytes.cmp(&b.lag_bytes);
}
a.priority.cmp(&b.priority)
});
sorted
.first()
.cloned()
.ok_or_else(|| ProxyError::FailoverFailed("No eligible candidates".to_string()))
}
async fn wait_for_sync(&self, standby: NodeId) -> Result<()> {
let endpoint = self
.candidates
.read()
.await
.get(&standby)
.map(|c| c.endpoint.clone());
let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
Some(c) => c,
None => {
tokio::time::sleep(Duration::from_millis(50)).await;
return Ok(());
}
};
let overall = self.config.failover_timeout;
tokio::time::timeout(overall, Self::poll_until_caught_up(cfg))
.await
.map_err(|_| ProxyError::Timeout("standby sync timeout".to_string()))??;
Ok(())
}
async fn poll_until_caught_up(cfg: BackendConfig) -> Result<()> {
let mut client = BackendClient::connect(&cfg)
.await
.map_err(|e| ProxyError::Failover(format!("connect to candidate: {}", e)))?;
let mut last: Option<String> = None;
let mut stable_polls = 0u32;
loop {
let value = client
.query_scalar("SELECT pg_last_wal_replay_lsn()::text")
.await
.map_err(|e| ProxyError::Failover(format!("wal lsn probe: {}", e)))?;
let lsn = value
.into_string()
.ok_or_else(|| ProxyError::Failover("null WAL replay LSN".into()))?;
if last.as_ref() == Some(&lsn) {
stable_polls += 1;
if stable_polls >= 2 {
tracing::info!(lsn = %lsn, "standby caught up");
client.close().await;
return Ok(());
}
} else {
stable_polls = 0;
last = Some(lsn);
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
}
async fn promote_standby(&self, standby: NodeId) -> Result<()> {
let endpoint = self
.candidates
.read()
.await
.get(&standby)
.map(|c| c.endpoint.clone());
let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
Some(c) => c,
None => {
tracing::info!(
node = ?standby,
"promote_standby: skeleton path (no backend template) — no-op"
);
return Ok(());
}
};
let wait_secs = self.config.failover_timeout.as_secs().max(10).min(300);
let mut client = BackendClient::connect(&cfg)
.await
.map_err(|e| ProxyError::FailoverFailed(format!("connect to promote: {}", e)))?;
let sql = format!("SELECT pg_promote(true, {})", wait_secs);
let value = client
.query_scalar(&sql)
.await
.map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)))?;
let promoted = value
.as_bool("pg_promote")
.map_err(|e| ProxyError::FailoverFailed(format!("pg_promote result: {}", e)))?
.unwrap_or(false);
client.close().await;
if !promoted {
return Err(ProxyError::FailoverFailed(
"pg_promote returned false".to_string(),
));
}
let mut verify = BackendClient::connect(&cfg)
.await
.map_err(|e| ProxyError::FailoverFailed(format!("connect to verify: {}", e)))?;
let in_recovery = verify
.query_scalar("SELECT pg_is_in_recovery()")
.await
.map_err(|e| ProxyError::FailoverFailed(format!("verify probe: {}", e)))?;
verify.close().await;
let still_standby = in_recovery
.as_bool("pg_is_in_recovery")
.map_err(|e| ProxyError::FailoverFailed(format!("verify bool: {}", e)))?
.unwrap_or(true);
if still_standby {
return Err(ProxyError::FailoverFailed(
"post-promote pg_is_in_recovery still true".to_string(),
));
}
tracing::info!(node = ?standby, "standby promoted to primary");
Ok(())
}
async fn fail_failover(&self, reason: &str) {
*self.state.write().await = FailoverState::Failed;
if let Some(entry) = self.history.write().await.last_mut() {
entry.ended_at = Some(chrono::Utc::now());
entry.success = false;
entry.error = Some(reason.to_string());
}
let _ = self
.event_tx
.send(FailoverEvent::FailoverFailed {
reason: reason.to_string(),
})
.await;
tracing::error!("Failover failed: {}", reason);
}
pub async fn on_old_primary_recovered(&self, node_id: NodeId) {
let _ = self
.event_tx
.send(FailoverEvent::OldPrimaryRecovered { node_id })
.await;
tracing::warn!(
"old primary {:?} recovered — must be demoted out-of-band to prevent split-brain",
node_id
);
let endpoint = self
.candidates
.read()
.await
.get(&node_id)
.map(|c| c.endpoint.clone());
let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
Some(c) => c,
None => return, };
match BackendClient::connect(&cfg).await {
Ok(mut client) => {
let in_recovery_result = client
.query_scalar("SELECT pg_is_in_recovery()")
.await;
client.close().await;
if let Ok(tv) = in_recovery_result {
if let Ok(Some(false)) = tv.as_bool("pg_is_in_recovery") {
tracing::error!(
"split-brain hazard: node {:?} recovered and still reports primary (pg_is_in_recovery=false). Shut it down or use pg_rewind before reintroducing.",
node_id
);
}
}
}
Err(e) => {
tracing::debug!(
error = %e,
"could not connect to recovered node for split-brain probe"
);
}
}
}
pub async fn manual_failover(&self, target: NodeId) -> Result<()> {
let candidates = self.candidates.read().await;
if !candidates.contains_key(&target) {
return Err(ProxyError::FailoverFailed(format!(
"Node {:?} is not a valid failover candidate",
target
)));
}
drop(candidates);
*self.state.write().await = FailoverState::InProgress;
let old_primary = self.current_primary.read().await.unwrap_or(NodeId::new());
let _ = self
.event_tx
.send(FailoverEvent::FailoverStarted {
from: old_primary,
to: target,
})
.await;
self.promote_standby(target).await?;
*self.current_primary.write().await = Some(target);
*self.state.write().await = FailoverState::Completed;
self.failover_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
pub fn failover_count(&self) -> u64 {
self.failover_count.load(Ordering::SeqCst)
}
pub async fn history(&self) -> Vec<FailoverHistoryEntry> {
self.history.read().await.clone()
}
pub fn take_event_receiver(&mut self) -> Option<mpsc::Receiver<FailoverEvent>> {
self.event_rx.take()
}
#[cfg(feature = "ha-tr")]
pub async fn coordinate_failover_replay(
&self,
journal: &TransactionJournal,
failed_node: NodeId,
new_primary_endpoint: &NodeEndpoint,
) -> Result<CoordinatedReplayResult> {
let start = std::time::Instant::now();
tracing::info!(
"Starting coordinated replay: failed_node={:?}, new_primary={:?}",
failed_node,
new_primary_endpoint.id
);
let affected_txs = journal.get_transactions_for_node(failed_node).await;
if affected_txs.is_empty() {
tracing::info!("No active transactions to replay");
return Ok(CoordinatedReplayResult {
total_transactions: 0,
successful_replays: 0,
failed_replays: 0,
transaction_results: vec![],
duration_ms: start.elapsed().as_millis() as u64,
new_primary: new_primary_endpoint.id,
});
}
tracing::info!("Found {} active transactions to replay", affected_txs.len());
let max_lsn = affected_txs.iter().map(|tx| tx.start_lsn).max().unwrap_or(0);
self.wait_for_lsn_catchup(new_primary_endpoint.id, max_lsn).await?;
let replay_manager = FailoverReplay::new(ReplayConfig {
verify_results: true,
statement_timeout_ms: 30000,
retry_on_error: true,
max_retries: 3,
skip_read_only: false,
wait_for_wal_sync: false, max_wal_lag_bytes: 0,
});
let mut transaction_results = Vec::new();
let mut successful_replays = 0;
let mut failed_replays = 0;
for tx_journal in affected_txs {
let tx_id = tx_journal.tx_id;
tracing::debug!("Replaying transaction {:?} with {} entries", tx_id, tx_journal.entries.len());
match replay_manager.start_replay(tx_journal, new_primary_endpoint.id).await {
Ok(_) => {
match replay_manager.execute_replay(tx_id).await {
Ok(result) => {
if result.success {
successful_replays += 1;
tracing::debug!("Transaction {:?} replayed successfully", tx_id);
} else {
failed_replays += 1;
tracing::warn!(
"Transaction {:?} replay failed: {:?}",
tx_id,
result.error
);
}
transaction_results.push(result);
}
Err(e) => {
failed_replays += 1;
tracing::error!("Failed to execute replay for {:?}: {}", tx_id, e);
transaction_results.push(ReplayResult {
tx_id,
success: false,
statements_replayed: 0,
statements_skipped: 0,
statements_failed: 0,
verification_failures: 0,
duration_ms: 0,
error: Some(e.to_string()),
statement_results: vec![],
});
}
}
}
Err(e) => {
failed_replays += 1;
tracing::error!("Failed to start replay for {:?}: {}", tx_id, e);
transaction_results.push(ReplayResult {
tx_id,
success: false,
statements_replayed: 0,
statements_skipped: 0,
statements_failed: 0,
verification_failures: 0,
duration_ms: 0,
error: Some(e.to_string()),
statement_results: vec![],
});
}
}
}
let duration_ms = start.elapsed().as_millis() as u64;
tracing::info!(
"Coordinated replay completed: {}/{} successful in {}ms",
successful_replays,
successful_replays + failed_replays,
duration_ms
);
Ok(CoordinatedReplayResult {
total_transactions: successful_replays + failed_replays,
successful_replays,
failed_replays,
transaction_results,
duration_ms,
new_primary: new_primary_endpoint.id,
})
}
#[cfg(feature = "ha-tr")]
async fn wait_for_lsn_catchup(&self, node: NodeId, target_lsn: u64) -> Result<()> {
if target_lsn == 0 {
return Ok(());
}
tracing::debug!("Waiting for node {:?} to catch up to LSN {}", node, target_lsn);
let timeout = self.config.failover_timeout;
let start = std::time::Instant::now();
loop {
if start.elapsed() >= timeout {
return Err(ProxyError::Timeout(format!(
"Timeout waiting for node {:?} to catch up to LSN {}",
node, target_lsn
)));
}
let candidates = self.candidates.read().await;
if let Some(candidate) = candidates.get(&node) {
if candidate.lag_bytes == 0 {
tracing::debug!("Node {:?} has caught up", node);
return Ok(());
}
}
drop(candidates);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
#[cfg(feature = "ha-tr")]
#[derive(Debug, Clone)]
pub struct CoordinatedReplayResult {
pub total_transactions: usize,
pub successful_replays: usize,
pub failed_replays: usize,
pub transaction_results: Vec<ReplayResult>,
pub duration_ms: u64,
pub new_primary: NodeId,
}
#[cfg(feature = "ha-tr")]
impl CoordinatedReplayResult {
pub fn all_successful(&self) -> bool {
self.failed_replays == 0
}
pub fn success_rate(&self) -> f64 {
if self.total_transactions == 0 {
100.0
} else {
(self.successful_replays as f64 / self.total_transactions as f64) * 100.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = FailoverConfig::default();
assert!(config.auto_failover);
assert!(config.prefer_sync_standby);
assert_eq!(config.max_retries, 3);
}
#[tokio::test]
async fn test_set_get_primary() {
let controller = FailoverController::new(FailoverConfig::default());
let node_id = NodeId::new();
controller.set_primary(node_id).await;
assert_eq!(controller.get_primary().await, Some(node_id));
}
#[tokio::test]
async fn test_register_candidate() {
let controller = FailoverController::new(FailoverConfig::default());
let node_id = NodeId::new();
let candidate = FailoverCandidate {
node_id,
endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
is_sync: true,
lag_bytes: 0,
priority: 1,
last_heartbeat: None,
};
controller.register_candidate(candidate).await;
let candidates = controller.candidates.read().await;
assert!(candidates.contains_key(&node_id));
}
#[tokio::test]
async fn test_state_transitions() {
let controller = FailoverController::new(FailoverConfig::default());
assert_eq!(controller.state().await, FailoverState::Normal);
*controller.state.write().await = FailoverState::PrimaryFailed;
assert_eq!(controller.state().await, FailoverState::PrimaryFailed);
}
#[tokio::test]
async fn test_select_best_candidate() {
let controller = FailoverController::new(FailoverConfig::default());
let sync_node = NodeId::new();
let async_node = NodeId::new();
controller
.register_candidate(FailoverCandidate {
node_id: async_node,
endpoint: NodeEndpoint::new("async", 5432),
is_sync: false,
lag_bytes: 100,
priority: 1,
last_heartbeat: None,
})
.await;
controller
.register_candidate(FailoverCandidate {
node_id: sync_node,
endpoint: NodeEndpoint::new("sync", 5432),
is_sync: true,
lag_bytes: 50,
priority: 2,
last_heartbeat: None,
})
.await;
let best = controller.select_best_candidate().await.unwrap();
assert_eq!(best.node_id, sync_node);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_coordinate_failover_replay_empty() {
use super::super::transaction_journal::TransactionJournal;
let controller = FailoverController::new(FailoverConfig::default());
let journal = TransactionJournal::new();
let failed_node = NodeId::new();
let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
let result = controller
.coordinate_failover_replay(&journal, failed_node, &new_primary)
.await
.unwrap();
assert_eq!(result.total_transactions, 0);
assert_eq!(result.successful_replays, 0);
assert_eq!(result.failed_replays, 0);
assert!(result.all_successful());
assert_eq!(result.success_rate(), 100.0);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_coordinate_failover_replay_with_transactions() {
use super::super::transaction_journal::{TransactionJournal, JournalEntry, JournalValue, StatementType};
use uuid::Uuid;
let controller = FailoverController::new(FailoverConfig::default());
let journal = TransactionJournal::new();
let failed_node = NodeId::new();
let new_primary_id = NodeId::new();
let new_primary = NodeEndpoint::new("new-primary", 5432)
.with_role(NodeRole::Primary);
controller.register_candidate(FailoverCandidate {
node_id: new_primary.id,
endpoint: new_primary.clone(),
is_sync: true,
lag_bytes: 0,
priority: 1,
last_heartbeat: None,
}).await;
let tx_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
journal.begin_transaction(tx_id, session_id, failed_node, 100).await.unwrap();
journal.log_statement(
tx_id,
"INSERT INTO users (name) VALUES ('test')".to_string(),
vec![JournalValue::Text("test".to_string())],
Some(12345),
Some(1),
10,
).await.unwrap();
let result = controller
.coordinate_failover_replay(&journal, failed_node, &new_primary)
.await
.unwrap();
assert_eq!(result.total_transactions, 1);
assert_eq!(result.successful_replays, 1);
assert_eq!(result.failed_replays, 0);
assert!(result.all_successful());
}
#[cfg(feature = "ha-tr")]
#[test]
fn test_coordinated_replay_result_methods() {
let result = CoordinatedReplayResult {
total_transactions: 10,
successful_replays: 8,
failed_replays: 2,
transaction_results: vec![],
duration_ms: 1000,
new_primary: NodeId::new(),
};
assert!(!result.all_successful());
assert_eq!(result.success_rate(), 80.0);
let perfect = CoordinatedReplayResult {
total_transactions: 5,
successful_replays: 5,
failed_replays: 0,
transaction_results: vec![],
duration_ms: 500,
new_primary: NodeId::new(),
};
assert!(perfect.all_successful());
assert_eq!(perfect.success_rate(), 100.0);
}
}