use std::pin::Pin;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tracing::debug;
use nodedb_raft::message::LogEntry;
use crate::applied_watcher::GroupAppliedWatchers;
use crate::catalog::ClusterCatalog;
use crate::error::Result;
use crate::forward::{NoopPlanExecutor, PlanExecutor};
use crate::loop_metrics::LoopMetrics;
use crate::metadata_group::applier::{MetadataApplier, NoopMetadataApplier};
use crate::multi_raft::MultiRaft;
use crate::topology::ClusterTopology;
use crate::transport::NexarTransport;
pub(super) const DEFAULT_TICK_INTERVAL: Duration = Duration::from_millis(10);
pub trait CommitApplier: Send + Sync + 'static {
fn apply_committed(&self, group_id: u64, entries: &[LogEntry]) -> u64;
}
pub trait SnapshotQuarantineHook: Send + Sync + 'static {
fn is_quarantined(&self, group_id: u64, last_included_index: u64) -> bool;
fn record_success(&self, group_id: u64, last_included_index: u64);
fn record_failure(&self, group_id: u64, last_included_index: u64, error: &str) -> bool;
}
pub type VShardEnvelopeHandler = Arc<
dyn Fn(Vec<u8>) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>
+ Send
+ Sync,
>;
pub struct RaftLoop<A: CommitApplier, P: PlanExecutor = NoopPlanExecutor> {
pub(super) node_id: u64,
pub(super) multi_raft: Arc<Mutex<MultiRaft>>,
pub(super) transport: Arc<NexarTransport>,
pub(super) topology: Arc<RwLock<ClusterTopology>>,
pub(super) applier: A,
pub(super) metadata_applier: Arc<dyn MetadataApplier>,
pub(super) plan_executor: Arc<P>,
pub(super) tick_interval: Duration,
pub(super) vshard_handler: Option<VShardEnvelopeHandler>,
pub(super) catalog: Option<Arc<ClusterCatalog>>,
pub(super) shutdown_watch: tokio::sync::watch::Sender<bool>,
pub(super) loop_metrics: Arc<LoopMetrics>,
pub(super) ready_watch: tokio::sync::watch::Sender<bool>,
pub(super) group_watchers: Arc<GroupAppliedWatchers>,
pub(super) prev_metadata_leader: std::sync::atomic::AtomicBool,
pub(super) snapshot_quarantine_hook: Option<Arc<dyn SnapshotQuarantineHook>>,
pub(super) partial_snapshots: Arc<crate::install_snapshot::PartialSnapshotMap>,
pub(super) data_dir: Option<std::path::PathBuf>,
pub(super) snapshot_chunk_bytes: u64,
pub(super) orphan_partial_max_age_secs: u64,
}
impl<A: CommitApplier> RaftLoop<A> {
pub fn new(
multi_raft: MultiRaft,
transport: Arc<NexarTransport>,
topology: Arc<RwLock<ClusterTopology>>,
applier: A,
) -> Self {
let node_id = multi_raft.node_id();
let (shutdown_watch, _) = tokio::sync::watch::channel(false);
let (ready_watch, _) = tokio::sync::watch::channel(false);
Self {
node_id,
multi_raft: Arc::new(Mutex::new(multi_raft)),
transport,
topology,
applier,
metadata_applier: Arc::new(NoopMetadataApplier),
plan_executor: Arc::new(NoopPlanExecutor),
tick_interval: DEFAULT_TICK_INTERVAL,
vshard_handler: None,
catalog: None,
shutdown_watch,
ready_watch,
loop_metrics: LoopMetrics::new("raft_tick_loop"),
group_watchers: Arc::new(GroupAppliedWatchers::new()),
prev_metadata_leader: std::sync::atomic::AtomicBool::new(false),
snapshot_quarantine_hook: None,
partial_snapshots: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
data_dir: None,
snapshot_chunk_bytes: 4 * 1024 * 1024,
orphan_partial_max_age_secs: 300,
}
}
}
impl<A: CommitApplier, P: PlanExecutor> RaftLoop<A, P> {
pub fn set_snapshot_quarantine_hook(&mut self, hook: Arc<dyn SnapshotQuarantineHook>) {
self.snapshot_quarantine_hook = Some(hook);
}
pub fn with_plan_executor<P2: PlanExecutor>(self, executor: Arc<P2>) -> RaftLoop<A, P2> {
RaftLoop {
node_id: self.node_id,
multi_raft: self.multi_raft,
transport: self.transport,
topology: self.topology,
applier: self.applier,
metadata_applier: self.metadata_applier,
plan_executor: executor,
tick_interval: self.tick_interval,
vshard_handler: self.vshard_handler,
catalog: self.catalog,
shutdown_watch: self.shutdown_watch,
ready_watch: self.ready_watch,
loop_metrics: self.loop_metrics,
group_watchers: self.group_watchers,
prev_metadata_leader: self.prev_metadata_leader,
snapshot_quarantine_hook: self.snapshot_quarantine_hook,
partial_snapshots: self.partial_snapshots,
data_dir: self.data_dir,
snapshot_chunk_bytes: self.snapshot_chunk_bytes,
orphan_partial_max_age_secs: self.orphan_partial_max_age_secs,
}
}
pub fn with_group_watchers(mut self, watchers: Arc<GroupAppliedWatchers>) -> Self {
self.group_watchers = watchers;
self
}
pub fn with_snapshot_quarantine_hook(mut self, hook: Arc<dyn SnapshotQuarantineHook>) -> Self {
self.snapshot_quarantine_hook = Some(hook);
self
}
pub fn with_data_dir(mut self, data_dir: std::path::PathBuf) -> Self {
self.data_dir = Some(data_dir);
self
}
pub fn with_snapshot_chunk_bytes(mut self, chunk_bytes: u64) -> Self {
self.snapshot_chunk_bytes = chunk_bytes;
self
}
pub fn with_orphan_partial_max_age_secs(mut self, secs: u64) -> Self {
self.orphan_partial_max_age_secs = secs;
self
}
pub fn group_watchers(&self) -> Arc<GroupAppliedWatchers> {
Arc::clone(&self.group_watchers)
}
pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
Arc::clone(&self.loop_metrics)
}
pub fn pending_groups(&self) -> usize {
let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
mr.group_count()
}
pub fn begin_shutdown(&self) {
let _ = self.shutdown_watch.send(true);
}
pub fn subscribe_ready(&self) -> tokio::sync::watch::Receiver<bool> {
self.ready_watch.subscribe()
}
pub fn with_vshard_handler(mut self, handler: VShardEnvelopeHandler) -> Self {
self.vshard_handler = Some(handler);
self
}
pub fn with_metadata_applier(mut self, applier: Arc<dyn MetadataApplier>) -> Self {
self.metadata_applier = applier;
self
}
pub fn with_tick_interval(mut self, interval: Duration) -> Self {
self.tick_interval = interval;
self
}
pub fn with_catalog(mut self, catalog: Arc<ClusterCatalog>) -> Self {
if let Err(e) = crate::cluster_epoch::init_local_cluster_epoch_from_catalog(&catalog) {
tracing::warn!(error = %e, "failed to load persisted cluster_epoch; defaulting to 0");
}
self.catalog = Some(catalog);
self
}
pub fn node_id(&self) -> u64 {
self.node_id
}
pub async fn run(&self, mut shutdown: tokio::sync::watch::Receiver<bool>) {
let mut interval = tokio::time::interval(self.tick_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
self.loop_metrics.set_up(true);
if let Some(ref dir) = self.data_dir {
match crate::install_snapshot::gc::sweep_orphans(dir, self.orphan_partial_max_age_secs)
{
Ok((removed, errs)) => {
if removed > 0 {
tracing::info!(removed, "startup: removed orphaned partial snapshot files");
}
for e in errs {
tracing::warn!(error = %e, "startup: partial snapshot GC error");
}
}
Err(e) => {
tracing::warn!(error = %e, "startup: failed to sweep partial snapshot directory");
}
}
}
loop {
tokio::select! {
_ = interval.tick() => {
let started = Instant::now();
self.do_tick();
self.loop_metrics.observe(started.elapsed());
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
debug!("raft loop shutting down");
self.begin_shutdown();
break;
}
}
}
}
self.loop_metrics.set_up(false);
}
pub fn multi_raft_handle(&self) -> Arc<Mutex<crate::multi_raft::MultiRaft>> {
self.multi_raft.clone()
}
pub fn group_statuses(&self) -> Vec<crate::multi_raft::GroupStatus> {
let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
mr.group_statuses()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::routing::RoutingTable;
use nodedb_types::config::tuning::ClusterTransportTuning;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
pub(crate) struct CountingApplier {
applied: Arc<AtomicU64>,
}
impl CountingApplier {
pub(crate) fn new() -> Self {
Self {
applied: Arc::new(AtomicU64::new(0)),
}
}
pub(crate) fn count(&self) -> u64 {
self.applied.load(Ordering::Relaxed)
}
pub(crate) fn metadata_applier(&self) -> Arc<CountingMetadataApplier> {
Arc::new(CountingMetadataApplier {
applied: self.applied.clone(),
})
}
}
impl CommitApplier for CountingApplier {
fn apply_committed(&self, _group_id: u64, entries: &[LogEntry]) -> u64 {
self.applied
.fetch_add(entries.len() as u64, Ordering::Relaxed);
entries.last().map(|e| e.index).unwrap_or(0)
}
}
pub(crate) struct CountingMetadataApplier {
applied: Arc<AtomicU64>,
}
impl MetadataApplier for CountingMetadataApplier {
fn apply(&self, entries: &[(u64, Vec<u8>)]) -> u64 {
self.applied
.fetch_add(entries.len() as u64, Ordering::Relaxed);
entries.last().map(|(idx, _)| *idx).unwrap_or(0)
}
}
fn make_transport(node_id: u64) -> Arc<NexarTransport> {
Arc::new(
NexarTransport::new(
node_id,
"127.0.0.1:0".parse().unwrap(),
crate::transport::credentials::TransportCredentials::Insecure,
)
.unwrap(),
)
}
#[tokio::test]
async fn single_node_raft_loop_commits() {
let dir = tempfile::tempdir().unwrap();
let transport = make_transport(1);
let rt = RoutingTable::uniform(1, &[1], 1);
let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
mr.add_group(0, vec![]).unwrap();
mr.add_group(1, vec![]).unwrap();
for node in mr.groups_mut().values_mut() {
node.election_deadline_override(Instant::now() - Duration::from_millis(1));
}
let applier = CountingApplier::new();
let meta = applier.metadata_applier();
let topo = Arc::new(RwLock::new(ClusterTopology::new()));
let raft_loop =
Arc::new(RaftLoop::new(mr, transport, topo, applier).with_metadata_applier(meta));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let rl = raft_loop.clone();
let run_handle = tokio::spawn(async move {
rl.run(shutdown_rx).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
raft_loop.applier.count() >= 1,
"expected at least 1 applied entry (no-op), got {}",
raft_loop.applier.count()
);
let (_gid, idx) = raft_loop.propose(0, b"hello".to_vec()).unwrap();
assert!(idx >= 2);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
raft_loop.applier.count() >= 2,
"expected at least 2 applied entries, got {}",
raft_loop.applier.count()
);
shutdown_tx.send(true).unwrap();
run_handle.abort();
}
#[tokio::test]
async fn three_node_election_over_quic() {
let t1 = make_transport(1);
let t2 = make_transport(2);
let t3 = make_transport(3);
t1.register_peer(2, t2.local_addr());
t1.register_peer(3, t3.local_addr());
t2.register_peer(1, t1.local_addr());
t2.register_peer(3, t3.local_addr());
t3.register_peer(1, t1.local_addr());
t3.register_peer(2, t2.local_addr());
let rt = RoutingTable::uniform(1, &[1, 2, 3], 3);
let dir1 = tempfile::tempdir().unwrap();
let mut mr1 = MultiRaft::new(1, rt.clone(), dir1.path().to_path_buf());
mr1.add_group(0, vec![2, 3]).unwrap();
mr1.add_group(1, vec![2, 3]).unwrap();
for node in mr1.groups_mut().values_mut() {
node.election_deadline_override(Instant::now() - Duration::from_millis(1));
}
let transport_tuning = ClusterTransportTuning::default();
let election_timeout_min =
Duration::from_millis(transport_tuning.effective_election_timeout_min_ms());
let election_timeout_max =
Duration::from_millis(transport_tuning.effective_election_timeout_max_ms());
let dir2 = tempfile::tempdir().unwrap();
let mut mr2 = MultiRaft::new(2, rt.clone(), dir2.path().to_path_buf())
.with_election_timeout(election_timeout_min, election_timeout_max);
mr2.add_group(0, vec![1, 3]).unwrap();
mr2.add_group(1, vec![1, 3]).unwrap();
let dir3 = tempfile::tempdir().unwrap();
let mut mr3 = MultiRaft::new(3, rt.clone(), dir3.path().to_path_buf())
.with_election_timeout(election_timeout_min, election_timeout_max);
mr3.add_group(0, vec![1, 2]).unwrap();
mr3.add_group(1, vec![1, 2]).unwrap();
let a1 = CountingApplier::new();
let m1 = a1.metadata_applier();
let a2 = CountingApplier::new();
let m2 = a2.metadata_applier();
let a3 = CountingApplier::new();
let m3 = a3.metadata_applier();
let topo1 = Arc::new(RwLock::new(ClusterTopology::new()));
let topo2 = Arc::new(RwLock::new(ClusterTopology::new()));
let topo3 = Arc::new(RwLock::new(ClusterTopology::new()));
let rl1 = Arc::new(RaftLoop::new(mr1, t1.clone(), topo1, a1).with_metadata_applier(m1));
let rl2 = Arc::new(RaftLoop::new(mr2, t2.clone(), topo2, a2).with_metadata_applier(m2));
let rl3 = Arc::new(RaftLoop::new(mr3, t3.clone(), topo3, a3).with_metadata_applier(m3));
let (shutdown_tx, _) = tokio::sync::watch::channel(false);
let rl2_h = rl2.clone();
let sr2 = shutdown_tx.subscribe();
tokio::spawn(async move { t2.serve(rl2_h, sr2).await });
let rl3_h = rl3.clone();
let sr3 = shutdown_tx.subscribe();
tokio::spawn(async move { t3.serve(rl3_h, sr3).await });
let rl1_r = rl1.clone();
let sr1 = shutdown_tx.subscribe();
tokio::spawn(async move { rl1_r.run(sr1).await });
let rl2_r = rl2.clone();
let sr2r = shutdown_tx.subscribe();
tokio::spawn(async move { rl2_r.run(sr2r).await });
let rl3_r = rl3.clone();
let sr3r = shutdown_tx.subscribe();
tokio::spawn(async move { rl3_r.run(sr3r).await });
let rl1_h = rl1.clone();
let sr1h = shutdown_tx.subscribe();
tokio::spawn(async move { t1.serve(rl1_h, sr1h).await });
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
loop {
if rl1.applier.count() >= 1 {
break;
}
assert!(
tokio::time::Instant::now() < deadline,
"node 1 should have committed at least the no-op, got {}",
rl1.applier.count()
);
tokio::time::sleep(Duration::from_millis(20)).await;
}
let (_gid, idx) = rl1.propose(0, b"distributed-cmd".to_vec()).unwrap();
assert!(idx >= 2);
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
loop {
if rl1.applier.count() >= 2 && rl2.applier.count() >= 1 && rl3.applier.count() >= 1 {
break;
}
assert!(
tokio::time::Instant::now() < deadline,
"replication timed out: n1={}, n2={}, n3={}",
rl1.applier.count(),
rl2.applier.count(),
rl3.applier.count()
);
tokio::time::sleep(Duration::from_millis(20)).await;
}
shutdown_tx.send(true).unwrap();
}
}