use core::time::Duration;
use std::collections::BTreeSet;
use futures_util::FutureExt;
use crate::LogIdOptionExt;
use crate::OptionalSend;
use crate::RaftTypeConfig;
use crate::async_runtime::watch::WatchReceiver;
use crate::core::ServerState;
use crate::metrics::Condition;
use crate::metrics::Metric;
use crate::metrics::RaftMetrics;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::VoteOf;
use crate::type_config::alias::WatchReceiverOf;
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum WaitError {
#[error("timeout after {0:?} when {1}")]
Timeout(Duration, String),
#[error("raft is shutting down")]
ShuttingDown,
}
pub struct Wait<C: RaftTypeConfig> {
pub timeout: Duration,
pub rx: WatchReceiverOf<C, RaftMetrics<C>>,
}
impl<C> Wait<C>
where C: RaftTypeConfig
{
#[tracing::instrument(level = "trace", skip(self, func), fields(msg=%msg.to_string()))]
pub async fn metrics<T>(&self, func: T, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError>
where T: Fn(&RaftMetrics<C>) -> bool + OptionalSend {
let timeout_at = C::now() + self.timeout;
let mut rx = self.rx.clone();
loop {
let latest = rx.borrow_watched().clone();
tracing::debug!("id={} wait {:} latest: {}", latest.id, msg.to_string(), latest);
if func(&latest) {
tracing::debug!("id={} done wait {:} latest: {}", latest.id, msg.to_string(), latest);
return Ok(latest);
}
let now = C::now();
if now >= timeout_at {
return Err(WaitError::Timeout(
self.timeout,
format!("{} latest: {}", msg.to_string(), latest),
));
}
let sleep_time = timeout_at - now;
tracing::debug!(?sleep_time, "wait timeout");
let delay = C::sleep(sleep_time);
futures_util::select_biased! {
_ = delay.fuse() => {
tracing::debug!( "id={} timeout wait {:} latest: {}", latest.id, msg.to_string(), latest );
return Err(WaitError::Timeout(self.timeout, format!("{} latest: {}", msg.to_string(), latest)));
}
changed = rx.changed().fuse() => {
match changed {
Ok(_) => {
},
Err(err) => {
tracing::debug!(
"id={} error: {:?}; wait {:} latest: {:?}",
latest.id,
err,
msg.to_string(),
latest
);
return Err(WaitError::ShuttingDown);
}
}
}
}
}
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn vote(&self, want: VoteOf<C>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::Vote(want), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn current_leader(&self, leader_id: C::NodeId, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|m| m.current_leader.as_ref() == Some(&leader_id),
&format!("{} .current_leader == {}", msg.to_string(), leader_id),
)
.await
}
#[deprecated(since = "0.9.0", note = "use `log_index()` and `applied_index()` instead")]
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log(&self, want_log_index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::LastLogIndex(want_log_index), msg.to_string()).await?;
self.eq(Metric::AppliedIndex(want_log_index), msg.to_string()).await
}
#[deprecated(
since = "0.9.0",
note = "use `log_index_at_least()` and `applied_index_at_least()` instead"
)]
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log_at_least(&self, want_log: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.ge(Metric::LastLogIndex(want_log), msg.to_string()).await?;
self.ge(Metric::AppliedIndex(want_log), msg.to_string()).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log_index(&self, index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::LastLogIndex(index), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log_index_at_least(
&self,
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
self.ge(Metric::LastLogIndex(index), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn applied_index(&self, index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::AppliedIndex(index), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn applied_index_at_least(
&self,
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
self.ge(Metric::AppliedIndex(index), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn state(&self, want_state: ServerState, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|m| m.state == want_state,
&format!("{} .state == {:?}", msg.to_string(), want_state),
)
.await
}
#[deprecated(since = "0.9.0", note = "use `voter_ids()` instead")]
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn members(
&self,
want_members: BTreeSet<C::NodeId>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|m| {
let got = m.membership_config.membership().voter_ids().collect::<BTreeSet<_>>();
want_members == got
},
&format!("{} .members -> {:?}", msg.to_string(), want_members),
)
.await
}
#[tracing::instrument(level = "trace", skip_all, fields(msg=msg.to_string().as_str()))]
pub async fn voter_ids(
&self,
voter_ids: impl IntoIterator<Item = C::NodeId>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
let want = voter_ids.into_iter().collect::<BTreeSet<_>>();
tracing::debug!("block until voter_ids == {:?}", want);
self.metrics(
|m| {
let got = m.membership_config.membership().voter_ids().collect();
want == got
},
&format!("{} .members == {:?}", msg.to_string(), want),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn snapshot(
&self,
snapshot_last_log_id: LogIdOf<C>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::Snapshot(Some(snapshot_last_log_id)), msg).await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn committed_index(&self, index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|m| m.committed.index() == index,
&format!("{} .committed_index == {:?}", msg.to_string(), index),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn committed_index_at_least(
&self,
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|m| m.committed.index() >= index,
&format!("{} .committed_index >= {:?}", msg.to_string(), index),
)
.await
}
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn purged(&self, want: Option<LogIdOf<C>>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.eq(Metric::Purged(want), msg).await
}
pub async fn ge(&self, metric: Metric<C>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.until(Condition::ge(metric), msg).await
}
pub async fn eq(&self, metric: Metric<C>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.until(Condition::eq(metric), msg).await
}
#[tracing::instrument(level = "trace", skip_all, fields(cond=cond.to_string(), msg=msg.to_string().as_str()))]
pub(crate) async fn until(&self, cond: Condition<C>, msg: impl ToString) -> Result<RaftMetrics<C>, WaitError> {
self.metrics(
|raft_metrics| match &cond {
Condition::GE(expect) => raft_metrics >= expect,
Condition::EQ(expect) => raft_metrics == expect,
},
&format!("{} .{}", msg.to_string(), cond),
)
.await
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "serde")]
#[test]
fn test_wait_error_serde() {
use super::*;
{
let err = WaitError::Timeout(Duration::from_millis(500), "waiting for leader".to_string());
let serialized = serde_json::to_string(&err).unwrap();
let deserialized: WaitError = serde_json::from_str(&serialized).unwrap();
assert_eq!(err, deserialized);
}
{
let err = WaitError::ShuttingDown;
let serialized = serde_json::to_string(&err).unwrap();
let deserialized: WaitError = serde_json::from_str(&serialized).unwrap();
assert_eq!(err, deserialized);
}
}
}