1use std::collections::HashSet;
4use std::path::PathBuf;
5use std::time::Duration;
6
7pub type NodeId = u64;
9
10pub type Term = u64;
12
13pub type LogIndex = u64;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MembershipChange {
19 AddNode {
21 node_id: NodeId,
23 address: String,
25 },
26 RemoveNode {
28 node_id: NodeId,
30 },
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct ClusterConfig {
36 members: Vec<(NodeId, String)>,
38 version: u64,
40}
41
42impl ClusterConfig {
43 pub fn new(members: Vec<(NodeId, String)>, version: u64) -> Self {
45 Self { members, version }
46 }
47
48 pub fn member_ids(&self) -> HashSet<NodeId> {
50 self.members.iter().map(|(id, _)| *id).collect()
51 }
52
53 pub fn members(&self) -> &[(NodeId, String)] {
55 &self.members
56 }
57
58 pub fn version(&self) -> u64 {
60 self.version
61 }
62
63 pub fn contains(&self, node_id: NodeId) -> bool {
65 self.members.iter().any(|(id, _)| *id == node_id)
66 }
67
68 pub fn quorum_size(&self) -> usize {
70 self.members.len() / 2 + 1
71 }
72
73 pub fn len(&self) -> usize {
75 self.members.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.members.is_empty()
81 }
82
83 pub fn with_added_member(&self, node_id: NodeId, address: String) -> Self {
85 let mut members = self.members.clone();
86 if !self.contains(node_id) {
87 members.push((node_id, address));
88 }
89 Self {
90 members,
91 version: self.version + 1,
92 }
93 }
94
95 pub fn without_member(&self, node_id: NodeId) -> Self {
97 let members: Vec<_> = self
98 .members
99 .iter()
100 .filter(|(id, _)| *id != node_id)
101 .cloned()
102 .collect();
103 Self {
104 members,
105 version: self.version + 1,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum ConfigState {
117 Stable(ClusterConfig),
119 Joint {
121 old: ClusterConfig,
123 new: ClusterConfig,
125 },
126}
127
128impl ConfigState {
129 pub fn new_stable(members: Vec<(NodeId, String)>) -> Self {
131 ConfigState::Stable(ClusterConfig::new(members, 0))
132 }
133
134 pub fn all_member_ids(&self) -> HashSet<NodeId> {
136 match self {
137 ConfigState::Stable(config) => config.member_ids(),
138 ConfigState::Joint { old, new } => {
139 let mut ids = old.member_ids();
140 ids.extend(new.member_ids());
141 ids
142 }
143 }
144 }
145
146 pub fn is_joint(&self) -> bool {
148 matches!(self, ConfigState::Joint { .. })
149 }
150
151 pub fn version(&self) -> u64 {
153 match self {
154 ConfigState::Stable(config) => config.version(),
155 ConfigState::Joint { old, new } => old.version().max(new.version()),
156 }
157 }
158
159 pub fn has_quorum(&self, responding_nodes: &HashSet<NodeId>) -> bool {
164 match self {
165 ConfigState::Stable(config) => {
166 let count = config.member_ids().intersection(responding_nodes).count();
167 count >= config.quorum_size()
168 }
169 ConfigState::Joint { old, new } => {
170 let old_count = old.member_ids().intersection(responding_nodes).count();
171 let new_count = new.member_ids().intersection(responding_nodes).count();
172 old_count >= old.quorum_size() && new_count >= new.quorum_size()
173 }
174 }
175 }
176
177 pub fn stable_config(&self) -> Option<&ClusterConfig> {
179 match self {
180 ConfigState::Stable(config) => Some(config),
181 ConfigState::Joint { .. } => None,
182 }
183 }
184
185 pub fn all_members(&self) -> Vec<(NodeId, String)> {
187 match self {
188 ConfigState::Stable(config) => config.members().to_vec(),
189 ConfigState::Joint { old, new } => {
190 let mut seen = HashSet::new();
191 let mut result = Vec::new();
192 for (id, addr) in old.members().iter().chain(new.members().iter()) {
193 if seen.insert(*id) {
194 result.push((*id, addr.clone()));
195 }
196 }
197 result
198 }
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum NodeState {
206 Follower,
208 Candidate,
210 Leader,
212}
213
214impl NodeState {
215 pub fn as_str(&self) -> &'static str {
217 match self {
218 NodeState::Follower => "Follower",
219 NodeState::Candidate => "Candidate",
220 NodeState::Leader => "Leader",
221 }
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct RaftConfig {
228 pub node_id: NodeId,
230 pub peers: Vec<NodeId>,
232 pub election_timeout_range: (u64, u64),
234 pub heartbeat_interval: u64,
236 pub max_entries_per_message: usize,
238 pub enable_compaction: bool,
240 pub snapshot_threshold: u64,
242 pub max_snapshots: usize,
244 pub snapshot_dir: Option<PathBuf>,
246 pub persistence_dir: Option<PathBuf>,
248 pub wal_dir: Option<PathBuf>,
250 pub sync_on_write: bool,
252}
253
254impl RaftConfig {
255 pub fn new(node_id: NodeId, peers: Vec<NodeId>) -> Self {
257 Self {
258 node_id,
259 peers,
260 election_timeout_range: (150, 300),
261 heartbeat_interval: 50,
262 max_entries_per_message: 100,
263 enable_compaction: true,
264 snapshot_threshold: 10000,
265 max_snapshots: 3,
266 snapshot_dir: None,
267 persistence_dir: None,
268 wal_dir: None,
269 sync_on_write: true,
270 }
271 }
272
273 pub fn random_election_timeout(&self) -> Duration {
275 use std::collections::hash_map::RandomState;
276 use std::hash::BuildHasher;
277
278 let (min, max) = self.election_timeout_range;
279 let range = max - min;
280
281 let now = std::time::SystemTime::now()
283 .duration_since(std::time::UNIX_EPOCH)
284 .map(|d| d.as_nanos())
285 .unwrap_or(0);
286
287 let random_value = RandomState::new().hash_one(now);
288
289 let timeout_ms = min + (random_value % range);
290 Duration::from_millis(timeout_ms)
291 }
292
293 pub fn heartbeat_interval(&self) -> Duration {
295 Duration::from_millis(self.heartbeat_interval)
296 }
297
298 pub fn validate(&self) -> Result<(), String> {
300 if !self.peers.contains(&self.node_id) {
302 return Err(format!("Node ID {} not found in peers list", self.node_id));
303 }
304
305 if self.peers.len() % 2 == 0 {
307 return Err(format!(
308 "Raft requires odd number of nodes, got {}",
309 self.peers.len()
310 ));
311 }
312
313 if self.peers.len() < 3 {
315 return Err(format!(
316 "Raft requires at least 3 nodes for fault tolerance, got {}",
317 self.peers.len()
318 ));
319 }
320
321 let (min, max) = self.election_timeout_range;
323 if min >= max {
324 return Err(format!(
325 "Election timeout min ({}) must be less than max ({})",
326 min, max
327 ));
328 }
329
330 if self.heartbeat_interval >= min {
332 return Err(format!(
333 "Heartbeat interval ({}) must be less than election timeout min ({})",
334 self.heartbeat_interval, min
335 ));
336 }
337
338 Ok(())
339 }
340
341 pub fn quorum_size(&self) -> usize {
343 self.peers.len() / 2 + 1
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct HeartbeatConfig {
350 pub interval_ms: u64,
352 pub timeout_ms: u64,
354 pub max_missed: u32,
356}
357
358impl HeartbeatConfig {
359 pub fn new(interval_ms: u64, timeout_ms: u64, max_missed: u32) -> Self {
361 Self {
362 interval_ms,
363 timeout_ms,
364 max_missed,
365 }
366 }
367
368 pub fn default_config() -> Self {
371 Self {
372 interval_ms: 100,
373 timeout_ms: 500,
374 max_missed: 3,
375 }
376 }
377
378 pub fn validate(&self) -> Result<(), String> {
380 if self.interval_ms == 0 {
381 return Err("Heartbeat interval must be > 0".to_string());
382 }
383 if self.timeout_ms == 0 {
384 return Err("Heartbeat timeout must be > 0".to_string());
385 }
386 if self.timeout_ms <= self.interval_ms {
387 return Err(format!(
388 "Heartbeat timeout ({}) must be greater than interval ({})",
389 self.timeout_ms, self.interval_ms
390 ));
391 }
392 if self.max_missed == 0 {
393 return Err("max_missed must be > 0".to_string());
394 }
395 Ok(())
396 }
397}
398
399impl Default for HeartbeatConfig {
400 fn default() -> Self {
401 Self::default_config()
402 }
403}
404
405#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
424pub struct FencingToken(pub u64);
425
426impl FencingToken {
427 pub fn new(term: u32, seq: u32) -> Self {
429 Self(((term as u64) << 32) | (seq as u64))
430 }
431
432 pub fn term(self) -> u32 {
434 (self.0 >> 32) as u32
435 }
436
437 pub fn seq(self) -> u32 {
439 self.0 as u32
440 }
441
442 pub fn raw(self) -> u64 {
444 self.0
445 }
446
447 pub fn bump_seq(self) -> Self {
452 let term = self.term();
453 let seq = self.seq().wrapping_add(1);
454 Self::new(term, seq)
455 }
456
457 pub fn new_leader_term(term: u32) -> Self {
459 Self::new(term, 0)
460 }
461}
462
463#[derive(Debug, Clone, PartialEq, Eq)]
465pub enum FailureEvent {
466 NodeFailed {
468 node_id: NodeId,
470 missed_count: u32,
472 last_seen_ago_ms: u64,
474 },
475 NodeRecovered {
477 node_id: NodeId,
479 },
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_node_state_as_str() {
488 assert_eq!(NodeState::Follower.as_str(), "Follower");
489 assert_eq!(NodeState::Candidate.as_str(), "Candidate");
490 assert_eq!(NodeState::Leader.as_str(), "Leader");
491 }
492
493 #[test]
494 fn test_raft_config_new() {
495 let config = RaftConfig::new(1, vec![1, 2, 3]);
496 assert_eq!(config.node_id, 1);
497 assert_eq!(config.peers, vec![1, 2, 3]);
498 assert_eq!(config.election_timeout_range, (150, 300));
499 assert_eq!(config.heartbeat_interval, 50);
500 }
501
502 #[test]
503 fn test_raft_config_validate_valid() {
504 let config = RaftConfig::new(1, vec![1, 2, 3]);
505 assert!(config.validate().is_ok());
506 }
507
508 #[test]
509 fn test_raft_config_validate_node_not_in_peers() {
510 let config = RaftConfig::new(4, vec![1, 2, 3]);
511 assert!(config.validate().is_err());
512 }
513
514 #[test]
515 fn test_raft_config_validate_even_number_of_nodes() {
516 let config = RaftConfig::new(1, vec![1, 2, 3, 4]);
517 assert!(config.validate().is_err());
518 }
519
520 #[test]
521 fn test_raft_config_validate_too_few_nodes() {
522 let config = RaftConfig::new(1, vec![1]);
523 assert!(config.validate().is_err());
524 }
525
526 #[test]
527 fn test_raft_config_quorum_size() {
528 let config = RaftConfig::new(1, vec![1, 2, 3]);
529 assert_eq!(config.quorum_size(), 2);
530
531 let config = RaftConfig::new(1, vec![1, 2, 3, 4, 5]);
532 assert_eq!(config.quorum_size(), 3);
533 }
534
535 #[test]
536 fn test_random_election_timeout() {
537 let config = RaftConfig::new(1, vec![1, 2, 3]);
538 let timeout1 = config.random_election_timeout();
539 let timeout2 = config.random_election_timeout();
540
541 assert!(timeout1.as_millis() >= 150);
543 assert!(timeout1.as_millis() <= 300);
544 assert!(timeout2.as_millis() >= 150);
545 assert!(timeout2.as_millis() <= 300);
546 }
547
548 #[test]
551 fn test_cluster_config_new() {
552 let members = vec![(1, "addr1".to_string()), (2, "addr2".to_string())];
553 let cfg = ClusterConfig::new(members.clone(), 0);
554 assert_eq!(cfg.len(), 2);
555 assert_eq!(cfg.version(), 0);
556 assert!(cfg.contains(1));
557 assert!(cfg.contains(2));
558 assert!(!cfg.contains(3));
559 }
560
561 #[test]
562 fn test_cluster_config_quorum() {
563 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
564 let cfg = ClusterConfig::new(members, 0);
565 assert_eq!(cfg.quorum_size(), 2); }
567
568 #[test]
569 fn test_cluster_config_add_remove() {
570 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
571 let cfg = ClusterConfig::new(members, 0);
572
573 let cfg2 = cfg.with_added_member(4, "d".into());
574 assert_eq!(cfg2.len(), 4);
575 assert!(cfg2.contains(4));
576 assert_eq!(cfg2.version(), 1);
577
578 let cfg3 = cfg2.without_member(2);
579 assert_eq!(cfg3.len(), 3);
580 assert!(!cfg3.contains(2));
581 assert_eq!(cfg3.version(), 2);
582 }
583
584 #[test]
585 fn test_cluster_config_add_existing_is_noop() {
586 let members = vec![(1, "a".into()), (2, "b".into())];
587 let cfg = ClusterConfig::new(members, 0);
588 let cfg2 = cfg.with_added_member(1, "a2".into());
589 assert_eq!(cfg2.len(), 2);
591 }
592
593 #[test]
596 fn test_config_state_stable_quorum() {
597 let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
598 let cs = ConfigState::new_stable(members);
599 assert!(!cs.is_joint());
600
601 let mut responding = HashSet::new();
602 responding.insert(1);
603 assert!(!cs.has_quorum(&responding)); responding.insert(2);
606 assert!(cs.has_quorum(&responding)); }
608
609 #[test]
610 fn test_config_state_joint_quorum() {
611 let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
612 let new = ClusterConfig::new(
613 vec![
614 (1, "a".into()),
615 (2, "b".into()),
616 (3, "c".into()),
617 (4, "d".into()),
618 ],
619 1,
620 );
621 let cs = ConfigState::Joint {
622 old: old.clone(),
623 new: new.clone(),
624 };
625 assert!(cs.is_joint());
626
627 let mut r = HashSet::new();
629 r.insert(1);
630 r.insert(2);
631 assert!(!cs.has_quorum(&r));
633
634 r.insert(3);
635 assert!(cs.has_quorum(&r));
637 }
638
639 #[test]
640 fn test_config_state_all_members() {
641 let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
642 let new = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (4, "d".into())], 1);
643 let cs = ConfigState::Joint { old, new };
644 let members = cs.all_members();
645 let ids: HashSet<NodeId> = members.iter().map(|(id, _)| *id).collect();
646 assert_eq!(ids.len(), 4); assert!(ids.contains(&3));
648 assert!(ids.contains(&4));
649 }
650
651 #[test]
652 fn test_config_state_version() {
653 let cs = ConfigState::new_stable(vec![(1, "a".into())]);
654 assert_eq!(cs.version(), 0);
655 }
656
657 #[test]
660 fn test_heartbeat_config_new() {
661 let config = HeartbeatConfig::new(100, 500, 3);
662 assert_eq!(config.interval_ms, 100);
663 assert_eq!(config.timeout_ms, 500);
664 assert_eq!(config.max_missed, 3);
665 }
666
667 #[test]
668 fn test_heartbeat_config_default() {
669 let config = HeartbeatConfig::default();
670 assert_eq!(config.interval_ms, 100);
671 assert_eq!(config.timeout_ms, 500);
672 assert_eq!(config.max_missed, 3);
673 }
674
675 #[test]
676 fn test_heartbeat_config_validate_ok() {
677 let config = HeartbeatConfig::new(100, 500, 3);
678 assert!(config.validate().is_ok());
679 }
680
681 #[test]
682 fn test_heartbeat_config_validate_zero_interval() {
683 let config = HeartbeatConfig::new(0, 500, 3);
684 assert!(config.validate().is_err());
685 }
686
687 #[test]
688 fn test_heartbeat_config_validate_zero_timeout() {
689 let config = HeartbeatConfig::new(100, 0, 3);
690 assert!(config.validate().is_err());
691 }
692
693 #[test]
694 fn test_heartbeat_config_validate_timeout_less_than_interval() {
695 let config = HeartbeatConfig::new(100, 50, 3);
696 assert!(config.validate().is_err());
697 }
698
699 #[test]
700 fn test_heartbeat_config_validate_timeout_equal_interval() {
701 let config = HeartbeatConfig::new(100, 100, 3);
702 assert!(config.validate().is_err());
703 }
704
705 #[test]
706 fn test_heartbeat_config_validate_zero_max_missed() {
707 let config = HeartbeatConfig::new(100, 500, 0);
708 assert!(config.validate().is_err());
709 }
710
711 #[test]
714 fn test_failure_event_node_failed_eq() {
715 let a = FailureEvent::NodeFailed {
716 node_id: 2,
717 missed_count: 3,
718 last_seen_ago_ms: 500,
719 };
720 let b = FailureEvent::NodeFailed {
721 node_id: 2,
722 missed_count: 3,
723 last_seen_ago_ms: 500,
724 };
725 assert_eq!(a, b);
726 }
727
728 #[test]
729 fn test_failure_event_node_recovered_eq() {
730 let a = FailureEvent::NodeRecovered { node_id: 2 };
731 let b = FailureEvent::NodeRecovered { node_id: 2 };
732 assert_eq!(a, b);
733 }
734
735 #[test]
736 fn test_failure_event_ne() {
737 let a = FailureEvent::NodeFailed {
738 node_id: 2,
739 missed_count: 3,
740 last_seen_ago_ms: 500,
741 };
742 let b = FailureEvent::NodeRecovered { node_id: 2 };
743 assert_ne!(a, b);
744 }
745}