pub(crate) mod api;
#[cfg(test)]
mod declare_raft_types_test;
mod impl_raft_blocking_write;
pub mod linearizable_read;
pub(crate) mod message;
mod raft_inner;
pub mod responder;
mod runtime_config_handle;
pub(crate) mod stream_append;
pub mod trigger;
mod watch_handle;
pub(crate) use api::app::AppApi;
pub(crate) use api::management::ManagementApi;
pub(crate) use api::protocol::ProtocolApi;
pub(in crate::raft) mod core_state;
mod leader;
use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use core_state::CoreState;
use derive_more::Display;
use futures_util::FutureExt;
use linearizable_read::Linearizer;
pub use message::AppendEntriesRequest;
pub use message::AppendEntriesResponse;
pub use message::ClientWriteResponse;
pub use message::ClientWriteResult;
pub use message::InstallSnapshotRequest;
pub use message::InstallSnapshotResponse;
pub use message::LogSegment;
pub use message::SnapshotResponse;
pub use message::StreamAppendError;
pub use message::TransferLeaderRequest;
pub use message::VoteRequest;
pub use message::VoteResponse;
pub use message::WriteRequest;
pub use message::WriteResponse;
pub use message::WriteResult;
use openraft_macros::since;
pub use stream_append::StreamAppendResult;
use tracing::Instrument;
use tracing::Level;
use tracing::trace_span;
pub use self::leader::Leader;
pub use self::watch_handle::WatchChangeHandle;
use crate::Extensions;
use crate::OptionalSend;
use crate::RaftNetworkFactory;
use crate::RaftState;
pub use crate::RaftTypeConfig;
use crate::StorageError;
use crate::StorageHelper;
use crate::async_runtime::MpscWeakSender;
use crate::async_runtime::OneshotSender;
use crate::async_runtime::mpsc::MpscSender;
use crate::async_runtime::watch::WatchReceiver;
use crate::base::BoxFuture;
use crate::base::BoxOnce;
use crate::base::BoxStream;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::ClientResponderQueue;
use crate::core::RaftCore;
use crate::core::SharedReplicateBatch;
use crate::core::Tick;
use crate::core::heartbeat::handle::HeartbeatWorkersHandle;
use crate::core::io_flush_tracking::AppliedProgress;
use crate::core::io_flush_tracking::CommitProgress;
pub use crate::core::io_flush_tracking::FlushPoint;
use crate::core::io_flush_tracking::IoProgressWatcher;
use crate::core::io_flush_tracking::LogProgress;
use crate::core::io_flush_tracking::SnapshotProgress;
use crate::core::io_flush_tracking::VoteProgress;
use crate::core::merged_raft_msg_receiver::BatchRaftMsgReceiver;
use crate::core::notification::Notification;
use crate::core::raft_msg::RaftMsg;
use crate::core::raft_msg::external_command::ExternalCommand;
use crate::core::runtime_stats::RuntimeStats;
use crate::core::sm;
use crate::core::sm::worker;
use crate::engine::Engine;
use crate::engine::EngineConfig;
use crate::entry::EntryPayload;
use crate::errors::ClientWriteError;
use crate::errors::Fatal;
use crate::errors::ForwardToLeader;
use crate::errors::InitializeError;
use crate::errors::LinearizableReadError;
use crate::errors::RaftError;
use crate::errors::into_raft_result::IntoRaftResult;
use crate::membership::IntoNodes;
use crate::metrics::MetricsRecorder;
use crate::metrics::RaftDataMetrics;
use crate::metrics::RaftMetrics;
use crate::metrics::RaftServerMetrics;
use crate::metrics::Wait;
use crate::raft::raft_inner::RaftInner;
pub use crate::raft::runtime_config_handle::RuntimeConfigHandle;
use crate::raft::trigger::Trigger;
use crate::raft_state::IOId;
use crate::storage::RaftLogStorage;
use crate::storage::RaftStateMachine;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::JoinErrorOf;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::MpscWeakSenderOf;
use crate::type_config::alias::NodeIdOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::type_config::alias::SnapshotOf;
use crate::type_config::alias::VoteOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::alias::WriteResponderOf;
use crate::vote::Vote;
use crate::vote::leader_id::raft_leader_id::RaftLeaderId;
use crate::vote::leader_id::raft_leader_id::RaftLeaderIdExt;
use crate::vote::non_committed::UncommittedVote;
use crate::vote::raft_vote::RaftVote;
use crate::vote::raft_vote::RaftVoteExt;
#[macro_export]
macro_rules! declare_raft_types {
($(#[$outer:meta])* $visibility:vis $id:ident) => {
$crate::declare_raft_types!($(#[$outer])* $visibility $id:);
};
($(#[$outer:meta])* $visibility:vis $id:ident: $($(#[$inner:meta])* $type_id:ident = $type:ty),* $(,)? ) => {
$(#[$outer])*
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd)]
$visibility struct $id {}
impl $crate::RaftTypeConfig for $id {
$crate::macros::expand!(
KEYED,
(T, ATTR, V) => {ATTR type T = V;},
$(($type_id, $(#[$inner])*, $type),)*
(D , , String ),
(R , , String ),
(NodeId , , u64 ),
(Node , , $crate::impls::BasicNode ),
(Term , , u64 ),
(LeaderId , , $crate::impls::leader_id_adv::LeaderId<Self::Term, Self::NodeId> ),
(Vote , , $crate::impls::Vote<Self::LeaderId> ),
(Entry , , $crate::Entry<<Self::LeaderId as $crate::vote::RaftLeaderId>::Committed, Self::D, Self::NodeId, Self::Node> ),
(SnapshotData , , std::io::Cursor<Vec<u8>> ),
(Responder<T> , , $crate::impls::ProgressResponder<Self, T> where T: $crate::OptionalSend + 'static ),
(Batch<T> , , $crate::impls::InlineBatch<T> where T: $crate::OptionalSend + 'static ),
(AsyncRuntime , , $crate::impls::TokioRuntime ),
(ErrorSource , , $crate::impls::BoxedErrorSource ),
);
}
};
}
#[derive(Clone, Debug, Display, PartialEq, Eq)]
pub enum ReadPolicy {
LeaseRead,
ReadIndex,
}
pub struct Raft<C, SM = ()>
where C: RaftTypeConfig
{
inner: Arc<RaftInner<C>>,
sm_cmd_tx: MpscWeakSenderOf<C, sm::Command<C, SM>>,
}
impl<C, SM> Clone for Raft<C, SM>
where C: RaftTypeConfig
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
sm_cmd_tx: self.sm_cmd_tx.clone(),
}
}
}
impl<C, SM> Debug for Raft<C, SM>
where C: RaftTypeConfig
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Raft").field("id", &self.inner.id).finish()
}
}
async fn io_completion_forwarder<C>(
mut rx_io: WatchReceiverOf<C, Result<IOId<C>, StorageError<C>>>,
weak_tx_notify: MpscWeakSenderOf<C, Notification<C>>,
) where
C: RaftTypeConfig,
{
const BATCH_INTERVAL: Duration = Duration::from_micros(1);
loop {
let deadline = C::now() + BATCH_INTERVAL;
if rx_io.changed().await.is_err() {
tracing::debug!("IO completion watch channel closed, forwarder exiting");
break;
}
let now = C::now();
if now < deadline {
C::sleep_until(deadline).await;
let _ = rx_io.changed().now_or_never();
}
let result = {
let borrowed = rx_io.borrow_watched();
borrowed.clone()
};
let Some(tx) = weak_tx_notify.upgrade() else {
tracing::debug!("Notification channel closed, forwarder exiting");
break;
};
let notification = match result {
Ok(io_id) => Notification::LocalIO { io_id },
Err(storage_error) => Notification::StorageError { error: storage_error },
};
if let Err(e) = tx.send(notification).await {
tracing::warn!("failed to forward IO completion: {}", e.0);
break;
}
}
}
impl<C, SM> Raft<C, SM>
where
C: RaftTypeConfig,
SM: RaftStateMachine<C>,
{
#[tracing::instrument(level="debug", skip_all, fields(cluster=%config.cluster_name))]
pub async fn new<LS, N>(
id: C::NodeId,
config: Arc<Config>,
network: N,
mut log_store: LS,
mut state_machine: SM,
) -> Result<Self, Fatal<C>>
where
N: RaftNetworkFactory<C>,
LS: RaftLogStorage<C>,
{
let api_channel_size = config.api_channel_size();
let notification_channel_size = config.notification_channel_size();
let (tx_api, rx_api) = C::mpsc(api_channel_size);
let (tx_notify, rx_notify) = C::mpsc(notification_channel_size);
let (tx_metrics, rx_metrics) = C::watch_channel(RaftMetrics::new_initial(id.clone()));
let (tx_data_metrics, rx_data_metrics) = C::watch_channel(RaftDataMetrics::default());
let (tx_server_metrics, rx_server_metrics) = C::watch_channel(RaftServerMetrics::new_initial(id.clone()));
let leader_id = C::LeaderId::new_with_default_term(id.clone());
let dummy_io_id = IOId::Vote(UncommittedVote::new(leader_id));
let (tx_io_completed, rx_io_completed) = C::watch_channel(Ok(dummy_io_id));
let weak_tx_notify = tx_notify.downgrade();
let (tx_progress, progress_watcher) = IoProgressWatcher::new();
let (tx_shutdown, rx_shutdown) = C::oneshot();
let tick_handle = Tick::spawn(
Duration::from_millis(config.heartbeat_interval * 3 / 2),
tx_notify.clone(),
config.enable_tick,
);
let runtime_config = Arc::new(RuntimeConfig::new(&config));
let core_span = tracing::span!(
parent: tracing::Span::current(),
Level::DEBUG,
"RaftCore",
id = display(&id),
cluster = display(&config.cluster_name)
);
let eng_config = EngineConfig::new(id.clone(), config.as_ref());
let state = {
let mut helper = StorageHelper::new(&mut log_store, &mut state_machine).with_id(id.clone());
helper.get_initial_state().await?
};
let engine = Engine::new(state, eng_config);
let sm_span = tracing::span!(parent: &core_span, Level::DEBUG, "sm_worker");
let sm_handle = worker::Worker::spawn(
state_machine,
log_store.get_log_reader().await,
tx_notify.clone(),
config.state_machine_channel_size(),
sm_span,
);
let sm_cmd_tx = sm_handle.downgrade_sender();
let default_io_id = IOId::new_vote_io(UncommittedVote::new_with_default_term(id.clone()));
let (io_accepted_tx, _io_accepted_rx) = C::watch_channel(default_io_id.clone());
let (io_submitted_tx, _io_submitted_rx) = C::watch_channel(default_io_id);
let (committed_tx, _committed_rx) = C::watch_channel(None);
let shared_replicate_batch = SharedReplicateBatch::new();
let core: RaftCore<C, N, LS, SM> = RaftCore {
id: id.clone(),
config: config.clone(),
runtime_config: runtime_config.clone(),
core_state: Default::default(),
network_factory: network,
log_store,
sm_handle,
engine,
client_responders: ClientResponderQueue::with_capacity(1024 * 8),
replications: Default::default(),
heartbeat_handle: HeartbeatWorkersHandle::new(id.clone(), config.clone()),
tx_api: tx_api.clone(),
rx_api: BatchRaftMsgReceiver::new(rx_api),
tx_notification: tx_notify,
rx_notification: rx_notify,
tx_io_completed,
io_accepted_tx,
io_submitted_tx,
committed_tx,
tx_metrics,
tx_data_metrics,
tx_server_metrics,
tx_progress,
runtime_stats: RuntimeStats::new(&config),
shared_replicate_batch,
metrics_recorder: None,
span: core_span,
};
let _forwarder_handle = C::spawn(io_completion_forwarder::<C>(rx_io_completed, weak_tx_notify));
let core_handle = C::spawn(core.main(rx_shutdown).instrument(trace_span!("spawn").or_current()));
let inner = RaftInner {
id,
config,
runtime_config,
tick_handle,
tx_api,
rx_metrics,
rx_data_metrics,
rx_server_metrics,
progress_watcher,
tx_shutdown: Mutex::new(Some(tx_shutdown)),
core_state: Mutex::new(CoreState::Running(core_handle)),
extensions: Extensions::default(),
};
Ok(Self {
inner: Arc::new(inner),
sm_cmd_tx,
})
}
}
impl<C, SM> Raft<C, SM>
where C: RaftTypeConfig
{
pub fn runtime_config(&self) -> RuntimeConfigHandle<'_, C> {
RuntimeConfigHandle::new(self.inner.as_ref())
}
pub fn config(&self) -> &Arc<Config> {
&self.inner.config
}
#[since(version = "0.10.0")]
pub fn extensions(&self) -> &Extensions {
&self.inner.extensions
}
#[since(version = "0.10.0")]
pub fn extension<T>(&self) -> T
where T: OptionalSend + Clone + Default + 'static {
self.inner.extensions.get::<T>()
}
#[cfg(feature = "runtime-stats")]
pub async fn runtime_stats(&self) -> Result<RuntimeStats<C>, Fatal<C>> {
let (tx, rx) = C::oneshot();
self.inner.call_core(RaftMsg::GetRuntimeStats { tx }, rx).await
}
#[since(version = "0.10.0")]
pub fn is_leader(&self) -> bool {
self.inner.rx_metrics.borrow_watched().state.is_leader()
}
#[since(version = "0.10.0")]
pub fn as_leader(&self) -> Result<Leader<C, SM>, ForwardToLeader<C>> {
let metrics = self.inner.rx_metrics.borrow_watched();
let Some(committed_vote) = metrics.vote.try_to_committed() else {
return Err(ForwardToLeader::empty());
};
let leader_id = committed_vote.leader_id();
let node_id = leader_id.node_id();
if node_id == &self.inner.id {
Ok(Leader {
raft: self.clone(),
leader_id: leader_id.clone(),
last_quorum_acked: metrics.last_quorum_acked.map(|s| s.into_inner()),
})
} else {
let node = metrics.membership_config.membership().get_node(node_id).cloned();
Err(ForwardToLeader {
leader_id: Some(node_id.clone()),
leader_node: node,
})
}
}
#[since(version = "0.10.0")]
pub fn node_id(&self) -> &C::NodeId {
&self.inner.id
}
#[since(version = "0.10.0")]
pub fn voter_ids(&self) -> impl Iterator<Item = C::NodeId> {
let membership = self.inner.rx_metrics.borrow_watched().membership_config.clone();
membership.voter_ids().collect::<Vec<_>>().into_iter()
}
#[since(version = "0.10.0")]
pub fn learner_ids(&self) -> impl Iterator<Item = C::NodeId> {
let membership = self.inner.rx_metrics.borrow_watched().membership_config.clone();
membership.membership().learner_ids().collect::<Vec<_>>().into_iter()
}
pub(crate) fn protocol_api(&self) -> ProtocolApi<C> {
ProtocolApi::new(self.inner.clone())
}
pub(crate) fn app_api(&self) -> AppApi<'_, C> {
AppApi::new(&self.inner)
}
pub(crate) fn management_api(&self) -> ManagementApi<'_, C> {
ManagementApi::new(self.inner.as_ref())
}
pub fn trigger(&self) -> Trigger<'_, C> {
Trigger::new(self.inner.as_ref())
}
pub async fn set_metrics_recorder(&self, recorder: Option<Arc<dyn MetricsRecorder>>) -> Result<(), Fatal<C>> {
self.inner.send_external_command(ExternalCommand::SetMetricsRecorder { recorder }).await
}
#[tracing::instrument(level = "debug", skip_all)]
pub async fn append_entries(&self, rpc: AppendEntriesRequest<C>) -> Result<AppendEntriesResponse<C>, RaftError<C>> {
self.protocol_api().append_entries(rpc).await.into_raft_result()
}
#[since(version = "0.10.0", change = "stream item contains Fatal")]
#[since(version = "0.10.0")]
pub fn stream_append<S>(
&self,
stream: S,
) -> impl futures_util::Stream<Item = Result<StreamAppendResult<C>, Fatal<C>>> + OptionalSend + 'static
where
S: futures_util::Stream<Item = AppendEntriesRequest<C>> + OptionalSend + 'static,
{
self.protocol_api().stream_append(stream)
}
#[tracing::instrument(level = "debug", skip_all)]
pub async fn vote(&self, rpc: VoteRequest<C>) -> Result<VoteResponse<C>, RaftError<C>> {
self.protocol_api().vote(rpc).await.into_raft_result()
}
#[tracing::instrument(level = "debug", skip_all)]
pub async fn get_snapshot(&self) -> Result<Option<SnapshotOf<C>>, RaftError<C>> {
self.protocol_api().get_snapshot().await.into_raft_result()
}
#[since(version = "0.10.0", change = "SnapshotData without Box")]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn begin_receiving_snapshot(&self) -> Result<SnapshotDataOf<C>, RaftError<C>> {
self.protocol_api().begin_receiving_snapshot().await.into_raft_result()
}
#[since(version = "0.9.0")]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn install_full_snapshot(
&self,
vote: VoteOf<C>,
snapshot: SnapshotOf<C>,
) -> Result<SnapshotResponse<C>, Fatal<C>> {
self.protocol_api().install_full_snapshot(vote, snapshot).await
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn current_leader(&self) -> Option<C::NodeId> {
self.metrics().borrow_watched().current_leader.clone()
}
#[since(version = "0.9.0")]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn ensure_linearizable(
&self,
read_policy: ReadPolicy,
) -> Result<Option<LogIdOf<C>>, RaftError<C, LinearizableReadError<C>>> {
let linearizer = self.app_api().get_read_linearizer(read_policy).await.into_raft_result()?;
let state = linearizer.await_ready(self).await?;
Ok(Some(state.read_log_id().clone()))
}
#[since(version = "0.9.0")]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_read_log_id(
&self,
read_policy: ReadPolicy,
) -> Result<(Option<LogIdOf<C>>, Option<LogIdOf<C>>), RaftError<C, LinearizableReadError<C>>> {
let linearizer = self.app_api().get_read_linearizer(read_policy).await.into_raft_result()?;
let read_log_id = linearizer.read_log_id();
let applied = linearizer.applied();
Ok((Some(read_log_id.clone()), applied.cloned()))
}
#[since(version = "0.10.0")]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn get_read_linearizer(
&self,
read_policy: ReadPolicy,
) -> Result<Linearizer<C>, RaftError<C, LinearizableReadError<C>>> {
self.app_api().get_read_linearizer(read_policy).await.into_raft_result()
}
#[tracing::instrument(level = "debug", skip(self, app_data))]
pub async fn client_write(
&self,
app_data: C::D,
) -> Result<ClientWriteResponse<C>, RaftError<C, ClientWriteError<C>>> {
self.app_api().client_write(EntryPayload::Normal(app_data)).await.into_raft_result()
}
#[since(version = "0.10.0")]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn write_blank(&self) -> Result<ClientWriteResponse<C>, RaftError<C, ClientWriteError<C>>> {
self.app_api().client_write(EntryPayload::Blank).await.into_raft_result()
}
#[since(version = "0.10.0", date = "2025-10-27", change = "add responder arg")]
#[since(version = "0.10.0")]
pub async fn client_write_ff(
&self,
app_data: C::D,
responder: Option<WriteResponderOf<C>>,
) -> Result<(), Fatal<C>> {
self.app_api().client_write_ff(EntryPayload::Normal(app_data), responder).await
}
#[since(version = "0.10.0")]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn client_write_many(
&self,
app_data: impl IntoIterator<Item = C::D>,
) -> Result<BoxStream<'static, Result<WriteResult<C>, Fatal<C>>>, Fatal<C>> {
self.app_api().client_write_many(app_data.into_iter().map(EntryPayload::Normal)).await
}
#[since(version = "0.10.0")]
pub fn write(&self, app_data: C::D) -> WriteRequest<'_, C> {
WriteRequest {
inner: &self.inner,
app_data,
responder: None,
expected_leader: None,
}
}
#[since(version = "0.10.0")]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn handle_transfer_leader(&self, req: TransferLeaderRequest<C>) -> Result<(), Fatal<C>> {
self.protocol_api().handle_transfer_leader(req).await
}
#[since(version = "0.10.0")]
pub async fn is_initialized(&self) -> Result<bool, Fatal<C>> {
let initialized = self.with_raft_state(|st| st.is_initialized()).await?;
Ok(initialized)
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn initialize<T>(&self, members: T) -> Result<(), RaftError<C, InitializeError<C>>>
where T: IntoNodes<C::NodeId, C::Node> + Debug {
self.management_api().initialize(members).await.into_raft_result()
}
pub async fn with_raft_state<F, V>(&self, func: F) -> Result<V, Fatal<C>>
where
F: FnOnce(&RaftState<C>) -> V + OptionalSend + 'static,
V: OptionalSend + 'static,
{
let (tx, rx) = C::oneshot();
self.external_request(|st| {
let result = func(st);
if let Err(_err) = tx.send(result) {
tracing::error!("{}: to-Raft tx send error", func_name!());
}
})
.await?;
match rx.await {
Ok(res) => Ok(res),
Err(err) => {
tracing::error!("{}: rx recv error: {}", func_name!(), err);
let fatal = self.inner.get_core_stop_error().await;
Err(fatal)
}
}
}
pub async fn external_request<F>(&self, req: F) -> Result<(), Fatal<C>>
where F: FnOnce(&RaftState<C>) + OptionalSend + 'static {
let req: BoxOnce<'static, RaftState<C>> = Box::new(req);
self.inner.send_msg(RaftMsg::WithRaftState { req }).await
}
pub fn metrics(&self) -> WatchReceiverOf<C, RaftMetrics<C>> {
self.inner.rx_metrics.clone()
}
pub fn data_metrics(&self) -> WatchReceiverOf<C, RaftDataMetrics<C>> {
self.inner.rx_data_metrics.clone()
}
pub fn server_metrics(&self) -> WatchReceiverOf<C, RaftServerMetrics<C>> {
self.inner.rx_server_metrics.clone()
}
#[since(version = "0.10.0")]
#[must_use = "progress handle should be stored to track I/O progress"]
pub fn watch_log_progress(&self) -> LogProgress<C> {
self.inner.progress_watcher.log_progress()
}
#[since(version = "0.10.0")]
#[must_use = "progress handle should be stored to track vote progress"]
pub fn watch_vote_progress(&self) -> VoteProgress<C> {
self.inner.progress_watcher.vote_progress()
}
#[since(version = "0.10.0")]
#[must_use = "progress handle should be stored to track commit progress"]
pub fn watch_commit_progress(&self) -> CommitProgress<C> {
self.inner.progress_watcher.commit_progress()
}
#[since(version = "0.10.0")]
#[must_use = "progress handle should be stored to track snapshot progress"]
pub fn watch_snapshot_progress(&self) -> SnapshotProgress<C> {
self.inner.progress_watcher.snapshot_progress()
}
#[since(version = "0.10.0")]
#[must_use = "progress handle should be stored to track applied progress"]
pub fn watch_apply_progress(&self) -> AppliedProgress<C> {
self.inner.progress_watcher.apply_progress()
}
#[since(version = "0.10.0")]
#[must_use = "handle must be held to keep the watch task running"]
pub fn on_cluster_leader_change<F, Fut>(&self, mut callback: F) -> WatchChangeHandle<C>
where
F: FnMut(Option<(C::LeaderId, bool)>, (C::LeaderId, bool)) -> Fut + OptionalSend + 'static,
Fut: Future<Output = ()> + OptionalSend + 'static,
{
let mut prev_vote: Option<Vote<C::LeaderId>> = None;
self.watch_vote_change(move |new_vote, _my_node_id| {
let old_leader = prev_vote.as_ref().map(|v| v.leader_id().clone());
let new_leader = new_vote.leader_id().clone();
let fut = if old_leader.as_ref() != Some(&new_leader) {
let old_state = prev_vote.as_ref().map(|v| (v.leader_id().clone(), v.is_committed()));
let new_state = (new_vote.leader_id().clone(), new_vote.is_committed());
Some(callback(old_state, new_state))
} else {
None
};
prev_vote = Some(new_vote);
async move {
if let Some(f) = fut {
f.await;
}
}
})
}
#[since(version = "0.10.0")]
#[must_use = "handle must be held to keep the watch task running"]
pub fn on_leader_change<F1, F2, Fut1, Fut2>(&self, start: F1, stop: F2) -> WatchChangeHandle<C>
where
F1: Fn(C::LeaderId) -> Fut1 + OptionalSend + 'static,
F2: Fn(C::LeaderId) -> Fut2 + OptionalSend + 'static,
Fut1: Future<Output = ()> + OptionalSend + 'static,
Fut2: Future<Output = ()> + OptionalSend + 'static,
{
let mut prev_leader_id = None;
self.watch_vote_change(move |vote, my_node_id| {
let leader_id = vote.leader_id().clone();
#[allow(clippy::collapsible_else_if)]
let (stop_fut, start_fut) = if leader_id.node_id() == my_node_id {
if vote.is_committed() && prev_leader_id.as_ref() != Some(&leader_id) {
let stop_fut = prev_leader_id.take().map(&stop);
let start_fut = Some(start(leader_id.clone()));
prev_leader_id = Some(leader_id);
(stop_fut, start_fut)
} else {
(None, None)
}
} else {
let stop_fut = prev_leader_id.take().map(&stop);
(stop_fut, None)
};
async move {
if let Some(f) = stop_fut {
f.await;
}
if let Some(f) = start_fut {
f.await;
}
}
})
}
fn watch_vote_change<F, Fut>(&self, mut callback: F) -> WatchChangeHandle<C>
where
F: FnMut(Vote<C::LeaderId>, &NodeIdOf<C>) -> Fut + OptionalSend + 'static,
Fut: Future<Output = ()> + OptionalSend + 'static,
{
use futures_util::FutureExt;
let my_node_id = self.inner.id().clone();
let mut vote_progress = self.watch_vote_progress();
let (cancel_tx, cancel_rx) = C::oneshot::<()>();
let handle = C::spawn(async move {
let mut cancel_rx = cancel_rx.fuse();
loop {
futures_util::select! {
_ = cancel_rx => break,
res = vote_progress.changed().fuse() => {
if res.is_err() {
break;
}
let Some(vote) = vote_progress.get() else {
continue;
};
callback(vote, &my_node_id).await;
}
}
}
});
WatchChangeHandle {
cancel_tx: Some(cancel_tx),
join_handle: Some(handle),
}
}
pub fn wait(&self, timeout: Option<Duration>) -> Wait<C> {
self.inner.wait(timeout)
}
pub async fn shutdown(&self) -> Result<(), JoinErrorOf<C>> {
if let Some(tx) = self.inner.tx_shutdown.lock().unwrap().take() {
let send_res = tx.send(());
tracing::info!("sending shutdown signal to RaftCore, sending res: {:?}", send_res);
}
self.inner.join_core_task().await;
if let Some(join_handle) = self.inner.tick_handle.shutdown() {
join_handle.await.ok();
}
Ok(())
}
#[since(version = "0.10.0")]
pub async fn with_state_machine<F, V>(&self, func: F) -> Result<V, Fatal<C>>
where
SM: OptionalSend + 'static,
F: FnOnce(&mut SM) -> BoxFuture<V> + OptionalSend + 'static,
V: OptionalSend + 'static,
{
let (tx, rx) = C::oneshot();
self.external_state_machine_request(|sm| {
Box::pin(async move {
let resp = func(sm).await;
if let Err(_err) = tx.send(resp) {
tracing::error!("{}: failed to send response to user tx", func_name!());
}
})
})
.await?;
let recv_res = rx.await;
tracing::debug!("{}: receives result is error: {:?}", func_name!(), recv_res.is_err());
let Ok(v) = recv_res else {
let fatal = self.inner.get_core_stop_error().await;
tracing::error!("{}: error: {}", func_name!(), fatal);
return Err(fatal);
};
Ok(v)
}
#[since(version = "0.10.0")]
pub async fn external_state_machine_request<F>(&self, req: F) -> Result<(), Fatal<C>>
where
SM: OptionalSend + 'static,
F: FnOnce(&mut SM) -> BoxFuture<()> + OptionalSend + 'static,
{
let Some(tx) = self.sm_cmd_tx.upgrade() else {
return Err(Fatal::Stopped);
};
let sm_cmd = sm::Command::ExternalFunc {
func: Box::new(move |sm| req(sm)),
};
tx.send(sm_cmd).await.map_err(|_e| Fatal::Stopped)
}
}