Skip to main content

amaters_cluster/
types.rs

1//! Core types for Raft consensus
2
3use std::collections::HashSet;
4use std::path::PathBuf;
5use std::time::Duration;
6
7/// Node identifier
8pub type NodeId = u64;
9
10/// Raft term number
11pub type Term = u64;
12
13/// Log entry index (1-indexed, 0 means no entry)
14pub type LogIndex = u64;
15
16/// A membership change request for dynamic cluster reconfiguration
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum MembershipChange {
19    /// Add a new node to the cluster
20    AddNode {
21        /// The node ID to add
22        node_id: NodeId,
23        /// The network address of the node
24        address: String,
25    },
26    /// Remove an existing node from the cluster
27    RemoveNode {
28        /// The node ID to remove
29        node_id: NodeId,
30    },
31}
32
33/// Tracks cluster members with their addresses and a monotonically increasing version
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct ClusterConfig {
36    /// Map of node IDs to their network addresses
37    members: Vec<(NodeId, String)>,
38    /// Monotonically increasing version number for this configuration
39    version: u64,
40}
41
42impl ClusterConfig {
43    /// Create a new cluster config with the given members and version
44    pub fn new(members: Vec<(NodeId, String)>, version: u64) -> Self {
45        Self { members, version }
46    }
47
48    /// Get the list of member node IDs
49    pub fn member_ids(&self) -> HashSet<NodeId> {
50        self.members.iter().map(|(id, _)| *id).collect()
51    }
52
53    /// Get all members as (node_id, address) pairs
54    pub fn members(&self) -> &[(NodeId, String)] {
55        &self.members
56    }
57
58    /// Get the version of this configuration
59    pub fn version(&self) -> u64 {
60        self.version
61    }
62
63    /// Check if a node is a member
64    pub fn contains(&self, node_id: NodeId) -> bool {
65        self.members.iter().any(|(id, _)| *id == node_id)
66    }
67
68    /// Get the majority quorum size for this config
69    pub fn quorum_size(&self) -> usize {
70        self.members.len() / 2 + 1
71    }
72
73    /// Get the number of members
74    pub fn len(&self) -> usize {
75        self.members.len()
76    }
77
78    /// Check if the config has no members
79    pub fn is_empty(&self) -> bool {
80        self.members.is_empty()
81    }
82
83    /// Add a member to the config, returning a new config with incremented version
84    pub fn with_added_member(&self, node_id: NodeId, address: String) -> Self {
85        let mut members = self.members.clone();
86        if !self.contains(node_id) {
87            members.push((node_id, address));
88        }
89        Self {
90            members,
91            version: self.version + 1,
92        }
93    }
94
95    /// Remove a member from the config, returning a new config with incremented version
96    pub fn without_member(&self, node_id: NodeId) -> Self {
97        let members: Vec<_> = self
98            .members
99            .iter()
100            .filter(|(id, _)| *id != node_id)
101            .cloned()
102            .collect();
103        Self {
104            members,
105            version: self.version + 1,
106        }
107    }
108}
109
110/// The state of cluster configuration during membership changes.
111///
112/// Implements the Raft joint consensus protocol (Section 6):
113/// - `Stable`: Normal operation with a single configuration
114/// - `Joint`: Transitional state requiring majority from BOTH old and new configs
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum ConfigState {
117    /// Normal operation with a single configuration
118    Stable(ClusterConfig),
119    /// Joint consensus: decisions require majority of both old and new configs
120    Joint {
121        /// The old (current) configuration
122        old: ClusterConfig,
123        /// The new (target) configuration
124        new: ClusterConfig,
125    },
126}
127
128impl ConfigState {
129    /// Create a new stable config state
130    pub fn new_stable(members: Vec<(NodeId, String)>) -> Self {
131        ConfigState::Stable(ClusterConfig::new(members, 0))
132    }
133
134    /// Get all unique member node IDs across both configs (if joint)
135    pub fn all_member_ids(&self) -> HashSet<NodeId> {
136        match self {
137            ConfigState::Stable(config) => config.member_ids(),
138            ConfigState::Joint { old, new } => {
139                let mut ids = old.member_ids();
140                ids.extend(new.member_ids());
141                ids
142            }
143        }
144    }
145
146    /// Check if we are in joint consensus
147    pub fn is_joint(&self) -> bool {
148        matches!(self, ConfigState::Joint { .. })
149    }
150
151    /// Get the current version (max of both configs if joint)
152    pub fn version(&self) -> u64 {
153        match self {
154            ConfigState::Stable(config) => config.version(),
155            ConfigState::Joint { old, new } => old.version().max(new.version()),
156        }
157    }
158
159    /// Check if a given set of responding nodes forms a quorum.
160    ///
161    /// During joint consensus, a quorum requires majority in BOTH the old
162    /// and new configurations independently.
163    pub fn has_quorum(&self, responding_nodes: &HashSet<NodeId>) -> bool {
164        match self {
165            ConfigState::Stable(config) => {
166                let count = config.member_ids().intersection(responding_nodes).count();
167                count >= config.quorum_size()
168            }
169            ConfigState::Joint { old, new } => {
170                let old_count = old.member_ids().intersection(responding_nodes).count();
171                let new_count = new.member_ids().intersection(responding_nodes).count();
172                old_count >= old.quorum_size() && new_count >= new.quorum_size()
173            }
174        }
175    }
176
177    /// Get the stable config (only valid if not in joint state)
178    pub fn stable_config(&self) -> Option<&ClusterConfig> {
179        match self {
180            ConfigState::Stable(config) => Some(config),
181            ConfigState::Joint { .. } => None,
182        }
183    }
184
185    /// Get all members as (node_id, address) pairs
186    pub fn all_members(&self) -> Vec<(NodeId, String)> {
187        match self {
188            ConfigState::Stable(config) => config.members().to_vec(),
189            ConfigState::Joint { old, new } => {
190                let mut seen = HashSet::new();
191                let mut result = Vec::new();
192                for (id, addr) in old.members().iter().chain(new.members().iter()) {
193                    if seen.insert(*id) {
194                        result.push((*id, addr.clone()));
195                    }
196                }
197                result
198            }
199        }
200    }
201}
202
203/// Raft node state
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum NodeState {
206    /// Follower state - passive, responds to RPCs
207    Follower,
208    /// Candidate state - requesting votes for leadership
209    Candidate,
210    /// Leader state - handles client requests and replicates log
211    Leader,
212}
213
214impl NodeState {
215    /// Get the state name as a string
216    pub fn as_str(&self) -> &'static str {
217        match self {
218            NodeState::Follower => "Follower",
219            NodeState::Candidate => "Candidate",
220            NodeState::Leader => "Leader",
221        }
222    }
223}
224
225/// Configuration for a Raft node
226#[derive(Debug, Clone)]
227pub struct RaftConfig {
228    /// This node's ID
229    pub node_id: NodeId,
230    /// List of all peer node IDs (including this node)
231    pub peers: Vec<NodeId>,
232    /// Election timeout range (min, max) in milliseconds
233    pub election_timeout_range: (u64, u64),
234    /// Heartbeat interval in milliseconds
235    pub heartbeat_interval: u64,
236    /// Maximum number of entries to send in a single AppendEntries RPC
237    pub max_entries_per_message: usize,
238    /// Whether to enable log compaction
239    pub enable_compaction: bool,
240    /// Snapshot threshold (number of log entries before triggering snapshot)
241    pub snapshot_threshold: u64,
242    /// Maximum number of snapshots to retain on disk
243    pub max_snapshots: usize,
244    /// Directory for storing snapshots (None = snapshots disabled on disk)
245    pub snapshot_dir: Option<PathBuf>,
246    /// Directory for Raft persistent state and log (None = in-memory only)
247    pub persistence_dir: Option<PathBuf>,
248    /// Directory for segment-based WAL replay on startup (None = WAL replay disabled)
249    pub wal_dir: Option<PathBuf>,
250    /// Whether to fsync after every persistent write (default: true)
251    pub sync_on_write: bool,
252}
253
254impl RaftConfig {
255    /// Create a new Raft configuration with sensible defaults
256    pub fn new(node_id: NodeId, peers: Vec<NodeId>) -> Self {
257        Self {
258            node_id,
259            peers,
260            election_timeout_range: (150, 300),
261            heartbeat_interval: 50,
262            max_entries_per_message: 100,
263            enable_compaction: true,
264            snapshot_threshold: 10000,
265            max_snapshots: 3,
266            snapshot_dir: None,
267            persistence_dir: None,
268            wal_dir: None,
269            sync_on_write: true,
270        }
271    }
272
273    /// Get a random election timeout within the configured range
274    pub fn random_election_timeout(&self) -> Duration {
275        use std::collections::hash_map::RandomState;
276        use std::hash::BuildHasher;
277
278        let (min, max) = self.election_timeout_range;
279        let range = max - min;
280
281        // Use current time as seed for randomization
282        let now = std::time::SystemTime::now()
283            .duration_since(std::time::UNIX_EPOCH)
284            .map(|d| d.as_nanos())
285            .unwrap_or(0);
286
287        let random_value = RandomState::new().hash_one(now);
288
289        let timeout_ms = min + (random_value % range);
290        Duration::from_millis(timeout_ms)
291    }
292
293    /// Get the heartbeat interval
294    pub fn heartbeat_interval(&self) -> Duration {
295        Duration::from_millis(self.heartbeat_interval)
296    }
297
298    /// Validate the configuration
299    pub fn validate(&self) -> Result<(), String> {
300        // Check that node_id is in peers list
301        if !self.peers.contains(&self.node_id) {
302            return Err(format!("Node ID {} not found in peers list", self.node_id));
303        }
304
305        // Check for odd number of nodes (for quorum)
306        if self.peers.len() % 2 == 0 {
307            return Err(format!(
308                "Raft requires odd number of nodes, got {}",
309                self.peers.len()
310            ));
311        }
312
313        // Check minimum nodes
314        if self.peers.len() < 3 {
315            return Err(format!(
316                "Raft requires at least 3 nodes for fault tolerance, got {}",
317                self.peers.len()
318            ));
319        }
320
321        // Check election timeout range
322        let (min, max) = self.election_timeout_range;
323        if min >= max {
324            return Err(format!(
325                "Election timeout min ({}) must be less than max ({})",
326                min, max
327            ));
328        }
329
330        // Check heartbeat interval vs election timeout
331        if self.heartbeat_interval >= min {
332            return Err(format!(
333                "Heartbeat interval ({}) must be less than election timeout min ({})",
334                self.heartbeat_interval, min
335            ));
336        }
337
338        Ok(())
339    }
340
341    /// Calculate the quorum size (majority)
342    pub fn quorum_size(&self) -> usize {
343        self.peers.len() / 2 + 1
344    }
345}
346
347/// Configuration for heartbeat-based failure detection
348#[derive(Debug, Clone)]
349pub struct HeartbeatConfig {
350    /// Interval between heartbeat sends in milliseconds
351    pub interval_ms: u64,
352    /// Time in milliseconds after which a peer is considered potentially failed
353    pub timeout_ms: u64,
354    /// Number of consecutive missed heartbeats before declaring failure
355    pub max_missed: u32,
356}
357
358impl HeartbeatConfig {
359    /// Create a new heartbeat configuration
360    pub fn new(interval_ms: u64, timeout_ms: u64, max_missed: u32) -> Self {
361        Self {
362            interval_ms,
363            timeout_ms,
364            max_missed,
365        }
366    }
367
368    /// Create a default heartbeat configuration
369    /// Default: 100ms interval, 500ms timeout, 3 missed max
370    pub fn default_config() -> Self {
371        Self {
372            interval_ms: 100,
373            timeout_ms: 500,
374            max_missed: 3,
375        }
376    }
377
378    /// Validate the configuration
379    pub fn validate(&self) -> Result<(), String> {
380        if self.interval_ms == 0 {
381            return Err("Heartbeat interval must be > 0".to_string());
382        }
383        if self.timeout_ms == 0 {
384            return Err("Heartbeat timeout must be > 0".to_string());
385        }
386        if self.timeout_ms <= self.interval_ms {
387            return Err(format!(
388                "Heartbeat timeout ({}) must be greater than interval ({})",
389                self.timeout_ms, self.interval_ms
390            ));
391        }
392        if self.max_missed == 0 {
393            return Err("max_missed must be > 0".to_string());
394        }
395        Ok(())
396    }
397}
398
399impl Default for HeartbeatConfig {
400    fn default() -> Self {
401        Self::default_config()
402    }
403}
404
405/// A packed monotonically increasing fencing token that uniquely identifies a write epoch.
406///
407/// Encoded as a single `u64`:
408/// - High 32 bits = Raft term (capped at `u32::MAX` for compactness; terms exceeding `u32::MAX`
409///   are exceedingly unlikely in any realistic deployment).
410/// - Low 32 bits  = per-term monotonic sequence number.
411///
412/// Each time a node becomes leader it resets the sequence to zero and bumps the term
413/// component via [`FencingToken::new_leader_term`].  Storage backends and followers use
414/// the token to reject stale writes from former leaders: a write is stale when its
415/// token's term is less than the current term, or the term matches but the sequence
416/// has been superseded by a higher-sequence write in the same term.
417///
418/// # Format
419///
420/// ```text
421/// [term: 32 bits][seq: 32 bits]
422/// ```
423#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
424pub struct FencingToken(pub u64);
425
426impl FencingToken {
427    /// Pack `term` and `seq` into a new fencing token.
428    pub fn new(term: u32, seq: u32) -> Self {
429        Self(((term as u64) << 32) | (seq as u64))
430    }
431
432    /// Extract the term component (high 32 bits).
433    pub fn term(self) -> u32 {
434        (self.0 >> 32) as u32
435    }
436
437    /// Extract the sequence component (low 32 bits).
438    pub fn seq(self) -> u32 {
439        self.0 as u32
440    }
441
442    /// Return the raw `u64` representation.
443    pub fn raw(self) -> u64 {
444        self.0
445    }
446
447    /// Return a new token with the sequence number incremented by one,
448    /// keeping the term unchanged.
449    ///
450    /// Wraps on `u32` overflow (extremely unlikely in practice).
451    pub fn bump_seq(self) -> Self {
452        let term = self.term();
453        let seq = self.seq().wrapping_add(1);
454        Self::new(term, seq)
455    }
456
457    /// Return a fresh token for a new leader term; the sequence number resets to 0.
458    pub fn new_leader_term(term: u32) -> Self {
459        Self::new(term, 0)
460    }
461}
462
463/// Events emitted by the failure detector
464#[derive(Debug, Clone, PartialEq, Eq)]
465pub enum FailureEvent {
466    /// A node has been detected as failed (missed too many heartbeats)
467    NodeFailed {
468        /// The node that failed
469        node_id: NodeId,
470        /// Number of consecutive missed heartbeats
471        missed_count: u32,
472        /// Duration since last successful heartbeat
473        last_seen_ago_ms: u64,
474    },
475    /// A previously failed node has recovered (heartbeat received again)
476    NodeRecovered {
477        /// The node that recovered
478        node_id: NodeId,
479    },
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_node_state_as_str() {
488        assert_eq!(NodeState::Follower.as_str(), "Follower");
489        assert_eq!(NodeState::Candidate.as_str(), "Candidate");
490        assert_eq!(NodeState::Leader.as_str(), "Leader");
491    }
492
493    #[test]
494    fn test_raft_config_new() {
495        let config = RaftConfig::new(1, vec![1, 2, 3]);
496        assert_eq!(config.node_id, 1);
497        assert_eq!(config.peers, vec![1, 2, 3]);
498        assert_eq!(config.election_timeout_range, (150, 300));
499        assert_eq!(config.heartbeat_interval, 50);
500    }
501
502    #[test]
503    fn test_raft_config_validate_valid() {
504        let config = RaftConfig::new(1, vec![1, 2, 3]);
505        assert!(config.validate().is_ok());
506    }
507
508    #[test]
509    fn test_raft_config_validate_node_not_in_peers() {
510        let config = RaftConfig::new(4, vec![1, 2, 3]);
511        assert!(config.validate().is_err());
512    }
513
514    #[test]
515    fn test_raft_config_validate_even_number_of_nodes() {
516        let config = RaftConfig::new(1, vec![1, 2, 3, 4]);
517        assert!(config.validate().is_err());
518    }
519
520    #[test]
521    fn test_raft_config_validate_too_few_nodes() {
522        let config = RaftConfig::new(1, vec![1]);
523        assert!(config.validate().is_err());
524    }
525
526    #[test]
527    fn test_raft_config_quorum_size() {
528        let config = RaftConfig::new(1, vec![1, 2, 3]);
529        assert_eq!(config.quorum_size(), 2);
530
531        let config = RaftConfig::new(1, vec![1, 2, 3, 4, 5]);
532        assert_eq!(config.quorum_size(), 3);
533    }
534
535    #[test]
536    fn test_random_election_timeout() {
537        let config = RaftConfig::new(1, vec![1, 2, 3]);
538        let timeout1 = config.random_election_timeout();
539        let timeout2 = config.random_election_timeout();
540
541        // Both should be within range
542        assert!(timeout1.as_millis() >= 150);
543        assert!(timeout1.as_millis() <= 300);
544        assert!(timeout2.as_millis() >= 150);
545        assert!(timeout2.as_millis() <= 300);
546    }
547
548    // ── ClusterConfig tests ─────────────────────────────────────────
549
550    #[test]
551    fn test_cluster_config_new() {
552        let members = vec![(1, "addr1".to_string()), (2, "addr2".to_string())];
553        let cfg = ClusterConfig::new(members.clone(), 0);
554        assert_eq!(cfg.len(), 2);
555        assert_eq!(cfg.version(), 0);
556        assert!(cfg.contains(1));
557        assert!(cfg.contains(2));
558        assert!(!cfg.contains(3));
559    }
560
561    #[test]
562    fn test_cluster_config_quorum() {
563        let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
564        let cfg = ClusterConfig::new(members, 0);
565        assert_eq!(cfg.quorum_size(), 2); // 3/2 + 1 = 2
566    }
567
568    #[test]
569    fn test_cluster_config_add_remove() {
570        let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
571        let cfg = ClusterConfig::new(members, 0);
572
573        let cfg2 = cfg.with_added_member(4, "d".into());
574        assert_eq!(cfg2.len(), 4);
575        assert!(cfg2.contains(4));
576        assert_eq!(cfg2.version(), 1);
577
578        let cfg3 = cfg2.without_member(2);
579        assert_eq!(cfg3.len(), 3);
580        assert!(!cfg3.contains(2));
581        assert_eq!(cfg3.version(), 2);
582    }
583
584    #[test]
585    fn test_cluster_config_add_existing_is_noop() {
586        let members = vec![(1, "a".into()), (2, "b".into())];
587        let cfg = ClusterConfig::new(members, 0);
588        let cfg2 = cfg.with_added_member(1, "a2".into());
589        // Should still have 2 members (not duplicated)
590        assert_eq!(cfg2.len(), 2);
591    }
592
593    // ── ConfigState tests ───────────────────────────────────────────
594
595    #[test]
596    fn test_config_state_stable_quorum() {
597        let members = vec![(1, "a".into()), (2, "b".into()), (3, "c".into())];
598        let cs = ConfigState::new_stable(members);
599        assert!(!cs.is_joint());
600
601        let mut responding = HashSet::new();
602        responding.insert(1);
603        assert!(!cs.has_quorum(&responding)); // 1 of 3 -- no quorum
604
605        responding.insert(2);
606        assert!(cs.has_quorum(&responding)); // 2 of 3 -- quorum
607    }
608
609    #[test]
610    fn test_config_state_joint_quorum() {
611        let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
612        let new = ClusterConfig::new(
613            vec![
614                (1, "a".into()),
615                (2, "b".into()),
616                (3, "c".into()),
617                (4, "d".into()),
618            ],
619            1,
620        );
621        let cs = ConfigState::Joint {
622            old: old.clone(),
623            new: new.clone(),
624        };
625        assert!(cs.is_joint());
626
627        // Need majority of old (2/3) AND new (3/4)
628        let mut r = HashSet::new();
629        r.insert(1);
630        r.insert(2);
631        // old: 2/3 ok, new: 2/4 not ok
632        assert!(!cs.has_quorum(&r));
633
634        r.insert(3);
635        // old: 3/3 ok, new: 3/4 ok
636        assert!(cs.has_quorum(&r));
637    }
638
639    #[test]
640    fn test_config_state_all_members() {
641        let old = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (3, "c".into())], 0);
642        let new = ClusterConfig::new(vec![(1, "a".into()), (2, "b".into()), (4, "d".into())], 1);
643        let cs = ConfigState::Joint { old, new };
644        let members = cs.all_members();
645        let ids: HashSet<NodeId> = members.iter().map(|(id, _)| *id).collect();
646        assert_eq!(ids.len(), 4); // 1, 2, 3, 4
647        assert!(ids.contains(&3));
648        assert!(ids.contains(&4));
649    }
650
651    #[test]
652    fn test_config_state_version() {
653        let cs = ConfigState::new_stable(vec![(1, "a".into())]);
654        assert_eq!(cs.version(), 0);
655    }
656
657    // ── HeartbeatConfig tests ───────────────────────────────────────
658
659    #[test]
660    fn test_heartbeat_config_new() {
661        let config = HeartbeatConfig::new(100, 500, 3);
662        assert_eq!(config.interval_ms, 100);
663        assert_eq!(config.timeout_ms, 500);
664        assert_eq!(config.max_missed, 3);
665    }
666
667    #[test]
668    fn test_heartbeat_config_default() {
669        let config = HeartbeatConfig::default();
670        assert_eq!(config.interval_ms, 100);
671        assert_eq!(config.timeout_ms, 500);
672        assert_eq!(config.max_missed, 3);
673    }
674
675    #[test]
676    fn test_heartbeat_config_validate_ok() {
677        let config = HeartbeatConfig::new(100, 500, 3);
678        assert!(config.validate().is_ok());
679    }
680
681    #[test]
682    fn test_heartbeat_config_validate_zero_interval() {
683        let config = HeartbeatConfig::new(0, 500, 3);
684        assert!(config.validate().is_err());
685    }
686
687    #[test]
688    fn test_heartbeat_config_validate_zero_timeout() {
689        let config = HeartbeatConfig::new(100, 0, 3);
690        assert!(config.validate().is_err());
691    }
692
693    #[test]
694    fn test_heartbeat_config_validate_timeout_less_than_interval() {
695        let config = HeartbeatConfig::new(100, 50, 3);
696        assert!(config.validate().is_err());
697    }
698
699    #[test]
700    fn test_heartbeat_config_validate_timeout_equal_interval() {
701        let config = HeartbeatConfig::new(100, 100, 3);
702        assert!(config.validate().is_err());
703    }
704
705    #[test]
706    fn test_heartbeat_config_validate_zero_max_missed() {
707        let config = HeartbeatConfig::new(100, 500, 0);
708        assert!(config.validate().is_err());
709    }
710
711    // ── FailureEvent tests ──────────────────────────────────────────
712
713    #[test]
714    fn test_failure_event_node_failed_eq() {
715        let a = FailureEvent::NodeFailed {
716            node_id: 2,
717            missed_count: 3,
718            last_seen_ago_ms: 500,
719        };
720        let b = FailureEvent::NodeFailed {
721            node_id: 2,
722            missed_count: 3,
723            last_seen_ago_ms: 500,
724        };
725        assert_eq!(a, b);
726    }
727
728    #[test]
729    fn test_failure_event_node_recovered_eq() {
730        let a = FailureEvent::NodeRecovered { node_id: 2 };
731        let b = FailureEvent::NodeRecovered { node_id: 2 };
732        assert_eq!(a, b);
733    }
734
735    #[test]
736    fn test_failure_event_ne() {
737        let a = FailureEvent::NodeFailed {
738            node_id: 2,
739            missed_count: 3,
740            last_seen_ago_ms: 500,
741        };
742        let b = FailureEvent::NodeRecovered { node_id: 2 };
743        assert_ne!(a, b);
744    }
745}