use std::collections::BTreeSet;
use std::collections::HashMap;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::Instant;
use crate::core::State;
use crate::error::Fatal;
use crate::raft_types::LogIdOptionExt;
use crate::EffectiveMembership;
use crate::LogId;
use crate::Membership;
use crate::MessageSummary;
use crate::NodeId;
use crate::ReplicationMetrics;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RaftMetrics {
pub running_state: Result<(), Fatal>,
pub id: NodeId,
pub state: State,
pub current_term: u64,
pub last_log_index: Option<u64>,
pub last_applied: Option<LogId>,
pub current_leader: Option<NodeId>,
pub membership_config: EffectiveMembership,
pub snapshot: Option<LogId>,
pub leader_metrics: Option<LeaderMetrics>,
}
impl MessageSummary for RaftMetrics {
fn summary(&self) -> String {
format!("Metrics{{id:{},{:?}, term:{}, last_log:{:?}, last_applied:{:?}, leader:{:?}, membership:{}, snapshot:{:?}, replication:{}",
self.id,
self.state,
self.current_term,
self.last_log_index,
self.last_applied,
self.current_leader,
self.membership_config.summary(),
self.snapshot,
self.leader_metrics.as_ref().map(|x| x.summary()).unwrap_or_default(),
)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LeaderMetrics {
pub replication: HashMap<NodeId, ReplicationMetrics>,
}
impl MessageSummary for LeaderMetrics {
fn summary(&self) -> String {
let mut res = vec!["LeaderMetrics{".to_string()];
for (i, (k, v)) in self.replication.iter().enumerate() {
if i > 0 {
res.push(",".to_string());
}
res.push(format!("{}:{}", k, v.summary()));
}
res.push("}".to_string());
res.join("")
}
}
impl RaftMetrics {
pub(crate) fn new_initial(id: NodeId) -> Self {
let membership_config = Membership::new_initial(id);
Self {
running_state: Ok(()),
id,
state: State::Follower,
current_term: 0,
last_log_index: None,
last_applied: None,
current_leader: None,
membership_config: EffectiveMembership {
log_id: LogId::default(),
membership: membership_config,
},
snapshot: None,
leader_metrics: None,
}
}
}
#[derive(Debug, Error)]
pub enum WaitError {
#[error("timeout after {0:?} when {1}")]
Timeout(Duration, String),
#[error("raft is shutting down")]
ShuttingDown,
}
pub struct Wait {
pub timeout: Duration,
pub rx: watch::Receiver<RaftMetrics>,
}
impl Wait {
#[tracing::instrument(level = "trace", skip(self, func), fields(msg=%msg.to_string()))]
pub async fn metrics<T>(&self, func: T, msg: impl ToString) -> Result<RaftMetrics, WaitError>
where T: Fn(&RaftMetrics) -> bool + Send {
let timeout_at = Instant::now() + self.timeout;
let mut rx = self.rx.clone();
loop {
let latest = rx.borrow().clone();
tracing::debug!(
"id={} wait {:} latest: {}",
latest.id,
msg.to_string(),
latest.summary()
);
if func(&latest) {
tracing::debug!(
"id={} done wait {:} latest: {}",
latest.id,
msg.to_string(),
latest.summary()
);
return Ok(latest);
}
let now = Instant::now();
if now >= timeout_at {
return Err(WaitError::Timeout(
self.timeout,
format!("{} latest: {:?}", msg.to_string(), latest),
));
}
let sleep_time = timeout_at - now;
tracing::debug!(?sleep_time, "wait timeout");
let delay = tokio::time::sleep(sleep_time);
tokio::select! {
_ = delay => {
tracing::debug!( "id={} timeout wait {:} latest: {}", latest.id, msg.to_string(), latest.summary() );
return Err(WaitError::Timeout(self.timeout, format!("{} latest: {}", msg.to_string(), latest.summary())));
}
changed = rx.changed() => {
match changed {
Ok(_) => {
},
Err(err) => {
tracing::debug!(
"id={} error: {:?}; wait {:} latest: {:?}",
latest.id,
err,
msg.to_string(),
latest
);
return Err(WaitError::ShuttingDown);
}
}
}
};
}
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn current_leader(&self, leader_id: NodeId, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.current_leader == Some(leader_id),
&format!("{} .current_leader -> {}", msg.to_string(), leader_id),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log(&self, want_log_index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.last_log_index == want_log_index,
&format!("{} .last_log_index -> {:?}", msg.to_string(), want_log_index),
)
.await?;
self.metrics(
|x| x.last_applied.index() == want_log_index,
&format!("{} .last_applied -> {:?}", msg.to_string(), want_log_index),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log_at_least(&self, want_log: u64, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.last_log_index >= Some(want_log),
&format!("{} .last_log_index >= {}", msg.to_string(), want_log),
)
.await?;
self.metrics(
|x| x.last_applied.index() >= Some(want_log),
&format!("{} .last_applied >= {}", msg.to_string(), want_log),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn state(&self, want_state: State, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.state == want_state,
&format!("{} .state -> {:?}", msg.to_string(), want_state),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn members(&self, want_members: BTreeSet<NodeId>, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.membership_config.membership.get_ith_config(0).cloned().unwrap() == want_members,
&format!("{} .membership_config.members -> {:?}", msg.to_string(), want_members),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn next_members(
&self,
want_members: Option<BTreeSet<NodeId>>,
msg: impl ToString,
) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.membership_config.membership.get_ith_config(1) == want_members.as_ref(),
&format!("{} .membership_config.next -> {:?}", msg.to_string(), want_members),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn snapshot(&self, want_snapshot: LogId, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.snapshot == Some(want_snapshot),
&format!("{} .snapshot -> {:?}", msg.to_string(), want_snapshot),
)
.await
}
}