use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use crate::error::{RaftError, RaftResult};
use crate::heartbeat::FailureDetector;
use crate::types::{FailureEvent, HeartbeatConfig, NodeId};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FailoverEvent {
LeaderLost {
old_leader: NodeId,
election_triggered: bool,
},
LeaderElected {
new_leader: NodeId,
},
FailoverTimeout,
PeerFailed {
node_id: NodeId,
},
PeerRecovered {
node_id: NodeId,
},
}
#[derive(Debug, Clone)]
pub struct FailoverConfig {
pub election_jitter_min_ms: u64,
pub election_jitter_max_ms: u64,
pub max_consecutive_failures: u32,
}
impl FailoverConfig {
pub fn new(
election_jitter_min_ms: u64,
election_jitter_max_ms: u64,
max_consecutive_failures: u32,
) -> Self {
Self {
election_jitter_min_ms,
election_jitter_max_ms,
max_consecutive_failures,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.election_jitter_min_ms == 0 {
return Err("election_jitter_min_ms must be > 0".to_string());
}
if self.election_jitter_max_ms <= self.election_jitter_min_ms {
return Err(format!(
"election_jitter_max_ms ({}) must be > election_jitter_min_ms ({})",
self.election_jitter_max_ms, self.election_jitter_min_ms,
));
}
if self.max_consecutive_failures == 0 {
return Err("max_consecutive_failures must be > 0".to_string());
}
Ok(())
}
fn random_jitter(&self) -> Duration {
let range = self.election_jitter_max_ms - self.election_jitter_min_ms;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let random_value = RandomState::new().hash_one(now);
let jitter_ms = self.election_jitter_min_ms + (random_value % range);
Duration::from_millis(jitter_ms)
}
}
impl Default for FailoverConfig {
fn default() -> Self {
Self {
election_jitter_min_ms: 150,
election_jitter_max_ms: 300,
max_consecutive_failures: 3,
}
}
}
#[derive(Debug)]
enum ElectionTimer {
Idle,
Pending {
started_at: Instant,
jitter: Duration,
},
Fired {
fired_at: Instant,
},
}
pub struct FailoverCoordinator {
detector: FailureDetector,
config: FailoverConfig,
self_id: NodeId,
current_leader: Option<NodeId>,
election_timer: ElectionTimer,
leader_failure_count: u32,
}
impl FailoverCoordinator {
pub fn new(
heartbeat_config: HeartbeatConfig,
failover_config: FailoverConfig,
self_id: NodeId,
) -> Self {
Self {
detector: FailureDetector::new(heartbeat_config, self_id),
config: failover_config,
self_id,
current_leader: None,
election_timer: ElectionTimer::Idle,
leader_failure_count: 0,
}
}
pub fn track_peer(&mut self, peer_id: NodeId) -> RaftResult<()> {
self.detector.track_peer(peer_id)
}
pub fn remove_peer(&mut self, peer_id: NodeId) {
self.detector.remove_peer(peer_id);
if self.current_leader == Some(peer_id) {
self.current_leader = None;
}
}
pub fn record_heartbeat(&mut self, peer_id: NodeId) -> RaftResult<()> {
self.detector.record_heartbeat(peer_id)
}
pub fn set_leader(&mut self, leader_id: NodeId) {
let changed = self.current_leader != Some(leader_id);
self.current_leader = Some(leader_id);
if changed {
self.leader_failure_count = 0;
self.election_timer = ElectionTimer::Idle;
debug!(
self_id = self.self_id,
leader_id = leader_id,
"FailoverCoordinator: leader updated"
);
}
}
pub fn clear_leader(&mut self) {
self.current_leader = None;
self.leader_failure_count = 0;
self.election_timer = ElectionTimer::Idle;
}
pub fn leader_hint(&self) -> Option<NodeId> {
self.current_leader
}
pub fn tick(&mut self) -> RaftResult<Vec<FailoverEvent>> {
let failure_events = self.detector.check_timeouts()?;
let mut out = Vec::new();
for fe in &failure_events {
match fe {
FailureEvent::NodeFailed { node_id, .. } => {
if Some(*node_id) == self.current_leader {
self.leader_failure_count = self.leader_failure_count.saturating_add(1);
let should_trigger =
self.leader_failure_count >= self.config.max_consecutive_failures;
if should_trigger {
self.schedule_election();
}
info!(
self_id = self.self_id,
leader = node_id,
failure_count = self.leader_failure_count,
triggered = should_trigger,
"Leader failure detected"
);
out.push(FailoverEvent::LeaderLost {
old_leader: *node_id,
election_triggered: should_trigger,
});
} else {
out.push(FailoverEvent::PeerFailed { node_id: *node_id });
}
}
FailureEvent::NodeRecovered { node_id } => {
if Some(*node_id) == self.current_leader {
self.leader_failure_count = 0;
self.election_timer = ElectionTimer::Idle;
debug!(
self_id = self.self_id,
leader = node_id,
"Leader recovered, election timer cancelled"
);
}
out.push(FailoverEvent::PeerRecovered { node_id: *node_id });
}
}
}
match &self.election_timer {
ElectionTimer::Pending { started_at, jitter } => {
if started_at.elapsed() >= *jitter {
info!(
self_id = self.self_id,
jitter_ms = jitter.as_millis() as u64,
"Election jitter expired, triggering failover"
);
self.election_timer = ElectionTimer::Fired {
fired_at: Instant::now(),
};
out.push(FailoverEvent::FailoverTimeout);
}
}
ElectionTimer::Fired { .. } | ElectionTimer::Idle => {}
}
Ok(out)
}
pub fn reset(&mut self) {
self.detector.reset_all();
self.leader_failure_count = 0;
self.election_timer = ElectionTimer::Idle;
}
pub fn failed_peers(&self) -> Vec<NodeId> {
self.detector.failed_peers()
}
pub fn alive_peers(&self) -> Vec<NodeId> {
self.detector.alive_peers()
}
pub fn peer_count(&self) -> usize {
self.detector.peer_count()
}
pub fn is_election_pending(&self) -> bool {
matches!(self.election_timer, ElectionTimer::Pending { .. })
}
pub fn is_election_fired(&self) -> bool {
matches!(self.election_timer, ElectionTimer::Fired { .. })
}
fn schedule_election(&mut self) {
if matches!(
self.election_timer,
ElectionTimer::Pending { .. } | ElectionTimer::Fired { .. }
) {
return;
}
let jitter = self.config.random_jitter();
debug!(
self_id = self.self_id,
jitter_ms = jitter.as_millis() as u64,
"Scheduling election with jitter"
);
self.election_timer = ElectionTimer::Pending {
started_at: Instant::now(),
jitter,
};
}
}
impl std::fmt::Debug for FailoverCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FailoverCoordinator")
.field("self_id", &self.self_id)
.field("current_leader", &self.current_leader)
.field("leader_failure_count", &self.leader_failure_count)
.field("peer_count", &self.detector.peer_count())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn fast_heartbeat_config() -> HeartbeatConfig {
HeartbeatConfig::new(10, 30, 1)
}
fn fast_failover_config() -> FailoverConfig {
FailoverConfig {
election_jitter_min_ms: 10,
election_jitter_max_ms: 30,
max_consecutive_failures: 1,
}
}
#[test]
fn test_failover_config_default() {
let cfg = FailoverConfig::default();
assert_eq!(cfg.election_jitter_min_ms, 150);
assert_eq!(cfg.election_jitter_max_ms, 300);
assert_eq!(cfg.max_consecutive_failures, 3);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_failover_config_validation() {
let bad1 = FailoverConfig::new(0, 300, 3);
assert!(bad1.validate().is_err());
let bad2 = FailoverConfig::new(300, 150, 3);
assert!(bad2.validate().is_err());
let bad3 = FailoverConfig::new(150, 300, 0);
assert!(bad3.validate().is_err());
let bad4 = FailoverConfig::new(150, 150, 3);
assert!(bad4.validate().is_err());
}
#[test]
fn test_failover_config_jitter_in_range() {
let cfg = FailoverConfig::new(100, 200, 3);
for _ in 0..20 {
let jitter = cfg.random_jitter();
assert!(jitter.as_millis() >= 100, "jitter too low: {:?}", jitter);
assert!(jitter.as_millis() < 200, "jitter too high: {:?}", jitter);
}
}
#[test]
fn test_coordinator_creation() {
let coord =
FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
assert_eq!(coord.leader_hint(), None);
assert_eq!(coord.peer_count(), 0);
assert!(!coord.is_election_pending());
}
#[test]
fn test_leader_hint_tracking() {
let mut coord =
FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
assert_eq!(coord.leader_hint(), None);
coord.set_leader(2);
assert_eq!(coord.leader_hint(), Some(2));
coord.set_leader(3);
assert_eq!(coord.leader_hint(), Some(3));
coord.clear_leader();
assert_eq!(coord.leader_hint(), None);
}
#[test]
fn test_leader_failure_triggers_election() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.track_peer(3).expect("track peer 3");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let events = coord.tick().expect("tick");
let leader_lost = events.iter().any(|e| {
matches!(
e,
FailoverEvent::LeaderLost {
old_leader: 2,
election_triggered: true,
}
)
});
assert!(leader_lost, "Expected LeaderLost event, got: {:?}", events);
assert!(coord.is_election_pending());
}
#[test]
fn test_election_timer_fires_after_jitter() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let _ = coord.tick().expect("tick 1");
thread::sleep(Duration::from_millis(50));
let events = coord.tick().expect("tick 2");
let timeout_fired = events
.iter()
.any(|e| matches!(e, FailoverEvent::FailoverTimeout));
assert!(
timeout_fired,
"Expected FailoverTimeout event, got: {:?}",
events
);
assert!(coord.is_election_fired());
}
#[test]
fn test_leader_recovery_cancels_election() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let _ = coord.tick().expect("tick");
assert!(coord.is_election_pending());
coord.record_heartbeat(2).expect("record heartbeat");
let events = coord.tick().expect("tick after recovery");
let recovered = events
.iter()
.any(|e| matches!(e, FailoverEvent::PeerRecovered { node_id: 2 }));
assert!(recovered, "Expected PeerRecovered, got: {:?}", events);
assert!(!coord.is_election_pending());
assert!(!coord.is_election_fired());
}
#[test]
fn test_non_leader_failure_emits_peer_failed() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.track_peer(3).expect("track peer 3");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
coord.record_heartbeat(2).expect("leader heartbeat refresh");
let events = coord.tick().expect("tick");
let peer_failed = events
.iter()
.any(|e| matches!(e, FailoverEvent::PeerFailed { node_id: 3 }));
assert!(peer_failed, "Expected PeerFailed for 3, got: {:?}", events);
assert!(
!coord.is_election_pending(),
"Non-leader failure should not trigger election"
);
}
#[test]
fn test_jitter_prevents_simultaneous_elections() {
let hb = fast_heartbeat_config();
let fo = FailoverConfig {
election_jitter_min_ms: 50,
election_jitter_max_ms: 200,
max_consecutive_failures: 1,
};
let mut c1 = FailoverCoordinator::new(hb.clone(), fo.clone(), 1);
let mut c2 = FailoverCoordinator::new(hb.clone(), fo.clone(), 3);
c1.track_peer(2).expect("c1 track 2");
c1.track_peer(3).expect("c1 track 3");
c1.set_leader(2);
c2.track_peer(1).expect("c2 track 1");
c2.track_peer(2).expect("c2 track 2");
c2.set_leader(2);
thread::sleep(Duration::from_millis(50));
let _ = c1.tick().expect("c1 tick");
let _ = c2.tick().expect("c2 tick");
assert!(c1.is_election_pending());
assert!(c2.is_election_pending());
}
#[test]
fn test_max_consecutive_failures_threshold() {
let mut coord = FailoverCoordinator::new(
fast_heartbeat_config(),
FailoverConfig {
election_jitter_min_ms: 10,
election_jitter_max_ms: 30,
max_consecutive_failures: 3,
},
1,
);
coord.track_peer(2).expect("track peer 2");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let events = coord.tick().expect("tick 1");
let triggered = events.iter().any(|e| {
matches!(
e,
FailoverEvent::LeaderLost {
election_triggered: true,
..
}
)
});
assert!(
!triggered,
"Should not trigger election after 1 failure, got: {:?}",
events
);
assert!(!coord.is_election_pending());
}
#[test]
fn test_set_new_leader_resets_state() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.track_peer(3).expect("track peer 3");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let _ = coord.tick().expect("tick");
assert!(coord.is_election_pending());
coord.set_leader(3);
assert!(!coord.is_election_pending());
assert!(!coord.is_election_fired());
assert_eq!(coord.leader_hint(), Some(3));
}
#[test]
fn test_reset_clears_all() {
let mut coord =
FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
coord.track_peer(2).expect("track peer 2");
coord.set_leader(2);
thread::sleep(Duration::from_millis(50));
let _ = coord.tick().expect("tick");
coord.reset();
assert!(!coord.is_election_pending());
assert!(!coord.is_election_fired());
assert!(coord.failed_peers().is_empty());
}
#[test]
fn test_remove_leader_peer_clears_leader() {
let mut coord =
FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
coord.track_peer(2).expect("track peer 2");
coord.set_leader(2);
assert_eq!(coord.leader_hint(), Some(2));
coord.remove_peer(2);
assert_eq!(coord.leader_hint(), None);
}
#[test]
fn test_debug_impl() {
let coord =
FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
let dbg = format!("{:?}", coord);
assert!(dbg.contains("FailoverCoordinator"));
assert!(dbg.contains("self_id"));
}
}