1use std::pin::Pin;
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::Duration;
13
14use tracing::{debug, warn};
15
16use nodedb_raft::message::LogEntry;
17use nodedb_raft::transport::RaftTransport;
18
19use crate::conf_change::ConfChange;
20use crate::error::{ClusterError, Result};
21use crate::forward::RequestForwarder;
22use crate::health;
23use crate::multi_raft::MultiRaft;
24use crate::rpc_codec::RaftRpc;
25use crate::topology::ClusterTopology;
26use crate::transport::{NexarTransport, RaftRpcHandler};
27
28const DEFAULT_TICK_INTERVAL: Duration = Duration::from_millis(10);
34
35pub trait CommitApplier: Send + Sync + 'static {
40 fn apply_committed(&self, group_id: u64, entries: &[LogEntry]) -> u64;
44}
45
46pub type VShardEnvelopeHandler = Arc<
51 dyn Fn(Vec<u8>) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>
52 + Send
53 + Sync,
54>;
55
56pub struct RaftLoop<A: CommitApplier, F: RequestForwarder = crate::forward::NoopForwarder> {
62 node_id: u64,
63 multi_raft: Arc<Mutex<MultiRaft>>,
64 transport: Arc<NexarTransport>,
65 topology: Arc<RwLock<ClusterTopology>>,
66 applier: A,
67 forwarder: Arc<F>,
68 tick_interval: Duration,
69 vshard_handler: Option<VShardEnvelopeHandler>,
72}
73
74impl<A: CommitApplier> RaftLoop<A> {
75 pub fn new(
76 multi_raft: MultiRaft,
77 transport: Arc<NexarTransport>,
78 topology: Arc<RwLock<ClusterTopology>>,
79 applier: A,
80 ) -> Self {
81 let node_id = multi_raft.node_id();
82 Self {
83 node_id,
84 multi_raft: Arc::new(Mutex::new(multi_raft)),
85 transport,
86 topology,
87 applier,
88 forwarder: Arc::new(crate::forward::NoopForwarder),
89 tick_interval: DEFAULT_TICK_INTERVAL,
90 vshard_handler: None,
91 }
92 }
93}
94
95impl<A: CommitApplier, F: RequestForwarder> RaftLoop<A, F> {
96 pub fn with_forwarder(
98 multi_raft: MultiRaft,
99 transport: Arc<NexarTransport>,
100 topology: Arc<RwLock<ClusterTopology>>,
101 applier: A,
102 forwarder: Arc<F>,
103 ) -> Self {
104 let node_id = multi_raft.node_id();
105 Self {
106 node_id,
107 multi_raft: Arc::new(Mutex::new(multi_raft)),
108 transport,
109 topology,
110 applier,
111 forwarder,
112 tick_interval: DEFAULT_TICK_INTERVAL,
113 vshard_handler: None,
114 }
115 }
116
117 pub fn with_vshard_handler(mut self, handler: VShardEnvelopeHandler) -> Self {
119 self.vshard_handler = Some(handler);
120 self
121 }
122
123 pub fn with_tick_interval(mut self, interval: Duration) -> Self {
124 self.tick_interval = interval;
125 self
126 }
127
128 pub async fn run(&self, mut shutdown: tokio::sync::watch::Receiver<bool>) {
133 let mut interval = tokio::time::interval(self.tick_interval);
134 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
135
136 loop {
137 tokio::select! {
138 _ = interval.tick() => {
139 self.do_tick();
140 }
141 _ = shutdown.changed() => {
142 if *shutdown.borrow() {
143 debug!("raft loop shutting down");
144 break;
145 }
146 }
147 }
148 }
149 }
150
151 fn do_tick(&self) {
153 let ready = {
155 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
156 mr.tick()
157 };
158
159 if ready.is_empty() {
160 return;
161 }
162
163 use std::collections::HashMap as BatchMap;
167
168 let mut ae_batches: BatchMap<u64, Vec<(u64, nodedb_raft::AppendEntriesRequest)>> =
169 BatchMap::new();
170 let mut vote_batches: BatchMap<u64, Vec<(u64, nodedb_raft::RequestVoteRequest)>> =
171 BatchMap::new();
172
173 for (group_id, group_ready) in &ready.groups {
174 for (peer, req) in &group_ready.messages {
175 ae_batches
176 .entry(*peer)
177 .or_default()
178 .push((*group_id, req.clone()));
179 }
180 for (peer, req) in &group_ready.vote_requests {
181 vote_batches
182 .entry(*peer)
183 .or_default()
184 .push((*group_id, req.clone()));
185 }
186 }
187
188 for (peer, messages) in ae_batches {
190 let transport = self.transport.clone();
191 let mr = self.multi_raft.clone();
192 tokio::spawn(async move {
193 for (group_id, req) in messages {
194 match transport.append_entries(peer, req).await {
195 Ok(resp) => {
196 let mut mr = mr.lock().unwrap_or_else(|p| p.into_inner());
197 if let Err(e) = mr.handle_append_entries_response(group_id, peer, &resp)
198 {
199 debug!(group_id, peer, error = %e, "handle ae response");
200 }
201 }
202 Err(e) => {
203 warn!(group_id, peer, error = %e, "append_entries RPC failed");
204 break; }
206 }
207 }
208 });
209 }
210
211 for (peer, votes) in vote_batches {
213 let transport = self.transport.clone();
214 let mr = self.multi_raft.clone();
215 tokio::spawn(async move {
216 for (group_id, req) in votes {
217 match transport.request_vote(peer, req).await {
218 Ok(resp) => {
219 let mut mr = mr.lock().unwrap_or_else(|p| p.into_inner());
220 if let Err(e) = mr.handle_request_vote_response(group_id, peer, &resp) {
221 debug!(group_id, peer, error = %e, "handle vote response");
222 }
223 }
224 Err(e) => {
225 warn!(group_id, peer, error = %e, "request_vote RPC failed");
226 break;
227 }
228 }
229 }
230 });
231 }
232
233 for (group_id, group_ready) in ready.groups {
234 if !group_ready.committed_entries.is_empty() {
236 for entry in &group_ready.committed_entries {
238 if let Some(cc) = ConfChange::from_entry_data(&entry.data) {
239 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
240 if let Err(e) = mr.apply_conf_change(group_id, &cc) {
241 warn!(group_id, error = %e, "failed to apply conf change");
242 }
243 }
244 }
245
246 let last_applied = self
248 .applier
249 .apply_committed(group_id, &group_ready.committed_entries);
250 if last_applied > 0 {
251 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
252 if let Err(e) = mr.advance_applied(group_id, last_applied) {
253 warn!(group_id, error = %e, "failed to advance applied index");
254 }
255 }
256 }
257
258 if !group_ready.snapshots_needed.is_empty() {
260 let snapshot_meta = {
262 let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
263 mr.snapshot_metadata(group_id).ok()
264 };
265
266 if let Some((term, snap_index, snap_term)) = snapshot_meta {
267 for peer in group_ready.snapshots_needed {
268 let transport = self.transport.clone();
269 let mr = self.multi_raft.clone();
270 let req = nodedb_raft::InstallSnapshotRequest {
271 term,
272 leader_id: self.node_id,
273 last_included_index: snap_index,
274 last_included_term: snap_term,
275 offset: 0,
276 data: vec![], done: true,
278 group_id,
279 };
280 tokio::spawn(async move {
281 match transport.install_snapshot(peer, req).await {
282 Ok(resp) => {
283 if resp.term > term {
284 let mut mr = mr.lock().unwrap_or_else(|p| p.into_inner());
285 let _ = mr.handle_append_entries_response(
287 group_id,
288 peer,
289 &nodedb_raft::AppendEntriesResponse {
290 term: resp.term,
291 success: false,
292 last_log_index: 0,
293 },
294 );
295 }
296 debug!(group_id, peer, "install_snapshot sent");
297 }
298 Err(e) => {
299 warn!(
300 group_id, peer, error = %e,
301 "install_snapshot RPC failed"
302 );
303 }
304 }
305 });
306 }
307 }
308 }
309 }
310 }
311
312 pub fn propose(&self, vshard_id: u16, data: Vec<u8>) -> Result<(u64, u64)> {
316 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
317 mr.propose(vshard_id, data)
318 }
319
320 pub fn group_statuses(&self) -> Vec<crate::multi_raft::GroupStatus> {
322 let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
323 mr.group_statuses()
324 }
325
326 pub fn propose_conf_change(&self, group_id: u64, change: &ConfChange) -> Result<(u64, u64)> {
330 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
331 mr.propose_conf_change(group_id, change)
332 }
333}
334
335impl<A: CommitApplier, F: RequestForwarder> RaftRpcHandler for RaftLoop<A, F> {
338 async fn handle_rpc(&self, rpc: RaftRpc) -> Result<RaftRpc> {
339 match rpc {
340 RaftRpc::AppendEntriesRequest(req) => {
342 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
343 let resp = mr.handle_append_entries(&req)?;
344 Ok(RaftRpc::AppendEntriesResponse(resp))
345 }
346 RaftRpc::RequestVoteRequest(req) => {
347 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
348 let resp = mr.handle_request_vote(&req)?;
349 Ok(RaftRpc::RequestVoteResponse(resp))
350 }
351 RaftRpc::InstallSnapshotRequest(req) => {
352 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
353 let resp = mr.handle_install_snapshot(&req)?;
354 Ok(RaftRpc::InstallSnapshotResponse(resp))
355 }
356 RaftRpc::Ping(req) => {
358 let topo_version = {
359 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
360 topo.version()
361 };
362 Ok(health::handle_ping(self.node_id, topo_version, &req))
363 }
364 RaftRpc::TopologyUpdate(update) => {
366 let (_updated, ack) =
367 health::handle_topology_update(self.node_id, &self.topology, &update);
368 Ok(ack)
369 }
370 RaftRpc::ForwardRequest(req) => {
372 let resp = self.forwarder.execute_forwarded(req).await;
373 Ok(RaftRpc::ForwardResponse(resp))
374 }
375 RaftRpc::VShardEnvelope(bytes) => {
377 if let Some(ref handler) = self.vshard_handler {
378 let response_bytes = handler(bytes).await?;
379 Ok(RaftRpc::VShardEnvelope(response_bytes))
380 } else {
381 Err(ClusterError::Transport {
382 detail: "VShardEnvelope handler not configured".into(),
383 })
384 }
385 }
386 other => Err(ClusterError::Transport {
387 detail: format!("unexpected request type in RPC handler: {other:?}"),
388 }),
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::routing::RoutingTable;
397 use nodedb_types::config::tuning::ClusterTransportTuning;
398 use std::sync::atomic::{AtomicU64, Ordering};
399 use std::time::Instant;
400
401 struct CountingApplier {
403 applied: AtomicU64,
404 }
405
406 impl CountingApplier {
407 fn new() -> Self {
408 Self {
409 applied: AtomicU64::new(0),
410 }
411 }
412
413 fn count(&self) -> u64 {
414 self.applied.load(Ordering::Relaxed)
415 }
416 }
417
418 impl CommitApplier for CountingApplier {
419 fn apply_committed(&self, _group_id: u64, entries: &[LogEntry]) -> u64 {
420 self.applied
421 .fetch_add(entries.len() as u64, Ordering::Relaxed);
422 entries.last().map(|e| e.index).unwrap_or(0)
423 }
424 }
425
426 fn make_transport(node_id: u64) -> Arc<NexarTransport> {
428 Arc::new(NexarTransport::new(node_id, "127.0.0.1:0".parse().unwrap()).unwrap())
429 }
430
431 #[tokio::test]
432 async fn single_node_raft_loop_commits() {
433 let dir = tempfile::tempdir().unwrap();
435 let transport = make_transport(1);
436 let rt = RoutingTable::uniform(1, &[1], 1);
437 let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
438 mr.add_group(0, vec![]).unwrap();
439
440 for node in mr.groups_mut().values_mut() {
442 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
443 }
444
445 let applier = CountingApplier::new();
446 let topo = Arc::new(RwLock::new(ClusterTopology::new()));
447 let raft_loop = Arc::new(RaftLoop::new(mr, transport, topo, applier));
448
449 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
450
451 let rl = raft_loop.clone();
452 let run_handle = tokio::spawn(async move {
453 rl.run(shutdown_rx).await;
454 });
455
456 tokio::time::sleep(Duration::from_millis(50)).await;
458
459 assert!(
461 raft_loop.applier.count() >= 1,
462 "expected at least 1 applied entry (no-op), got {}",
463 raft_loop.applier.count()
464 );
465
466 let (_gid, idx) = raft_loop.propose(0, b"hello".to_vec()).unwrap();
468 assert!(idx >= 2); tokio::time::sleep(Duration::from_millis(50)).await;
472
473 assert!(
474 raft_loop.applier.count() >= 2,
475 "expected at least 2 applied entries, got {}",
476 raft_loop.applier.count()
477 );
478
479 shutdown_tx.send(true).unwrap();
480 run_handle.abort();
481 }
482
483 #[tokio::test]
484 async fn three_node_election_over_quic() {
485 let t1 = make_transport(1);
487 let t2 = make_transport(2);
488 let t3 = make_transport(3);
489
490 t1.register_peer(2, t2.local_addr());
492 t1.register_peer(3, t3.local_addr());
493 t2.register_peer(1, t1.local_addr());
494 t2.register_peer(3, t3.local_addr());
495 t3.register_peer(1, t1.local_addr());
496 t3.register_peer(2, t2.local_addr());
497
498 let rt = RoutingTable::uniform(1, &[1, 2, 3], 3);
499
500 let dir1 = tempfile::tempdir().unwrap();
502 let mut mr1 = MultiRaft::new(1, rt.clone(), dir1.path().to_path_buf());
503 mr1.add_group(0, vec![2, 3]).unwrap();
504 for node in mr1.groups_mut().values_mut() {
505 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
506 }
507
508 let transport_tuning = ClusterTransportTuning::default();
512 let election_timeout_min = Duration::from_secs(transport_tuning.election_timeout_min_secs);
513 let election_timeout_max = Duration::from_secs(transport_tuning.election_timeout_max_secs);
514
515 let dir2 = tempfile::tempdir().unwrap();
516 let mut mr2 = MultiRaft::new(2, rt.clone(), dir2.path().to_path_buf())
517 .with_election_timeout(election_timeout_min, election_timeout_max);
518 mr2.add_group(0, vec![1, 3]).unwrap();
519
520 let dir3 = tempfile::tempdir().unwrap();
521 let mut mr3 = MultiRaft::new(3, rt.clone(), dir3.path().to_path_buf())
522 .with_election_timeout(election_timeout_min, election_timeout_max);
523 mr3.add_group(0, vec![1, 2]).unwrap();
524
525 let a1 = CountingApplier::new();
526 let a2 = CountingApplier::new();
527 let a3 = CountingApplier::new();
528
529 let topo1 = Arc::new(RwLock::new(ClusterTopology::new()));
530 let topo2 = Arc::new(RwLock::new(ClusterTopology::new()));
531 let topo3 = Arc::new(RwLock::new(ClusterTopology::new()));
532
533 let rl1 = Arc::new(RaftLoop::new(mr1, t1.clone(), topo1, a1));
534 let rl2 = Arc::new(RaftLoop::new(mr2, t2.clone(), topo2, a2));
535 let rl3 = Arc::new(RaftLoop::new(mr3, t3.clone(), topo3, a3));
536
537 let (shutdown_tx, _) = tokio::sync::watch::channel(false);
538
539 let rl2_h = rl2.clone();
541 let sr2 = shutdown_tx.subscribe();
542 tokio::spawn(async move { t2.serve(rl2_h, sr2).await });
543
544 let rl3_h = rl3.clone();
545 let sr3 = shutdown_tx.subscribe();
546 tokio::spawn(async move { t3.serve(rl3_h, sr3).await });
547
548 let rl1_r = rl1.clone();
550 let sr1 = shutdown_tx.subscribe();
551 tokio::spawn(async move { rl1_r.run(sr1).await });
552
553 let rl2_r = rl2.clone();
554 let sr2r = shutdown_tx.subscribe();
555 tokio::spawn(async move { rl2_r.run(sr2r).await });
556
557 let rl3_r = rl3.clone();
558 let sr3r = shutdown_tx.subscribe();
559 tokio::spawn(async move { rl3_r.run(sr3r).await });
560
561 let rl1_h = rl1.clone();
563 let sr1h = shutdown_tx.subscribe();
564 tokio::spawn(async move { t1.serve(rl1_h, sr1h).await });
565
566 tokio::time::sleep(Duration::from_millis(200)).await;
568
569 assert!(
571 rl1.applier.count() >= 1,
572 "node 1 should have committed at least the no-op, got {}",
573 rl1.applier.count()
574 );
575
576 let (_gid, idx) = rl1.propose(0, b"distributed-cmd".to_vec()).unwrap();
578 assert!(idx >= 2);
579
580 tokio::time::sleep(Duration::from_millis(200)).await;
582
583 assert!(
585 rl1.applier.count() >= 2,
586 "node 1: expected >= 2 applied, got {}",
587 rl1.applier.count()
588 );
589
590 assert!(
593 rl2.applier.count() >= 1,
594 "node 2: expected >= 1 applied, got {}",
595 rl2.applier.count()
596 );
597 assert!(
598 rl3.applier.count() >= 1,
599 "node 3: expected >= 1 applied, got {}",
600 rl3.applier.count()
601 );
602
603 shutdown_tx.send(true).unwrap();
604 }
605
606 #[tokio::test]
607 async fn rpc_handler_routes_append_entries() {
608 let dir = tempfile::tempdir().unwrap();
609 let transport = make_transport(1);
610 let rt = RoutingTable::uniform(1, &[1], 1);
611 let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
612 mr.add_group(0, vec![]).unwrap();
613
614 for node in mr.groups_mut().values_mut() {
616 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
617 }
618
619 let topo = Arc::new(RwLock::new(ClusterTopology::new()));
620 let raft_loop = RaftLoop::new(mr, transport, topo, CountingApplier::new());
621
622 raft_loop.do_tick();
624 tokio::time::sleep(Duration::from_millis(20)).await;
625
626 let req = RaftRpc::AppendEntriesRequest(nodedb_raft::AppendEntriesRequest {
628 term: 99, leader_id: 2,
630 prev_log_index: 0,
631 prev_log_term: 0,
632 entries: vec![],
633 leader_commit: 0,
634 group_id: 0,
635 });
636
637 let resp = raft_loop.handle_rpc(req).await.unwrap();
638 match resp {
639 RaftRpc::AppendEntriesResponse(r) => {
640 assert!(r.success);
641 assert_eq!(r.term, 99);
642 }
643 other => panic!("expected AppendEntriesResponse, got {other:?}"),
644 }
645 }
646
647 #[tokio::test]
648 async fn rpc_handler_routes_request_vote() {
649 let dir = tempfile::tempdir().unwrap();
650 let transport = make_transport(1);
651 let rt = RoutingTable::uniform(1, &[1, 2, 3], 3);
652 let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
653 mr.add_group(0, vec![2, 3]).unwrap();
654
655 let topo = Arc::new(RwLock::new(ClusterTopology::new()));
656 let raft_loop = RaftLoop::new(mr, transport, topo, CountingApplier::new());
657
658 let req = RaftRpc::RequestVoteRequest(nodedb_raft::RequestVoteRequest {
659 term: 1,
660 candidate_id: 2,
661 last_log_index: 0,
662 last_log_term: 0,
663 group_id: 0,
664 });
665
666 let resp = raft_loop.handle_rpc(req).await.unwrap();
667 match resp {
668 RaftRpc::RequestVoteResponse(r) => {
669 assert!(r.vote_granted);
670 assert_eq!(r.term, 1);
671 }
672 other => panic!("expected RequestVoteResponse, got {other:?}"),
673 }
674 }
675}