use std::time::Duration;
use crate::serde_json::{self, Value as JsonValue};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemberKind {
Data,
Witness,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VotingState {
Voting,
CatchingUp,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Member {
pub id: String,
pub kind: MemberKind,
pub state: VotingState,
}
impl Member {
pub fn data_voting(id: impl Into<String>) -> Self {
Self {
id: id.into(),
kind: MemberKind::Data,
state: VotingState::Voting,
}
}
pub fn data_catching_up(id: impl Into<String>) -> Self {
Self {
id: id.into(),
kind: MemberKind::Data,
state: VotingState::CatchingUp,
}
}
pub fn witness(id: impl Into<String>) -> Self {
Self {
id: id.into(),
kind: MemberKind::Witness,
state: VotingState::Voting,
}
}
pub fn is_voter(&self) -> bool {
matches!(self.state, VotingState::Voting)
}
pub fn is_electable(&self) -> bool {
self.kind == MemberKind::Data && self.is_voter()
}
}
pub fn quorum_threshold(members: &[Member]) -> usize {
let voters = members.iter().filter(|m| m.is_voter()).count();
voters / 2 + 1
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct LastVote {
pub term: u64,
pub voted_for: Option<String>,
}
impl LastVote {
fn to_json(&self) -> JsonValue {
let mut obj = serde_json::Map::new();
obj.insert("term".to_string(), JsonValue::Number(self.term as f64));
obj.insert(
"voted_for".to_string(),
match &self.voted_for {
Some(id) => JsonValue::String(id.clone()),
None => JsonValue::Null,
},
);
JsonValue::Object(obj)
}
fn from_json(value: &JsonValue) -> Result<Self, LastVoteError> {
let obj = value.as_object().ok_or_else(|| {
LastVoteError::InvalidFormat("last-vote json is not an object".into())
})?;
let term = obj
.get("term")
.and_then(JsonValue::as_u64)
.ok_or_else(|| LastVoteError::InvalidFormat("missing term".into()))?;
let voted_for = match obj.get("voted_for") {
None | Some(JsonValue::Null) => None,
Some(JsonValue::String(s)) => Some(s.clone()),
Some(_) => {
return Err(LastVoteError::InvalidFormat(
"voted_for must be a string or null".into(),
))
}
};
Ok(Self { term, voted_for })
}
}
#[derive(Debug)]
pub enum LastVoteError {
Io(std::io::Error),
InvalidFormat(String),
}
impl std::fmt::Display for LastVoteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(err) => write!(f, "last-vote io error: {err}"),
Self::InvalidFormat(msg) => write!(f, "invalid last-vote format: {msg}"),
}
}
}
impl std::error::Error for LastVoteError {}
pub trait LastVoteStore {
fn load(&self) -> Result<LastVote, LastVoteError>;
fn persist(&self, vote: &LastVote) -> Result<(), LastVoteError>;
}
#[derive(Debug, Default)]
pub struct MemoryLastVoteStore {
inner: std::sync::Mutex<LastVote>,
}
impl MemoryLastVoteStore {
pub fn new() -> Self {
Self::default()
}
pub fn seeded(vote: LastVote) -> Self {
Self {
inner: std::sync::Mutex::new(vote),
}
}
}
impl LastVoteStore for MemoryLastVoteStore {
fn load(&self) -> Result<LastVote, LastVoteError> {
Ok(self.inner.lock().expect("last-vote mutex").clone())
}
fn persist(&self, vote: &LastVote) -> Result<(), LastVoteError> {
*self.inner.lock().expect("last-vote mutex") = vote.clone();
Ok(())
}
}
pub struct FileLastVoteStore {
path: std::path::PathBuf,
}
impl FileLastVoteStore {
pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
Self { path: path.into() }
}
}
impl LastVoteStore for FileLastVoteStore {
fn load(&self) -> Result<LastVote, LastVoteError> {
match std::fs::read(&self.path) {
Ok(bytes) => {
let json: JsonValue = serde_json::from_slice(&bytes)
.map_err(|err| LastVoteError::InvalidFormat(format!("parse: {err}")))?;
LastVote::from_json(&json)
}
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(LastVote::default()),
Err(err) => Err(LastVoteError::Io(err)),
}
}
fn persist(&self, vote: &LastVote) -> Result<(), LastVoteError> {
let bytes = serde_json::to_vec(&vote.to_json())
.map_err(|err| LastVoteError::InvalidFormat(format!("serialize: {err}")))?;
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent).map_err(LastVoteError::Io)?;
}
let tmp = self.path.with_extension("lastvote.tmp");
std::fs::write(&tmp, &bytes).map_err(LastVoteError::Io)?;
if let Ok(f) = std::fs::File::open(&tmp) {
let _ = f.sync_all();
}
std::fs::rename(&tmp, &self.path).map_err(LastVoteError::Io)?;
if let Some(parent) = self.path.parent() {
if let Ok(dir) = std::fs::File::open(parent) {
let _ = dir.sync_all();
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VoteRequest {
pub candidate_id: String,
pub term: u64,
pub last_log_lsn: u64,
pub dry_run: bool,
}
impl VoteRequest {
pub fn probe(candidate_id: impl Into<String>, term: u64, last_log_lsn: u64) -> Self {
Self {
candidate_id: candidate_id.into(),
term,
last_log_lsn,
dry_run: true,
}
}
pub fn real(candidate_id: impl Into<String>, term: u64, last_log_lsn: u64) -> Self {
Self {
candidate_id: candidate_id.into(),
term,
last_log_lsn,
dry_run: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RefusalReason {
WatermarkNotCovered { candidate_lsn: u64, watermark: u64 },
AlreadyVoted { term: u64, voted_for: String },
StaleTerm {
candidate_term: u64,
voter_term: u64,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VoteDecision {
Granted,
Refused(RefusalReason),
}
impl VoteDecision {
pub fn is_granted(&self) -> bool {
matches!(self, VoteDecision::Granted)
}
}
pub struct Voter<S: LastVoteStore> {
id: String,
store: S,
}
impl<S: LastVoteStore> Voter<S> {
pub fn new(id: impl Into<String>, store: S) -> Self {
Self {
id: id.into(),
store,
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn current_term(&self) -> Result<u64, LastVoteError> {
Ok(self.store.load()?.term)
}
pub fn consider(
&self,
req: &VoteRequest,
commit_watermark: u64,
) -> Result<VoteDecision, LastVoteError> {
if req.last_log_lsn < commit_watermark {
return Ok(VoteDecision::Refused(RefusalReason::WatermarkNotCovered {
candidate_lsn: req.last_log_lsn,
watermark: commit_watermark,
}));
}
let last = self.store.load()?;
if req.term < last.term {
return Ok(VoteDecision::Refused(RefusalReason::StaleTerm {
candidate_term: req.term,
voter_term: last.term,
}));
}
if req.term == last.term {
match &last.voted_for {
Some(other) if other != &req.candidate_id => {
return Ok(VoteDecision::Refused(RefusalReason::AlreadyVoted {
term: last.term,
voted_for: other.clone(),
}));
}
Some(_) => return Ok(VoteDecision::Granted),
None => {}
}
}
if !req.dry_run {
self.store.persist(&LastVote {
term: req.term,
voted_for: Some(req.candidate_id.clone()),
})?;
}
Ok(VoteDecision::Granted)
}
}
pub fn randomized_election_timeout(base: Duration, jitter: Duration, seed: u64) -> Duration {
if jitter.is_zero() {
return base;
}
let jitter_ms = jitter.as_millis().max(1) as u64;
base + Duration::from_millis(seed % jitter_ms)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ElectionRequest {
pub candidate: Member,
pub current_term: u64,
pub last_log_lsn: u64,
pub commit_watermark: u64,
}
impl ElectionRequest {
pub fn new_term(&self) -> u64 {
self.current_term + 1
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ElectionOutcome {
Elected {
term: u64,
votes: usize,
needed: usize,
},
ProbeFailed { votes: usize, needed: usize },
Lost {
term: u64,
votes: usize,
needed: usize,
},
NotElectable,
TimedOut { votes: usize, needed: usize },
}
impl ElectionOutcome {
pub fn is_elected(&self) -> bool {
matches!(self, ElectionOutcome::Elected { .. })
}
}
pub trait ElectionTransport {
fn members(&self) -> Vec<Member>;
fn request_vote(&mut self, peer_id: &str, req: &VoteRequest) -> VoteDecision;
fn elapsed(&self) -> Duration;
fn bump_term(&mut self, new_term: u64);
fn promote(&mut self, new_term: u64);
}
pub struct ElectionCoordinator;
impl ElectionCoordinator {
pub fn run(
req: &ElectionRequest,
tx: &mut dyn ElectionTransport,
timeout: Duration,
) -> ElectionOutcome {
if !req.candidate.is_electable() || req.last_log_lsn < req.commit_watermark {
return ElectionOutcome::NotElectable;
}
let members = tx.members();
let needed = quorum_threshold(&members);
let new_term = req.new_term();
let peers: Vec<String> = members
.iter()
.filter(|m| m.is_voter() && m.id != req.candidate.id)
.map(|m| m.id.clone())
.collect();
let probe = VoteRequest::probe(&req.candidate.id, new_term, req.last_log_lsn);
let probe_votes = match Self::collect(tx, &peers, &probe, needed, timeout) {
CollectResult::Reached(v) => v,
CollectResult::Exhausted(v) => {
return ElectionOutcome::ProbeFailed { votes: v, needed }
}
CollectResult::TimedOut(v) => return ElectionOutcome::TimedOut { votes: v, needed },
};
debug_assert!(probe_votes >= needed);
tx.bump_term(new_term);
let ballot = VoteRequest::real(&req.candidate.id, new_term, req.last_log_lsn);
match Self::collect(tx, &peers, &ballot, needed, timeout) {
CollectResult::Reached(votes) => {
tx.promote(new_term);
ElectionOutcome::Elected {
term: new_term,
votes,
needed,
}
}
CollectResult::Exhausted(votes) => ElectionOutcome::Lost {
term: new_term,
votes,
needed,
},
CollectResult::TimedOut(votes) => ElectionOutcome::TimedOut { votes, needed },
}
}
fn collect(
tx: &mut dyn ElectionTransport,
peers: &[String],
req: &VoteRequest,
needed: usize,
timeout: Duration,
) -> CollectResult {
let mut votes = 1usize; if votes >= needed {
return CollectResult::Reached(votes);
}
for peer in peers {
if tx.elapsed() >= timeout {
return CollectResult::TimedOut(votes);
}
if tx.request_vote(peer, req).is_granted() {
votes += 1;
if votes >= needed {
return CollectResult::Reached(votes);
}
}
}
CollectResult::Exhausted(votes)
}
}
enum CollectResult {
Reached(usize),
Exhausted(usize),
TimedOut(usize),
}
#[cfg(test)]
mod tests;