use {
crate::{
PeerId,
groups::{
CommandError,
Index,
IndexRange,
LeadershipPreference,
QueryError,
QueryResultAt,
StateMachine,
StateSync,
StateSyncSession,
Storage,
Term,
raft::{
Message,
candidate::Candidate,
protocol::{AppendEntries, AppendEntriesResponse, Forward, Vote},
role::{Role, RoleHandlerError},
shared::Shared,
},
},
primitives::{
BoxPinFut,
Encoded,
InternalFutureExt,
Short,
UnboundedChannel,
},
},
core::{
future::ready,
marker::PhantomData,
ops::ControlFlow,
pin::Pin,
task::{Context, Poll},
},
std::{collections::HashMap, time::Instant},
tokio::{
sync::oneshot,
time::{Sleep, sleep, timeout},
},
};
pub struct Follower<M: StateMachine> {
term: Term,
leader: Option<PeerId>,
leadership_preference: LeadershipPreference,
election_timeout: Pin<Box<Sleep>>,
pending_commands: HashMap<u64, oneshot::Sender<IndexRange>>,
pending_queries: HashMap<u64, oneshot::Sender<QueryResultAt<M>>>,
expired_commands: UnboundedChannel<u64>,
expired_queries: UnboundedChannel<u64>,
catchup: Option<<M::StateSync as StateSync>::Session>,
#[doc(hidden)]
_marker: PhantomData<M>,
}
impl<M: StateMachine> Follower<M> {
pub fn new<S: Storage<M::Command>>(
term: Term,
leader: Option<PeerId>,
shared: &Shared<S, M>,
) -> Self {
let leadership_preference = shared.state_machine.leadership_preference();
let mut election_timeout = shared.config().consensus().election_timeout();
if term.is_zero() {
election_timeout += shared.config().consensus().bootstrap_delay;
}
if let LeadershipPreference::Reluctant { factor } = leadership_preference {
election_timeout *= factor;
}
if let Some(leader) = leader {
shared.update_leader(Some(leader));
}
shared.set_offline();
Self {
term,
leader,
leadership_preference,
catchup: None,
pending_commands: HashMap::default(),
pending_queries: HashMap::default(),
expired_commands: UnboundedChannel::default(),
expired_queries: UnboundedChannel::default(),
election_timeout: Box::pin(sleep(election_timeout)),
_marker: PhantomData,
}
}
pub const fn term(&self) -> Term {
self.term
}
pub const fn leader(&self) -> Option<PeerId> {
self.leader
}
}
impl<M: StateMachine> Follower<M> {
pub fn poll<S: Storage<M::Command>>(
&mut self,
cx: &mut Context<'_>,
shared: &mut Shared<S, M>,
) -> Poll<ControlFlow<Role<M>>> {
self.clean_expired_forwarded_commands(cx);
if let Some(catchup) = self.catchup.as_mut() {
match catchup.poll(cx, shared) {
Poll::Ready(cursor) => {
assert_eq!(cursor, shared.storage.last());
self.catchup = None;
tracing::info!(
log_at = %cursor,
term = %self.term(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"state sync complete, back online"
);
shared.update_log_pos(cursor);
}
Poll::Pending => return Poll::Pending,
}
}
if self.election_timeout.as_mut().poll(cx).is_ready() {
if self.leadership_preference == LeadershipPreference::Observer {
self.reset_election_timeout(shared);
} else {
return Poll::Ready(ControlFlow::Break(
Candidate::new(self.term.next(), shared).into(),
));
}
}
Poll::Pending
}
pub fn receive_protocol_message<S: Storage<M::Command>>(
&mut self,
message: Message<M>,
sender: PeerId,
shared: &mut Shared<S, M>,
) -> Result<(), RoleHandlerError<M>> {
match message {
Message::AppendEntries(request) => {
self.on_append_entries(request, sender, shared);
}
Message::Forward(Forward::CommandAck {
request_id,
assigned: assigned_indices,
}) => {
self.on_forward_command_response(request_id, assigned_indices);
}
Message::Forward(Forward::QueryResponse {
request_id,
result,
position,
}) => {
self.on_forward_query_response(request_id, result.0, position);
}
Message::StateSync(message) => {
if let Some(catchup) = self.catchup.as_mut() {
catchup.receive(message, sender, shared);
}
}
message => {
return Err(RoleHandlerError::<M>::Unexpected(message));
}
}
Ok(())
}
pub fn reset_election_timeout<S: Storage<M::Command>>(
&mut self,
shared: &Shared<S, M>,
) {
let mut timeout = shared.consensus().election_timeout();
if let LeadershipPreference::Reluctant { factor } =
self.leadership_preference
{
timeout *= factor;
}
let next_election_timeout = Instant::now() + timeout;
self
.election_timeout
.as_mut()
.reset(next_election_timeout.into());
}
pub fn forward_commands<S: Storage<M::Command>>(
&mut self,
commands: Vec<M::Command>,
shared: &Shared<S, M>,
) -> BoxPinFut<Result<IndexRange, CommandError<M>>> {
let Some(leader) = self.leader() else {
return ready(Err(CommandError::Offline(commands))).pin();
};
let request_id: u64 = loop {
let id = rand::random();
if self.pending_commands.contains_key(&id) {
continue;
}
break id;
};
let message = Message::Forward(Forward::Command {
commands: commands.iter().cloned().map(Encoded).collect(),
request_id: Some(request_id),
});
let (forward_ack_tx, forward_ack_rx) = oneshot::channel();
self.pending_commands.insert(request_id, forward_ack_tx);
if let Err(e) = shared.bonds().send_raft_to(message, leader) {
return ready(Err(CommandError::Encoding(commands, e))).pin();
}
let expired_sender = self.expired_commands.sender().clone();
let forward_timeout = shared.consensus().forward_timeout;
async move {
if let Ok(Ok(assigned)) = timeout(forward_timeout, forward_ack_rx).await {
Ok(assigned)
} else {
expired_sender.send(request_id).ok();
Err(CommandError::Offline(commands))
}
}
.pin()
}
pub fn forward_query<S: Storage<M::Command>>(
&mut self,
query: M::Query,
shared: &Shared<S, M>,
) -> BoxPinFut<Result<QueryResultAt<M>, QueryError<M>>> {
let Some(leader) = self.leader() else {
return ready(Err(QueryError::Offline(query))).pin();
};
let request_id: u64 = loop {
let id = rand::random();
if self.pending_queries.contains_key(&id) {
continue;
}
break id;
};
let message = Message::Forward(Forward::Query {
query: Encoded(query.clone()), request_id,
});
let (response_tx, response_rx) = oneshot::channel();
self.pending_queries.insert(request_id, response_tx);
if let Err(e) = shared.bonds().send_raft_to(message, leader) {
return ready(Err(QueryError::Encoding(query, e))).pin();
}
let expired_sender = self.expired_queries.sender().clone();
let query_timeout = shared.consensus().query_timeout;
async move {
if let Ok(Ok(response)) = timeout(query_timeout, response_rx).await {
Ok(response)
} else {
expired_sender.send(request_id).ok();
Err(QueryError::Offline(query))
}
}
.pin()
}
fn on_append_entries<S: Storage<M::Command>>(
&mut self,
request: AppendEntries<M::Command>,
sender: PeerId,
shared: &mut Shared<S, M>,
) {
self.leader = Some(request.leader);
shared.update_leader(Some(request.leader));
self.reset_election_timeout(shared);
let consistent =
match shared.storage.term_at(request.prev_log_position.index()) {
Some(local_term) if local_term == request.prev_log_position.term() => {
if let Some(first) = request.entries.first() {
let next_index = request.prev_log_position.index().next();
if let Some(existing_term) = shared.storage.term_at(next_index)
&& existing_term != first.term
{
shared.storage.truncate(next_index);
}
}
true
}
Some(_) => {
shared.storage.truncate(request.prev_log_position.index());
false
}
None => false,
};
if consistent {
shared.set_online();
self.accept_in_sync_entries(request, sender, shared);
} else {
shared.set_offline();
let entries = request
.entries
.into_iter()
.map(|entry| (entry.command.0, entry.term))
.collect();
if let Some(catchup) = self.catchup.as_mut() {
catchup.buffer(request.prev_log_position, entries, shared);
} else {
tracing::info!(
leader_pos = %request.prev_log_position,
local_pos = %shared.storage.last(),
term = %self.term(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
"starting state sync"
);
self.catchup = Some(shared.create_sync_session(
request.prev_log_position, request.leader_commit,
entries,
));
}
}
}
fn on_forward_command_response(
&mut self,
request_id: u64,
assigned: IndexRange,
) {
if let Some(ack) = self.pending_commands.remove(&request_id) {
let _ = ack.send(assigned);
}
}
fn on_forward_query_response(
&mut self,
request_id: u64,
result: M::QueryResult,
at_position: Index,
) {
if let Some(response) = self.pending_queries.remove(&request_id) {
let _ = response.send(QueryResultAt {
result,
at_position,
});
}
}
fn clean_expired_forwarded_commands(&mut self, cx: &mut Context<'_>) {
if !self.expired_commands.is_empty() {
let count = self.expired_commands.len();
let mut ids = Vec::with_capacity(count);
if self
.expired_commands
.poll_recv_many(cx, &mut ids, count)
.is_ready()
{
for id in ids {
self.pending_commands.remove(&id);
}
}
}
if !self.expired_queries.is_empty() {
let count = self.expired_queries.len();
let mut ids = Vec::with_capacity(count);
if self
.expired_queries
.poll_recv_many(cx, &mut ids, count)
.is_ready()
{
for id in ids {
self.pending_queries.remove(&id);
}
}
}
}
fn accept_in_sync_entries<S: Storage<M::Command>>(
&self,
request: AppendEntries<M::Command>,
sender: PeerId,
shared: &mut Shared<S, M>,
) {
let mut appended_count = 0;
if !request.entries.is_empty() {
let start_index = request.prev_log_position.index().next();
for (i, entry) in request.entries.into_iter().enumerate() {
let index = start_index + i;
if shared.storage.term_at(index) == Some(entry.term) {
continue;
}
shared.storage.append(entry.command.0, entry.term);
appended_count += 1;
}
let local_position = shared.storage.last();
let vote = if self.leadership_preference == LeadershipPreference::Observer
{
Vote::Abstained
} else {
Vote::Granted
};
shared
.bonds()
.send_raft_to(
Message::AppendEntriesResponse(AppendEntriesResponse {
term: self.term(),
vote,
last_log_index: local_position.index(),
}),
sender,
)
.expect("infallible serialization");
}
let local_log_pos = shared.storage.last();
let prev_committed = shared.committed();
let leader_committed = request.leader_commit.min(local_log_pos.index());
let mut new_committed = prev_committed;
if prev_committed < leader_committed {
new_committed = shared.commit_up_to(leader_committed);
if prev_committed < new_committed {
shared.update_committed(new_committed);
shared.prune_safe_prefix();
}
}
if prev_committed != new_committed || appended_count > 0 {
tracing::trace!(
committed_ix = %new_committed,
new_entries = appended_count,
local_log = %shared.storage.last(),
term = %self.term(),
group = %Short(shared.group_id()),
network = %Short(shared.network_id()),
);
}
}
}