1use std::collections::HashSet;
4use std::path::PathBuf;
5use std::time::Duration;
6
7pub type NodeId = u64;
9
10pub type Term = u64;
12
13pub type LogIndex = u64;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MembershipChange {
19 AddNode {
21 node_id: NodeId,
23 address: String,
25 },
26 RemoveNode {
28 node_id: NodeId,
30 },
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct ClusterConfig {
36 members: Vec<(NodeId, String)>,
38 version: u64,
40}
41
42impl ClusterConfig {
43 pub fn new(members: Vec<(NodeId, String)>, version: u64) -> Self {
45 Self { members, version }
46 }
47
48 pub fn member_ids(&self) -> HashSet<NodeId> {
50 self.members.iter().map(|(id, _)| *id).collect()
51 }
52
53 pub fn members(&self) -> &[(NodeId, String)] {
55 &self.members
56 }
57
58 pub fn version(&self) -> u64 {
60 self.version
61 }
62
63 pub fn contains(&self, node_id: NodeId) -> bool {
65 self.members.iter().any(|(id, _)| *id == node_id)
66 }
67
68 pub fn quorum_size(&self) -> usize {
70 self.members.len() / 2 + 1
71 }
72
73 pub fn len(&self) -> usize {
75 self.members.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.members.is_empty()
81 }
82
83 pub fn with_added_member(&self, node_id: NodeId, address: String) -> Self {
85 let mut members = self.members.clone();
86 if !self.contains(node_id) {
87 members.push((node_id, address));
88 }
89 Self {
90 members,
91 version: self.version + 1,
92 }
93 }
94
95 pub fn without_member(&self, node_id: NodeId) -> Self {
97 let members: Vec<_> = self
98 .members
99 .iter()
100 .filter(|(id, _)| *id != node_id)
101 .cloned()
102 .collect();
103 Self {
104 members,
105 version: self.version + 1,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum ConfigState {
117 Stable(ClusterConfig),
119 Joint {
121 old: ClusterConfig,
123 new: ClusterConfig,
125 },
126}
127
128impl ConfigState {
129 pub fn new_stable(members: Vec<(NodeId, String)>) -> Self {
131 ConfigState::Stable(ClusterConfig::new(members, 0))
132 }
133
134 pub fn all_member_ids(&self) -> HashSet<NodeId> {
136 match self {
137 ConfigState::Stable(config) => config.member_ids(),
138 ConfigState::Joint { old, new } => {
139 let mut ids = old.member_ids();
140 ids.extend(new.member_ids());
141 ids
142 }
143 }
144 }
145
146 pub fn is_joint(&self) -> bool {
148 matches!(self, ConfigState::Joint { .. })
149 }
150
151 pub fn version(&self) -> u64 {
153 match self {
154 ConfigState::Stable(config) => config.version(),
155 ConfigState::Joint { old, new } => old.version().max(new.version()),
156 }
157 }
158
159 pub fn has_quorum(&self, responding_nodes: &HashSet<NodeId>) -> bool {
164 match self {
165 ConfigState::Stable(config) => {
166 let count = config.member_ids().intersection(responding_nodes).count();
167 count >= config.quorum_size()
168 }
169 ConfigState::Joint { old, new } => {
170 let old_count = old.member_ids().intersection(responding_nodes).count();
171 let new_count = new.member_ids().intersection(responding_nodes).count();
172 old_count >= old.quorum_size() && new_count >= new.quorum_size()
173 }
174 }
175 }
176
177 pub fn stable_config(&self) -> Option<&ClusterConfig> {
179 match self {
180 ConfigState::Stable(config) => Some(config),
181 ConfigState::Joint { .. } => None,
182 }
183 }
184
185 pub fn all_members(&self) -> Vec<(NodeId, String)> {
187 match self {
188 ConfigState::Stable(config) => config.members().to_vec(),
189 ConfigState::Joint { old, new } => {
190 let mut seen = HashSet::new();
191 let mut result = Vec::new();
192 for (id, addr) in old.members().iter().chain(new.members().iter()) {
193 if seen.insert(*id) {
194 result.push((*id, addr.clone()));
195 }
196 }
197 result
198 }
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum NodeState {
206 Follower,
208 Candidate,
210 Leader,
212}
213
214impl NodeState {
215 pub fn as_str(&self) -> &'static str {
217 match self {
218 NodeState::Follower => "Follower",
219 NodeState::Candidate => "Candidate",
220 NodeState::Leader => "Leader",
221 }
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct RaftConfig {
228 pub node_id: NodeId,
230 pub peers: Vec<NodeId>,
232 pub election_timeout_range: (u64, u64),
234 pub heartbeat_interval: u64,
236 pub max_entries_per_message: usize,
238 pub enable_compaction: bool,
240 pub snapshot_threshold: u64,
242 pub max_snapshots: usize,
244 pub snapshot_dir: Option<PathBuf>,
246 pub persistence_dir: Option<PathBuf>,
248 pub wal_dir: Option<PathBuf>,
250 pub sync_on_write: bool,
252 pub snapshot_chunk_threshold_bytes: u64,
254 pub snapshot_chunk_size_bytes: usize,
256}
257
258impl RaftConfig {
259 pub fn new(node_id: NodeId, peers: Vec<NodeId>) -> Self {
261 Self {
262 node_id,
263 peers,
264 election_timeout_range: (150, 300),
265 heartbeat_interval: 50,
266 max_entries_per_message: 100,
267 enable_compaction: true,
268 snapshot_threshold: 10000,
269 max_snapshots: 3,
270 snapshot_dir: None,
271 persistence_dir: None,
272 wal_dir: None,
273 sync_on_write: true,
274 snapshot_chunk_threshold_bytes: 4 * 1024 * 1024,
275 snapshot_chunk_size_bytes: 1024 * 1024,
276 }
277 }
278
279 pub fn random_election_timeout(&self) -> Duration {
281 use std::collections::hash_map::RandomState;
282 use std::hash::BuildHasher;
283
284 let (min, max) = self.election_timeout_range;
285 let range = max - min;
286
287 let now = std::time::SystemTime::now()
289 .duration_since(std::time::UNIX_EPOCH)
290 .map(|d| d.as_nanos())
291 .unwrap_or(0);
292
293 let random_value = RandomState::new().hash_one(now);
294
295 let timeout_ms = min + (random_value % range);
296 Duration::from_millis(timeout_ms)
297 }
298
299 pub fn heartbeat_interval(&self) -> Duration {
301 Duration::from_millis(self.heartbeat_interval)
302 }
303
304 pub fn validate(&self) -> Result<(), String> {
306 if !self.peers.contains(&self.node_id) {
308 return Err(format!("Node ID {} not found in peers list", self.node_id));
309 }
310
311 if self.peers.len() % 2 == 0 {
313 return Err(format!(
314 "Raft requires odd number of nodes, got {}",
315 self.peers.len()
316 ));
317 }
318
319 if self.peers.len() < 3 {
321 return Err(format!(
322 "Raft requires at least 3 nodes for fault tolerance, got {}",
323 self.peers.len()
324 ));
325 }
326
327 let (min, max) = self.election_timeout_range;
329 if min >= max {
330 return Err(format!(
331 "Election timeout min ({}) must be less than max ({})",
332 min, max
333 ));
334 }
335
336 if self.heartbeat_interval >= min {
338 return Err(format!(
339 "Heartbeat interval ({}) must be less than election timeout min ({})",
340 self.heartbeat_interval, min
341 ));
342 }
343
344 Ok(())
345 }
346
347 pub fn quorum_size(&self) -> usize {
349 self.peers.len() / 2 + 1
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct HeartbeatConfig {
356 pub interval_ms: u64,
358 pub timeout_ms: u64,
360 pub max_missed: u32,
362}
363
364impl HeartbeatConfig {
365 pub fn new(interval_ms: u64, timeout_ms: u64, max_missed: u32) -> Self {
367 Self {
368 interval_ms,
369 timeout_ms,
370 max_missed,
371 }
372 }
373
374 pub fn default_config() -> Self {
377 Self {
378 interval_ms: 100,
379 timeout_ms: 500,
380 max_missed: 3,
381 }
382 }
383
384 pub fn validate(&self) -> Result<(), String> {
386 if self.interval_ms == 0 {
387 return Err("Heartbeat interval must be > 0".to_string());
388 }
389 if self.timeout_ms == 0 {
390 return Err("Heartbeat timeout must be > 0".to_string());
391 }
392 if self.timeout_ms <= self.interval_ms {
393 return Err(format!(
394 "Heartbeat timeout ({}) must be greater than interval ({})",
395 self.timeout_ms, self.interval_ms
396 ));
397 }
398 if self.max_missed == 0 {
399 return Err("max_missed must be > 0".to_string());
400 }
401 Ok(())
402 }
403}
404
405impl Default for HeartbeatConfig {
406 fn default() -> Self {
407 Self::default_config()
408 }
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
430pub struct FencingToken(pub u64);
431
432impl FencingToken {
433 pub fn new(term: u32, seq: u32) -> Self {
435 Self(((term as u64) << 32) | (seq as u64))
436 }
437
438 pub fn term(self) -> u32 {
440 (self.0 >> 32) as u32
441 }
442
443 pub fn seq(self) -> u32 {
445 self.0 as u32
446 }
447
448 pub fn raw(self) -> u64 {
450 self.0
451 }
452
453 pub fn bump_seq(self) -> Self {
458 let term = self.term();
459 let seq = self.seq().wrapping_add(1);
460 Self::new(term, seq)
461 }
462
463 pub fn new_leader_term(term: u32) -> Self {
465 Self::new(term, 0)
466 }
467}
468
469#[derive(Debug, Clone, PartialEq, Eq)]
471pub enum FailureEvent {
472 NodeFailed {
474 node_id: NodeId,
476 missed_count: u32,
478 last_seen_ago_ms: u64,
480 },
481 NodeRecovered {
483 node_id: NodeId,
485 },
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_node_state_as_str() {
494 assert_eq!(NodeState::Follower.as_str(), "Follower");
495 assert_eq!(NodeState::Candidate.as_str(), "Candidate");
496 assert_eq!(NodeState::Leader.as_str(), "Leader");
497 }
498
499 #[test]
500 fn test_raft_config_new() {
501 let config = RaftConfig::new(1, vec![1, 2, 3]);
502 assert_eq!(config.node_id, 1);
503 assert_eq!(config.peers, vec![1, 2, 3]);
504 assert_eq!(config.election_timeout_range, (150, 300));
505 assert_eq!(config.heartbeat_interval, 50);
506 }
507
508 #[test]
509 fn test_raft_config_validate_valid() {
510 let config = RaftConfig::new(1, vec![1, 2, 3]);
511 assert!(config.validate().is_ok());
512 }
513
514 #[test]
515 fn test_raft_config_validate_node_not_in_peers() {
516 let config = RaftConfig::new(4, vec![1, 2, 3]);
517 assert!(config.validate().is_err());
518 }
519
520 #[test]
521 fn test_raft_config_validate_even_number_of_nodes() {
522 let config = RaftConfig::new(1, vec![1, 2, 3, 4]);
523 assert!(config.validate().is_err());
524 }
525
526 #[test]
527 fn test_raft_config_validate_too_few_nodes() {
528 let config = RaftConfig::new(1, vec![1]);
529 assert!(config.validate().is_err());
530 }
531
532 #[test]
533 fn test_raft_config_quorum_size() {
534 let config = RaftConfig::new(1, vec![1, 2, 3]);
535 assert_eq!(config.quorum_size(), 2);
536
537 let config = RaftConfig::new(1, vec![1, 2, 3, 4, 5]);
538 assert_eq!(config.quorum_size(), 3);
539 }
540
541 #[test]
542 fn test_random_election_timeout() {
543 let config = RaftConfig::new(1, vec![1, 2, 3]);
544 let timeout1 = config.random_election_timeout();
545 let timeout2 = config.random_election_timeout();
546
547 assert!(timeout1.as_millis() >= 150);
549 assert!(timeout1.as_millis() <= 300);
550 assert!(timeout2.as_millis() >= 150);
551 assert!(timeout2.as_millis() <= 300);
552 }
553
554 #[test]
557 fn test_cluster_config_new() {
558 let members = vec![(1, "addr1".to_string()), (2, "addr2".to_string())];
559 let cfg = ClusterConfig::new(members.clone(), 0);
560 assert_eq!(cfg.len(), 2);
561 assert_eq!(cfg.version(), 0);
562 assert!(cfg.contains(1));
563 assert!(cfg.contains(2));
564 assert!(!cfg.contains(3));
565 }
566
567 #[test]
568 fn test_cluster_config_quorum() {
569 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
570 let cfg = ClusterConfig::new(members, 0);
571 assert_eq!(cfg.quorum_size(), 2); }
573
574 #[test]
575 fn test_cluster_config_add_remove() {
576 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
577 let cfg = ClusterConfig::new(members, 0);
578
579 let cfg2 = cfg.with_added_member(4, "d".into());
580 assert_eq!(cfg2.len(), 4);
581 assert!(cfg2.contains(4));
582 assert_eq!(cfg2.version(), 1);
583
584 let cfg3 = cfg2.without_member(2);
585 assert_eq!(cfg3.len(), 3);
586 assert!(!cfg3.contains(2));
587 assert_eq!(cfg3.version(), 2);
588 }
589
590 #[test]
591 fn test_cluster_config_add_existing_is_noop() {
592 let members = vec![(1, "a".into()), (2, "b".into())];
593 let cfg = ClusterConfig::new(members, 0);
594 let cfg2 = cfg.with_added_member(1, "a2".into());
595 assert_eq!(cfg2.len(), 2);
597 }
598
599 #[test]
602 fn test_config_state_stable_quorum() {
603 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
604 let cs = ConfigState::new_stable(members);
605 assert!(!cs.is_joint());
606
607 let mut responding = HashSet::new();
608 responding.insert(1);
609 assert!(!cs.has_quorum(&responding)); responding.insert(2);
612 assert!(cs.has_quorum(&responding)); }
614
615 #[test]
616 fn test_config_state_joint_quorum() {
617 let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
618 let new = ClusterConfig::new(
619 vec![
620 (1, "a".into()),
621 (2, "b".into()),
622 (3, "c".into()),
623 (4, "d".into()),
624 ],
625 1,
626 );
627 let cs = ConfigState::Joint {
628 old: old.clone(),
629 new: new.clone(),
630 };
631 assert!(cs.is_joint());
632
633 let mut r = HashSet::new();
635 r.insert(1);
636 r.insert(2);
637 assert!(!cs.has_quorum(&r));
639
640 r.insert(3);
641 assert!(cs.has_quorum(&r));
643 }
644
645 #[test]
646 fn test_config_state_all_members() {
647 let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
648 let new = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (4, "d".into())], 1);
649 let cs = ConfigState::Joint { old, new };
650 let members = cs.all_members();
651 let ids: HashSet<NodeId> = members.iter().map(|(id, _)| *id).collect();
652 assert_eq!(ids.len(), 4); assert!(ids.contains(&3));
654 assert!(ids.contains(&4));
655 }
656
657 #[test]
658 fn test_config_state_version() {
659 let cs = ConfigState::new_stable(vec![(1, "a".into())]);
660 assert_eq!(cs.version(), 0);
661 }
662
663 #[test]
666 fn test_heartbeat_config_new() {
667 let config = HeartbeatConfig::new(100, 500, 3);
668 assert_eq!(config.interval_ms, 100);
669 assert_eq!(config.timeout_ms, 500);
670 assert_eq!(config.max_missed, 3);
671 }
672
673 #[test]
674 fn test_heartbeat_config_default() {
675 let config = HeartbeatConfig::default();
676 assert_eq!(config.interval_ms, 100);
677 assert_eq!(config.timeout_ms, 500);
678 assert_eq!(config.max_missed, 3);
679 }
680
681 #[test]
682 fn test_heartbeat_config_validate_ok() {
683 let config = HeartbeatConfig::new(100, 500, 3);
684 assert!(config.validate().is_ok());
685 }
686
687 #[test]
688 fn test_heartbeat_config_validate_zero_interval() {
689 let config = HeartbeatConfig::new(0, 500, 3);
690 assert!(config.validate().is_err());
691 }
692
693 #[test]
694 fn test_heartbeat_config_validate_zero_timeout() {
695 let config = HeartbeatConfig::new(100, 0, 3);
696 assert!(config.validate().is_err());
697 }
698
699 #[test]
700 fn test_heartbeat_config_validate_timeout_less_than_interval() {
701 let config = HeartbeatConfig::new(100, 50, 3);
702 assert!(config.validate().is_err());
703 }
704
705 #[test]
706 fn test_heartbeat_config_validate_timeout_equal_interval() {
707 let config = HeartbeatConfig::new(100, 100, 3);
708 assert!(config.validate().is_err());
709 }
710
711 #[test]
712 fn test_heartbeat_config_validate_zero_max_missed() {
713 let config = HeartbeatConfig::new(100, 500, 0);
714 assert!(config.validate().is_err());
715 }
716
717 #[test]
720 fn test_failure_event_node_failed_eq() {
721 let a = FailureEvent::NodeFailed {
722 node_id: 2,
723 missed_count: 3,
724 last_seen_ago_ms: 500,
725 };
726 let b = FailureEvent::NodeFailed {
727 node_id: 2,
728 missed_count: 3,
729 last_seen_ago_ms: 500,
730 };
731 assert_eq!(a, b);
732 }
733
734 #[test]
735 fn test_failure_event_node_recovered_eq() {
736 let a = FailureEvent::NodeRecovered { node_id: 2 };
737 let b = FailureEvent::NodeRecovered { node_id: 2 };
738 assert_eq!(a, b);
739 }
740
741 #[test]
742 fn test_failure_event_ne() {
743 let a = FailureEvent::NodeFailed {
744 node_id: 2,
745 missed_count: 3,
746 last_seen_ago_ms: 500,
747 };
748 let b = FailureEvent::NodeRecovered { node_id: 2 };
749 assert_ne!(a, b);
750 }
751}