use crate::{NodeId, Term};
use rand::Rng;
use std::time::Duration;
use tokio::time::Instant;
#[derive(Debug)]
pub struct ElectionTimer {
last_reset: Instant,
timeout: Duration,
min_timeout_ms: u64,
max_timeout_ms: u64,
}
impl ElectionTimer {
pub fn new(min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
let timeout = Self::random_timeout(min_timeout_ms, max_timeout_ms);
Self {
last_reset: Instant::now(),
timeout,
min_timeout_ms,
max_timeout_ms,
}
}
pub fn with_defaults() -> Self {
Self::new(150, 300)
}
pub fn reset(&mut self) {
self.last_reset = Instant::now();
self.timeout = Self::random_timeout(self.min_timeout_ms, self.max_timeout_ms);
}
pub fn is_elapsed(&self) -> bool {
self.last_reset.elapsed() >= self.timeout
}
pub fn time_remaining(&self) -> Duration {
self.timeout.saturating_sub(self.last_reset.elapsed())
}
fn random_timeout(min_ms: u64, max_ms: u64) -> Duration {
let mut rng = rand::thread_rng();
let timeout_ms = rng.gen_range(min_ms..=max_ms);
Duration::from_millis(timeout_ms)
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}
#[derive(Debug)]
pub struct VoteTracker {
votes_received: Vec<NodeId>,
cluster_size: usize,
quorum_size: usize,
}
impl VoteTracker {
pub fn new(cluster_size: usize) -> Self {
let quorum_size = (cluster_size / 2) + 1;
Self {
votes_received: Vec::new(),
cluster_size,
quorum_size,
}
}
pub fn record_vote(&mut self, node_id: NodeId) {
if !self.votes_received.contains(&node_id) {
self.votes_received.push(node_id);
}
}
pub fn has_quorum(&self) -> bool {
self.votes_received.len() >= self.quorum_size
}
pub fn vote_count(&self) -> usize {
self.votes_received.len()
}
pub fn quorum_size(&self) -> usize {
self.quorum_size
}
pub fn reset(&mut self) {
self.votes_received.clear();
}
}
#[derive(Debug)]
pub struct ElectionState {
pub timer: ElectionTimer,
pub votes: VoteTracker,
pub current_term: Term,
}
impl ElectionState {
pub fn new(cluster_size: usize, min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
Self {
timer: ElectionTimer::new(min_timeout_ms, max_timeout_ms),
votes: VoteTracker::new(cluster_size),
current_term: 0,
}
}
pub fn start_election(&mut self, term: Term, self_id: &NodeId) {
self.current_term = term;
self.votes.reset();
self.votes.record_vote(self_id.clone());
self.timer.reset();
}
pub fn reset_timer(&mut self) {
self.timer.reset();
}
pub fn should_start_election(&self) -> bool {
self.timer.is_elapsed()
}
pub fn record_vote(&mut self, node_id: NodeId) -> bool {
self.votes.record_vote(node_id);
self.votes.has_quorum()
}
pub fn update_cluster_size(&mut self, cluster_size: usize) {
self.votes = VoteTracker::new(cluster_size);
}
}
pub struct VoteValidator;
impl VoteValidator {
pub fn should_grant_vote(
receiver_term: Term,
receiver_voted_for: &Option<NodeId>,
receiver_last_log_index: u64,
receiver_last_log_term: Term,
candidate_id: &NodeId,
candidate_term: Term,
candidate_last_log_index: u64,
candidate_last_log_term: Term,
) -> bool {
if candidate_term < receiver_term {
return false;
}
let can_vote = match receiver_voted_for {
None => true,
Some(voted_for) => voted_for == candidate_id,
};
if !can_vote {
return false;
}
Self::is_log_up_to_date(
candidate_last_log_term,
candidate_last_log_index,
receiver_last_log_term,
receiver_last_log_index,
)
}
fn is_log_up_to_date(
candidate_last_term: Term,
candidate_last_index: u64,
receiver_last_term: Term,
receiver_last_index: u64,
) -> bool {
if candidate_last_term != receiver_last_term {
candidate_last_term >= receiver_last_term
} else {
candidate_last_index >= receiver_last_index
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn test_election_timer() {
let mut timer = ElectionTimer::new(50, 100);
assert!(!timer.is_elapsed());
sleep(Duration::from_millis(150));
assert!(timer.is_elapsed());
timer.reset();
assert!(!timer.is_elapsed());
}
#[test]
fn test_vote_tracker() {
let mut tracker = VoteTracker::new(5);
assert_eq!(tracker.quorum_size(), 3);
assert!(!tracker.has_quorum());
tracker.record_vote("node1".to_string());
assert!(!tracker.has_quorum());
tracker.record_vote("node2".to_string());
assert!(!tracker.has_quorum());
tracker.record_vote("node3".to_string());
assert!(tracker.has_quorum());
}
#[test]
fn test_election_state() {
let mut state = ElectionState::new(5, 50, 100);
let self_id = "node1".to_string();
state.start_election(1, &self_id);
assert_eq!(state.current_term, 1);
assert_eq!(state.votes.vote_count(), 1);
let won = state.record_vote("node2".to_string());
assert!(!won);
let won = state.record_vote("node3".to_string());
assert!(won);
}
#[test]
fn test_vote_validation() {
assert!(VoteValidator::should_grant_vote(
1,
&None,
10,
1,
&"candidate".to_string(),
2,
10,
1
));
assert!(!VoteValidator::should_grant_vote(
2,
&None,
10,
1,
&"candidate".to_string(),
1,
10,
1
));
assert!(!VoteValidator::should_grant_vote(
1,
&Some("other".to_string()),
10,
1,
&"candidate".to_string(),
1,
10,
1
));
assert!(VoteValidator::should_grant_vote(
1,
&Some("candidate".to_string()),
10,
1,
&"candidate".to_string(),
1,
10,
1
));
}
#[test]
fn test_log_up_to_date() {
assert!(VoteValidator::is_log_up_to_date(2, 5, 1, 10));
assert!(!VoteValidator::is_log_up_to_date(1, 10, 2, 5));
assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 5));
assert!(!VoteValidator::is_log_up_to_date(1, 5, 1, 10));
assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 10));
}
}