use {
crate::{
PeerId,
groups::{
LeadershipPreference,
StateMachine,
Storage,
Term,
raft::{
Message,
candidate::Candidate,
follower::Follower,
leader::Leader,
protocol::{AppendEntries, RequestVoteResponse, Vote},
shared::Shared,
},
},
primitives::Short,
},
core::{
ops::ControlFlow,
task::{Context, Poll},
},
derive_more::{Display, From},
};
#[derive(Display, From)]
#[allow(clippy::large_enum_variant)]
pub enum Role<M: StateMachine> {
#[display("Follower")]
Follower(Follower<M>),
#[display("Candidate")]
Candidate(Candidate<M>),
#[display("Leader")]
Leader(Leader<M>),
}
impl<M: StateMachine> Role<M> {
pub fn new<S: Storage<M::Command>>(shared: &Shared<S, M>) -> Self {
Self::Follower(Follower::new(Term::zero(), None, shared))
}
pub fn poll<S: Storage<M::Command>>(
&mut self,
cx: &mut Context<'_>,
shared: &mut Shared<S, M>,
) -> Poll<()> {
let next_step = match self {
Self::Follower(follower) => follower.poll(cx, shared),
Self::Candidate(candidate) => candidate.poll(cx, shared),
Self::Leader(leader) => leader.poll(cx, shared),
};
let readiness = match next_step {
Poll::Ready(next) => {
if let ControlFlow::Break(next_role) = next {
*self = next_role;
}
Poll::Ready(())
}
Poll::Pending => {
shared.add_waker(cx.waker().clone());
Poll::Pending
}
};
if shared.poll_state_sync_provider(cx).is_ready() {
return Poll::Ready(());
}
readiness
}
pub fn receive_protocol_message<S: Storage<M::Command>>(
&mut self,
message: Message<M>,
sender: PeerId,
shared: &mut Shared<S, M>,
) {
if let Some(message_term) = message.term()
&& message_term < self.term()
{
tracing::trace!(
local_term = %self.term(),
message_term = %message_term,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
sender = %Short(sender),
message = %message,
local_role = %self,
"ignoring stale raft message"
);
return;
}
let Err(message) = Self::maybe_state_sync(message, sender, shared) else {
return;
};
self.maybe_step_down(&message, shared);
if self.maybe_cast_vote(&message, sender, shared) {
return;
}
let result = match self {
Self::Follower(follower) => {
follower.receive_protocol_message(message, sender, shared)
}
Self::Candidate(candidate) => {
candidate.receive_protocol_message(message, sender, shared)
}
Self::Leader(leader) => {
leader.receive_protocol_message(message, sender, shared)
}
};
match result {
Ok(()) => {}
Err(RoleHandlerError::Unexpected(message)) => {
tracing::trace!(
local_term = %self.term(),
message_term = ?message.term(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
sender = %Short(sender),
message = %message,
"unexpected message type received as {self}",
);
}
Err(RoleHandlerError::StepDown(request)) => {
*self = Follower::<M>::new(
request.term, Some(request.leader),
shared,
)
.into();
shared.update_leader(Some(request.leader));
tracing::debug!(
leader = %Short(request.leader),
term = %self.term(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"stepping down and following",
);
self.receive_protocol_message(
Message::AppendEntries(request),
sender,
shared,
);
}
Err(RoleHandlerError::RivalLeader(request)) => {
tracing::warn!(
term = %request.term,
other_leader = %Short(request.leader),
other_leader_log = %request.prev_log_position,
local_log = %shared.storage.last(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"rival group leader detected",
);
*self = Candidate::<M>::new(self.term().next(), shared).into();
shared.update_leader(None);
shared.wake_all();
}
}
}
fn maybe_step_down<S: Storage<M::Command>>(
&mut self,
message: &Message<M>,
shared: &Shared<S, M>,
) {
let Some(message_term) = message.term() else {
return;
};
assert!(message_term >= self.term());
if message_term > self.term() {
if let Some(leader) = message.leader() {
tracing::debug!(
leader = %Short(leader),
old_term = %self.term(),
new_term = %message_term,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"following",
);
} else {
tracing::debug!(
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
old_term = %self.term(),
new_term = %message_term,
"stepping down to follower",
);
}
*self = Follower::<M>::new::<S>(
message_term, message.leader(),
shared,
)
.into();
shared.update_leader(message.leader());
}
}
fn maybe_cast_vote<S: Storage<M::Command>>(
&mut self,
message: &Message<M>,
sender: PeerId,
shared: &mut Shared<S, M>,
) -> bool {
let Message::RequestVote(request) = message else {
return false;
};
assert!(request.term >= self.term());
let local_cursor = shared.storage.last();
tracing::debug!(
candidate = %Short(request.candidate),
term = %request.term,
candidate_log = %request.log_position,
local_log = %local_cursor,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"new leader elections started by",
);
let bonds = shared.group.bonds.clone();
let vote_with = |vote: Vote| {
bonds
.send_raft_to(
Message::RequestVoteResponse(RequestVoteResponse {
vote,
term: request.term,
}),
sender,
)
.expect("infallible serialization");
};
if !shared.can_vote(request.term, request.candidate) {
vote_with(Vote::Denied);
tracing::debug!(
candidate = %Short(request.candidate),
term = %request.term,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"denying vote, already voted in this term",
);
return true;
}
if request.log_position.is_behind(&local_cursor) {
vote_with(Vote::Denied);
tracing::debug!(
candidate = %Short(request.candidate),
term = %request.term,
our_log = %local_cursor,
candidate_log = %request.log_position,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"denying vote because our log is ahead",
);
return true;
}
if local_cursor.is_behind(&request.log_position) {
vote_with(Vote::Abstained);
tracing::debug!(
candidate = %Short(request.candidate),
term = %request.term,
candidate_log = %request.log_position,
local_log = %local_cursor,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"abstained from voting because we are behind their log",
);
} else {
shared.save_vote(request.term, sender);
let is_observer = shared.state_machine.leadership_preference()
== LeadershipPreference::Observer;
if is_observer {
vote_with(Vote::Abstained);
} else {
vote_with(Vote::Granted);
}
tracing::debug!(
candidate = %Short(request.candidate),
term = %request.term,
candidate_log = %request.log_position,
local_log = %local_cursor,
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"{}",
if is_observer {
"abstained from voting as observer"
} else {
"granting vote to candidate"
},
);
}
if let Self::Follower(follower) = self {
follower.reset_election_timeout(shared);
}
true
}
fn maybe_state_sync<S: Storage<M::Command>>(
message: Message<M>,
sender: PeerId,
shared: &mut Shared<S, M>,
) -> Result<(), Message<M>> {
let Message::StateSync(message) = message else {
return Err(message);
};
shared
.sync_provider_receive(message, sender)
.map_err(Message::StateSync)
}
}
impl<M: StateMachine> Role<M> {
pub const fn term(&self) -> Term {
match self {
Self::Follower(follower) => follower.term(),
Self::Candidate(candidate) => candidate.term(),
Self::Leader(leader) => leader.term(),
}
}
}
pub(super) enum RoleHandlerError<M: StateMachine> {
Unexpected(Message<M>),
StepDown(AppendEntries<M::Command>),
RivalLeader(AppendEntries<M::Command>),
}