use tokio::sync::oneshot;
use crate::config::SnapshotPolicy;
use crate::core::{ConsensusState, LeaderState, ReplicationState, SnapshotState, State, UpdateCurrentLeader};
use crate::error::RaftResult;
use crate::replication::{RaftEvent, ReplicaEvent, ReplicationStream};
use crate::storage::CurrentSnapshotData;
use crate::{AppData, AppDataResponse, NodeId, RaftNetwork, RaftStorage};
impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> LeaderState<'a, D, R, N, S> {
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn spawn_replication_stream(&self, target: NodeId) -> ReplicationState<D> {
let replstream = ReplicationStream::new(
self.core.id,
target,
self.core.current_term,
self.core.config.clone(),
self.core.last_log_index,
self.core.last_log_term,
self.core.commit_index,
self.core.network.clone(),
self.core.storage.clone(),
self.replicationtx.clone(),
);
ReplicationState {
match_index: self.core.last_log_index,
match_term: self.core.current_term,
is_at_line_rate: false,
replstream,
remove_after_commit: None,
}
}
#[tracing::instrument(level = "trace", skip(self, event))]
pub(super) async fn handle_replica_event(&mut self, event: ReplicaEvent<S::Snapshot>) {
let res = match event {
ReplicaEvent::RateUpdate { target, is_line_rate } => self.handle_rate_update(target, is_line_rate).await,
ReplicaEvent::RevertToFollower { target, term } => self.handle_revert_to_follower(target, term).await,
ReplicaEvent::UpdateMatchIndex {
target,
match_index,
match_term,
} => self.handle_update_match_index(target, match_index, match_term).await,
ReplicaEvent::NeedsSnapshot { target, tx } => self.handle_needs_snapshot(target, tx).await,
ReplicaEvent::Shutdown => {
self.core.set_target_state(State::Shutdown);
return;
}
};
if let Err(err) = res {
tracing::error!({error=%err}, "error while processing event from replication stream");
}
}
#[tracing::instrument(level = "trace", skip(self, target, is_line_rate))]
async fn handle_rate_update(&mut self, target: NodeId, is_line_rate: bool) -> RaftResult<()> {
if let Some(state) = self.nodes.get_mut(&target) {
state.is_at_line_rate = is_line_rate;
return Ok(());
}
if let Some(state) = self.non_voters.get_mut(&target) {
state.state.is_at_line_rate = is_line_rate;
state.is_ready_to_join = is_line_rate;
if state.is_ready_to_join {
if let Some(tx) = state.tx.take() {
let _ = tx.send(Ok(()));
}
match std::mem::replace(&mut self.consensus_state, ConsensusState::Uniform) {
ConsensusState::NonVoterSync { mut awaiting, members, tx } => {
awaiting.remove(&target);
if awaiting.is_empty() {
self.consensus_state = ConsensusState::Uniform;
self.change_membership(members, tx).await;
} else {
self.consensus_state = ConsensusState::NonVoterSync { awaiting, members, tx };
}
}
other => self.consensus_state = other, }
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self, term))]
async fn handle_revert_to_follower(&mut self, _: NodeId, term: u64) -> RaftResult<()> {
if term > self.core.current_term {
self.core.update_current_term(term, None);
self.core.save_hard_state().await?;
self.core.update_current_leader(UpdateCurrentLeader::Unknown);
self.core.set_target_state(State::Follower);
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self, target, match_index))]
async fn handle_update_match_index(&mut self, target: NodeId, match_index: u64, match_term: u64) -> RaftResult<()> {
if let Some(state) = self.non_voters.get_mut(&target) {
state.state.match_index = match_index;
state.state.match_term = match_term;
return Ok(());
}
let mut needs_removal = false;
match self.nodes.get_mut(&target) {
Some(state) => {
state.match_index = match_index;
state.match_term = match_term;
if let Some(threshold) = &state.remove_after_commit {
if &match_index >= threshold {
needs_removal = true;
}
}
}
_ => return Ok(()), }
if needs_removal {
if let Some(node) = self.nodes.remove(&target) {
let _ = node.replstream.repltx.send(RaftEvent::Terminate);
}
}
let mut indices_c0 = self
.nodes
.iter()
.filter(|(id, _)| self.core.membership.members.contains(id))
.map(|(_, node)| node.match_index)
.collect::<Vec<_>>();
if !self.is_stepping_down {
indices_c0.push(self.core.last_log_index);
}
let commit_index_c0 = calculate_new_commit_index(indices_c0, self.core.commit_index);
let mut commit_index_c1 = commit_index_c0; if let Some(members) = &self.core.membership.members_after_consensus {
let indices_c1 = self
.nodes
.iter()
.filter(|(id, _)| members.contains(id))
.map(|(_, node)| node.match_index)
.collect();
commit_index_c1 = calculate_new_commit_index(indices_c1, self.core.commit_index);
}
let has_new_commit_index = commit_index_c0 > self.core.commit_index && commit_index_c1 > self.core.commit_index;
if has_new_commit_index {
self.core.commit_index = std::cmp::min(commit_index_c0, commit_index_c1);
for node in self.nodes.values() {
let _ = node.replstream.repltx.send(RaftEvent::UpdateCommitIndex {
commit_index: self.core.commit_index,
});
}
for node in self.non_voters.values() {
let _ = node.state.replstream.repltx.send(RaftEvent::UpdateCommitIndex {
commit_index: self.core.commit_index,
});
}
let filter = self
.awaiting_committed
.iter()
.enumerate()
.take_while(|(_idx, elem)| elem.entry.index <= self.core.commit_index)
.last()
.map(|(idx, _)| idx);
if let Some(offset) = filter {
for request in self.awaiting_committed.drain(..=offset).collect::<Vec<_>>() {
self.client_request_post_commit(request).await;
}
}
self.core.report_metrics();
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self, tx))]
async fn handle_needs_snapshot(&mut self, _: NodeId, tx: oneshot::Sender<CurrentSnapshotData<S::Snapshot>>) -> RaftResult<()> {
let threshold = match &self.core.config.snapshot_policy {
SnapshotPolicy::LogsSinceLast(threshold) => *threshold,
};
let current_snapshot_opt = self
.core
.storage
.get_current_snapshot()
.await
.map_err(|err| self.core.map_fatal_storage_error(err))?;
if let Some(snapshot) = current_snapshot_opt {
if snapshot_is_within_half_of_threshold(&snapshot.index, &self.core.last_log_index, &threshold) {
let _ = tx.send(snapshot);
return Ok(());
}
}
if let Some(SnapshotState::Snapshotting { handle, sender }) = self.core.snapshot_state.take() {
let mut chan = sender.subscribe();
tokio::spawn(async move {
let _ = chan.recv().await;
drop(tx);
});
self.core.snapshot_state = Some(SnapshotState::Snapshotting { handle, sender });
return Ok(());
}
self.core.trigger_log_compaction_if_needed();
Ok(())
}
}
fn calculate_new_commit_index(mut entries: Vec<u64>, current_commit: u64) -> u64 {
let len = entries.len();
if len == 0 {
return current_commit;
} else if len == 1 {
let only_elem = entries[0];
return if only_elem < current_commit { current_commit } else { only_elem };
};
entries.sort_unstable();
let offset = if (len % 2) == 0 { (len / 2) - 1 } else { len / 2 };
let new_val = entries.get(offset).unwrap_or(¤t_commit);
if new_val < ¤t_commit {
current_commit
} else {
*new_val
}
}
fn snapshot_is_within_half_of_threshold(snapshot_last_index: &u64, last_log_index: &u64, threshold: &u64) -> bool {
let distance_from_line = if snapshot_last_index > last_log_index {
0u64
} else {
last_log_index - snapshot_last_index
}; let half_of_threshold = threshold / 2;
distance_from_line <= half_of_threshold
}
#[cfg(test)]
mod tests {
use super::*;
mod snapshot_is_within_half_of_threshold {
use super::*;
macro_rules! test_snapshot_is_within_half_of_threshold {
({test=>$name:ident, snapshot_last_index=>$snapshot_last_index:expr, last_log_index=>$last_log:expr, threshold=>$thresh:expr, expected=>$exp:literal}) => {
#[test]
fn $name() {
let res = snapshot_is_within_half_of_threshold($snapshot_last_index, $last_log, $thresh);
assert_eq!(res, $exp)
}
};
}
test_snapshot_is_within_half_of_threshold!({
test=>happy_path_true_when_within_half_threshold,
snapshot_last_index=>&50, last_log_index=>&100, threshold=>&500, expected=>true
});
test_snapshot_is_within_half_of_threshold!({
test=>happy_path_false_when_above_half_threshold,
snapshot_last_index=>&1, last_log_index=>&500, threshold=>&100, expected=>false
});
test_snapshot_is_within_half_of_threshold!({
test=>guards_against_underflow,
snapshot_last_index=>&200, last_log_index=>&100, threshold=>&500, expected=>true
});
}
mod calculate_new_commit_index {
use super::*;
macro_rules! test_calculate_new_commit_index {
($name:ident, $expected:literal, $current:literal, $entries:expr) => {
#[test]
fn $name() {
let mut entries = $entries;
let output = calculate_new_commit_index(entries.clone(), $current);
entries.sort_unstable();
assert_eq!(output, $expected, "Sorted values: {:?}", entries);
}
};
}
test_calculate_new_commit_index!(basic_values, 10, 5, vec![20, 5, 0, 15, 10]);
test_calculate_new_commit_index!(len_zero_should_return_current_commit, 20, 20, vec![]);
test_calculate_new_commit_index!(len_one_where_greater_than_current, 100, 0, vec![100]);
test_calculate_new_commit_index!(len_one_where_less_than_current, 100, 100, vec![50]);
test_calculate_new_commit_index!(even_number_of_nodes, 0, 0, vec![0, 100, 0, 100, 0, 100]);
test_calculate_new_commit_index!(majority_wins, 100, 0, vec![0, 100, 0, 100, 0, 100, 100]);
}
}