1use std::collections::hash_map::RandomState;
13use std::hash::BuildHasher;
14use std::time::{Duration, Instant};
15
16use tracing::{debug, info, warn};
17
18use crate::error::{RaftError, RaftResult};
19use crate::heartbeat::FailureDetector;
20use crate::types::{FailureEvent, HeartbeatConfig, NodeId};
21
22#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum FailoverEvent {
27 LeaderLost {
30 old_leader: NodeId,
32 election_triggered: bool,
34 },
35 LeaderElected {
37 new_leader: NodeId,
39 },
40 FailoverTimeout,
42 PeerFailed {
44 node_id: NodeId,
46 },
47 PeerRecovered {
49 node_id: NodeId,
51 },
52}
53
54#[derive(Debug, Clone)]
58pub struct FailoverConfig {
59 pub election_jitter_min_ms: u64,
61 pub election_jitter_max_ms: u64,
63 pub max_consecutive_failures: u32,
66}
67
68impl FailoverConfig {
69 pub fn new(
71 election_jitter_min_ms: u64,
72 election_jitter_max_ms: u64,
73 max_consecutive_failures: u32,
74 ) -> Self {
75 Self {
76 election_jitter_min_ms,
77 election_jitter_max_ms,
78 max_consecutive_failures,
79 }
80 }
81
82 pub fn validate(&self) -> Result<(), String> {
84 if self.election_jitter_min_ms == 0 {
85 return Err("election_jitter_min_ms must be > 0".to_string());
86 }
87 if self.election_jitter_max_ms <= self.election_jitter_min_ms {
88 return Err(format!(
89 "election_jitter_max_ms ({}) must be > election_jitter_min_ms ({})",
90 self.election_jitter_max_ms, self.election_jitter_min_ms,
91 ));
92 }
93 if self.max_consecutive_failures == 0 {
94 return Err("max_consecutive_failures must be > 0".to_string());
95 }
96 Ok(())
97 }
98
99 fn random_jitter(&self) -> Duration {
101 let range = self.election_jitter_max_ms - self.election_jitter_min_ms;
102 let now = std::time::SystemTime::now()
103 .duration_since(std::time::UNIX_EPOCH)
104 .map(|d| d.as_nanos())
105 .unwrap_or(0);
106 let random_value = RandomState::new().hash_one(now);
107 let jitter_ms = self.election_jitter_min_ms + (random_value % range);
108 Duration::from_millis(jitter_ms)
109 }
110}
111
112impl Default for FailoverConfig {
113 fn default() -> Self {
114 Self {
115 election_jitter_min_ms: 150,
116 election_jitter_max_ms: 300,
117 max_consecutive_failures: 3,
118 }
119 }
120}
121
122#[derive(Debug)]
126enum ElectionTimer {
127 Idle,
129 Pending {
132 started_at: Instant,
134 jitter: Duration,
136 },
137 Fired {
140 fired_at: Instant,
142 },
143}
144
145pub struct FailoverCoordinator {
167 detector: FailureDetector,
169 config: FailoverConfig,
171 self_id: NodeId,
173 current_leader: Option<NodeId>,
175 election_timer: ElectionTimer,
177 leader_failure_count: u32,
179}
180
181impl FailoverCoordinator {
182 pub fn new(
184 heartbeat_config: HeartbeatConfig,
185 failover_config: FailoverConfig,
186 self_id: NodeId,
187 ) -> Self {
188 Self {
189 detector: FailureDetector::new(heartbeat_config, self_id),
190 config: failover_config,
191 self_id,
192 current_leader: None,
193 election_timer: ElectionTimer::Idle,
194 leader_failure_count: 0,
195 }
196 }
197
198 pub fn track_peer(&mut self, peer_id: NodeId) -> RaftResult<()> {
202 self.detector.track_peer(peer_id)
203 }
204
205 pub fn remove_peer(&mut self, peer_id: NodeId) {
207 self.detector.remove_peer(peer_id);
208 if self.current_leader == Some(peer_id) {
209 self.current_leader = None;
210 }
211 }
212
213 pub fn record_heartbeat(&mut self, peer_id: NodeId) -> RaftResult<()> {
215 self.detector.record_heartbeat(peer_id)
216 }
217
218 pub fn set_leader(&mut self, leader_id: NodeId) {
222 let changed = self.current_leader != Some(leader_id);
223 self.current_leader = Some(leader_id);
224 if changed {
225 self.leader_failure_count = 0;
226 self.election_timer = ElectionTimer::Idle;
227 debug!(
228 self_id = self.self_id,
229 leader_id = leader_id,
230 "FailoverCoordinator: leader updated"
231 );
232 }
233 }
234
235 pub fn clear_leader(&mut self) {
237 self.current_leader = None;
238 self.leader_failure_count = 0;
239 self.election_timer = ElectionTimer::Idle;
240 }
241
242 pub fn leader_hint(&self) -> Option<NodeId> {
244 self.current_leader
245 }
246
247 pub fn tick(&mut self) -> RaftResult<Vec<FailoverEvent>> {
255 let failure_events = self.detector.check_timeouts()?;
256 let mut out = Vec::new();
257
258 for fe in &failure_events {
259 match fe {
260 FailureEvent::NodeFailed { node_id, .. } => {
261 if Some(*node_id) == self.current_leader {
262 self.leader_failure_count = self.leader_failure_count.saturating_add(1);
263 let should_trigger =
264 self.leader_failure_count >= self.config.max_consecutive_failures;
265
266 if should_trigger {
267 self.schedule_election();
268 }
269
270 info!(
271 self_id = self.self_id,
272 leader = node_id,
273 failure_count = self.leader_failure_count,
274 triggered = should_trigger,
275 "Leader failure detected"
276 );
277
278 out.push(FailoverEvent::LeaderLost {
279 old_leader: *node_id,
280 election_triggered: should_trigger,
281 });
282 } else {
283 out.push(FailoverEvent::PeerFailed { node_id: *node_id });
284 }
285 }
286 FailureEvent::NodeRecovered { node_id } => {
287 if Some(*node_id) == self.current_leader {
288 self.leader_failure_count = 0;
290 self.election_timer = ElectionTimer::Idle;
291 debug!(
292 self_id = self.self_id,
293 leader = node_id,
294 "Leader recovered, election timer cancelled"
295 );
296 }
297 out.push(FailoverEvent::PeerRecovered { node_id: *node_id });
298 }
299 }
300 }
301
302 match &self.election_timer {
304 ElectionTimer::Pending { started_at, jitter } => {
305 if started_at.elapsed() >= *jitter {
306 info!(
307 self_id = self.self_id,
308 jitter_ms = jitter.as_millis() as u64,
309 "Election jitter expired, triggering failover"
310 );
311 self.election_timer = ElectionTimer::Fired {
312 fired_at: Instant::now(),
313 };
314 out.push(FailoverEvent::FailoverTimeout);
315 }
316 }
317 ElectionTimer::Fired { .. } | ElectionTimer::Idle => {}
318 }
319
320 Ok(out)
321 }
322
323 pub fn reset(&mut self) {
325 self.detector.reset_all();
326 self.leader_failure_count = 0;
327 self.election_timer = ElectionTimer::Idle;
328 }
329
330 pub fn failed_peers(&self) -> Vec<NodeId> {
332 self.detector.failed_peers()
333 }
334
335 pub fn alive_peers(&self) -> Vec<NodeId> {
337 self.detector.alive_peers()
338 }
339
340 pub fn peer_count(&self) -> usize {
342 self.detector.peer_count()
343 }
344
345 pub fn is_election_pending(&self) -> bool {
347 matches!(self.election_timer, ElectionTimer::Pending { .. })
348 }
349
350 pub fn is_election_fired(&self) -> bool {
352 matches!(self.election_timer, ElectionTimer::Fired { .. })
353 }
354
355 fn schedule_election(&mut self) {
358 if matches!(
359 self.election_timer,
360 ElectionTimer::Pending { .. } | ElectionTimer::Fired { .. }
361 ) {
362 return;
364 }
365 let jitter = self.config.random_jitter();
366 debug!(
367 self_id = self.self_id,
368 jitter_ms = jitter.as_millis() as u64,
369 "Scheduling election with jitter"
370 );
371 self.election_timer = ElectionTimer::Pending {
372 started_at: Instant::now(),
373 jitter,
374 };
375 }
376}
377
378impl std::fmt::Debug for FailoverCoordinator {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.debug_struct("FailoverCoordinator")
381 .field("self_id", &self.self_id)
382 .field("current_leader", &self.current_leader)
383 .field("leader_failure_count", &self.leader_failure_count)
384 .field("peer_count", &self.detector.peer_count())
385 .finish()
386 }
387}
388
389#[cfg(test)]
392mod tests {
393 use super::*;
394 use std::thread;
395
396 fn fast_heartbeat_config() -> HeartbeatConfig {
397 HeartbeatConfig::new(10, 30, 1)
399 }
400
401 fn fast_failover_config() -> FailoverConfig {
402 FailoverConfig {
403 election_jitter_min_ms: 10,
404 election_jitter_max_ms: 30,
405 max_consecutive_failures: 1,
406 }
407 }
408
409 #[test]
410 fn test_failover_config_default() {
411 let cfg = FailoverConfig::default();
412 assert_eq!(cfg.election_jitter_min_ms, 150);
413 assert_eq!(cfg.election_jitter_max_ms, 300);
414 assert_eq!(cfg.max_consecutive_failures, 3);
415 assert!(cfg.validate().is_ok());
416 }
417
418 #[test]
419 fn test_failover_config_validation() {
420 let bad1 = FailoverConfig::new(0, 300, 3);
421 assert!(bad1.validate().is_err());
422
423 let bad2 = FailoverConfig::new(300, 150, 3);
424 assert!(bad2.validate().is_err());
425
426 let bad3 = FailoverConfig::new(150, 300, 0);
427 assert!(bad3.validate().is_err());
428
429 let bad4 = FailoverConfig::new(150, 150, 3);
430 assert!(bad4.validate().is_err());
431 }
432
433 #[test]
434 fn test_failover_config_jitter_in_range() {
435 let cfg = FailoverConfig::new(100, 200, 3);
436 for _ in 0..20 {
437 let jitter = cfg.random_jitter();
438 assert!(jitter.as_millis() >= 100, "jitter too low: {:?}", jitter);
439 assert!(jitter.as_millis() < 200, "jitter too high: {:?}", jitter);
440 }
441 }
442
443 #[test]
444 fn test_coordinator_creation() {
445 let coord =
446 FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
447 assert_eq!(coord.leader_hint(), None);
448 assert_eq!(coord.peer_count(), 0);
449 assert!(!coord.is_election_pending());
450 }
451
452 #[test]
453 fn test_leader_hint_tracking() {
454 let mut coord =
455 FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
456 assert_eq!(coord.leader_hint(), None);
457
458 coord.set_leader(2);
459 assert_eq!(coord.leader_hint(), Some(2));
460
461 coord.set_leader(3);
462 assert_eq!(coord.leader_hint(), Some(3));
463
464 coord.clear_leader();
465 assert_eq!(coord.leader_hint(), None);
466 }
467
468 #[test]
469 fn test_leader_failure_triggers_election() {
470 let mut coord =
471 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
472 coord.track_peer(2).expect("track peer 2");
473 coord.track_peer(3).expect("track peer 3");
474 coord.set_leader(2);
475
476 thread::sleep(Duration::from_millis(50));
478
479 let events = coord.tick().expect("tick");
480 let leader_lost = events.iter().any(|e| {
481 matches!(
482 e,
483 FailoverEvent::LeaderLost {
484 old_leader: 2,
485 election_triggered: true,
486 }
487 )
488 });
489 assert!(leader_lost, "Expected LeaderLost event, got: {:?}", events);
490 assert!(coord.is_election_pending());
491 }
492
493 #[test]
494 fn test_election_timer_fires_after_jitter() {
495 let mut coord =
496 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
497 coord.track_peer(2).expect("track peer 2");
498 coord.set_leader(2);
499
500 thread::sleep(Duration::from_millis(50));
502 let _ = coord.tick().expect("tick 1");
503
504 thread::sleep(Duration::from_millis(50));
506 let events = coord.tick().expect("tick 2");
507
508 let timeout_fired = events
509 .iter()
510 .any(|e| matches!(e, FailoverEvent::FailoverTimeout));
511 assert!(
512 timeout_fired,
513 "Expected FailoverTimeout event, got: {:?}",
514 events
515 );
516 assert!(coord.is_election_fired());
517 }
518
519 #[test]
520 fn test_leader_recovery_cancels_election() {
521 let mut coord =
522 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
523 coord.track_peer(2).expect("track peer 2");
524 coord.set_leader(2);
525
526 thread::sleep(Duration::from_millis(50));
528 let _ = coord.tick().expect("tick");
529 assert!(coord.is_election_pending());
530
531 coord.record_heartbeat(2).expect("record heartbeat");
533 let events = coord.tick().expect("tick after recovery");
534
535 let recovered = events
536 .iter()
537 .any(|e| matches!(e, FailoverEvent::PeerRecovered { node_id: 2 }));
538 assert!(recovered, "Expected PeerRecovered, got: {:?}", events);
539
540 assert!(!coord.is_election_pending());
542 assert!(!coord.is_election_fired());
543 }
544
545 #[test]
546 fn test_non_leader_failure_emits_peer_failed() {
547 let mut coord =
548 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
549 coord.track_peer(2).expect("track peer 2");
550 coord.track_peer(3).expect("track peer 3");
551 coord.set_leader(2);
552
553 thread::sleep(Duration::from_millis(50));
555 coord.record_heartbeat(2).expect("leader heartbeat refresh");
557
558 let events = coord.tick().expect("tick");
559 let peer_failed = events
560 .iter()
561 .any(|e| matches!(e, FailoverEvent::PeerFailed { node_id: 3 }));
562 assert!(peer_failed, "Expected PeerFailed for 3, got: {:?}", events);
563 assert!(
564 !coord.is_election_pending(),
565 "Non-leader failure should not trigger election"
566 );
567 }
568
569 #[test]
570 fn test_jitter_prevents_simultaneous_elections() {
571 let hb = fast_heartbeat_config();
574 let fo = FailoverConfig {
575 election_jitter_min_ms: 50,
576 election_jitter_max_ms: 200,
577 max_consecutive_failures: 1,
578 };
579
580 let mut c1 = FailoverCoordinator::new(hb.clone(), fo.clone(), 1);
581 let mut c2 = FailoverCoordinator::new(hb.clone(), fo.clone(), 3);
582
583 c1.track_peer(2).expect("c1 track 2");
584 c1.track_peer(3).expect("c1 track 3");
585 c1.set_leader(2);
586
587 c2.track_peer(1).expect("c2 track 1");
588 c2.track_peer(2).expect("c2 track 2");
589 c2.set_leader(2);
590
591 thread::sleep(Duration::from_millis(50));
593 let _ = c1.tick().expect("c1 tick");
594 let _ = c2.tick().expect("c2 tick");
595
596 assert!(c1.is_election_pending());
600 assert!(c2.is_election_pending());
601 }
602
603 #[test]
604 fn test_max_consecutive_failures_threshold() {
605 let mut coord = FailoverCoordinator::new(
606 fast_heartbeat_config(),
607 FailoverConfig {
608 election_jitter_min_ms: 10,
609 election_jitter_max_ms: 30,
610 max_consecutive_failures: 3,
611 },
612 1,
613 );
614 coord.track_peer(2).expect("track peer 2");
615 coord.set_leader(2);
616
617 thread::sleep(Duration::from_millis(50));
619 let events = coord.tick().expect("tick 1");
620 let triggered = events.iter().any(|e| {
621 matches!(
622 e,
623 FailoverEvent::LeaderLost {
624 election_triggered: true,
625 ..
626 }
627 )
628 });
629 assert!(
630 !triggered,
631 "Should not trigger election after 1 failure, got: {:?}",
632 events
633 );
634
635 assert!(!coord.is_election_pending());
640 }
641
642 #[test]
643 fn test_set_new_leader_resets_state() {
644 let mut coord =
645 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
646 coord.track_peer(2).expect("track peer 2");
647 coord.track_peer(3).expect("track peer 3");
648 coord.set_leader(2);
649
650 thread::sleep(Duration::from_millis(50));
652 let _ = coord.tick().expect("tick");
653 assert!(coord.is_election_pending());
654
655 coord.set_leader(3);
657 assert!(!coord.is_election_pending());
658 assert!(!coord.is_election_fired());
659 assert_eq!(coord.leader_hint(), Some(3));
660 }
661
662 #[test]
663 fn test_reset_clears_all() {
664 let mut coord =
665 FailoverCoordinator::new(fast_heartbeat_config(), fast_failover_config(), 1);
666 coord.track_peer(2).expect("track peer 2");
667 coord.set_leader(2);
668
669 thread::sleep(Duration::from_millis(50));
670 let _ = coord.tick().expect("tick");
671
672 coord.reset();
673 assert!(!coord.is_election_pending());
674 assert!(!coord.is_election_fired());
675 assert!(coord.failed_peers().is_empty());
676 }
677
678 #[test]
679 fn test_remove_leader_peer_clears_leader() {
680 let mut coord =
681 FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
682 coord.track_peer(2).expect("track peer 2");
683 coord.set_leader(2);
684 assert_eq!(coord.leader_hint(), Some(2));
685
686 coord.remove_peer(2);
687 assert_eq!(coord.leader_hint(), None);
688 }
689
690 #[test]
691 fn test_debug_impl() {
692 let coord =
693 FailoverCoordinator::new(HeartbeatConfig::default(), FailoverConfig::default(), 1);
694 let dbg = format!("{:?}", coord);
695 assert!(dbg.contains("FailoverCoordinator"));
696 assert!(dbg.contains("self_id"));
697 }
698}