aegis_replication/
raft.rs

1//! Aegis Raft Consensus
2//!
3//! Core Raft consensus algorithm implementation.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::log::{LogEntry, LogIndex, ReplicatedLog, Term};
9use crate::node::{NodeId, NodeRole};
10use crate::state::{Command, CommandResult, StateMachine};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16// =============================================================================
17// Raft Configuration
18// =============================================================================
19
20/// Configuration for the Raft consensus algorithm.
21#[derive(Debug, Clone)]
22pub struct RaftConfig {
23    pub election_timeout_min: Duration,
24    pub election_timeout_max: Duration,
25    pub heartbeat_interval: Duration,
26    pub max_entries_per_request: usize,
27    pub snapshot_threshold: u64,
28}
29
30impl Default for RaftConfig {
31    fn default() -> Self {
32        Self {
33            election_timeout_min: Duration::from_millis(150),
34            election_timeout_max: Duration::from_millis(300),
35            heartbeat_interval: Duration::from_millis(50),
36            max_entries_per_request: 100,
37            snapshot_threshold: 10000,
38        }
39    }
40}
41
42impl RaftConfig {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn with_election_timeout(mut self, min: Duration, max: Duration) -> Self {
48        self.election_timeout_min = min;
49        self.election_timeout_max = max;
50        self
51    }
52
53    pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
54        self.heartbeat_interval = interval;
55        self
56    }
57}
58
59// =============================================================================
60// Raft State
61// =============================================================================
62
63/// Persistent state for Raft consensus.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RaftState {
66    pub current_term: Term,
67    pub voted_for: Option<NodeId>,
68    pub commit_index: LogIndex,
69    pub last_applied: LogIndex,
70}
71
72impl Default for RaftState {
73    fn default() -> Self {
74        Self {
75            current_term: 0,
76            voted_for: None,
77            commit_index: 0,
78            last_applied: 0,
79        }
80    }
81}
82
83// =============================================================================
84// Vote Request/Response
85// =============================================================================
86
87/// Request for a vote during leader election.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct VoteRequest {
90    pub term: Term,
91    pub candidate_id: NodeId,
92    pub last_log_index: LogIndex,
93    pub last_log_term: Term,
94}
95
96/// Response to a vote request.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct VoteResponse {
99    pub term: Term,
100    pub vote_granted: bool,
101    pub voter_id: NodeId,
102}
103
104// =============================================================================
105// Append Entries Request/Response
106// =============================================================================
107
108/// Request to append entries to the log.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AppendEntriesRequest {
111    pub term: Term,
112    pub leader_id: NodeId,
113    pub prev_log_index: LogIndex,
114    pub prev_log_term: Term,
115    pub entries: Vec<LogEntry>,
116    pub leader_commit: LogIndex,
117}
118
119/// Response to an append entries request.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct AppendEntriesResponse {
122    pub term: Term,
123    pub success: bool,
124    pub match_index: LogIndex,
125    pub conflict_index: Option<LogIndex>,
126    pub conflict_term: Option<Term>,
127}
128
129// =============================================================================
130// Install Snapshot Request/Response
131// =============================================================================
132
133/// Request to install a snapshot.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct InstallSnapshotRequest {
136    pub term: Term,
137    pub leader_id: NodeId,
138    pub last_included_index: LogIndex,
139    pub last_included_term: Term,
140    pub offset: u64,
141    pub data: Vec<u8>,
142    pub done: bool,
143}
144
145/// Response to an install snapshot request.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct InstallSnapshotResponse {
148    pub term: Term,
149}
150
151// =============================================================================
152// Raft Node
153// =============================================================================
154
155/// A node in the Raft cluster.
156pub struct RaftNode {
157    id: NodeId,
158    config: RaftConfig,
159    state: RwLock<RaftState>,
160    role: RwLock<NodeRole>,
161    log: Arc<ReplicatedLog>,
162    state_machine: Arc<StateMachine>,
163    peers: RwLock<HashSet<NodeId>>,
164    leader_id: RwLock<Option<NodeId>>,
165    next_index: RwLock<HashMap<NodeId, LogIndex>>,
166    match_index: RwLock<HashMap<NodeId, LogIndex>>,
167    last_heartbeat: RwLock<Instant>,
168    votes_received: RwLock<HashSet<NodeId>>,
169}
170
171impl RaftNode {
172    /// Create a new Raft node.
173    pub fn new(id: impl Into<NodeId>, config: RaftConfig) -> Self {
174        Self {
175            id: id.into(),
176            config,
177            state: RwLock::new(RaftState::default()),
178            role: RwLock::new(NodeRole::Follower),
179            log: Arc::new(ReplicatedLog::new()),
180            state_machine: Arc::new(StateMachine::new()),
181            peers: RwLock::new(HashSet::new()),
182            leader_id: RwLock::new(None),
183            next_index: RwLock::new(HashMap::new()),
184            match_index: RwLock::new(HashMap::new()),
185            last_heartbeat: RwLock::new(Instant::now()),
186            votes_received: RwLock::new(HashSet::new()),
187        }
188    }
189
190    /// Get the node ID.
191    pub fn id(&self) -> NodeId {
192        self.id.clone()
193    }
194
195    /// Get the current role.
196    pub fn role(&self) -> NodeRole {
197        *self.role.read().unwrap()
198    }
199
200    /// Get the current term.
201    pub fn current_term(&self) -> Term {
202        self.state.read().unwrap().current_term
203    }
204
205    /// Get the current leader ID.
206    pub fn leader_id(&self) -> Option<NodeId> {
207        self.leader_id.read().unwrap().clone()
208    }
209
210    /// Check if this node is the leader.
211    pub fn is_leader(&self) -> bool {
212        self.role() == NodeRole::Leader
213    }
214
215    /// Add a peer to the cluster.
216    pub fn add_peer(&self, peer_id: NodeId) {
217        let mut peers = self.peers.write().unwrap();
218        peers.insert(peer_id.clone());
219
220        let last_log = self.log.last_index();
221        self.next_index.write().unwrap().insert(peer_id.clone(), last_log + 1);
222        self.match_index.write().unwrap().insert(peer_id, 0);
223    }
224
225    /// Remove a peer from the cluster.
226    pub fn remove_peer(&self, peer_id: &NodeId) {
227        let mut peers = self.peers.write().unwrap();
228        peers.remove(peer_id);
229        self.next_index.write().unwrap().remove(peer_id);
230        self.match_index.write().unwrap().remove(peer_id);
231    }
232
233    /// Get the list of peers.
234    pub fn peers(&self) -> Vec<NodeId> {
235        self.peers.read().unwrap().iter().cloned().collect()
236    }
237
238    /// Get the cluster size (including self).
239    pub fn cluster_size(&self) -> usize {
240        self.peers.read().unwrap().len() + 1
241    }
242
243    /// Get the quorum size.
244    pub fn quorum_size(&self) -> usize {
245        (self.cluster_size() / 2) + 1
246    }
247
248    /// Reset the heartbeat timer.
249    pub fn reset_heartbeat(&self) {
250        *self.last_heartbeat.write().unwrap() = Instant::now();
251    }
252
253    /// Check if the election timeout has elapsed.
254    pub fn election_timeout_elapsed(&self) -> bool {
255        let elapsed = self.last_heartbeat.read().unwrap().elapsed();
256        elapsed >= self.config.election_timeout_min
257    }
258
259    // =========================================================================
260    // Leader Election
261    // =========================================================================
262
263    /// Start an election as a candidate.
264    pub fn start_election(&self) -> VoteRequest {
265        let mut state = self.state.write().unwrap();
266        state.current_term += 1;
267        state.voted_for = Some(self.id.clone());
268
269        *self.role.write().unwrap() = NodeRole::Candidate;
270        *self.leader_id.write().unwrap() = None;
271
272        let mut votes = self.votes_received.write().unwrap();
273        votes.clear();
274        votes.insert(self.id.clone());
275
276        self.reset_heartbeat();
277
278        VoteRequest {
279            term: state.current_term,
280            candidate_id: self.id.clone(),
281            last_log_index: self.log.last_index(),
282            last_log_term: self.log.last_term(),
283        }
284    }
285
286    /// Handle a vote request.
287    pub fn handle_vote_request(&self, request: &VoteRequest) -> VoteResponse {
288        let mut state = self.state.write().unwrap();
289
290        if request.term > state.current_term {
291            state.current_term = request.term;
292            state.voted_for = None;
293            *self.role.write().unwrap() = NodeRole::Follower;
294            *self.leader_id.write().unwrap() = None;
295        }
296
297        let vote_granted = request.term >= state.current_term
298            && (state.voted_for.is_none() || state.voted_for.as_ref() == Some(&request.candidate_id))
299            && self.log.is_up_to_date(request.last_log_index, request.last_log_term);
300
301        if vote_granted {
302            state.voted_for = Some(request.candidate_id.clone());
303            self.reset_heartbeat();
304        }
305
306        VoteResponse {
307            term: state.current_term,
308            vote_granted,
309            voter_id: self.id.clone(),
310        }
311    }
312
313    /// Handle a vote response.
314    pub fn handle_vote_response(&self, response: &VoteResponse) -> bool {
315        let current_term = {
316            let mut state = self.state.write().unwrap();
317
318            if response.term > state.current_term {
319                state.current_term = response.term;
320                state.voted_for = None;
321                *self.role.write().unwrap() = NodeRole::Follower;
322                *self.leader_id.write().unwrap() = None;
323                return false;
324            }
325
326            if self.role() != NodeRole::Candidate || response.term != state.current_term {
327                return false;
328            }
329
330            state.current_term
331        };
332
333        if response.vote_granted {
334            self.votes_received.write().unwrap().insert(response.voter_id.clone());
335        }
336
337        let votes = self.votes_received.read().unwrap().len();
338        if votes >= self.quorum_size() {
339            self.become_leader_with_term(current_term);
340            return true;
341        }
342
343        false
344    }
345
346    /// Become the leader.
347    #[allow(dead_code)]
348    fn become_leader(&self) {
349        let term = self.current_term();
350        self.become_leader_with_term(term);
351    }
352
353    /// Become the leader with a specific term (avoids deadlock when called with state lock held).
354    fn become_leader_with_term(&self, term: Term) {
355        *self.role.write().unwrap() = NodeRole::Leader;
356        *self.leader_id.write().unwrap() = Some(self.id.clone());
357
358        let last_log = self.log.last_index();
359        let peers: Vec<_> = self.peers.read().unwrap().iter().cloned().collect();
360
361        let mut next_index = self.next_index.write().unwrap();
362        let mut match_index = self.match_index.write().unwrap();
363
364        for peer in peers {
365            next_index.insert(peer.clone(), last_log + 1);
366            match_index.insert(peer, 0);
367        }
368
369        drop(next_index);
370        drop(match_index);
371
372        let noop = LogEntry::noop(last_log + 1, term);
373        self.log.append(noop);
374    }
375
376    // =========================================================================
377    // Log Replication
378    // =========================================================================
379
380    /// Propose a command (leader only).
381    pub fn propose(&self, command: Command) -> Result<LogIndex, String> {
382        if !self.is_leader() {
383            return Err("Not the leader".to_string());
384        }
385
386        let term = self.current_term();
387        let index = self.log.last_index() + 1;
388        let entry = LogEntry::command(index, term, command.to_bytes());
389
390        self.log.append(entry);
391        Ok(index)
392    }
393
394    /// Create an append entries request for a peer.
395    pub fn create_append_entries(&self, peer_id: &NodeId) -> Option<AppendEntriesRequest> {
396        if !self.is_leader() {
397            return None;
398        }
399
400        let next_index = *self.next_index.read().unwrap().get(peer_id)?;
401        let prev_log_index = next_index.saturating_sub(1);
402        let prev_log_term = self.log.term_at(prev_log_index).unwrap_or(0);
403
404        let entries = self.log.get_range(next_index, self.log.last_index() + 1);
405        let entries: Vec<_> = entries
406            .into_iter()
407            .take(self.config.max_entries_per_request)
408            .collect();
409
410        let state = self.state.read().unwrap();
411
412        Some(AppendEntriesRequest {
413            term: state.current_term,
414            leader_id: self.id.clone(),
415            prev_log_index,
416            prev_log_term,
417            entries,
418            leader_commit: state.commit_index,
419        })
420    }
421
422    /// Handle an append entries request.
423    pub fn handle_append_entries(&self, request: &AppendEntriesRequest) -> AppendEntriesResponse {
424        let mut state = self.state.write().unwrap();
425
426        if request.term < state.current_term {
427            return AppendEntriesResponse {
428                term: state.current_term,
429                success: false,
430                match_index: 0,
431                conflict_index: None,
432                conflict_term: None,
433            };
434        }
435
436        if request.term > state.current_term {
437            state.current_term = request.term;
438            state.voted_for = None;
439        }
440
441        *self.role.write().unwrap() = NodeRole::Follower;
442        *self.leader_id.write().unwrap() = Some(request.leader_id.clone());
443        self.reset_heartbeat();
444
445        if request.prev_log_index > 0 {
446            match self.log.term_at(request.prev_log_index) {
447                None => {
448                    return AppendEntriesResponse {
449                        term: state.current_term,
450                        success: false,
451                        match_index: self.log.last_index(),
452                        conflict_index: Some(self.log.last_index() + 1),
453                        conflict_term: None,
454                    };
455                }
456                Some(term) if term != request.prev_log_term => {
457                    let conflict_index = self.find_first_index_of_term(term);
458                    return AppendEntriesResponse {
459                        term: state.current_term,
460                        success: false,
461                        match_index: 0,
462                        conflict_index: Some(conflict_index),
463                        conflict_term: Some(term),
464                    };
465                }
466                _ => {}
467            }
468        }
469
470        if !request.entries.is_empty() {
471            if let Some(conflict) = self.log.find_conflict(&request.entries) {
472                self.log.truncate_from(conflict);
473            }
474
475            let existing_last = self.log.last_index();
476            let new_entries: Vec<_> = request
477                .entries
478                .iter()
479                .filter(|e| e.index > existing_last)
480                .cloned()
481                .collect();
482
483            if !new_entries.is_empty() {
484                self.log.append_entries(new_entries);
485            }
486        }
487
488        if request.leader_commit > state.commit_index {
489            let last_new_index = if request.entries.is_empty() {
490                request.prev_log_index
491            } else {
492                request.entries.last().unwrap().index
493            };
494            state.commit_index = std::cmp::min(request.leader_commit, last_new_index);
495            self.log.set_commit_index(state.commit_index);
496        }
497
498        AppendEntriesResponse {
499            term: state.current_term,
500            success: true,
501            match_index: self.log.last_index(),
502            conflict_index: None,
503            conflict_term: None,
504        }
505    }
506
507    /// Handle an append entries response.
508    pub fn handle_append_entries_response(
509        &self,
510        peer_id: &NodeId,
511        response: &AppendEntriesResponse,
512    ) {
513        let mut state = self.state.write().unwrap();
514
515        if response.term > state.current_term {
516            state.current_term = response.term;
517            state.voted_for = None;
518            *self.role.write().unwrap() = NodeRole::Follower;
519            *self.leader_id.write().unwrap() = None;
520            return;
521        }
522
523        if !self.is_leader() {
524            return;
525        }
526
527        let mut next_index = self.next_index.write().unwrap();
528        let mut match_index = self.match_index.write().unwrap();
529
530        if response.success {
531            match_index.insert(peer_id.clone(), response.match_index);
532            next_index.insert(peer_id.clone(), response.match_index + 1);
533            drop(next_index);
534            drop(match_index);
535            drop(state);
536            self.try_advance_commit_index();
537        } else {
538            if let Some(conflict_index) = response.conflict_index {
539                next_index.insert(peer_id.clone(), conflict_index);
540            } else {
541                let current = *next_index.get(peer_id).unwrap_or(&1);
542                next_index.insert(peer_id.clone(), current.saturating_sub(1).max(1));
543            }
544        }
545    }
546
547    /// Try to advance the commit index based on match indices.
548    fn try_advance_commit_index(&self) {
549        let match_indices: Vec<_> = {
550            let match_index = self.match_index.read().unwrap();
551            let mut indices: Vec<_> = match_index.values().copied().collect();
552            indices.push(self.log.last_index());
553            indices.sort_unstable();
554            indices
555        };
556
557        let quorum_index = match_indices.len() / 2;
558        let new_commit = match_indices[quorum_index];
559
560        let mut state = self.state.write().unwrap();
561        if new_commit > state.commit_index {
562            if let Some(term) = self.log.term_at(new_commit) {
563                if term == state.current_term {
564                    state.commit_index = new_commit;
565                    self.log.set_commit_index(new_commit);
566                }
567            }
568        }
569    }
570
571    fn find_first_index_of_term(&self, term: Term) -> LogIndex {
572        let mut index = self.log.last_index();
573        while index > 0 {
574            if let Some(t) = self.log.term_at(index) {
575                if t != term {
576                    return index + 1;
577                }
578            }
579            index -= 1;
580        }
581        1
582    }
583
584    // =========================================================================
585    // State Machine Application
586    // =========================================================================
587
588    /// Apply committed entries to the state machine.
589    pub fn apply_committed(&self) -> Vec<CommandResult> {
590        let mut results = Vec::new();
591
592        while self.log.has_entries_to_apply() {
593            if let Some(entry) = self.log.next_to_apply() {
594                if let Some(command) = Command::from_bytes(&entry.data) {
595                    let result = self.state_machine.apply(&command, entry.index);
596                    results.push(result);
597                }
598                self.log.set_last_applied(entry.index);
599
600                let mut state = self.state.write().unwrap();
601                state.last_applied = entry.index;
602            }
603        }
604
605        results
606    }
607
608    /// Get a value from the state machine.
609    pub fn get(&self, key: &str) -> Option<Vec<u8>> {
610        self.state_machine.get(key)
611    }
612
613    /// Get the log.
614    pub fn log(&self) -> &ReplicatedLog {
615        &self.log
616    }
617
618    /// Get the state machine.
619    pub fn state_machine(&self) -> &StateMachine {
620        &self.state_machine
621    }
622}
623
624// =============================================================================
625// Tests
626// =============================================================================
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[test]
633    fn test_raft_config() {
634        let config = RaftConfig::default();
635        assert_eq!(config.election_timeout_min, Duration::from_millis(150));
636        assert_eq!(config.heartbeat_interval, Duration::from_millis(50));
637    }
638
639    #[test]
640    fn test_raft_node_creation() {
641        let node = RaftNode::new("node1", RaftConfig::default());
642        assert_eq!(node.id().as_str(), "node1");
643        assert_eq!(node.role(), NodeRole::Follower);
644        assert_eq!(node.current_term(), 0);
645        assert!(!node.is_leader());
646    }
647
648    #[test]
649    fn test_add_peer() {
650        let node = RaftNode::new("node1", RaftConfig::default());
651        node.add_peer(NodeId::new("node2"));
652        node.add_peer(NodeId::new("node3"));
653
654        assert_eq!(node.cluster_size(), 3);
655        assert_eq!(node.quorum_size(), 2);
656        assert_eq!(node.peers().len(), 2);
657    }
658
659    #[test]
660    fn test_start_election() {
661        let node = RaftNode::new("node1", RaftConfig::default());
662        let request = node.start_election();
663
664        assert_eq!(request.term, 1);
665        assert_eq!(request.candidate_id.as_str(), "node1");
666        assert_eq!(node.role(), NodeRole::Candidate);
667        assert_eq!(node.current_term(), 1);
668    }
669
670    #[test]
671    fn test_vote_request_handling() {
672        let node1 = RaftNode::new("node1", RaftConfig::default());
673        let node2 = RaftNode::new("node2", RaftConfig::default());
674
675        let request = node1.start_election();
676        let response = node2.handle_vote_request(&request);
677
678        assert!(response.vote_granted);
679        assert_eq!(response.term, 1);
680    }
681
682    #[test]
683    fn test_become_leader() {
684        let node = RaftNode::new("node1", RaftConfig::default());
685        node.add_peer(NodeId::new("node2"));
686
687        let request = node.start_election();
688        let response = VoteResponse {
689            term: request.term,
690            vote_granted: true,
691            voter_id: NodeId::new("node2"),
692        };
693
694        let became_leader = node.handle_vote_response(&response);
695        assert!(became_leader);
696        assert!(node.is_leader());
697        assert_eq!(node.leader_id(), Some(NodeId::new("node1")));
698    }
699
700    #[test]
701    fn test_propose_command() {
702        let node = RaftNode::new("node1", RaftConfig::default());
703        node.add_peer(NodeId::new("node2"));
704
705        node.start_election();
706        let response = VoteResponse {
707            term: 1,
708            vote_granted: true,
709            voter_id: NodeId::new("node2"),
710        };
711        node.handle_vote_response(&response);
712
713        let command = Command::set("key1", b"value1".to_vec());
714        let result = node.propose(command);
715        assert!(result.is_ok());
716    }
717
718    #[test]
719    fn test_append_entries() {
720        let leader = RaftNode::new("leader", RaftConfig::default());
721        let follower = RaftNode::new("follower", RaftConfig::default());
722
723        leader.add_peer(NodeId::new("follower"));
724        leader.start_election();
725        let vote = VoteResponse {
726            term: 1,
727            vote_granted: true,
728            voter_id: NodeId::new("follower"),
729        };
730        leader.handle_vote_response(&vote);
731
732        let command = Command::set("key", b"value".to_vec());
733        leader.propose(command).unwrap();
734
735        let request = leader.create_append_entries(&NodeId::new("follower")).unwrap();
736        let response = follower.handle_append_entries(&request);
737
738        assert!(response.success);
739        assert_eq!(follower.log().last_index(), leader.log().last_index());
740    }
741
742    #[test]
743    fn test_follower_rejects_old_term() {
744        let follower = RaftNode::new("follower", RaftConfig::default());
745
746        {
747            let mut state = follower.state.write().unwrap();
748            state.current_term = 5;
749        }
750
751        let request = AppendEntriesRequest {
752            term: 3,
753            leader_id: NodeId::new("old_leader"),
754            prev_log_index: 0,
755            prev_log_term: 0,
756            entries: vec![],
757            leader_commit: 0,
758        };
759
760        let response = follower.handle_append_entries(&request);
761        assert!(!response.success);
762        assert_eq!(response.term, 5);
763    }
764}