1use crate::log::{LogEntry, LogIndex, ReplicatedLog, Term};
9use crate::node::{NodeId, NodeRole};
10use crate::state::{Command, CommandResult, Snapshot, StateMachine, StateMachineBackend};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16const SNAPSHOT_CHUNK_SIZE: usize = 64 * 1024;
18
19#[derive(Debug, Clone)]
25pub struct RaftConfig {
26 pub election_timeout_min: Duration,
27 pub election_timeout_max: Duration,
28 pub heartbeat_interval: Duration,
29 pub max_entries_per_request: usize,
30 pub snapshot_threshold: u64,
31 pub lease_duration: Duration,
34}
35
36impl Default for RaftConfig {
37 fn default() -> Self {
38 Self {
39 election_timeout_min: Duration::from_millis(150),
40 election_timeout_max: Duration::from_millis(300),
41 heartbeat_interval: Duration::from_millis(50),
42 max_entries_per_request: 100,
43 snapshot_threshold: 10000,
44 lease_duration: Duration::from_millis(200),
45 }
46 }
47}
48
49impl RaftConfig {
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn with_election_timeout(mut self, min: Duration, max: Duration) -> Self {
55 self.election_timeout_min = min;
56 self.election_timeout_max = max;
57 self
58 }
59
60 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
61 self.heartbeat_interval = interval;
62 self
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72pub struct RaftState {
73 pub current_term: Term,
74 pub voted_for: Option<NodeId>,
75 pub commit_index: LogIndex,
76 pub last_applied: LogIndex,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct VoteRequest {
86 pub term: Term,
87 pub candidate_id: NodeId,
88 pub last_log_index: LogIndex,
89 pub last_log_term: Term,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct VoteResponse {
95 pub term: Term,
96 pub vote_granted: bool,
97 pub voter_id: NodeId,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct AppendEntriesRequest {
107 pub term: Term,
108 pub leader_id: NodeId,
109 pub prev_log_index: LogIndex,
110 pub prev_log_term: Term,
111 pub entries: Vec<LogEntry>,
112 pub leader_commit: LogIndex,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct AppendEntriesResponse {
118 pub term: Term,
119 pub success: bool,
120 pub match_index: LogIndex,
121 pub conflict_index: Option<LogIndex>,
122 pub conflict_term: Option<Term>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct InstallSnapshotRequest {
132 pub term: Term,
133 pub leader_id: NodeId,
134 pub last_included_index: LogIndex,
135 pub last_included_term: Term,
136 pub offset: u64,
137 pub data: Vec<u8>,
138 pub done: bool,
139 #[serde(default)]
142 pub checksum: Option<u32>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct InstallSnapshotResponse {
148 pub term: Term,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct SnapshotMetadata {
158 pub last_included_index: LogIndex,
159 pub last_included_term: Term,
160 pub size: u64,
161}
162
163struct PendingSnapshot {
165 metadata: SnapshotMetadata,
166 data: Vec<u8>,
167 offset: u64,
168}
169
170pub struct RaftNode {
173 id: NodeId,
174 config: RaftConfig,
175 state: RwLock<RaftState>,
176 role: RwLock<NodeRole>,
177 log: Arc<ReplicatedLog>,
178 state_machine: Arc<dyn StateMachineBackend>,
180 peers: RwLock<HashSet<NodeId>>,
181 leader_id: RwLock<Option<NodeId>>,
182 next_index: RwLock<HashMap<NodeId, LogIndex>>,
183 match_index: RwLock<HashMap<NodeId, LogIndex>>,
184 last_heartbeat: RwLock<Instant>,
185 votes_received: RwLock<HashSet<NodeId>>,
186 snapshot_metadata: RwLock<Option<SnapshotMetadata>>,
188 pending_snapshot: RwLock<Option<PendingSnapshot>>,
190 lease_expiry: RwLock<Option<Instant>>,
193}
194
195impl RaftNode {
196 pub fn new(id: impl Into<NodeId>, config: RaftConfig) -> Self {
198 Self::with_state_machine(id, config, Arc::new(StateMachine::new()))
199 }
200
201 pub fn with_state_machine(
204 id: impl Into<NodeId>,
205 config: RaftConfig,
206 state_machine: Arc<dyn StateMachineBackend>,
207 ) -> Self {
208 Self {
209 id: id.into(),
210 config,
211 state: RwLock::new(RaftState::default()),
212 role: RwLock::new(NodeRole::Follower),
213 log: Arc::new(ReplicatedLog::new()),
214 state_machine,
215 peers: RwLock::new(HashSet::new()),
216 leader_id: RwLock::new(None),
217 next_index: RwLock::new(HashMap::new()),
218 match_index: RwLock::new(HashMap::new()),
219 last_heartbeat: RwLock::new(Instant::now()),
220 votes_received: RwLock::new(HashSet::new()),
221 snapshot_metadata: RwLock::new(None),
222 pending_snapshot: RwLock::new(None),
223 lease_expiry: RwLock::new(None),
224 }
225 }
226
227 pub fn id(&self) -> NodeId {
229 self.id.clone()
230 }
231
232 pub fn role(&self) -> NodeRole {
234 *self.role.read().expect("raft role lock poisoned")
235 }
236
237 pub fn current_term(&self) -> Term {
239 self.state
240 .read()
241 .expect("raft state lock poisoned")
242 .current_term
243 }
244
245 pub fn leader_id(&self) -> Option<NodeId> {
247 self.leader_id
248 .read()
249 .expect("raft leader_id lock poisoned")
250 .clone()
251 }
252
253 pub fn is_leader(&self) -> bool {
255 self.role() == NodeRole::Leader
256 }
257
258 pub fn is_leader_with_lease(&self) -> bool {
262 self.is_leader() && self.has_valid_lease()
263 }
264
265 pub fn extend_lease(&self) {
268 let expiry = Instant::now() + self.config.lease_duration;
269 *self
270 .lease_expiry
271 .write()
272 .expect("raft lease_expiry lock poisoned") = Some(expiry);
273 }
274
275 pub fn has_valid_lease(&self) -> bool {
277 match *self
278 .lease_expiry
279 .read()
280 .expect("raft lease_expiry lock poisoned")
281 {
282 Some(expiry) => Instant::now() < expiry,
283 None => false,
284 }
285 }
286
287 pub fn check_lease(&self) {
290 if self.is_leader() && !self.has_valid_lease() {
291 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
292 *self
293 .leader_id
294 .write()
295 .expect("raft leader_id lock poisoned") = None;
296 *self
298 .lease_expiry
299 .write()
300 .expect("raft lease_expiry lock poisoned") = None;
301 }
302 }
303
304 pub fn add_peer(&self, peer_id: NodeId) {
306 let mut peers = self.peers.write().expect("raft peers lock poisoned");
307 peers.insert(peer_id.clone());
308
309 let last_log = self.log.last_index();
310 self.next_index
311 .write()
312 .expect("raft next_index lock poisoned")
313 .insert(peer_id.clone(), last_log + 1);
314 self.match_index
315 .write()
316 .expect("raft match_index lock poisoned")
317 .insert(peer_id, 0);
318 }
319
320 pub fn remove_peer(&self, peer_id: &NodeId) {
322 let mut peers = self.peers.write().expect("raft peers lock poisoned");
323 peers.remove(peer_id);
324 self.next_index
325 .write()
326 .expect("raft next_index lock poisoned")
327 .remove(peer_id);
328 self.match_index
329 .write()
330 .expect("raft match_index lock poisoned")
331 .remove(peer_id);
332 }
333
334 pub fn peers(&self) -> Vec<NodeId> {
336 self.peers
337 .read()
338 .expect("raft peers lock poisoned")
339 .iter()
340 .cloned()
341 .collect()
342 }
343
344 pub fn cluster_size(&self) -> usize {
346 self.peers.read().expect("raft peers lock poisoned").len() + 1
347 }
348
349 pub fn quorum_size(&self) -> usize {
351 (self.cluster_size() / 2) + 1
352 }
353
354 pub fn reset_heartbeat(&self) {
356 *self
357 .last_heartbeat
358 .write()
359 .expect("raft last_heartbeat lock poisoned") = Instant::now();
360 }
361
362 pub fn election_timeout_elapsed(&self) -> bool {
364 let elapsed = self
365 .last_heartbeat
366 .read()
367 .expect("raft last_heartbeat lock poisoned")
368 .elapsed();
369 elapsed >= self.config.election_timeout_min
370 }
371
372 pub fn start_election(&self) -> VoteRequest {
378 let mut state = self.state.write().expect("raft state lock poisoned");
379 state.current_term += 1;
380 state.voted_for = Some(self.id.clone());
381
382 *self.role.write().expect("raft role lock poisoned") = NodeRole::Candidate;
383 *self
384 .leader_id
385 .write()
386 .expect("raft leader_id lock poisoned") = None;
387
388 let mut votes = self
389 .votes_received
390 .write()
391 .expect("raft votes_received lock poisoned");
392 votes.clear();
393 votes.insert(self.id.clone());
394
395 self.reset_heartbeat();
396
397 VoteRequest {
398 term: state.current_term,
399 candidate_id: self.id.clone(),
400 last_log_index: self.log.last_index(),
401 last_log_term: self.log.last_term(),
402 }
403 }
404
405 pub fn handle_vote_request(&self, request: &VoteRequest) -> VoteResponse {
407 let mut state = self.state.write().expect("raft state lock poisoned");
408
409 if request.term > state.current_term {
410 state.current_term = request.term;
411 state.voted_for = None;
412 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
413 *self
414 .leader_id
415 .write()
416 .expect("raft leader_id lock poisoned") = None;
417 }
418
419 let vote_granted = request.term >= state.current_term
420 && (state.voted_for.is_none()
421 || state.voted_for.as_ref() == Some(&request.candidate_id))
422 && self
423 .log
424 .is_up_to_date(request.last_log_index, request.last_log_term);
425
426 if vote_granted {
427 state.voted_for = Some(request.candidate_id.clone());
428 self.reset_heartbeat();
429 }
430
431 VoteResponse {
432 term: state.current_term,
433 vote_granted,
434 voter_id: self.id.clone(),
435 }
436 }
437
438 pub fn handle_vote_response(&self, response: &VoteResponse) -> bool {
440 let current_term = {
441 let mut state = self.state.write().expect("raft state lock poisoned");
442
443 if response.term > state.current_term {
444 state.current_term = response.term;
445 state.voted_for = None;
446 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
447 *self
448 .leader_id
449 .write()
450 .expect("raft leader_id lock poisoned") = None;
451 return false;
452 }
453
454 if self.role() != NodeRole::Candidate || response.term != state.current_term {
455 return false;
456 }
457
458 state.current_term
459 };
460
461 if response.vote_granted {
462 self.votes_received
463 .write()
464 .expect("raft votes_received lock poisoned")
465 .insert(response.voter_id.clone());
466 }
467
468 let votes = self
469 .votes_received
470 .read()
471 .expect("raft votes_received lock poisoned")
472 .len();
473 if votes >= self.quorum_size() {
474 self.become_leader_with_term(current_term);
475 return true;
476 }
477
478 false
479 }
480
481 #[allow(dead_code)]
483 fn become_leader(&self) {
484 let term = self.current_term();
485 self.become_leader_with_term(term);
486 }
487
488 fn become_leader_with_term(&self, term: Term) {
490 *self.role.write().expect("raft role lock poisoned") = NodeRole::Leader;
491 *self
492 .leader_id
493 .write()
494 .expect("raft leader_id lock poisoned") = Some(self.id.clone());
495
496 let last_log = self.log.last_index();
497 let peers: Vec<_> = self
498 .peers
499 .read()
500 .expect("raft peers lock poisoned")
501 .iter()
502 .cloned()
503 .collect();
504
505 let mut next_index = self
506 .next_index
507 .write()
508 .expect("raft next_index lock poisoned");
509 let mut match_index = self
510 .match_index
511 .write()
512 .expect("raft match_index lock poisoned");
513
514 for peer in peers {
515 next_index.insert(peer.clone(), last_log + 1);
516 match_index.insert(peer, 0);
517 }
518
519 drop(next_index);
520 drop(match_index);
521
522 let noop = LogEntry::noop(last_log + 1, term);
523 self.log.append(noop);
524 }
525
526 pub fn propose(&self, command: Command) -> Result<LogIndex, String> {
532 if !self.is_leader() {
533 return Err("Not the leader".to_string());
534 }
535
536 let term = self.current_term();
537 let index = self.log.last_index() + 1;
538 let entry = LogEntry::command(index, term, command.to_bytes());
539
540 self.log.append(entry);
541 Ok(index)
542 }
543
544 pub fn create_append_entries(&self, peer_id: &NodeId) -> Option<AppendEntriesRequest> {
546 if !self.is_leader() {
547 return None;
548 }
549
550 let next_index = *self
551 .next_index
552 .read()
553 .expect("raft next_index lock poisoned")
554 .get(peer_id)?;
555 let prev_log_index = next_index.saturating_sub(1);
556 let prev_log_term = self.log.term_at(prev_log_index).unwrap_or(0);
557
558 let entries = self.log.get_range(next_index, self.log.last_index() + 1);
559 let entries: Vec<_> = entries
560 .into_iter()
561 .take(self.config.max_entries_per_request)
562 .collect();
563
564 let state = self.state.read().expect("raft state lock poisoned");
565
566 Some(AppendEntriesRequest {
567 term: state.current_term,
568 leader_id: self.id.clone(),
569 prev_log_index,
570 prev_log_term,
571 entries,
572 leader_commit: state.commit_index,
573 })
574 }
575
576 pub fn handle_append_entries(&self, request: &AppendEntriesRequest) -> AppendEntriesResponse {
578 let mut state = self.state.write().expect("raft state lock poisoned");
579
580 if request.term < state.current_term {
581 return AppendEntriesResponse {
582 term: state.current_term,
583 success: false,
584 match_index: 0,
585 conflict_index: None,
586 conflict_term: None,
587 };
588 }
589
590 if request.term > state.current_term {
591 state.current_term = request.term;
592 state.voted_for = None;
593 }
594
595 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
596 *self
597 .leader_id
598 .write()
599 .expect("raft leader_id lock poisoned") = Some(request.leader_id.clone());
600 self.reset_heartbeat();
601
602 if request.prev_log_index > 0 {
603 match self.log.term_at(request.prev_log_index) {
604 None => {
605 return AppendEntriesResponse {
606 term: state.current_term,
607 success: false,
608 match_index: self.log.last_index(),
609 conflict_index: Some(self.log.last_index() + 1),
610 conflict_term: None,
611 };
612 }
613 Some(term) if term != request.prev_log_term => {
614 let conflict_index = self.find_first_index_of_term(term);
615 return AppendEntriesResponse {
616 term: state.current_term,
617 success: false,
618 match_index: 0,
619 conflict_index: Some(conflict_index),
620 conflict_term: Some(term),
621 };
622 }
623 _ => {}
624 }
625 }
626
627 if !request.entries.is_empty() {
628 if let Some(conflict) = self.log.find_conflict(&request.entries) {
629 self.log.truncate_from(conflict);
630 }
631
632 let existing_last = self.log.last_index();
633 let new_entries: Vec<_> = request
634 .entries
635 .iter()
636 .filter(|e| e.index > existing_last)
637 .cloned()
638 .collect();
639
640 if !new_entries.is_empty() {
641 self.log.append_entries(new_entries);
642 }
643 }
644
645 if request.leader_commit > state.commit_index {
646 let last_new_index = if request.entries.is_empty() {
647 request.prev_log_index
648 } else {
649 request
650 .entries
651 .last()
652 .expect("entries confirmed non-empty")
653 .index
654 };
655 state.commit_index = std::cmp::min(request.leader_commit, last_new_index);
656 self.log.set_commit_index(state.commit_index);
657 }
658
659 AppendEntriesResponse {
660 term: state.current_term,
661 success: true,
662 match_index: self.log.last_index(),
663 conflict_index: None,
664 conflict_term: None,
665 }
666 }
667
668 pub fn handle_append_entries_response(
670 &self,
671 peer_id: &NodeId,
672 response: &AppendEntriesResponse,
673 ) {
674 let mut state = self.state.write().expect("raft state lock poisoned");
675
676 if response.term > state.current_term {
677 state.current_term = response.term;
678 state.voted_for = None;
679 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
680 *self
681 .leader_id
682 .write()
683 .expect("raft leader_id lock poisoned") = None;
684 return;
685 }
686
687 if !self.is_leader() {
688 return;
689 }
690
691 let mut next_index = self
692 .next_index
693 .write()
694 .expect("raft next_index lock poisoned");
695 let mut match_index = self
696 .match_index
697 .write()
698 .expect("raft match_index lock poisoned");
699
700 if response.success {
701 match_index.insert(peer_id.clone(), response.match_index);
702 next_index.insert(peer_id.clone(), response.match_index + 1);
703
704 let ack_count = match_index.values().filter(|&&idx| idx > 0).count() + 1; let quorum = self.quorum_size();
708
709 drop(next_index);
710 drop(match_index);
711 drop(state);
712
713 if ack_count >= quorum {
714 self.extend_lease();
715 }
716
717 self.try_advance_commit_index();
718 } else if let Some(conflict_index) = response.conflict_index {
719 next_index.insert(peer_id.clone(), conflict_index);
720 } else {
721 let current = *next_index.get(peer_id).unwrap_or(&1);
722 next_index.insert(peer_id.clone(), current.saturating_sub(1).max(1));
723 }
724 }
725
726 fn try_advance_commit_index(&self) {
728 let match_indices: Vec<_> = {
729 let match_index = self
730 .match_index
731 .read()
732 .expect("raft match_index lock poisoned");
733 let mut indices: Vec<_> = match_index.values().copied().collect();
734 indices.push(self.log.last_index());
735 indices.sort_unstable();
736 indices
737 };
738
739 let quorum_index = match_indices.len() / 2;
740 let new_commit = match_indices[quorum_index];
741
742 let mut state = self.state.write().expect("raft state lock poisoned");
743 if new_commit > state.commit_index {
744 if let Some(term) = self.log.term_at(new_commit) {
745 if term == state.current_term {
746 state.commit_index = new_commit;
747 self.log.set_commit_index(new_commit);
748 }
749 }
750 }
751 }
752
753 fn find_first_index_of_term(&self, term: Term) -> LogIndex {
754 let mut index = self.log.last_index();
755 while index > 0 {
756 if let Some(t) = self.log.term_at(index) {
757 if t != term {
758 return index + 1;
759 }
760 }
761 index -= 1;
762 }
763 1
764 }
765
766 pub fn apply_committed(&self) -> Vec<CommandResult> {
772 let mut results = Vec::new();
773
774 while self.log.has_entries_to_apply() {
775 if let Some(entry) = self.log.next_to_apply() {
776 if let Some(command) = Command::from_bytes(&entry.data) {
777 let result = self.state_machine.apply(&command, entry.index);
778 results.push(result);
779 }
780 self.log.set_last_applied(entry.index);
781
782 let mut state = self.state.write().expect("raft state lock poisoned");
783 state.last_applied = entry.index;
784 }
785 }
786
787 results
788 }
789
790 pub fn get(&self, key: &str) -> Option<Vec<u8>> {
792 self.state_machine.get(key)
793 }
794
795 pub fn log(&self) -> &ReplicatedLog {
797 &self.log
798 }
799
800 pub fn state_machine(&self) -> &dyn StateMachineBackend {
802 self.state_machine.as_ref()
803 }
804
805 pub fn should_snapshot(&self) -> bool {
811 let log_size = self.log.len() as u64;
812 log_size >= self.config.snapshot_threshold
813 }
814
815 pub fn take_snapshot(&self) -> Option<SnapshotMetadata> {
818 let state = self.state.read().expect("raft state lock poisoned");
819 let last_applied = state.last_applied;
820
821 if last_applied == 0 {
822 return None;
823 }
824
825 let last_applied_term = self.log.term_at(last_applied)?;
827
828 let snapshot = self.state_machine.snapshot();
830 let snapshot_bytes = snapshot.to_bytes();
831 let size = snapshot_bytes.len() as u64;
832
833 self.log.compact(last_applied, last_applied_term);
835
836 let metadata = SnapshotMetadata {
837 last_included_index: last_applied,
838 last_included_term: last_applied_term,
839 size,
840 };
841
842 *self
843 .snapshot_metadata
844 .write()
845 .expect("raft snapshot_metadata lock poisoned") = Some(metadata.clone());
846 Some(metadata)
847 }
848
849 pub fn get_snapshot_data(&self) -> Option<(SnapshotMetadata, Vec<u8>)> {
851 let metadata = self
852 .snapshot_metadata
853 .read()
854 .expect("raft snapshot_metadata lock poisoned")
855 .clone()?;
856 let snapshot = self.state_machine.snapshot();
857 let data = snapshot.to_bytes();
858 Some((metadata, data))
859 }
860
861 pub fn create_install_snapshot(
864 &self,
865 peer_id: &NodeId,
866 offset: u64,
867 ) -> Option<InstallSnapshotRequest> {
868 if !self.is_leader() {
869 return None;
870 }
871
872 let next_index = *self
873 .next_index
874 .read()
875 .expect("raft next_index lock poisoned")
876 .get(peer_id)?;
877 let (metadata, data) = self.get_snapshot_data()?;
878
879 if next_index > metadata.last_included_index {
881 return None;
882 }
883
884 let term = self.current_term();
885
886 let start = offset as usize;
888 let end = std::cmp::min(start + SNAPSHOT_CHUNK_SIZE, data.len());
889 let chunk = data[start..end].to_vec();
890 let done = end >= data.len();
891
892 let checksum = if done {
894 Some(crc32fast::hash(&data))
895 } else {
896 None
897 };
898
899 Some(InstallSnapshotRequest {
900 term,
901 leader_id: self.id.clone(),
902 last_included_index: metadata.last_included_index,
903 last_included_term: metadata.last_included_term,
904 offset,
905 data: chunk,
906 done,
907 checksum,
908 })
909 }
910
911 pub fn handle_install_snapshot(
913 &self,
914 request: &InstallSnapshotRequest,
915 ) -> InstallSnapshotResponse {
916 let mut state = self.state.write().expect("raft state lock poisoned");
917
918 if request.term < state.current_term {
920 return InstallSnapshotResponse {
921 term: state.current_term,
922 };
923 }
924
925 if request.term > state.current_term {
927 state.current_term = request.term;
928 state.voted_for = None;
929 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
930 }
931
932 *self
934 .leader_id
935 .write()
936 .expect("raft leader_id lock poisoned") = Some(request.leader_id.clone());
937 self.reset_heartbeat();
938
939 let mut pending = self
940 .pending_snapshot
941 .write()
942 .expect("raft pending_snapshot lock poisoned");
943
944 if request.offset == 0 {
946 *pending = Some(PendingSnapshot {
947 metadata: SnapshotMetadata {
948 last_included_index: request.last_included_index,
949 last_included_term: request.last_included_term,
950 size: 0, },
952 data: Vec::new(),
953 offset: 0,
954 });
955 }
956
957 if let Some(ref mut snapshot) = *pending {
959 if request.offset != snapshot.offset {
960 *pending = None;
962 return InstallSnapshotResponse {
963 term: state.current_term,
964 };
965 }
966
967 snapshot.data.extend_from_slice(&request.data);
969 snapshot.offset += request.data.len() as u64;
970
971 if request.done {
973 let snapshot_data = std::mem::take(&mut snapshot.data);
974 let metadata = snapshot.metadata.clone();
975 drop(pending);
976
977 if let Some(expected_checksum) = request.checksum {
979 let actual_checksum = crc32fast::hash(&snapshot_data);
980 if actual_checksum != expected_checksum {
981 *self
983 .pending_snapshot
984 .write()
985 .expect("raft pending_snapshot lock poisoned") = None;
986 return InstallSnapshotResponse {
987 term: state.current_term,
988 };
989 }
990 }
991
992 if let Some(restored_snapshot) = Snapshot::from_bytes(&snapshot_data) {
994 self.state_machine.restore(restored_snapshot);
995
996 self.log
998 .compact(metadata.last_included_index, metadata.last_included_term);
999
1000 state.commit_index =
1002 std::cmp::max(state.commit_index, metadata.last_included_index);
1003 state.last_applied = metadata.last_included_index;
1004 self.log.set_commit_index(state.commit_index);
1005 self.log.set_last_applied(state.last_applied);
1006
1007 *self
1009 .snapshot_metadata
1010 .write()
1011 .expect("raft snapshot_metadata lock poisoned") = Some(SnapshotMetadata {
1012 last_included_index: metadata.last_included_index,
1013 last_included_term: metadata.last_included_term,
1014 size: snapshot_data.len() as u64,
1015 });
1016 }
1017
1018 *self
1020 .pending_snapshot
1021 .write()
1022 .expect("raft pending_snapshot lock poisoned") = None;
1023 }
1024 }
1025
1026 InstallSnapshotResponse {
1027 term: state.current_term,
1028 }
1029 }
1030
1031 pub fn handle_install_snapshot_response(
1033 &self,
1034 peer_id: &NodeId,
1035 response: &InstallSnapshotResponse,
1036 _last_chunk_offset: u64,
1037 was_last_chunk: bool,
1038 ) {
1039 let mut state = self.state.write().expect("raft state lock poisoned");
1040
1041 if response.term > state.current_term {
1043 state.current_term = response.term;
1044 state.voted_for = None;
1045 *self.role.write().expect("raft role lock poisoned") = NodeRole::Follower;
1046 *self
1047 .leader_id
1048 .write()
1049 .expect("raft leader_id lock poisoned") = None;
1050 return;
1051 }
1052
1053 if !self.is_leader() {
1054 return;
1055 }
1056
1057 if was_last_chunk {
1059 if let Some(ref metadata) = *self
1060 .snapshot_metadata
1061 .read()
1062 .expect("raft snapshot_metadata lock poisoned")
1063 {
1064 self.next_index
1065 .write()
1066 .expect("raft next_index lock poisoned")
1067 .insert(peer_id.clone(), metadata.last_included_index + 1);
1068 self.match_index
1069 .write()
1070 .expect("raft match_index lock poisoned")
1071 .insert(peer_id.clone(), metadata.last_included_index);
1072 }
1073 }
1074 }
1075
1076 pub fn peer_needs_snapshot(&self, peer_id: &NodeId) -> bool {
1078 let next_index = match self
1079 .next_index
1080 .read()
1081 .expect("raft next_index lock poisoned")
1082 .get(peer_id)
1083 {
1084 Some(&idx) => idx,
1085 None => return false,
1086 };
1087
1088 if let Some(ref metadata) = *self
1090 .snapshot_metadata
1091 .read()
1092 .expect("raft snapshot_metadata lock poisoned")
1093 {
1094 return next_index <= metadata.last_included_index;
1095 }
1096
1097 false
1098 }
1099
1100 pub fn snapshot_metadata(&self) -> Option<SnapshotMetadata> {
1102 self.snapshot_metadata
1103 .read()
1104 .expect("raft snapshot_metadata lock poisoned")
1105 .clone()
1106 }
1107}
1108
1109#[cfg(test)]
1114mod tests {
1115 use super::*;
1116
1117 #[test]
1118 fn test_raft_config() {
1119 let config = RaftConfig::default();
1120 assert_eq!(config.election_timeout_min, Duration::from_millis(150));
1121 assert_eq!(config.heartbeat_interval, Duration::from_millis(50));
1122 }
1123
1124 #[test]
1125 fn test_raft_node_creation() {
1126 let node = RaftNode::new("node1", RaftConfig::default());
1127 assert_eq!(node.id().as_str(), "node1");
1128 assert_eq!(node.role(), NodeRole::Follower);
1129 assert_eq!(node.current_term(), 0);
1130 assert!(!node.is_leader());
1131 }
1132
1133 #[test]
1134 fn test_add_peer() {
1135 let node = RaftNode::new("node1", RaftConfig::default());
1136 node.add_peer(NodeId::new("node2"));
1137 node.add_peer(NodeId::new("node3"));
1138
1139 assert_eq!(node.cluster_size(), 3);
1140 assert_eq!(node.quorum_size(), 2);
1141 assert_eq!(node.peers().len(), 2);
1142 }
1143
1144 #[test]
1145 fn test_start_election() {
1146 let node = RaftNode::new("node1", RaftConfig::default());
1147 let request = node.start_election();
1148
1149 assert_eq!(request.term, 1);
1150 assert_eq!(request.candidate_id.as_str(), "node1");
1151 assert_eq!(node.role(), NodeRole::Candidate);
1152 assert_eq!(node.current_term(), 1);
1153 }
1154
1155 #[test]
1156 fn test_vote_request_handling() {
1157 let node1 = RaftNode::new("node1", RaftConfig::default());
1158 let node2 = RaftNode::new("node2", RaftConfig::default());
1159
1160 let request = node1.start_election();
1161 let response = node2.handle_vote_request(&request);
1162
1163 assert!(response.vote_granted);
1164 assert_eq!(response.term, 1);
1165 }
1166
1167 #[test]
1168 fn test_become_leader() {
1169 let node = RaftNode::new("node1", RaftConfig::default());
1170 node.add_peer(NodeId::new("node2"));
1171
1172 let request = node.start_election();
1173 let response = VoteResponse {
1174 term: request.term,
1175 vote_granted: true,
1176 voter_id: NodeId::new("node2"),
1177 };
1178
1179 let became_leader = node.handle_vote_response(&response);
1180 assert!(became_leader);
1181 assert!(node.is_leader());
1182 assert_eq!(node.leader_id(), Some(NodeId::new("node1")));
1183 }
1184
1185 #[test]
1186 fn test_propose_command() {
1187 let node = RaftNode::new("node1", RaftConfig::default());
1188 node.add_peer(NodeId::new("node2"));
1189
1190 node.start_election();
1191 let response = VoteResponse {
1192 term: 1,
1193 vote_granted: true,
1194 voter_id: NodeId::new("node2"),
1195 };
1196 node.handle_vote_response(&response);
1197
1198 let command = Command::set("key1", b"value1".to_vec());
1199 let result = node.propose(command);
1200 assert!(result.is_ok());
1201 }
1202
1203 #[test]
1204 fn test_append_entries() {
1205 let leader = RaftNode::new("leader", RaftConfig::default());
1206 let follower = RaftNode::new("follower", RaftConfig::default());
1207
1208 leader.add_peer(NodeId::new("follower"));
1209 leader.start_election();
1210 let vote = VoteResponse {
1211 term: 1,
1212 vote_granted: true,
1213 voter_id: NodeId::new("follower"),
1214 };
1215 leader.handle_vote_response(&vote);
1216
1217 let command = Command::set("key", b"value".to_vec());
1218 leader.propose(command).unwrap();
1219
1220 let request = leader
1221 .create_append_entries(&NodeId::new("follower"))
1222 .unwrap();
1223 let response = follower.handle_append_entries(&request);
1224
1225 assert!(response.success);
1226 assert_eq!(follower.log().last_index(), leader.log().last_index());
1227 }
1228
1229 #[test]
1230 fn test_follower_rejects_old_term() {
1231 let follower = RaftNode::new("follower", RaftConfig::default());
1232
1233 {
1234 let mut state = follower.state.write().unwrap();
1235 state.current_term = 5;
1236 }
1237
1238 let request = AppendEntriesRequest {
1239 term: 3,
1240 leader_id: NodeId::new("old_leader"),
1241 prev_log_index: 0,
1242 prev_log_term: 0,
1243 entries: vec![],
1244 leader_commit: 0,
1245 };
1246
1247 let response = follower.handle_append_entries(&request);
1248 assert!(!response.success);
1249 assert_eq!(response.term, 5);
1250 }
1251
1252 #[test]
1253 fn test_should_snapshot() {
1254 let config = RaftConfig {
1255 snapshot_threshold: 3,
1256 ..RaftConfig::default()
1257 };
1258 let node = RaftNode::new("node1", config);
1259
1260 assert!(!node.should_snapshot());
1262
1263 use crate::log::LogEntry;
1265 node.log.append(LogEntry::command(1, 1, vec![1]));
1266 node.log.append(LogEntry::command(2, 1, vec![2]));
1267 assert!(!node.should_snapshot());
1268
1269 node.log.append(LogEntry::command(3, 1, vec![3]));
1271 assert!(node.should_snapshot());
1272 }
1273
1274 #[test]
1275 fn test_take_snapshot() {
1276 let node = RaftNode::new("node1", RaftConfig::default());
1277
1278 let cmd1 = Command::set("key1", b"value1".to_vec());
1280 let cmd2 = Command::set("key2", b"value2".to_vec());
1281
1282 use crate::log::LogEntry;
1284 node.log.append(LogEntry::command(1, 1, cmd1.to_bytes()));
1285 node.log.append(LogEntry::command(2, 1, cmd2.to_bytes()));
1286 node.log.set_commit_index(2);
1287
1288 node.state_machine.apply(&cmd1, 1);
1290 node.state_machine.apply(&cmd2, 2);
1291
1292 {
1293 let mut state = node.state.write().unwrap();
1294 state.last_applied = 2;
1295 state.commit_index = 2;
1296 }
1297
1298 let metadata = node.take_snapshot();
1300 assert!(metadata.is_some());
1301
1302 let metadata = metadata.unwrap();
1303 assert_eq!(metadata.last_included_index, 2);
1304 assert_eq!(metadata.last_included_term, 1);
1305 assert!(metadata.size > 0);
1306
1307 assert!(node.snapshot_metadata().is_some());
1309 }
1310
1311 #[test]
1312 fn test_install_snapshot_single_chunk() {
1313 let leader = RaftNode::new("leader", RaftConfig::default());
1314 let follower = RaftNode::new("follower", RaftConfig::default());
1315
1316 {
1318 let mut state = leader.state.write().unwrap();
1319 state.current_term = 2;
1320 }
1321 *leader.role.write().unwrap() = NodeRole::Leader;
1322
1323 let cmd1 = Command::set("key1", b"value1".to_vec());
1325 let cmd2 = Command::set("key2", b"value2".to_vec());
1326 leader.state_machine.apply(&cmd1, 1);
1327 leader.state_machine.apply(&cmd2, 2);
1328
1329 let snapshot = leader.state_machine.snapshot();
1331 let snapshot_bytes = snapshot.to_bytes();
1332
1333 let request = InstallSnapshotRequest {
1335 term: 2,
1336 leader_id: NodeId::new("leader"),
1337 last_included_index: 2,
1338 last_included_term: 1,
1339 offset: 0,
1340 data: snapshot_bytes,
1341 done: true,
1342 checksum: None,
1343 };
1344
1345 let response = follower.handle_install_snapshot(&request);
1346 assert_eq!(response.term, 2);
1347
1348 assert_eq!(follower.state_machine.get("key1").unwrap(), b"value1");
1350 assert_eq!(follower.state_machine.get("key2").unwrap(), b"value2");
1351
1352 let state = follower.state.read().unwrap();
1354 assert_eq!(state.last_applied, 2);
1355 assert!(state.commit_index >= 2);
1356 }
1357
1358 #[test]
1359 fn test_install_snapshot_multiple_chunks() {
1360 let follower = RaftNode::new("follower", RaftConfig::default());
1361
1362 let mut data = std::collections::HashMap::new();
1364 data.insert("key1".to_string(), b"value1".to_vec());
1365 data.insert("key2".to_string(), b"value2".to_vec());
1366
1367 let snapshot = crate::state::Snapshot {
1368 data,
1369 last_applied: 5,
1370 version: 2,
1371 };
1372 let snapshot_bytes = snapshot.to_bytes();
1373
1374 let chunk_size = snapshot_bytes.len() / 2;
1376 let chunk1 = &snapshot_bytes[..chunk_size];
1377 let chunk2 = &snapshot_bytes[chunk_size..];
1378
1379 let request1 = InstallSnapshotRequest {
1381 term: 1,
1382 leader_id: NodeId::new("leader"),
1383 last_included_index: 5,
1384 last_included_term: 1,
1385 offset: 0,
1386 data: chunk1.to_vec(),
1387 done: false,
1388 checksum: None,
1389 };
1390 let response1 = follower.handle_install_snapshot(&request1);
1391 assert_eq!(response1.term, 1);
1392
1393 let request2 = InstallSnapshotRequest {
1395 term: 1,
1396 leader_id: NodeId::new("leader"),
1397 last_included_index: 5,
1398 last_included_term: 1,
1399 offset: chunk_size as u64,
1400 data: chunk2.to_vec(),
1401 done: true,
1402 checksum: None,
1403 };
1404 let response2 = follower.handle_install_snapshot(&request2);
1405 assert_eq!(response2.term, 1);
1406
1407 assert_eq!(follower.state_machine.get("key1").unwrap(), b"value1");
1409 assert_eq!(follower.state_machine.get("key2").unwrap(), b"value2");
1410 assert_eq!(follower.state_machine.last_applied(), 5);
1411 }
1412
1413 #[test]
1414 fn test_install_snapshot_rejects_old_term() {
1415 let follower = RaftNode::new("follower", RaftConfig::default());
1416
1417 {
1419 let mut state = follower.state.write().unwrap();
1420 state.current_term = 5;
1421 }
1422
1423 let request = InstallSnapshotRequest {
1424 term: 3,
1425 leader_id: NodeId::new("old_leader"),
1426 last_included_index: 10,
1427 last_included_term: 2,
1428 offset: 0,
1429 data: vec![1, 2, 3],
1430 done: true,
1431 checksum: None,
1432 };
1433
1434 let response = follower.handle_install_snapshot(&request);
1435 assert_eq!(response.term, 5);
1436
1437 assert!(follower.state_machine.is_empty());
1439 }
1440
1441 #[test]
1442 fn test_peer_needs_snapshot() {
1443 let leader = RaftNode::new("leader", RaftConfig::default());
1444 leader.add_peer(NodeId::new("follower"));
1445
1446 assert!(!leader.peer_needs_snapshot(&NodeId::new("follower")));
1448
1449 *leader.snapshot_metadata.write().unwrap() = Some(SnapshotMetadata {
1451 last_included_index: 100,
1452 last_included_term: 5,
1453 size: 1000,
1454 });
1455
1456 assert!(leader.peer_needs_snapshot(&NodeId::new("follower")));
1458
1459 leader
1461 .next_index
1462 .write()
1463 .unwrap()
1464 .insert(NodeId::new("follower"), 101);
1465 assert!(!leader.peer_needs_snapshot(&NodeId::new("follower")));
1466 }
1467
1468 #[test]
1469 fn test_create_install_snapshot() {
1470 let leader = RaftNode::new("leader", RaftConfig::default());
1471 leader.add_peer(NodeId::new("follower"));
1472
1473 {
1475 let mut state = leader.state.write().unwrap();
1476 state.current_term = 2;
1477 }
1478 *leader.role.write().unwrap() = NodeRole::Leader;
1479
1480 let cmd = Command::set("test_key", b"test_value".to_vec());
1482 leader.state_machine.apply(&cmd, 1);
1483
1484 use crate::log::LogEntry;
1486 leader.log.append(LogEntry::command(1, 2, cmd.to_bytes()));
1487 leader.log.set_commit_index(1);
1488 {
1489 let mut state = leader.state.write().unwrap();
1490 state.last_applied = 1;
1491 state.commit_index = 1;
1492 }
1493
1494 leader.take_snapshot();
1496
1497 leader
1499 .next_index
1500 .write()
1501 .unwrap()
1502 .insert(NodeId::new("follower"), 0);
1503
1504 let request = leader.create_install_snapshot(&NodeId::new("follower"), 0);
1506 assert!(request.is_some());
1507
1508 let request = request.unwrap();
1509 assert_eq!(request.term, 2);
1510 assert_eq!(request.leader_id.as_str(), "leader");
1511 assert_eq!(request.last_included_index, 1);
1512 assert_eq!(request.last_included_term, 2);
1513 assert_eq!(request.offset, 0);
1514 assert!(!request.data.is_empty());
1515 }
1516
1517 #[test]
1522 fn test_leader_lease_not_valid_initially() {
1523 let node = RaftNode::new("node1", RaftConfig::default());
1524 assert!(!node.has_valid_lease());
1525 assert!(!node.is_leader_with_lease());
1526 }
1527
1528 #[test]
1529 fn test_extend_and_check_lease() {
1530 let node = RaftNode::new("node1", RaftConfig::default());
1531 node.extend_lease();
1532 assert!(node.has_valid_lease());
1533 }
1534
1535 #[test]
1536 fn test_is_leader_with_lease() {
1537 let node = RaftNode::new("node1", RaftConfig::default());
1538 node.add_peer(NodeId::new("node2"));
1539
1540 node.start_election();
1542 let response = VoteResponse {
1543 term: 1,
1544 vote_granted: true,
1545 voter_id: NodeId::new("node2"),
1546 };
1547 node.handle_vote_response(&response);
1548 assert!(node.is_leader());
1549
1550 assert!(!node.is_leader_with_lease());
1552
1553 node.extend_lease();
1555 assert!(node.is_leader_with_lease());
1556 }
1557
1558 #[test]
1559 fn test_check_lease_steps_down_leader() {
1560 let config = RaftConfig {
1561 lease_duration: Duration::from_millis(1),
1562 ..RaftConfig::default()
1563 };
1564 let node = RaftNode::new("node1", config);
1565 node.add_peer(NodeId::new("node2"));
1566
1567 node.start_election();
1569 let response = VoteResponse {
1570 term: 1,
1571 vote_granted: true,
1572 voter_id: NodeId::new("node2"),
1573 };
1574 node.handle_vote_response(&response);
1575 assert!(node.is_leader());
1576
1577 node.check_lease();
1579 assert!(!node.is_leader());
1580 assert_eq!(node.role(), NodeRole::Follower);
1581 }
1582
1583 #[test]
1584 fn test_lease_extended_on_majority_ack() {
1585 let node = RaftNode::new("leader", RaftConfig::default());
1586 node.add_peer(NodeId::new("follower1"));
1587 node.add_peer(NodeId::new("follower2"));
1588
1589 node.start_election();
1591 node.handle_vote_response(&VoteResponse {
1592 term: 1,
1593 vote_granted: true,
1594 voter_id: NodeId::new("follower1"),
1595 });
1596 assert!(node.is_leader());
1597
1598 let command = Command::set("key", b"value".to_vec());
1600 node.propose(command).unwrap();
1601
1602 let response = AppendEntriesResponse {
1604 term: 1,
1605 success: true,
1606 match_index: 2, conflict_index: None,
1608 conflict_term: None,
1609 };
1610 node.handle_append_entries_response(&NodeId::new("follower1"), &response);
1611
1612 assert!(node.has_valid_lease());
1614 assert!(node.is_leader_with_lease());
1615 }
1616
1617 #[test]
1622 fn test_create_install_snapshot_has_checksum() {
1623 let leader = RaftNode::new("leader", RaftConfig::default());
1624 leader.add_peer(NodeId::new("follower"));
1625
1626 {
1628 let mut state = leader.state.write().unwrap();
1629 state.current_term = 2;
1630 }
1631 *leader.role.write().unwrap() = NodeRole::Leader;
1632
1633 let cmd = Command::set("test_key", b"test_value".to_vec());
1635 leader.state_machine.apply(&cmd, 1);
1636
1637 use crate::log::LogEntry;
1638 leader.log.append(LogEntry::command(1, 2, cmd.to_bytes()));
1639 leader.log.set_commit_index(1);
1640 {
1641 let mut state = leader.state.write().unwrap();
1642 state.last_applied = 1;
1643 state.commit_index = 1;
1644 }
1645
1646 leader.take_snapshot();
1647
1648 leader
1649 .next_index
1650 .write()
1651 .unwrap()
1652 .insert(NodeId::new("follower"), 0);
1653
1654 let request = leader
1656 .create_install_snapshot(&NodeId::new("follower"), 0)
1657 .unwrap();
1658 assert!(request.done);
1659 assert!(request.checksum.is_some());
1660
1661 let (_, full_data) = leader.get_snapshot_data().unwrap();
1663 let expected_checksum = crc32fast::hash(&full_data);
1664 assert_eq!(request.checksum.unwrap(), expected_checksum);
1665 }
1666
1667 #[test]
1668 fn test_snapshot_checksum_mismatch_discards() {
1669 let follower = RaftNode::new("follower", RaftConfig::default());
1670
1671 let mut data = std::collections::HashMap::new();
1673 data.insert("key1".to_string(), b"value1".to_vec());
1674 let snapshot = crate::state::Snapshot {
1675 data,
1676 last_applied: 3,
1677 version: 2,
1678 };
1679 let snapshot_bytes = snapshot.to_bytes();
1680
1681 let request = InstallSnapshotRequest {
1683 term: 1,
1684 leader_id: NodeId::new("leader"),
1685 last_included_index: 3,
1686 last_included_term: 1,
1687 offset: 0,
1688 data: snapshot_bytes,
1689 done: true,
1690 checksum: Some(0xDEADBEEF), };
1692
1693 let _response = follower.handle_install_snapshot(&request);
1694
1695 assert!(follower.state_machine.is_empty());
1697 }
1698
1699 #[test]
1700 fn test_snapshot_checksum_valid_applies() {
1701 let follower = RaftNode::new("follower", RaftConfig::default());
1702
1703 let mut data = std::collections::HashMap::new();
1704 data.insert("key1".to_string(), b"value1".to_vec());
1705 let snapshot = crate::state::Snapshot {
1706 data,
1707 last_applied: 3,
1708 version: 2,
1709 };
1710 let snapshot_bytes = snapshot.to_bytes();
1711 let correct_checksum = crc32fast::hash(&snapshot_bytes);
1712
1713 let request = InstallSnapshotRequest {
1714 term: 1,
1715 leader_id: NodeId::new("leader"),
1716 last_included_index: 3,
1717 last_included_term: 1,
1718 offset: 0,
1719 data: snapshot_bytes,
1720 done: true,
1721 checksum: Some(correct_checksum),
1722 };
1723
1724 let _response = follower.handle_install_snapshot(&request);
1725
1726 assert_eq!(follower.state_machine.get("key1").unwrap(), b"value1");
1728 }
1729}