Skip to main content

nodedb_cluster/
multi_raft.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::time::Duration;
4
5use tracing::{debug, info};
6
7use nodedb_raft::node::RaftConfig;
8use nodedb_raft::{
9    AppendEntriesRequest, AppendEntriesResponse, RaftNode, Ready, RequestVoteRequest,
10    RequestVoteResponse,
11};
12
13use crate::error::{ClusterError, Result};
14use crate::raft_storage::RedbLogStorage;
15use crate::routing::RoutingTable;
16
17/// Snapshot of a single Raft group's state for observability.
18#[derive(Debug, Clone)]
19pub struct GroupStatus {
20    pub group_id: u64,
21    /// Role as a human-readable string ("Leader", "Follower", "Candidate", "Learner").
22    pub role: String,
23    pub leader_id: u64,
24    pub term: u64,
25    pub commit_index: u64,
26    pub last_applied: u64,
27    pub member_count: usize,
28    pub vshard_count: usize,
29}
30
31/// Multi-Raft coordinator managing multiple Raft groups on a single node.
32///
33/// This coordinator:
34/// - Manages all Raft groups hosted on this node
35/// - Batches heartbeats across groups sharing the same leader
36/// - Routes incoming RPCs to the correct group
37/// - Collects `Ready` output from all groups for the caller to execute
38pub struct MultiRaft {
39    /// This node's ID.
40    node_id: u64,
41    /// Raft groups hosted on this node (group_id → RaftNode).
42    groups: HashMap<u64, RaftNode<RedbLogStorage>>,
43    /// Routing table (vShard → group mapping).
44    routing: RoutingTable,
45    /// Default election timeout range.
46    election_timeout_min: Duration,
47    election_timeout_max: Duration,
48    /// Heartbeat interval.
49    heartbeat_interval: Duration,
50    /// Data directory for persistent Raft log storage.
51    data_dir: PathBuf,
52}
53
54/// Aggregated output from all Raft groups after a tick.
55#[derive(Debug, Default)]
56pub struct MultiRaftReady {
57    /// Per-group ready output: (group_id, Ready).
58    pub groups: Vec<(u64, Ready)>,
59}
60
61impl MultiRaftReady {
62    pub fn is_empty(&self) -> bool {
63        self.groups.iter().all(|(_gid, r)| r.is_empty())
64    }
65
66    /// Total committed entries across all groups.
67    pub fn total_committed(&self) -> usize {
68        self.groups
69            .iter()
70            .map(|(_, r)| r.committed_entries.len())
71            .sum()
72    }
73}
74
75impl MultiRaft {
76    pub fn new(node_id: u64, routing: RoutingTable, data_dir: PathBuf) -> Self {
77        Self {
78            node_id,
79            groups: HashMap::new(),
80            routing,
81            election_timeout_min: Duration::from_millis(150),
82            election_timeout_max: Duration::from_millis(300),
83            heartbeat_interval: Duration::from_millis(50),
84            data_dir,
85        }
86    }
87
88    /// Configure election timeout range.
89    pub fn with_election_timeout(mut self, min: Duration, max: Duration) -> Self {
90        self.election_timeout_min = min;
91        self.election_timeout_max = max;
92        self
93    }
94
95    /// Configure heartbeat interval.
96    pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
97        self.heartbeat_interval = interval;
98        self
99    }
100
101    /// Initialize a Raft group on this node.
102    pub fn add_group(&mut self, group_id: u64, peers: Vec<u64>) -> Result<()> {
103        let config = RaftConfig {
104            node_id: self.node_id,
105            group_id,
106            peers,
107            election_timeout_min: self.election_timeout_min,
108            election_timeout_max: self.election_timeout_max,
109            heartbeat_interval: self.heartbeat_interval,
110        };
111
112        let storage_path = self.data_dir.join(format!("raft/group-{group_id}.redb"));
113        let storage = RedbLogStorage::open(&storage_path).map_err(|e| ClusterError::Transport {
114            detail: format!("failed to open raft storage for group {group_id}: {e}"),
115        })?;
116        let node = RaftNode::new(config, storage);
117        self.groups.insert(group_id, node);
118
119        info!(node = self.node_id, group = group_id, path = %storage_path.display(), "added raft group with persistent storage");
120        Ok(())
121    }
122
123    /// Tick all Raft groups. Returns aggregated ready output.
124    pub fn tick(&mut self) -> MultiRaftReady {
125        let mut ready = MultiRaftReady::default();
126
127        for (&group_id, node) in &mut self.groups {
128            node.tick();
129            let r = node.take_ready();
130            if !r.is_empty() {
131                ready.groups.push((group_id, r));
132            }
133        }
134
135        ready
136    }
137
138    /// Propose a command to the Raft group that owns the given vShard.
139    ///
140    /// Returns `(group_id, log_index)` on success.
141    pub fn propose(&mut self, vshard_id: u16, data: Vec<u8>) -> Result<(u64, u64)> {
142        let group_id = self.routing.group_for_vshard(vshard_id)?;
143        let node = self
144            .groups
145            .get_mut(&group_id)
146            .ok_or(ClusterError::GroupNotFound { group_id })?;
147        let log_index = node.propose(data)?;
148        Ok((group_id, log_index))
149    }
150
151    /// Route an AppendEntries RPC to the correct group.
152    pub fn handle_append_entries(
153        &mut self,
154        req: &AppendEntriesRequest,
155    ) -> Result<AppendEntriesResponse> {
156        let node = self
157            .groups
158            .get_mut(&req.group_id)
159            .ok_or(ClusterError::GroupNotFound {
160                group_id: req.group_id,
161            })?;
162        Ok(node.handle_append_entries(req))
163    }
164
165    /// Route a RequestVote RPC to the correct group.
166    pub fn handle_request_vote(&mut self, req: &RequestVoteRequest) -> Result<RequestVoteResponse> {
167        let node = self
168            .groups
169            .get_mut(&req.group_id)
170            .ok_or(ClusterError::GroupNotFound {
171                group_id: req.group_id,
172            })?;
173        Ok(node.handle_request_vote(req))
174    }
175
176    /// Route an InstallSnapshot RPC to the correct group.
177    pub fn handle_install_snapshot(
178        &mut self,
179        req: &nodedb_raft::InstallSnapshotRequest,
180    ) -> Result<nodedb_raft::InstallSnapshotResponse> {
181        let node = self
182            .groups
183            .get_mut(&req.group_id)
184            .ok_or(ClusterError::GroupNotFound {
185                group_id: req.group_id,
186            })?;
187        Ok(node.handle_install_snapshot(req))
188    }
189
190    /// Get the current term and snapshot metadata for a group (for building InstallSnapshot RPCs).
191    pub fn snapshot_metadata(&self, group_id: u64) -> Result<(u64, u64, u64)> {
192        let node = self
193            .groups
194            .get(&group_id)
195            .ok_or(ClusterError::GroupNotFound { group_id })?;
196        Ok((
197            node.current_term(),
198            node.log_snapshot_index(),
199            node.log_snapshot_term(),
200        ))
201    }
202
203    /// Handle AppendEntries response for a specific group.
204    pub fn handle_append_entries_response(
205        &mut self,
206        group_id: u64,
207        peer: u64,
208        resp: &AppendEntriesResponse,
209    ) -> Result<()> {
210        let node = self
211            .groups
212            .get_mut(&group_id)
213            .ok_or(ClusterError::GroupNotFound { group_id })?;
214        node.handle_append_entries_response(peer, resp);
215        Ok(())
216    }
217
218    /// Handle RequestVote response for a specific group.
219    pub fn handle_request_vote_response(
220        &mut self,
221        group_id: u64,
222        peer: u64,
223        resp: &RequestVoteResponse,
224    ) -> Result<()> {
225        let node = self
226            .groups
227            .get_mut(&group_id)
228            .ok_or(ClusterError::GroupNotFound { group_id })?;
229        node.handle_request_vote_response(peer, resp);
230        Ok(())
231    }
232
233    /// Advance applied index for a group after processing committed entries.
234    pub fn advance_applied(&mut self, group_id: u64, applied_to: u64) -> Result<()> {
235        let node = self
236            .groups
237            .get_mut(&group_id)
238            .ok_or(ClusterError::GroupNotFound { group_id })?;
239        node.advance_applied(applied_to);
240        Ok(())
241    }
242
243    pub fn routing(&self) -> &RoutingTable {
244        &self.routing
245    }
246
247    pub fn routing_mut(&mut self) -> &mut RoutingTable {
248        &mut self.routing
249    }
250
251    pub fn node_id(&self) -> u64 {
252        self.node_id
253    }
254
255    pub fn group_count(&self) -> usize {
256        self.groups.len()
257    }
258
259    /// Mutable access to the underlying Raft groups (for testing / bootstrap).
260    pub fn groups_mut(&mut self) -> &mut HashMap<u64, RaftNode<RedbLogStorage>> {
261        &mut self.groups
262    }
263
264    /// Propose a configuration change to a Raft group.
265    ///
266    /// The change is proposed as a regular Raft log entry with a special
267    /// prefix. When committed, the state machine applies it via
268    /// [`reconfigure_group`].
269    ///
270    /// Returns `(group_id, log_index)` on success.
271    pub fn propose_conf_change(
272        &mut self,
273        group_id: u64,
274        change: &crate::conf_change::ConfChange,
275    ) -> Result<(u64, u64)> {
276        let node = self
277            .groups
278            .get_mut(&group_id)
279            .ok_or(ClusterError::GroupNotFound { group_id })?;
280        let data = change.to_entry_data();
281        let log_index = node.propose(data)?;
282        Ok((group_id, log_index))
283    }
284
285    /// Apply a committed configuration change to a Raft group.
286    ///
287    /// Called by the state machine after a ConfChange entry is committed.
288    /// Also updates the RoutingTable's group membership.
289    pub fn apply_conf_change(
290        &mut self,
291        group_id: u64,
292        change: &crate::conf_change::ConfChange,
293    ) -> Result<()> {
294        use crate::conf_change::ConfChangeType;
295
296        let node = self
297            .groups
298            .get_mut(&group_id)
299            .ok_or(ClusterError::GroupNotFound { group_id })?;
300
301        match change.change_type {
302            ConfChangeType::AddNode | ConfChangeType::PromoteLearner => {
303                node.add_peer(change.node_id);
304                // Update routing table members.
305                if let Some(info) = self.routing.group_info(group_id)
306                    && !info.members.contains(&change.node_id)
307                {
308                    let mut new_members = info.members.clone();
309                    new_members.push(change.node_id);
310                    self.routing.set_group_members(group_id, new_members);
311                }
312            }
313            ConfChangeType::RemoveNode => {
314                node.remove_peer(change.node_id);
315                if let Some(info) = self.routing.group_info(group_id) {
316                    let new_members: Vec<u64> = info
317                        .members
318                        .iter()
319                        .copied()
320                        .filter(|&id| id != change.node_id)
321                        .collect();
322                    self.routing.set_group_members(group_id, new_members);
323                }
324            }
325            ConfChangeType::AddLearner => {
326                // Learners don't vote — just start replicating to them.
327                // The RaftNode doesn't need to know about learners for voting,
328                // but the leader needs to send AppendEntries to them.
329                // For now, add as a regular peer (simplified — full learner
330                // support would track them separately).
331                node.add_peer(change.node_id);
332            }
333        }
334
335        debug!(
336            node = self.node_id,
337            group = group_id,
338            change_type = ?change.change_type,
339            target_node = change.node_id,
340            new_peers = ?self.groups.get(&group_id).map(|n| n.peers()),
341            "applied conf change"
342        );
343
344        Ok(())
345    }
346
347    /// Query a peer's match_index from a specific Raft group's leader state.
348    pub fn match_index_for(&self, group_id: u64, peer: u64) -> Option<u64> {
349        self.groups.get(&group_id)?.match_index_for(peer)
350    }
351
352    /// Snapshot of all Raft group states for observability.
353    pub fn group_statuses(&self) -> Vec<GroupStatus> {
354        let mut statuses = Vec::with_capacity(self.groups.len());
355        for (&group_id, node) in &self.groups {
356            let vshard_count = self.routing.vshards_for_group(group_id).len();
357            let members = self
358                .routing
359                .group_info(group_id)
360                .map(|info| info.members.clone())
361                .unwrap_or_default();
362
363            statuses.push(GroupStatus {
364                group_id,
365                role: format!("{:?}", node.role()),
366                leader_id: node.leader_id(),
367                term: node.current_term(),
368                commit_index: node.commit_index(),
369                last_applied: node.last_applied(),
370                member_count: members.len(),
371                vshard_count,
372            });
373        }
374        statuses.sort_by_key(|s| s.group_id);
375        statuses
376    }
377
378    /// Get the leader for a given vShard (from local group state).
379    pub fn leader_for_vshard(&self, vshard_id: u16) -> Result<Option<u64>> {
380        let group_id = self.routing.group_for_vshard(vshard_id)?;
381        let node = self
382            .groups
383            .get(&group_id)
384            .ok_or(ClusterError::GroupNotFound { group_id })?;
385        let lid = node.leader_id();
386        Ok(if lid == 0 { None } else { Some(lid) })
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use std::time::Instant;
394
395    #[test]
396    fn single_node_multi_raft() {
397        let dir = tempfile::tempdir().unwrap();
398        let rt = RoutingTable::uniform(4, &[1], 1);
399        let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
400
401        // Add 4 groups, each with no peers (single-node).
402        for gid in 0..4 {
403            mr.add_group(gid, vec![]).unwrap();
404        }
405        assert_eq!(mr.group_count(), 4);
406
407        // Force election timeout on all groups.
408        // We need to access groups directly for this test.
409        for node in mr.groups.values_mut() {
410            node.election_deadline_override(Instant::now() - Duration::from_millis(1));
411        }
412
413        let ready = mr.tick();
414        // All 4 groups should have become leaders.
415        assert_eq!(ready.groups.len(), 4);
416    }
417
418    #[test]
419    fn propose_routes_to_correct_group() {
420        let dir = tempfile::tempdir().unwrap();
421        let rt = RoutingTable::uniform(4, &[1], 1);
422        let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
423
424        for gid in 0..4 {
425            mr.add_group(gid, vec![]).unwrap();
426        }
427        for node in mr.groups.values_mut() {
428            node.election_deadline_override(Instant::now() - Duration::from_millis(1));
429        }
430        mr.tick();
431        // Drain initial ready.
432        for (gid, ready) in mr.tick().groups {
433            if let Some(last) = ready.committed_entries.last() {
434                mr.advance_applied(gid, last.index).unwrap();
435            }
436        }
437
438        // vShard 0 maps to group 0, vShard 1 to group 1, etc.
439        let (_gid, idx) = mr.propose(0, b"cmd-shard-0".to_vec()).unwrap();
440        assert!(idx > 0);
441
442        let (_gid, idx) = mr.propose(256, b"cmd-shard-256".to_vec()).unwrap();
443        assert!(idx > 0);
444    }
445
446    #[test]
447    fn three_node_multi_raft_election() {
448        let nodes = vec![1, 2, 3];
449        let rt = RoutingTable::uniform(2, &nodes, 3);
450
451        // Create MultiRaft for each node.
452        let dir1 = tempfile::tempdir().unwrap();
453        let dir2 = tempfile::tempdir().unwrap();
454        let dir3 = tempfile::tempdir().unwrap();
455        let mut mr1 = MultiRaft::new(1, rt.clone(), dir1.path().to_path_buf());
456        let mut mr2 = MultiRaft::new(2, rt.clone(), dir2.path().to_path_buf());
457        let mut mr3 = MultiRaft::new(3, rt.clone(), dir3.path().to_path_buf());
458
459        // Add groups to each node.
460        for gid in 0..2u64 {
461            mr1.add_group(gid, vec![2, 3]).unwrap();
462            mr2.add_group(gid, vec![1, 3]).unwrap();
463            mr3.add_group(gid, vec![1, 2]).unwrap();
464        }
465
466        // Force node 1 to start elections.
467        for node in mr1.groups.values_mut() {
468            node.election_deadline_override(Instant::now() - Duration::from_millis(1));
469        }
470
471        let ready1 = mr1.tick();
472
473        // Process vote requests on nodes 2 and 3.
474        for (group_id, ready) in &ready1.groups {
475            for (peer_id, vote_req) in &ready.vote_requests {
476                if *peer_id == 2 {
477                    let resp = mr2.handle_request_vote(vote_req).unwrap();
478                    mr1.handle_request_vote_response(*group_id, 2, &resp)
479                        .unwrap();
480                } else if *peer_id == 3 {
481                    let resp = mr3.handle_request_vote(vote_req).unwrap();
482                    mr1.handle_request_vote_response(*group_id, 3, &resp)
483                        .unwrap();
484                }
485            }
486        }
487
488        // Node 1 should be leader for both groups.
489        for gid in 0..2u64 {
490            let leader = mr1.leader_for_vshard(gid as u16 * 512).unwrap();
491            assert_eq!(leader, Some(1));
492        }
493    }
494}