use std::sync::Arc;
use anyhow::anyhow;
use futures::stream::FuturesUnordered;
use futures::future::TryFutureExt;
use tokio::stream::StreamExt;
use tokio::sync::oneshot;
use tokio::time::{Duration, timeout};
use crate::{AppData, AppDataResponse, RaftNetwork, RaftStorage};
use crate::core::{LeaderState, State};
use crate::error::{ClientReadError, ClientWriteError, RaftError, RaftResult};
use crate::raft::{ClientWriteRequest, ClientWriteResponse, ClientReadResponseTx, ClientWriteResponseTx, Entry, EntryPayload};
use crate::raft::{AppendEntriesRequest};
use crate::replication::RaftEvent;
pub(super) struct ClientRequestEntry<D: AppData, R: AppDataResponse> {
pub entry: Arc<Entry<D>>,
pub tx: ClientOrInternalResponseTx<D, R>,
}
impl<D: AppData, R: AppDataResponse> ClientRequestEntry<D, R> {
pub(crate) fn from_entry<T: Into<ClientOrInternalResponseTx<D, R>>>(entry: Entry<D>, tx: T) -> Self {
Self{entry: Arc::new(entry), tx: tx.into()}
}
}
#[derive(derive_more::From)]
pub enum ClientOrInternalResponseTx<D: AppData, R: AppDataResponse> {
Client(ClientWriteResponseTx<D, R>),
Internal(oneshot::Sender<Result<u64, RaftError>>),
}
impl<'a, D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> LeaderState<'a, D, R, N, S> {
#[tracing::instrument(level="trace", skip(self))]
pub(super) async fn commit_initial_leader_entry(&mut self) -> RaftResult<()> {
let req: ClientWriteRequest<D> = if self.core.last_log_index == 0 {
ClientWriteRequest::new_config(self.core.membership.clone())
} else {
ClientWriteRequest::new_blank_payload()
};
let mut pending_config = None; if &self.core.last_log_index > &self.core.commit_index {
let (stale_logs_start, stale_logs_stop) = (self.core.commit_index + 1, self.core.last_log_index + 1);
pending_config = self.core.storage.get_log_entries(stale_logs_start, stale_logs_stop).await
.map_err(|err| self.core.map_fatal_storage_error(err))?
.iter().rev()
.filter_map(|entry| match &entry.payload {
EntryPayload::ConfigChange(cfg) => Some(cfg.membership.is_in_joint_consensus()),
EntryPayload::SnapshotPointer(cfg) => Some(cfg.membership.is_in_joint_consensus()),
_ => None,
})
.nth(0);
}
let (tx_payload_committed, rx_payload_committed) = oneshot::channel();
let entry = self.append_payload_to_log(req.entry).await?;
self.core.last_log_term = self.core.current_term; let cr_entry = ClientRequestEntry::from_entry(entry, tx_payload_committed);
self.replicate_client_request(cr_entry).await;
self.core.report_metrics();
if let Some(is_in_join_consensus) = pending_config {
if is_in_join_consensus {
self.joint_consensus_cb.push(rx_payload_committed); } else {
self.uniform_consensus_cb.push(rx_payload_committed); }
}
Ok(())
}
#[tracing::instrument(level="trace", skip(self, tx))]
pub(super) async fn handle_client_read_request(&mut self, tx: ClientReadResponseTx) {
let len_members = self.core.membership.members.len();
let mut c0_confirmed = 0usize;
let c0_needed: usize = if (len_members % 2) == 0 { (len_members/2)-1 } else { len_members/2 };
let mut c1_confirmed = 0usize;
let mut c1_needed = 0usize;
if let Some(joint_members) = &self.core.membership.members_after_consensus {
let len = joint_members.len(); c1_needed = if (len % 2) == 0 { (len/2)-1 } else { len/2 };
}
if !self.is_stepping_down {
if self.core.membership.members.contains(&self.core.id) {
c0_confirmed += 1;
}
if self.core.membership.members_after_consensus.as_ref().map(|members| members.contains(&self.core.id)).unwrap_or(false) {
c1_confirmed += 1;
}
}
let mut pending = FuturesUnordered::new();
for (id, node) in self.nodes.iter() {
let rpc = AppendEntriesRequest{
term: self.core.current_term,
leader_id: self.core.id,
prev_log_index: node.match_index,
prev_log_term: node.match_term,
entries: vec![],
leader_commit: self.core.commit_index,
};
let target = id.clone();
let network = self.core.network.clone();
let ttl = Duration::from_millis(self.core.config.heartbeat_interval);
let task = tokio::spawn(async move {
match timeout(ttl, network.append_entries(target, rpc)).await {
Ok(Ok(data)) => Ok((target, data)),
Ok(Err(err)) => Err((target, err)),
Err(_timeout) => Err((target, anyhow!("timeout waiting for leadership confirmation"))),
}
}).map_err(move |err| (*id, err));
pending.push(task);
}
while let Some(res) = pending.next().await {
let (target, data) = match res {
Ok(Ok(res)) => res,
Ok(Err((target, err))) => {
tracing::error!({target, error=%err}, "timeout while confirming leadership for read request");
continue;
}
Err((target, err)) => {
tracing::error!({target}, "{}", err);
continue;
}
};
if &data.term != &self.core.current_term {
self.core.update_current_term(data.term, None);
self.core.set_target_state(State::Follower);
}
if self.core.membership.members.contains(&target) {
c0_confirmed += 1;
}
if self.core.membership.members_after_consensus.as_ref().map(|members| members.contains(&target)).unwrap_or(false) {
c1_confirmed += 1;
}
if &c0_confirmed >= &c0_needed && &c1_confirmed >= &c1_needed {
let _ = tx.send(Ok(()));
return;
}
}
let _ = tx.send(Err(ClientReadError::RaftError(
RaftError::RaftNetwork(anyhow!("too many requests failed, could not confirm leadership"))
)));
}
#[tracing::instrument(level="trace", skip(self, rpc, tx))]
pub(super) async fn handle_client_write_request(&mut self, rpc: ClientWriteRequest<D>, tx: ClientWriteResponseTx<D, R>) {
let entry = match self.append_payload_to_log(rpc.entry).await {
Ok(entry) => ClientRequestEntry::from_entry(entry, tx),
Err(err) => {
let _ = tx.send(Err(ClientWriteError::RaftError(err)));
return;
}
};
self.replicate_client_request(entry).await;
}
#[tracing::instrument(level="trace", skip(self, payload))]
pub(super) async fn append_payload_to_log(&mut self, payload: EntryPayload<D>) -> RaftResult<Entry<D>> {
let entry = Entry{index: self.core.last_log_index + 1, term: self.core.current_term, payload};
self.core.storage.append_entry_to_log(&entry).await.map_err(|err| self.core.map_fatal_storage_error(err))?;
self.core.last_log_index = entry.index;
Ok(entry)
}
#[tracing::instrument(level="trace", skip(self, req))]
pub(super) async fn replicate_client_request(&mut self, req: ClientRequestEntry<D, R>) {
let entry_arc = req.entry.clone();
if !self.nodes.is_empty() {
self.awaiting_committed.push(req);
for node in self.nodes.values() {
let _ = node.replstream.repltx.send(RaftEvent::Replicate{
entry: entry_arc.clone(),
commit_index: self.core.commit_index,
});
}
} else {
self.core.commit_index = entry_arc.index;
self.core.report_metrics();
self.client_request_post_commit(req).await;
}
if !self.non_voters.is_empty() {
for node in self.non_voters.values() {
let _ = node.state.replstream.repltx.send(RaftEvent::Replicate{
entry: entry_arc.clone(),
commit_index: self.core.commit_index,
});
}
}
}
#[tracing::instrument(level="trace", skip(self, req))]
pub(super) async fn client_request_post_commit(&mut self, req: ClientRequestEntry<D, R>) {
match req.tx {
ClientOrInternalResponseTx::Client(tx) => match &req.entry.payload {
EntryPayload::Normal(inner) => {
match self.apply_entry_to_state_machine(&req.entry.index, &inner.data).await {
Ok(data) => {
let _ = tx.send(Ok(ClientWriteResponse{index: req.entry.index, data}));
}
Err(err) => {
let _ = tx.send(Err(ClientWriteError::RaftError(RaftError::from(err))));
}
}
}
_ => {
tracing::error!("critical error in Raft, this is a programming bug, please open an issue");
self.core.set_target_state(State::Shutdown);
}
}
ClientOrInternalResponseTx::Internal(tx) => {
self.core.last_applied = req.entry.index;
self.core.report_metrics();
let _ = tx.send(Ok(req.entry.index));
}
}
self.core.trigger_log_compaction_if_needed();
}
#[tracing::instrument(level="trace", skip(self, entry))]
pub(super) async fn apply_entry_to_state_machine(&mut self, index: &u64, entry: &D) -> RaftResult<R> {
let expected_next_index = self.core.last_applied + 1;
if index != &expected_next_index {
let entries = self.core.storage.get_log_entries(expected_next_index, *index).await.map_err(|err| self.core.map_fatal_storage_error(err))?;
if let Some(entry) = entries.last() {
self.core.last_applied = entry.index;
}
let data_entries: Vec<_> = entries.iter()
.filter_map(|entry| match &entry.payload {
EntryPayload::Normal(inner) => Some((&entry.index, &inner.data)),
_ => None,
})
.collect();
if !data_entries.is_empty() {
self.core.storage.replicate_to_state_machine(&data_entries).await.map_err(|err| self.core.map_fatal_storage_error(err))?;
}
}
let res = self.core.storage.apply_entry_to_state_machine(index, entry).await.map_err(|err| self.core.map_fatal_storage_error(err))?;
self.core.last_applied = *index;
self.core.report_metrics();
Ok(res)
}
}