use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use tokio::sync::{broadcast, mpsc, oneshot};
use uuid::Uuid;
use crate::{Result, Error};
use super::role_manager::{RoleManager, NodeRole, RoleChangeReason, SwitchoverPhase};
use super::ha_state::HAStateRegistry;
#[derive(Debug, Clone)]
pub struct SwitchoverConfig {
pub sync_timeout: Duration,
pub total_timeout: Duration,
pub min_synced_standbys: usize,
pub allow_partial_sync: bool,
pub drain_timeout: Duration,
pub health_check_interval: Duration,
}
impl Default for SwitchoverConfig {
fn default() -> Self {
Self {
sync_timeout: Duration::from_secs(30),
total_timeout: Duration::from_secs(60),
min_synced_standbys: 1,
allow_partial_sync: false,
drain_timeout: Duration::from_secs(10),
health_check_interval: Duration::from_millis(100),
}
}
}
#[derive(Debug, Clone)]
pub struct StandbyStatus {
pub node_id: Uuid,
pub current_lsn: u64,
pub target_lsn: u64,
pub is_synced: bool,
pub last_seen: Instant,
pub replication_lag_ms: u64,
}
#[derive(Debug, Clone)]
pub struct SwitchoverCheck {
pub can_proceed: bool,
pub target_healthy: bool,
pub target_lsn: u64,
pub primary_lsn: u64,
pub lag_bytes: u64,
pub synced_standbys: Vec<Uuid>,
pub warnings: Vec<String>,
pub blockers: Vec<String>,
}
#[derive(Debug)]
pub enum SwitchoverCommand {
Initiate {
target_node: Uuid,
response: oneshot::Sender<Result<Uuid>>,
},
Cancel {
response: oneshot::Sender<Result<()>>,
},
Check {
target_node: Uuid,
response: oneshot::Sender<Result<SwitchoverCheck>>,
},
StandbyProgress {
node_id: Uuid,
lsn: u64,
},
StandbyReady {
node_id: Uuid,
},
Shutdown,
}
#[derive(Debug, Clone)]
pub enum SwitchoverEvent {
Started {
switchover_id: Uuid,
source: Uuid,
target: Uuid,
},
PhaseChanged {
switchover_id: Uuid,
phase: SwitchoverPhase,
},
PrepareNewPrimary {
switchover_id: Uuid,
new_primary: Uuid,
new_primary_addr: String,
},
Completed {
switchover_id: Uuid,
new_primary: Uuid,
duration_ms: u64,
},
Failed {
switchover_id: Uuid,
error: String,
},
Cancelled {
switchover_id: Uuid,
},
}
pub struct SwitchoverCoordinator {
node_id: Uuid,
role_manager: Arc<RoleManager>,
ha_registry: Arc<HAStateRegistry>,
config: SwitchoverConfig,
command_tx: mpsc::Sender<SwitchoverCommand>,
event_tx: broadcast::Sender<SwitchoverEvent>,
node_addresses: Arc<RwLock<HashMap<Uuid, String>>>,
in_flight_transactions: Arc<std::sync::atomic::AtomicU64>,
writes_blocked: Arc<std::sync::atomic::AtomicBool>,
}
impl SwitchoverCoordinator {
pub fn new(
node_id: Uuid,
role_manager: Arc<RoleManager>,
ha_registry: Arc<HAStateRegistry>,
config: SwitchoverConfig,
) -> (Self, mpsc::Receiver<SwitchoverCommand>) {
let (command_tx, command_rx) = mpsc::channel(64);
let (event_tx, _) = broadcast::channel(64);
let coordinator = Self {
node_id,
role_manager,
ha_registry,
config,
command_tx,
event_tx,
node_addresses: Arc::new(RwLock::new(HashMap::new())),
in_flight_transactions: Arc::new(std::sync::atomic::AtomicU64::new(0)),
writes_blocked: Arc::new(std::sync::atomic::AtomicBool::new(false)),
};
(coordinator, command_rx)
}
pub fn command_sender(&self) -> mpsc::Sender<SwitchoverCommand> {
self.command_tx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<SwitchoverEvent> {
self.event_tx.subscribe()
}
pub fn register_node_address(&self, node_id: Uuid, address: String) {
self.node_addresses.write().insert(node_id, address);
}
pub fn are_writes_blocked(&self) -> bool {
self.writes_blocked.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn begin_transaction(&self) -> Result<TransactionGuard> {
if self.writes_blocked.load(std::sync::atomic::Ordering::SeqCst) {
return Err(Error::ha("Writes blocked during switchover"));
}
self.in_flight_transactions.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(TransactionGuard {
counter: Arc::clone(&self.in_flight_transactions),
})
}
pub async fn check_switchover(&self, target_node: Uuid) -> Result<SwitchoverCheck> {
let (response_tx, response_rx) = oneshot::channel();
self.command_tx
.send(SwitchoverCommand::Check {
target_node,
response: response_tx,
})
.await
.map_err(|_| Error::ha("Coordinator channel closed"))?;
response_rx.await.map_err(|_| Error::ha("Response channel closed"))?
}
pub async fn initiate_switchover(&self, target_node: Uuid) -> Result<Uuid> {
let (response_tx, response_rx) = oneshot::channel();
self.command_tx
.send(SwitchoverCommand::Initiate {
target_node,
response: response_tx,
})
.await
.map_err(|_| Error::ha("Coordinator channel closed"))?;
response_rx.await.map_err(|_| Error::ha("Response channel closed"))?
}
pub async fn cancel_switchover(&self) -> Result<()> {
let (response_tx, response_rx) = oneshot::channel();
self.command_tx
.send(SwitchoverCommand::Cancel {
response: response_tx,
})
.await
.map_err(|_| Error::ha("Coordinator channel closed"))?;
response_rx.await.map_err(|_| Error::ha("Response channel closed"))?
}
pub async fn run(&self, mut command_rx: mpsc::Receiver<SwitchoverCommand>) {
tracing::info!("Switchover coordinator started on node {}", self.node_id);
while let Some(cmd) = command_rx.recv().await {
match cmd {
SwitchoverCommand::Initiate { target_node, response } => {
let result = self.handle_initiate(target_node).await;
let _ = response.send(result);
}
SwitchoverCommand::Cancel { response } => {
let result = self.handle_cancel().await;
let _ = response.send(result);
}
SwitchoverCommand::Check { target_node, response } => {
let result = self.handle_check(target_node).await;
let _ = response.send(result);
}
SwitchoverCommand::StandbyProgress { node_id, lsn } => {
self.handle_standby_progress(node_id, lsn).await;
}
SwitchoverCommand::StandbyReady { node_id } => {
self.handle_standby_ready(node_id).await;
}
SwitchoverCommand::Shutdown => {
tracing::info!("Switchover coordinator shutting down");
break;
}
}
}
}
async fn handle_initiate(&self, target_node: Uuid) -> Result<Uuid> {
if !self.role_manager.is_primary() {
return Err(Error::ha("Only primary can initiate switchover"));
}
let check = self.handle_check(target_node).await?;
if !check.can_proceed {
return Err(Error::ha(format!(
"Switchover blocked: {}",
check.blockers.join(", ")
)));
}
let switchover_id = self.role_manager.begin_switchover(target_node)?;
let start_time = Instant::now();
let _ = self.event_tx.send(SwitchoverEvent::Started {
switchover_id,
source: self.node_id,
target: target_node,
});
let result = self.execute_switchover(switchover_id, target_node, check.primary_lsn).await;
match result {
Ok(()) => {
let duration = start_time.elapsed();
let _ = self.event_tx.send(SwitchoverEvent::Completed {
switchover_id,
new_primary: target_node,
duration_ms: duration.as_millis() as u64,
});
tracing::info!(
"Switchover {} completed in {}ms",
switchover_id,
duration.as_millis()
);
Ok(switchover_id)
}
Err(e) => {
let _ = self.event_tx.send(SwitchoverEvent::Failed {
switchover_id,
error: e.to_string(),
});
self.rollback_switchover().await;
Err(e)
}
}
}
async fn execute_switchover(
&self,
switchover_id: Uuid,
target_node: Uuid,
primary_lsn: u64,
) -> Result<()> {
let timeout = Instant::now() + self.config.total_timeout;
self.advance_phase(SwitchoverPhase::Preparation)?;
let _ = self.event_tx.send(SwitchoverEvent::PhaseChanged {
switchover_id,
phase: SwitchoverPhase::Preparation,
});
self.advance_phase(SwitchoverPhase::Synchronization)?;
let _ = self.event_tx.send(SwitchoverEvent::PhaseChanged {
switchover_id,
phase: SwitchoverPhase::Synchronization,
});
self.writes_blocked.store(true, std::sync::atomic::Ordering::SeqCst);
tracing::info!("Writes blocked for switchover synchronization");
self.role_manager.change_role(NodeRole::Draining, RoleChangeReason::Switchover)?;
let drain_deadline = Instant::now() + self.config.drain_timeout;
while self.in_flight_transactions.load(std::sync::atomic::Ordering::SeqCst) > 0 {
if Instant::now() > drain_deadline {
return Err(Error::ha("Drain timeout: in-flight transactions not completing"));
}
if Instant::now() > timeout {
return Err(Error::ha("Switchover timeout during drain phase"));
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
tracing::info!("Primary drained, no in-flight transactions");
let final_lsn = self.ha_registry.get_lsn();
self.role_manager.set_switchover_target_lsn(final_lsn)?;
let sync_deadline = Instant::now() + self.config.sync_timeout;
loop {
let standby_info = self.ha_registry.get_standbys()
.into_iter()
.find(|s| s.node_id == target_node);
if let Some(info) = standby_info {
if info.apply_lsn >= final_lsn {
tracing::info!(
"Target standby {} caught up to LSN {}",
target_node,
final_lsn
);
break;
}
}
if Instant::now() > sync_deadline {
return Err(Error::ha("Sync timeout: target standby not caught up"));
}
if Instant::now() > timeout {
return Err(Error::ha("Switchover timeout during sync phase"));
}
tokio::time::sleep(self.config.health_check_interval).await;
}
self.advance_phase(SwitchoverPhase::RoleChange)?;
let _ = self.event_tx.send(SwitchoverEvent::PhaseChanged {
switchover_id,
phase: SwitchoverPhase::RoleChange,
});
self.role_manager.change_role(
NodeRole::TransitioningToStandby,
RoleChangeReason::Switchover,
)?;
self.role_manager.demote_to_standby(RoleChangeReason::Switchover)?;
tracing::info!("Signaling target {} to promote", target_node);
self.advance_phase(SwitchoverPhase::Reconfiguration)?;
let _ = self.event_tx.send(SwitchoverEvent::PhaseChanged {
switchover_id,
phase: SwitchoverPhase::Reconfiguration,
});
let target_addr = self.node_addresses.read().get(&target_node).cloned()
.unwrap_or_else(|| format!("{}:5433", target_node));
let _ = self.event_tx.send(SwitchoverEvent::PrepareNewPrimary {
switchover_id,
new_primary: target_node,
new_primary_addr: target_addr,
});
self.role_manager.set_current_primary(Some(target_node));
self.advance_phase(SwitchoverPhase::Resumption)?;
let _ = self.event_tx.send(SwitchoverEvent::PhaseChanged {
switchover_id,
phase: SwitchoverPhase::Resumption,
});
self.writes_blocked.store(false, std::sync::atomic::Ordering::SeqCst);
self.advance_phase(SwitchoverPhase::Completed)?;
Ok(())
}
async fn handle_cancel(&self) -> Result<()> {
self.role_manager.cancel_switchover()?;
self.rollback_switchover().await;
if let Some(state) = self.role_manager.switchover_state() {
let _ = self.event_tx.send(SwitchoverEvent::Cancelled {
switchover_id: state.switchover_id,
});
}
Ok(())
}
async fn rollback_switchover(&self) {
tracing::warn!("Rolling back switchover");
self.writes_blocked.store(false, std::sync::atomic::Ordering::SeqCst);
let current_role = self.role_manager.role();
if matches!(
current_role,
NodeRole::Draining | NodeRole::TransitioningToStandby
) {
if let Err(e) = self.role_manager.change_role(NodeRole::Primary, RoleChangeReason::Switchover) {
tracing::error!("Failed to rollback to primary: {}", e);
}
}
self.role_manager.set_current_primary(Some(self.node_id));
}
async fn handle_check(&self, target_node: Uuid) -> Result<SwitchoverCheck> {
let mut check = SwitchoverCheck {
can_proceed: true,
target_healthy: false,
target_lsn: 0,
primary_lsn: self.ha_registry.get_lsn(),
lag_bytes: 0,
synced_standbys: vec![],
warnings: vec![],
blockers: vec![],
};
if !self.role_manager.is_primary() {
check.can_proceed = false;
check.blockers.push("This node is not the primary".to_string());
}
if self.role_manager.is_switchover_in_progress() {
check.can_proceed = false;
check.blockers.push("Switchover already in progress".to_string());
}
let standby_info = self.ha_registry.get_standbys()
.into_iter()
.find(|s| s.node_id == target_node);
if let Some(info) = standby_info {
check.target_healthy = true;
check.target_lsn = info.apply_lsn;
check.lag_bytes = check.primary_lsn.saturating_sub(check.target_lsn);
if info.apply_lsn < check.primary_lsn {
let lag = check.primary_lsn - info.apply_lsn;
check.warnings.push(format!(
"Target standby is {} LSN behind (will sync during switchover)",
lag
));
}
} else {
check.can_proceed = false;
check.blockers.push(format!(
"Target node {} not found or not healthy",
target_node
));
}
for info in self.ha_registry.get_standbys() {
if info.apply_lsn >= check.primary_lsn.saturating_sub(100) {
check.synced_standbys.push(info.node_id);
}
}
if check.synced_standbys.len() < self.config.min_synced_standbys {
if self.config.allow_partial_sync {
check.warnings.push(format!(
"Only {} standbys synced (minimum: {})",
check.synced_standbys.len(),
self.config.min_synced_standbys
));
} else {
check.can_proceed = false;
check.blockers.push(format!(
"Insufficient synced standbys: {} (need {})",
check.synced_standbys.len(),
self.config.min_synced_standbys
));
}
}
Ok(check)
}
async fn handle_standby_progress(&self, node_id: Uuid, lsn: u64) {
if let Some(state) = self.role_manager.switchover_state() {
if node_id == state.target_node {
if let Some(target_lsn) = state.target_lsn {
if lsn >= target_lsn {
tracing::info!(
"Target standby {} reached target LSN {}",
node_id,
target_lsn
);
}
}
}
}
}
async fn handle_standby_ready(&self, node_id: Uuid) {
tracing::info!("Standby {} reports ready", node_id);
}
fn advance_phase(&self, phase: SwitchoverPhase) -> Result<()> {
self.role_manager.advance_switchover_phase(phase)
}
}
pub struct TransactionGuard {
counter: Arc<std::sync::atomic::AtomicU64>,
}
impl Drop for TransactionGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
}
}
#[derive(Debug, Clone)]
pub struct ReconnectTarget {
pub node_id: Uuid,
pub address: String,
}
pub struct StandbySwitchoverHandler {
node_id: Uuid,
role_manager: Arc<RoleManager>,
ha_registry: Arc<HAStateRegistry>,
event_rx: broadcast::Receiver<SwitchoverEvent>,
reconnect_tx: Option<mpsc::Sender<ReconnectTarget>>,
}
impl StandbySwitchoverHandler {
pub fn new(
node_id: Uuid,
role_manager: Arc<RoleManager>,
ha_registry: Arc<HAStateRegistry>,
event_rx: broadcast::Receiver<SwitchoverEvent>,
) -> Self {
Self {
node_id,
role_manager,
ha_registry,
event_rx,
reconnect_tx: None,
}
}
pub fn with_reconnect_channel(
node_id: Uuid,
role_manager: Arc<RoleManager>,
ha_registry: Arc<HAStateRegistry>,
event_rx: broadcast::Receiver<SwitchoverEvent>,
reconnect_tx: mpsc::Sender<ReconnectTarget>,
) -> Self {
Self {
node_id,
role_manager,
ha_registry,
event_rx,
reconnect_tx: Some(reconnect_tx),
}
}
pub async fn run(mut self) {
tracing::info!("Standby switchover handler started on node {}", self.node_id);
loop {
match self.event_rx.recv().await {
Ok(event) => self.handle_event(event).await,
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Standby switchover handler lagged {} events", n);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("Switchover event channel closed");
break;
}
}
}
}
async fn handle_event(&self, event: SwitchoverEvent) {
match event {
SwitchoverEvent::Started { switchover_id, source, target } => {
tracing::info!(
"Switchover {} started: {} -> {}",
switchover_id,
source,
target
);
if target == self.node_id {
tracing::info!("This node is the switchover target - preparing for promotion");
if let Err(e) = self.role_manager.change_role(
NodeRole::CatchingUp,
RoleChangeReason::Switchover,
) {
tracing::error!("Failed to enter catching up state: {}", e);
}
}
}
SwitchoverEvent::PrepareNewPrimary { switchover_id, new_primary, new_primary_addr } => {
tracing::info!(
"Switchover {}: new primary is {} at {}",
switchover_id,
new_primary,
new_primary_addr
);
if new_primary == self.node_id {
if let Err(e) = self.role_manager.promote_to_primary(RoleChangeReason::Switchover) {
tracing::error!("Failed to promote to primary: {}", e);
} else {
tracing::info!("Successfully promoted to primary");
}
} else {
self.role_manager.set_current_primary(Some(new_primary));
tracing::info!("Reconfigured to follow new primary {}", new_primary);
if let Some(ref tx) = self.reconnect_tx {
let target = ReconnectTarget {
node_id: new_primary,
address: new_primary_addr.clone(),
};
if let Err(e) = tx.try_send(target) {
tracing::error!("Failed to signal WAL replicator reconnection: {}", e);
} else {
tracing::info!("Signaled WAL replicator to reconnect to {}", new_primary_addr);
}
}
}
}
SwitchoverEvent::Completed { switchover_id, new_primary, duration_ms } => {
tracing::info!(
"Switchover {} completed in {}ms, new primary: {}",
switchover_id,
duration_ms,
new_primary
);
}
SwitchoverEvent::Failed { switchover_id, error } => {
tracing::error!("Switchover {} failed: {}", switchover_id, error);
if self.role_manager.role().is_transitioning() {
let _ = self.role_manager.change_role(
NodeRole::Standby,
RoleChangeReason::Switchover,
);
}
}
SwitchoverEvent::Cancelled { switchover_id } => {
tracing::info!("Switchover {} cancelled", switchover_id);
if self.role_manager.role().is_transitioning() {
let _ = self.role_manager.change_role(
NodeRole::Standby,
RoleChangeReason::Switchover,
);
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_switchover_check() {
}
}