1use crate::log::{LogEntry, LogIndex, ReplicatedLog, Term};
9use crate::node::{NodeId, NodeRole};
10use crate::state::{Command, CommandResult, StateMachine};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone)]
22pub struct RaftConfig {
23 pub election_timeout_min: Duration,
24 pub election_timeout_max: Duration,
25 pub heartbeat_interval: Duration,
26 pub max_entries_per_request: usize,
27 pub snapshot_threshold: u64,
28}
29
30impl Default for RaftConfig {
31 fn default() -> Self {
32 Self {
33 election_timeout_min: Duration::from_millis(150),
34 election_timeout_max: Duration::from_millis(300),
35 heartbeat_interval: Duration::from_millis(50),
36 max_entries_per_request: 100,
37 snapshot_threshold: 10000,
38 }
39 }
40}
41
42impl RaftConfig {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_election_timeout(mut self, min: Duration, max: Duration) -> Self {
48 self.election_timeout_min = min;
49 self.election_timeout_max = max;
50 self
51 }
52
53 pub fn with_heartbeat_interval(mut self, interval: Duration) -> Self {
54 self.heartbeat_interval = interval;
55 self
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RaftState {
66 pub current_term: Term,
67 pub voted_for: Option<NodeId>,
68 pub commit_index: LogIndex,
69 pub last_applied: LogIndex,
70}
71
72impl Default for RaftState {
73 fn default() -> Self {
74 Self {
75 current_term: 0,
76 voted_for: None,
77 commit_index: 0,
78 last_applied: 0,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct VoteRequest {
90 pub term: Term,
91 pub candidate_id: NodeId,
92 pub last_log_index: LogIndex,
93 pub last_log_term: Term,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct VoteResponse {
99 pub term: Term,
100 pub vote_granted: bool,
101 pub voter_id: NodeId,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AppendEntriesRequest {
111 pub term: Term,
112 pub leader_id: NodeId,
113 pub prev_log_index: LogIndex,
114 pub prev_log_term: Term,
115 pub entries: Vec<LogEntry>,
116 pub leader_commit: LogIndex,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct AppendEntriesResponse {
122 pub term: Term,
123 pub success: bool,
124 pub match_index: LogIndex,
125 pub conflict_index: Option<LogIndex>,
126 pub conflict_term: Option<Term>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct InstallSnapshotRequest {
136 pub term: Term,
137 pub leader_id: NodeId,
138 pub last_included_index: LogIndex,
139 pub last_included_term: Term,
140 pub offset: u64,
141 pub data: Vec<u8>,
142 pub done: bool,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct InstallSnapshotResponse {
148 pub term: Term,
149}
150
151pub struct RaftNode {
157 id: NodeId,
158 config: RaftConfig,
159 state: RwLock<RaftState>,
160 role: RwLock<NodeRole>,
161 log: Arc<ReplicatedLog>,
162 state_machine: Arc<StateMachine>,
163 peers: RwLock<HashSet<NodeId>>,
164 leader_id: RwLock<Option<NodeId>>,
165 next_index: RwLock<HashMap<NodeId, LogIndex>>,
166 match_index: RwLock<HashMap<NodeId, LogIndex>>,
167 last_heartbeat: RwLock<Instant>,
168 votes_received: RwLock<HashSet<NodeId>>,
169}
170
171impl RaftNode {
172 pub fn new(id: impl Into<NodeId>, config: RaftConfig) -> Self {
174 Self {
175 id: id.into(),
176 config,
177 state: RwLock::new(RaftState::default()),
178 role: RwLock::new(NodeRole::Follower),
179 log: Arc::new(ReplicatedLog::new()),
180 state_machine: Arc::new(StateMachine::new()),
181 peers: RwLock::new(HashSet::new()),
182 leader_id: RwLock::new(None),
183 next_index: RwLock::new(HashMap::new()),
184 match_index: RwLock::new(HashMap::new()),
185 last_heartbeat: RwLock::new(Instant::now()),
186 votes_received: RwLock::new(HashSet::new()),
187 }
188 }
189
190 pub fn id(&self) -> NodeId {
192 self.id.clone()
193 }
194
195 pub fn role(&self) -> NodeRole {
197 *self.role.read().unwrap()
198 }
199
200 pub fn current_term(&self) -> Term {
202 self.state.read().unwrap().current_term
203 }
204
205 pub fn leader_id(&self) -> Option<NodeId> {
207 self.leader_id.read().unwrap().clone()
208 }
209
210 pub fn is_leader(&self) -> bool {
212 self.role() == NodeRole::Leader
213 }
214
215 pub fn add_peer(&self, peer_id: NodeId) {
217 let mut peers = self.peers.write().unwrap();
218 peers.insert(peer_id.clone());
219
220 let last_log = self.log.last_index();
221 self.next_index.write().unwrap().insert(peer_id.clone(), last_log + 1);
222 self.match_index.write().unwrap().insert(peer_id, 0);
223 }
224
225 pub fn remove_peer(&self, peer_id: &NodeId) {
227 let mut peers = self.peers.write().unwrap();
228 peers.remove(peer_id);
229 self.next_index.write().unwrap().remove(peer_id);
230 self.match_index.write().unwrap().remove(peer_id);
231 }
232
233 pub fn peers(&self) -> Vec<NodeId> {
235 self.peers.read().unwrap().iter().cloned().collect()
236 }
237
238 pub fn cluster_size(&self) -> usize {
240 self.peers.read().unwrap().len() + 1
241 }
242
243 pub fn quorum_size(&self) -> usize {
245 (self.cluster_size() / 2) + 1
246 }
247
248 pub fn reset_heartbeat(&self) {
250 *self.last_heartbeat.write().unwrap() = Instant::now();
251 }
252
253 pub fn election_timeout_elapsed(&self) -> bool {
255 let elapsed = self.last_heartbeat.read().unwrap().elapsed();
256 elapsed >= self.config.election_timeout_min
257 }
258
259 pub fn start_election(&self) -> VoteRequest {
265 let mut state = self.state.write().unwrap();
266 state.current_term += 1;
267 state.voted_for = Some(self.id.clone());
268
269 *self.role.write().unwrap() = NodeRole::Candidate;
270 *self.leader_id.write().unwrap() = None;
271
272 let mut votes = self.votes_received.write().unwrap();
273 votes.clear();
274 votes.insert(self.id.clone());
275
276 self.reset_heartbeat();
277
278 VoteRequest {
279 term: state.current_term,
280 candidate_id: self.id.clone(),
281 last_log_index: self.log.last_index(),
282 last_log_term: self.log.last_term(),
283 }
284 }
285
286 pub fn handle_vote_request(&self, request: &VoteRequest) -> VoteResponse {
288 let mut state = self.state.write().unwrap();
289
290 if request.term > state.current_term {
291 state.current_term = request.term;
292 state.voted_for = None;
293 *self.role.write().unwrap() = NodeRole::Follower;
294 *self.leader_id.write().unwrap() = None;
295 }
296
297 let vote_granted = request.term >= state.current_term
298 && (state.voted_for.is_none() || state.voted_for.as_ref() == Some(&request.candidate_id))
299 && self.log.is_up_to_date(request.last_log_index, request.last_log_term);
300
301 if vote_granted {
302 state.voted_for = Some(request.candidate_id.clone());
303 self.reset_heartbeat();
304 }
305
306 VoteResponse {
307 term: state.current_term,
308 vote_granted,
309 voter_id: self.id.clone(),
310 }
311 }
312
313 pub fn handle_vote_response(&self, response: &VoteResponse) -> bool {
315 let current_term = {
316 let mut state = self.state.write().unwrap();
317
318 if response.term > state.current_term {
319 state.current_term = response.term;
320 state.voted_for = None;
321 *self.role.write().unwrap() = NodeRole::Follower;
322 *self.leader_id.write().unwrap() = None;
323 return false;
324 }
325
326 if self.role() != NodeRole::Candidate || response.term != state.current_term {
327 return false;
328 }
329
330 state.current_term
331 };
332
333 if response.vote_granted {
334 self.votes_received.write().unwrap().insert(response.voter_id.clone());
335 }
336
337 let votes = self.votes_received.read().unwrap().len();
338 if votes >= self.quorum_size() {
339 self.become_leader_with_term(current_term);
340 return true;
341 }
342
343 false
344 }
345
346 #[allow(dead_code)]
348 fn become_leader(&self) {
349 let term = self.current_term();
350 self.become_leader_with_term(term);
351 }
352
353 fn become_leader_with_term(&self, term: Term) {
355 *self.role.write().unwrap() = NodeRole::Leader;
356 *self.leader_id.write().unwrap() = Some(self.id.clone());
357
358 let last_log = self.log.last_index();
359 let peers: Vec<_> = self.peers.read().unwrap().iter().cloned().collect();
360
361 let mut next_index = self.next_index.write().unwrap();
362 let mut match_index = self.match_index.write().unwrap();
363
364 for peer in peers {
365 next_index.insert(peer.clone(), last_log + 1);
366 match_index.insert(peer, 0);
367 }
368
369 drop(next_index);
370 drop(match_index);
371
372 let noop = LogEntry::noop(last_log + 1, term);
373 self.log.append(noop);
374 }
375
376 pub fn propose(&self, command: Command) -> Result<LogIndex, String> {
382 if !self.is_leader() {
383 return Err("Not the leader".to_string());
384 }
385
386 let term = self.current_term();
387 let index = self.log.last_index() + 1;
388 let entry = LogEntry::command(index, term, command.to_bytes());
389
390 self.log.append(entry);
391 Ok(index)
392 }
393
394 pub fn create_append_entries(&self, peer_id: &NodeId) -> Option<AppendEntriesRequest> {
396 if !self.is_leader() {
397 return None;
398 }
399
400 let next_index = *self.next_index.read().unwrap().get(peer_id)?;
401 let prev_log_index = next_index.saturating_sub(1);
402 let prev_log_term = self.log.term_at(prev_log_index).unwrap_or(0);
403
404 let entries = self.log.get_range(next_index, self.log.last_index() + 1);
405 let entries: Vec<_> = entries
406 .into_iter()
407 .take(self.config.max_entries_per_request)
408 .collect();
409
410 let state = self.state.read().unwrap();
411
412 Some(AppendEntriesRequest {
413 term: state.current_term,
414 leader_id: self.id.clone(),
415 prev_log_index,
416 prev_log_term,
417 entries,
418 leader_commit: state.commit_index,
419 })
420 }
421
422 pub fn handle_append_entries(&self, request: &AppendEntriesRequest) -> AppendEntriesResponse {
424 let mut state = self.state.write().unwrap();
425
426 if request.term < state.current_term {
427 return AppendEntriesResponse {
428 term: state.current_term,
429 success: false,
430 match_index: 0,
431 conflict_index: None,
432 conflict_term: None,
433 };
434 }
435
436 if request.term > state.current_term {
437 state.current_term = request.term;
438 state.voted_for = None;
439 }
440
441 *self.role.write().unwrap() = NodeRole::Follower;
442 *self.leader_id.write().unwrap() = Some(request.leader_id.clone());
443 self.reset_heartbeat();
444
445 if request.prev_log_index > 0 {
446 match self.log.term_at(request.prev_log_index) {
447 None => {
448 return AppendEntriesResponse {
449 term: state.current_term,
450 success: false,
451 match_index: self.log.last_index(),
452 conflict_index: Some(self.log.last_index() + 1),
453 conflict_term: None,
454 };
455 }
456 Some(term) if term != request.prev_log_term => {
457 let conflict_index = self.find_first_index_of_term(term);
458 return AppendEntriesResponse {
459 term: state.current_term,
460 success: false,
461 match_index: 0,
462 conflict_index: Some(conflict_index),
463 conflict_term: Some(term),
464 };
465 }
466 _ => {}
467 }
468 }
469
470 if !request.entries.is_empty() {
471 if let Some(conflict) = self.log.find_conflict(&request.entries) {
472 self.log.truncate_from(conflict);
473 }
474
475 let existing_last = self.log.last_index();
476 let new_entries: Vec<_> = request
477 .entries
478 .iter()
479 .filter(|e| e.index > existing_last)
480 .cloned()
481 .collect();
482
483 if !new_entries.is_empty() {
484 self.log.append_entries(new_entries);
485 }
486 }
487
488 if request.leader_commit > state.commit_index {
489 let last_new_index = if request.entries.is_empty() {
490 request.prev_log_index
491 } else {
492 request.entries.last().unwrap().index
493 };
494 state.commit_index = std::cmp::min(request.leader_commit, last_new_index);
495 self.log.set_commit_index(state.commit_index);
496 }
497
498 AppendEntriesResponse {
499 term: state.current_term,
500 success: true,
501 match_index: self.log.last_index(),
502 conflict_index: None,
503 conflict_term: None,
504 }
505 }
506
507 pub fn handle_append_entries_response(
509 &self,
510 peer_id: &NodeId,
511 response: &AppendEntriesResponse,
512 ) {
513 let mut state = self.state.write().unwrap();
514
515 if response.term > state.current_term {
516 state.current_term = response.term;
517 state.voted_for = None;
518 *self.role.write().unwrap() = NodeRole::Follower;
519 *self.leader_id.write().unwrap() = None;
520 return;
521 }
522
523 if !self.is_leader() {
524 return;
525 }
526
527 let mut next_index = self.next_index.write().unwrap();
528 let mut match_index = self.match_index.write().unwrap();
529
530 if response.success {
531 match_index.insert(peer_id.clone(), response.match_index);
532 next_index.insert(peer_id.clone(), response.match_index + 1);
533 drop(next_index);
534 drop(match_index);
535 drop(state);
536 self.try_advance_commit_index();
537 } else {
538 if let Some(conflict_index) = response.conflict_index {
539 next_index.insert(peer_id.clone(), conflict_index);
540 } else {
541 let current = *next_index.get(peer_id).unwrap_or(&1);
542 next_index.insert(peer_id.clone(), current.saturating_sub(1).max(1));
543 }
544 }
545 }
546
547 fn try_advance_commit_index(&self) {
549 let match_indices: Vec<_> = {
550 let match_index = self.match_index.read().unwrap();
551 let mut indices: Vec<_> = match_index.values().copied().collect();
552 indices.push(self.log.last_index());
553 indices.sort_unstable();
554 indices
555 };
556
557 let quorum_index = match_indices.len() / 2;
558 let new_commit = match_indices[quorum_index];
559
560 let mut state = self.state.write().unwrap();
561 if new_commit > state.commit_index {
562 if let Some(term) = self.log.term_at(new_commit) {
563 if term == state.current_term {
564 state.commit_index = new_commit;
565 self.log.set_commit_index(new_commit);
566 }
567 }
568 }
569 }
570
571 fn find_first_index_of_term(&self, term: Term) -> LogIndex {
572 let mut index = self.log.last_index();
573 while index > 0 {
574 if let Some(t) = self.log.term_at(index) {
575 if t != term {
576 return index + 1;
577 }
578 }
579 index -= 1;
580 }
581 1
582 }
583
584 pub fn apply_committed(&self) -> Vec<CommandResult> {
590 let mut results = Vec::new();
591
592 while self.log.has_entries_to_apply() {
593 if let Some(entry) = self.log.next_to_apply() {
594 if let Some(command) = Command::from_bytes(&entry.data) {
595 let result = self.state_machine.apply(&command, entry.index);
596 results.push(result);
597 }
598 self.log.set_last_applied(entry.index);
599
600 let mut state = self.state.write().unwrap();
601 state.last_applied = entry.index;
602 }
603 }
604
605 results
606 }
607
608 pub fn get(&self, key: &str) -> Option<Vec<u8>> {
610 self.state_machine.get(key)
611 }
612
613 pub fn log(&self) -> &ReplicatedLog {
615 &self.log
616 }
617
618 pub fn state_machine(&self) -> &StateMachine {
620 &self.state_machine
621 }
622}
623
624#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_raft_config() {
634 let config = RaftConfig::default();
635 assert_eq!(config.election_timeout_min, Duration::from_millis(150));
636 assert_eq!(config.heartbeat_interval, Duration::from_millis(50));
637 }
638
639 #[test]
640 fn test_raft_node_creation() {
641 let node = RaftNode::new("node1", RaftConfig::default());
642 assert_eq!(node.id().as_str(), "node1");
643 assert_eq!(node.role(), NodeRole::Follower);
644 assert_eq!(node.current_term(), 0);
645 assert!(!node.is_leader());
646 }
647
648 #[test]
649 fn test_add_peer() {
650 let node = RaftNode::new("node1", RaftConfig::default());
651 node.add_peer(NodeId::new("node2"));
652 node.add_peer(NodeId::new("node3"));
653
654 assert_eq!(node.cluster_size(), 3);
655 assert_eq!(node.quorum_size(), 2);
656 assert_eq!(node.peers().len(), 2);
657 }
658
659 #[test]
660 fn test_start_election() {
661 let node = RaftNode::new("node1", RaftConfig::default());
662 let request = node.start_election();
663
664 assert_eq!(request.term, 1);
665 assert_eq!(request.candidate_id.as_str(), "node1");
666 assert_eq!(node.role(), NodeRole::Candidate);
667 assert_eq!(node.current_term(), 1);
668 }
669
670 #[test]
671 fn test_vote_request_handling() {
672 let node1 = RaftNode::new("node1", RaftConfig::default());
673 let node2 = RaftNode::new("node2", RaftConfig::default());
674
675 let request = node1.start_election();
676 let response = node2.handle_vote_request(&request);
677
678 assert!(response.vote_granted);
679 assert_eq!(response.term, 1);
680 }
681
682 #[test]
683 fn test_become_leader() {
684 let node = RaftNode::new("node1", RaftConfig::default());
685 node.add_peer(NodeId::new("node2"));
686
687 let request = node.start_election();
688 let response = VoteResponse {
689 term: request.term,
690 vote_granted: true,
691 voter_id: NodeId::new("node2"),
692 };
693
694 let became_leader = node.handle_vote_response(&response);
695 assert!(became_leader);
696 assert!(node.is_leader());
697 assert_eq!(node.leader_id(), Some(NodeId::new("node1")));
698 }
699
700 #[test]
701 fn test_propose_command() {
702 let node = RaftNode::new("node1", RaftConfig::default());
703 node.add_peer(NodeId::new("node2"));
704
705 node.start_election();
706 let response = VoteResponse {
707 term: 1,
708 vote_granted: true,
709 voter_id: NodeId::new("node2"),
710 };
711 node.handle_vote_response(&response);
712
713 let command = Command::set("key1", b"value1".to_vec());
714 let result = node.propose(command);
715 assert!(result.is_ok());
716 }
717
718 #[test]
719 fn test_append_entries() {
720 let leader = RaftNode::new("leader", RaftConfig::default());
721 let follower = RaftNode::new("follower", RaftConfig::default());
722
723 leader.add_peer(NodeId::new("follower"));
724 leader.start_election();
725 let vote = VoteResponse {
726 term: 1,
727 vote_granted: true,
728 voter_id: NodeId::new("follower"),
729 };
730 leader.handle_vote_response(&vote);
731
732 let command = Command::set("key", b"value".to_vec());
733 leader.propose(command).unwrap();
734
735 let request = leader.create_append_entries(&NodeId::new("follower")).unwrap();
736 let response = follower.handle_append_entries(&request);
737
738 assert!(response.success);
739 assert_eq!(follower.log().last_index(), leader.log().last_index());
740 }
741
742 #[test]
743 fn test_follower_rejects_old_term() {
744 let follower = RaftNode::new("follower", RaftConfig::default());
745
746 {
747 let mut state = follower.state.write().unwrap();
748 state.current_term = 5;
749 }
750
751 let request = AppendEntriesRequest {
752 term: 3,
753 leader_id: NodeId::new("old_leader"),
754 prev_log_index: 0,
755 prev_log_term: 0,
756 entries: vec![],
757 leader_commit: 0,
758 };
759
760 let response = follower.handle_append_entries(&request);
761 assert!(!response.success);
762 assert_eq!(response.term, 5);
763 }
764}