1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::time::Duration;
4
5use tracing::{debug, info};
6
7use nodedb_raft::node::RaftConfig;
8use nodedb_raft::{
9 AppendEntriesRequest, AppendEntriesResponse, RaftNode, Ready, RequestVoteRequest,
10 RequestVoteResponse,
11};
12
13use crate::error::{ClusterError, Result};
14use crate::raft_storage::RedbLogStorage;
15use crate::routing::RoutingTable;
16
17#[derive(Debug, Clone)]
19pub struct GroupStatus {
20 pub group_id: u64,
21 pub role: String,
23 pub leader_id: u64,
24 pub term: u64,
25 pub commit_index: u64,
26 pub last_applied: u64,
27 pub member_count: usize,
28 pub vshard_count: usize,
29}
30
31pub struct MultiRaft {
39 node_id: u64,
41 groups: HashMap<u64, RaftNode<RedbLogStorage>>,
43 routing: RoutingTable,
45 election_timeout_min: Duration,
47 election_timeout_max: Duration,
48 heartbeat_interval: Duration,
50 data_dir: PathBuf,
52}
53
54#[derive(Debug, Default)]
56pub struct MultiRaftReady {
57 pub groups: Vec<(u64, Ready)>,
59}
60
61impl MultiRaftReady {
62 pub fn is_empty(&self) -> bool {
63 self.groups.iter().all(|(_gid, r)| r.is_empty())
64 }
65
66 pub fn total_committed(&self) -> usize {
68 self.groups
69 .iter()
70 .map(|(_, r)| r.committed_entries.len())
71 .sum()
72 }
73}
74
75impl MultiRaft {
76 pub fn new(node_id: u64, routing: RoutingTable, data_dir: PathBuf) -> Self {
77 Self {
78 node_id,
79 groups: HashMap::new(),
80 routing,
81 election_timeout_min: Duration::from_millis(150),
82 election_timeout_max: Duration::from_millis(300),
83 heartbeat_interval: Duration::from_millis(50),
84 data_dir,
85 }
86 }
87
88 pub fn with_election_timeout(mut self, min: Duration, max: Duration) -> Self {
90 self.election_timeout_min = min;
91 self.election_timeout_max = max;
92 self
93 }
94
95 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
97 self.heartbeat_interval = interval;
98 self
99 }
100
101 pub fn add_group(&mut self, group_id: u64, peers: Vec<u64>) -> Result<()> {
103 let config = RaftConfig {
104 node_id: self.node_id,
105 group_id,
106 peers,
107 election_timeout_min: self.election_timeout_min,
108 election_timeout_max: self.election_timeout_max,
109 heartbeat_interval: self.heartbeat_interval,
110 };
111
112 let storage_path = self.data_dir.join(format!("raft/group-{group_id}.redb"));
113 let storage = RedbLogStorage::open(&storage_path).map_err(|e| ClusterError::Transport {
114 detail: format!("failed to open raft storage for group {group_id}: {e}"),
115 })?;
116 let node = RaftNode::new(config, storage);
117 self.groups.insert(group_id, node);
118
119 info!(node = self.node_id, group = group_id, path = %storage_path.display(), "added raft group with persistent storage");
120 Ok(())
121 }
122
123 pub fn tick(&mut self) -> MultiRaftReady {
125 let mut ready = MultiRaftReady::default();
126
127 for (&group_id, node) in &mut self.groups {
128 node.tick();
129 let r = node.take_ready();
130 if !r.is_empty() {
131 ready.groups.push((group_id, r));
132 }
133 }
134
135 ready
136 }
137
138 pub fn propose(&mut self, vshard_id: u16, data: Vec<u8>) -> Result<(u64, u64)> {
142 let group_id = self.routing.group_for_vshard(vshard_id)?;
143 let node = self
144 .groups
145 .get_mut(&group_id)
146 .ok_or(ClusterError::GroupNotFound { group_id })?;
147 let log_index = node.propose(data)?;
148 Ok((group_id, log_index))
149 }
150
151 pub fn handle_append_entries(
153 &mut self,
154 req: &AppendEntriesRequest,
155 ) -> Result<AppendEntriesResponse> {
156 let node = self
157 .groups
158 .get_mut(&req.group_id)
159 .ok_or(ClusterError::GroupNotFound {
160 group_id: req.group_id,
161 })?;
162 Ok(node.handle_append_entries(req))
163 }
164
165 pub fn handle_request_vote(&mut self, req: &RequestVoteRequest) -> Result<RequestVoteResponse> {
167 let node = self
168 .groups
169 .get_mut(&req.group_id)
170 .ok_or(ClusterError::GroupNotFound {
171 group_id: req.group_id,
172 })?;
173 Ok(node.handle_request_vote(req))
174 }
175
176 pub fn handle_install_snapshot(
178 &mut self,
179 req: &nodedb_raft::InstallSnapshotRequest,
180 ) -> Result<nodedb_raft::InstallSnapshotResponse> {
181 let node = self
182 .groups
183 .get_mut(&req.group_id)
184 .ok_or(ClusterError::GroupNotFound {
185 group_id: req.group_id,
186 })?;
187 Ok(node.handle_install_snapshot(req))
188 }
189
190 pub fn snapshot_metadata(&self, group_id: u64) -> Result<(u64, u64, u64)> {
192 let node = self
193 .groups
194 .get(&group_id)
195 .ok_or(ClusterError::GroupNotFound { group_id })?;
196 Ok((
197 node.current_term(),
198 node.log_snapshot_index(),
199 node.log_snapshot_term(),
200 ))
201 }
202
203 pub fn handle_append_entries_response(
205 &mut self,
206 group_id: u64,
207 peer: u64,
208 resp: &AppendEntriesResponse,
209 ) -> Result<()> {
210 let node = self
211 .groups
212 .get_mut(&group_id)
213 .ok_or(ClusterError::GroupNotFound { group_id })?;
214 node.handle_append_entries_response(peer, resp);
215 Ok(())
216 }
217
218 pub fn handle_request_vote_response(
220 &mut self,
221 group_id: u64,
222 peer: u64,
223 resp: &RequestVoteResponse,
224 ) -> Result<()> {
225 let node = self
226 .groups
227 .get_mut(&group_id)
228 .ok_or(ClusterError::GroupNotFound { group_id })?;
229 node.handle_request_vote_response(peer, resp);
230 Ok(())
231 }
232
233 pub fn advance_applied(&mut self, group_id: u64, applied_to: u64) -> Result<()> {
235 let node = self
236 .groups
237 .get_mut(&group_id)
238 .ok_or(ClusterError::GroupNotFound { group_id })?;
239 node.advance_applied(applied_to);
240 Ok(())
241 }
242
243 pub fn routing(&self) -> &RoutingTable {
244 &self.routing
245 }
246
247 pub fn routing_mut(&mut self) -> &mut RoutingTable {
248 &mut self.routing
249 }
250
251 pub fn node_id(&self) -> u64 {
252 self.node_id
253 }
254
255 pub fn group_count(&self) -> usize {
256 self.groups.len()
257 }
258
259 pub fn groups_mut(&mut self) -> &mut HashMap<u64, RaftNode<RedbLogStorage>> {
261 &mut self.groups
262 }
263
264 pub fn propose_conf_change(
272 &mut self,
273 group_id: u64,
274 change: &crate::conf_change::ConfChange,
275 ) -> Result<(u64, u64)> {
276 let node = self
277 .groups
278 .get_mut(&group_id)
279 .ok_or(ClusterError::GroupNotFound { group_id })?;
280 let data = change.to_entry_data();
281 let log_index = node.propose(data)?;
282 Ok((group_id, log_index))
283 }
284
285 pub fn apply_conf_change(
290 &mut self,
291 group_id: u64,
292 change: &crate::conf_change::ConfChange,
293 ) -> Result<()> {
294 use crate::conf_change::ConfChangeType;
295
296 let node = self
297 .groups
298 .get_mut(&group_id)
299 .ok_or(ClusterError::GroupNotFound { group_id })?;
300
301 match change.change_type {
302 ConfChangeType::AddNode | ConfChangeType::PromoteLearner => {
303 node.add_peer(change.node_id);
304 if let Some(info) = self.routing.group_info(group_id)
306 && !info.members.contains(&change.node_id)
307 {
308 let mut new_members = info.members.clone();
309 new_members.push(change.node_id);
310 self.routing.set_group_members(group_id, new_members);
311 }
312 }
313 ConfChangeType::RemoveNode => {
314 node.remove_peer(change.node_id);
315 if let Some(info) = self.routing.group_info(group_id) {
316 let new_members: Vec<u64> = info
317 .members
318 .iter()
319 .copied()
320 .filter(|&id| id != change.node_id)
321 .collect();
322 self.routing.set_group_members(group_id, new_members);
323 }
324 }
325 ConfChangeType::AddLearner => {
326 node.add_peer(change.node_id);
332 }
333 }
334
335 debug!(
336 node = self.node_id,
337 group = group_id,
338 change_type = ?change.change_type,
339 target_node = change.node_id,
340 new_peers = ?self.groups.get(&group_id).map(|n| n.peers()),
341 "applied conf change"
342 );
343
344 Ok(())
345 }
346
347 pub fn match_index_for(&self, group_id: u64, peer: u64) -> Option<u64> {
349 self.groups.get(&group_id)?.match_index_for(peer)
350 }
351
352 pub fn group_statuses(&self) -> Vec<GroupStatus> {
354 let mut statuses = Vec::with_capacity(self.groups.len());
355 for (&group_id, node) in &self.groups {
356 let vshard_count = self.routing.vshards_for_group(group_id).len();
357 let members = self
358 .routing
359 .group_info(group_id)
360 .map(|info| info.members.clone())
361 .unwrap_or_default();
362
363 statuses.push(GroupStatus {
364 group_id,
365 role: format!("{:?}", node.role()),
366 leader_id: node.leader_id(),
367 term: node.current_term(),
368 commit_index: node.commit_index(),
369 last_applied: node.last_applied(),
370 member_count: members.len(),
371 vshard_count,
372 });
373 }
374 statuses.sort_by_key(|s| s.group_id);
375 statuses
376 }
377
378 pub fn leader_for_vshard(&self, vshard_id: u16) -> Result<Option<u64>> {
380 let group_id = self.routing.group_for_vshard(vshard_id)?;
381 let node = self
382 .groups
383 .get(&group_id)
384 .ok_or(ClusterError::GroupNotFound { group_id })?;
385 let lid = node.leader_id();
386 Ok(if lid == 0 { None } else { Some(lid) })
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use std::time::Instant;
394
395 #[test]
396 fn single_node_multi_raft() {
397 let dir = tempfile::tempdir().unwrap();
398 let rt = RoutingTable::uniform(4, &[1], 1);
399 let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
400
401 for gid in 0..4 {
403 mr.add_group(gid, vec![]).unwrap();
404 }
405 assert_eq!(mr.group_count(), 4);
406
407 for node in mr.groups.values_mut() {
410 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
411 }
412
413 let ready = mr.tick();
414 assert_eq!(ready.groups.len(), 4);
416 }
417
418 #[test]
419 fn propose_routes_to_correct_group() {
420 let dir = tempfile::tempdir().unwrap();
421 let rt = RoutingTable::uniform(4, &[1], 1);
422 let mut mr = MultiRaft::new(1, rt, dir.path().to_path_buf());
423
424 for gid in 0..4 {
425 mr.add_group(gid, vec![]).unwrap();
426 }
427 for node in mr.groups.values_mut() {
428 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
429 }
430 mr.tick();
431 for (gid, ready) in mr.tick().groups {
433 if let Some(last) = ready.committed_entries.last() {
434 mr.advance_applied(gid, last.index).unwrap();
435 }
436 }
437
438 let (_gid, idx) = mr.propose(0, b"cmd-shard-0".to_vec()).unwrap();
440 assert!(idx > 0);
441
442 let (_gid, idx) = mr.propose(256, b"cmd-shard-256".to_vec()).unwrap();
443 assert!(idx > 0);
444 }
445
446 #[test]
447 fn three_node_multi_raft_election() {
448 let nodes = vec![1, 2, 3];
449 let rt = RoutingTable::uniform(2, &nodes, 3);
450
451 let dir1 = tempfile::tempdir().unwrap();
453 let dir2 = tempfile::tempdir().unwrap();
454 let dir3 = tempfile::tempdir().unwrap();
455 let mut mr1 = MultiRaft::new(1, rt.clone(), dir1.path().to_path_buf());
456 let mut mr2 = MultiRaft::new(2, rt.clone(), dir2.path().to_path_buf());
457 let mut mr3 = MultiRaft::new(3, rt.clone(), dir3.path().to_path_buf());
458
459 for gid in 0..2u64 {
461 mr1.add_group(gid, vec![2, 3]).unwrap();
462 mr2.add_group(gid, vec![1, 3]).unwrap();
463 mr3.add_group(gid, vec![1, 2]).unwrap();
464 }
465
466 for node in mr1.groups.values_mut() {
468 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
469 }
470
471 let ready1 = mr1.tick();
472
473 for (group_id, ready) in &ready1.groups {
475 for (peer_id, vote_req) in &ready.vote_requests {
476 if *peer_id == 2 {
477 let resp = mr2.handle_request_vote(vote_req).unwrap();
478 mr1.handle_request_vote_response(*group_id, 2, &resp)
479 .unwrap();
480 } else if *peer_id == 3 {
481 let resp = mr3.handle_request_vote(vote_req).unwrap();
482 mr1.handle_request_vote_response(*group_id, 3, &resp)
483 .unwrap();
484 }
485 }
486 }
487
488 for gid in 0..2u64 {
490 let leader = mr1.leader_for_vshard(gid as u16 * 512).unwrap();
491 assert_eq!(leader, Some(1));
492 }
493 }
494}