1use crate::network::{ConsensusNetwork, NetworkError, RaftMessage, RaftNetwork};
7use crate::raft::{
8 ClusterConfig, LogCommand, LogEntry, LogIndex, LogInfo, NodeId, RaftRole, RaftStorage, Term,
9};
10use async_trait::async_trait;
11use rand::Rng;
12use serde::{Serialize, de::DeserializeOwned};
13use std::collections::HashMap;
14
15use std::time::Duration;
16
17use serde::Deserialize;
22
23#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
25pub struct RaftConfig {
26 #[serde(
28 deserialize_with = "deserialize_duration",
29 default = "default_election_min"
30 )]
31 pub election_timeout_min: Duration,
32 #[serde(
34 deserialize_with = "deserialize_duration",
35 default = "default_election_max"
36 )]
37 pub election_timeout_max: Duration,
38 #[serde(
40 deserialize_with = "deserialize_duration",
41 default = "default_heartbeat"
42 )]
43 pub heartbeat_interval: Duration,
44 #[serde(
46 deserialize_with = "deserialize_duration",
47 default = "default_rpc_timeout"
48 )]
49 pub rpc_timeout: Duration,
50 #[serde(default = "default_max_entries")]
52 pub max_entries_per_rpc: usize,
53}
54
55fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
56where
57 D: serde::Deserializer<'de>,
58{
59 let ms = u64::deserialize(deserializer)?;
60 Ok(Duration::from_millis(ms))
61}
62
63fn default_election_min() -> Duration {
64 Duration::from_millis(150)
65}
66fn default_election_max() -> Duration {
67 Duration::from_millis(300)
68}
69fn default_heartbeat() -> Duration {
70 Duration::from_millis(50)
71}
72fn default_rpc_timeout() -> Duration {
73 Duration::from_millis(100)
74}
75fn default_max_entries() -> usize {
76 100
77}
78
79impl Default for RaftConfig {
80 fn default() -> Self {
81 Self {
82 election_timeout_min: default_election_min(),
83 election_timeout_max: default_election_max(),
84 heartbeat_interval: default_heartbeat(),
85 rpc_timeout: default_rpc_timeout(),
86 max_entries_per_rpc: default_max_entries(),
87 }
88 }
89}
90
91impl RaftConfig {
92 pub fn load() -> Self {
105 let builder = config::Config::builder()
106 .add_source(config::File::with_name("raft").required(false))
107 .add_source(config::Environment::with_prefix("PRABORROW"));
108
109 match builder.build().and_then(|c| c.try_deserialize()) {
110 Ok(config) => config,
111 Err(e) => {
112 tracing::warn!("Failed to load configuration: {}. Using defaults.", e);
113 Self::default()
114 }
115 }
116 }
117
118 pub fn random_election_timeout(&self) -> Duration {
120 let min = self.election_timeout_min.as_millis() as u64;
121 let max = self.election_timeout_max.as_millis() as u64;
122 let timeout_ms = rand::rng().random_range(min..=max);
123 Duration::from_millis(timeout_ms)
124 }
125
126 pub fn validate(&self) -> Result<(), String> {
128 if self.election_timeout_min >= self.election_timeout_max {
129 return Err("election_timeout_min must be less than election_timeout_max".to_string());
130 }
131 if self.heartbeat_interval.mul_f64(2.0) > self.election_timeout_min {
133 return Err(
134 "heartbeat_interval must be at most half of election_timeout_min".to_string(),
135 );
136 }
137 if self.heartbeat_interval >= self.election_timeout_min {
138 return Err("heartbeat_interval must be less than election_timeout_min".to_string());
139 }
140 if self.rpc_timeout.is_zero() {
141 return Err("rpc_timeout must be non-zero".to_string());
142 }
143 if self.max_entries_per_rpc == 0 {
144 return Err("max_entries_per_rpc must be positive".to_string());
145 }
146 Ok(())
147 }
148
149 pub fn builder() -> RaftConfigBuilder {
151 RaftConfigBuilder::default()
152 }
153}
154
155#[derive(Default)]
157pub struct RaftConfigBuilder {
158 election_timeout_min: Option<Duration>,
159 election_timeout_max: Option<Duration>,
160 heartbeat_interval: Option<Duration>,
161 rpc_timeout: Option<Duration>,
162 max_entries_per_rpc: Option<usize>,
163}
164
165impl RaftConfigBuilder {
166 pub fn election_timeout_min(mut self, timeout: Duration) -> Self {
167 self.election_timeout_min = Some(timeout);
168 self
169 }
170
171 pub fn election_timeout_max(mut self, timeout: Duration) -> Self {
172 self.election_timeout_max = Some(timeout);
173 self
174 }
175
176 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
177 self.heartbeat_interval = Some(interval);
178 self
179 }
180
181 pub fn rpc_timeout(mut self, timeout: Duration) -> Self {
182 self.rpc_timeout = Some(timeout);
183 self
184 }
185
186 pub fn max_entries_per_rpc(mut self, max: usize) -> Self {
187 self.max_entries_per_rpc = Some(max);
188 self
189 }
190
191 pub fn build(self) -> Result<RaftConfig, String> {
192 let config = RaftConfig {
193 election_timeout_min: self
194 .election_timeout_min
195 .unwrap_or_else(default_election_min),
196 election_timeout_max: self
197 .election_timeout_max
198 .unwrap_or_else(default_election_max),
199 heartbeat_interval: self.heartbeat_interval.unwrap_or_else(default_heartbeat),
200 rpc_timeout: self.rpc_timeout.unwrap_or_else(default_rpc_timeout),
201 max_entries_per_rpc: self.max_entries_per_rpc.unwrap_or_else(default_max_entries),
202 };
203 config.validate()?;
204 Ok(config)
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210pub enum ConsensusStrategy {
211 Raft,
213 Paxos,
215}
216
217#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
223pub enum ConsensusError {
224 #[error("Consensus strategy {0:?} is not yet implemented")]
225 NotImplemented(ConsensusStrategy),
226 #[error("Storage error: {0}")]
227 StorageError(String),
228 #[error("Network error: {0}")]
229 NetworkError(String),
230 #[error("Not leader")]
231 NotLeader,
232 #[error("Term mismatch")]
233 TermMismatch,
234 #[error("Snapshot error: {0}")]
235 SnapshotError(String),
236 #[error("Compaction error: {0}")]
237 CompactionError(String),
238 #[error("Integrity check failed: {0}")]
239 IntegrityError(String),
240 #[error("Index out of bounds: requested {requested}, available {available}")]
241 IndexOutOfBounds { requested: u64, available: u64 },
242 #[error("Configuration change error: {0}")]
243 ConfigChangeError(String),
244 #[error("Configuration change in progress")]
245 ConfigChangeInProgress,
246 #[cfg(feature = "grpc")]
247 #[error("TLS error: {0}")]
248 Tls(String),
249 #[error("Shutdown requested")]
250 Shutdown,
251}
252
253impl From<NetworkError> for ConsensusError {
254 fn from(e: NetworkError) -> Self {
255 ConsensusError::NetworkError(e.to_string())
256 }
257}
258
259impl From<Box<dyn std::error::Error>> for ConsensusError {
260 fn from(e: Box<dyn std::error::Error>) -> Self {
261 ConsensusError::NetworkError(e.to_string())
262 }
263}
264
265#[cfg(feature = "grpc")]
266impl From<tonic::transport::Error> for ConsensusError {
267 fn from(e: tonic::transport::Error) -> Self {
268 ConsensusError::Tls(e.to_string())
269 }
270}
271
272#[async_trait]
278pub trait ConsensusEngine<T>: Send {
279 async fn run(&mut self) -> Result<(), ConsensusError>;
281
282 async fn propose(&mut self, value: T) -> Result<LogIndex, ConsensusError>;
284
285 async fn propose_conf_change(
287 &mut self,
288 change: crate::raft::ConfChange,
289 ) -> Result<LogIndex, ConsensusError>;
290
291 fn leader_id(&self) -> Option<NodeId>;
293
294 fn is_leader(&self) -> bool;
296
297 async fn current_term(&self) -> Term;
299
300 fn commit_index(&self) -> LogIndex;
302}
303
304#[derive(Debug, Default)]
310struct LeaderState {
311 next_index: HashMap<NodeId, LogIndex>,
313 match_index: HashMap<NodeId, LogIndex>,
315}
316
317impl LeaderState {
318 fn new(peers: &[NodeId], last_log_index: LogIndex) -> Self {
319 let mut next_index = HashMap::new();
320 let mut match_index = HashMap::new();
321
322 for &peer in peers {
323 next_index.insert(peer, last_log_index + 1);
324 match_index.insert(peer, 0);
325 }
326
327 Self {
328 next_index,
329 match_index,
330 }
331 }
332}
333
334pub struct RaftEngine<T, N, S>
336where
337 T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
338 N: RaftNetwork<T>,
339 S: RaftStorage<T>,
340{
341 id: NodeId,
343
344 storage: S,
346
347 role: RaftRole,
349 commit_index: LogIndex,
350 _last_applied: LogIndex,
351
352 leader_state: Option<LeaderState>,
354
355 current_leader: Option<NodeId>,
357
358 network: N,
360
361 config: RaftConfig,
363
364 cluster_config: ClusterConfig,
366
367 votes_received: HashMap<NodeId, bool>,
369
370 _phantom: std::marker::PhantomData<T>,
372}
373
374impl<T, N, S> RaftEngine<T, N, S>
375where
376 T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
377 N: RaftNetwork<T>,
378 S: RaftStorage<T>,
379{
380 pub fn new(id: NodeId, network: N, storage: S, config: RaftConfig) -> Self {
385 tracing::info!(node_id = id, "Creating Raft engine");
386
387 let peers = network.peer_ids();
388 let mut nodes = vec![id];
389 nodes.extend(peers);
390
391 Self {
392 id,
393 storage,
394 role: RaftRole::Follower,
395 commit_index: 0,
396 _last_applied: 0,
397 leader_state: None,
398 current_leader: None,
399 network,
400 config,
401 votes_received: HashMap::new(),
402 cluster_config: ClusterConfig::Single(nodes),
403 _phantom: std::marker::PhantomData,
404 }
405 }
406
407 fn peer_ids(&self) -> Vec<NodeId> {
409 self.network.peer_ids()
410 }
411
412 fn has_majority(&self, votes: &HashMap<NodeId, bool>) -> bool {
414 let voters: Vec<NodeId> = votes
415 .iter()
416 .filter(|&(_, &granted)| granted)
417 .map(|(&id, _)| id)
418 .collect();
419
420 self.cluster_config.has_majority(&voters)
421 }
422
423 async fn become_follower(&mut self, term: Term) {
428 let old_role = self.role.clone();
429 self.role = RaftRole::Follower;
430 self.leader_state = None;
431 self.votes_received.clear();
432
433 let _ = self.storage.set_term(term).await;
434 let _ = self.storage.set_vote(None).await;
435
436 if old_role != RaftRole::Follower {
437 tracing::info!(
438 node_id = self.id,
439 from_role = %old_role,
440 new_term = term,
441 "Became follower"
442 );
443 }
444 }
445
446 async fn become_candidate(&mut self) {
447 let current_term = self.storage.get_term().await.unwrap_or(0);
448 let new_term = current_term + 1;
449
450 self.role = RaftRole::Candidate;
451 self.leader_state = None;
452 self.current_leader = None;
453 self.votes_received.clear();
454
455 let _ = self
457 .storage
458 .set_term_and_vote(new_term, Some(self.id))
459 .await;
460 self.votes_received.insert(self.id, true);
461
462 tracing::info!(
463 node_id = self.id,
464 term = new_term,
465 "Became candidate, starting election"
466 );
467 }
468
469 async fn become_leader(&mut self) {
470 let term = self.storage.get_term().await.unwrap_or(0);
471 let log_info = self.storage.get_last_log_info().await.unwrap_or_default();
472
473 self.role = RaftRole::Leader;
474 self.current_leader = Some(self.id);
475 self.leader_state = Some(LeaderState::new(&self.peer_ids(), log_info.last_index));
476 self.votes_received.clear();
477
478 tracing::info!(node_id = self.id, term = term, "Became leader");
479 }
480
481 pub async fn run_loop(&mut self) -> Result<(), ConsensusError> {
487 tracing::info!(node_id = self.id, "Starting Raft consensus loop");
488
489 loop {
490 match &self.role {
491 RaftRole::Follower => self.run_follower().await?,
492 RaftRole::Candidate => self.run_candidate().await?,
493 RaftRole::Leader => self.run_leader().await?,
494 }
495 }
496 }
497
498 async fn run_follower(&mut self) -> Result<(), ConsensusError> {
500 let timeout = self.config.random_election_timeout();
501
502 tokio::select! {
503 result = self.network.receive() => {
505 match result {
506 Ok(msg) => self.handle_message(msg).await?,
507 Err(e) => {
508 tracing::warn!(error = %e, "Network receive error");
509 }
510 }
511 }
512
513 _ = tokio::time::sleep(timeout) => {
515 tracing::debug!(
516 node_id = self.id,
517 timeout_ms = timeout.as_millis(),
518 "Election timeout, becoming candidate"
519 );
520 self.become_candidate().await;
521 }
522 }
523
524 Ok(())
525 }
526
527 async fn run_candidate(&mut self) -> Result<(), ConsensusError> {
529 let term = self.storage.get_term().await.unwrap_or(0);
530 let log_info = self.storage.get_last_log_info().await.unwrap_or_default();
531
532 for peer_id in self.peer_ids() {
534 let _ = self
535 .network
536 .send_request_vote(
537 peer_id,
538 term,
539 self.id,
540 log_info.last_index,
541 log_info.last_term,
542 )
543 .await;
544 }
545
546 let timeout = self.config.random_election_timeout();
547 let deadline = tokio::time::Instant::now() + timeout;
548
549 while tokio::time::Instant::now() < deadline {
551 let remaining = deadline - tokio::time::Instant::now();
552
553 tokio::select! {
554 result = self.network.receive() => {
555 match result {
556 Ok(msg) => {
557 self.handle_message(msg).await?;
558
559 if self.role == RaftRole::Leader {
561 return Ok(());
562 }
563
564 if self.role == RaftRole::Follower {
566 return Ok(());
567 }
568 }
569 Err(e) => {
570 tracing::warn!(error = %e, "Network receive error during election");
571 }
572 }
573 }
574
575 _ = tokio::time::sleep(remaining) => {
576 tracing::debug!(node_id = self.id, "Election timeout, restarting");
578 self.become_candidate().await;
579 return Ok(());
580 }
581 }
582 }
583
584 Ok(())
585 }
586
587 async fn run_leader(&mut self) -> Result<(), ConsensusError> {
589 self.send_append_entries_to_all().await?;
591
592 tokio::select! {
594 result = self.network.receive() => {
595 match result {
596 Ok(msg) => self.handle_message(msg).await?,
597 Err(e) => {
598 tracing::warn!(error = %e, "Network receive error");
599 }
600 }
601 }
602
603 _ = tokio::time::sleep(self.config.heartbeat_interval) => {
604 }
606 }
607
608 self.advance_commit_index().await?;
610
611 Ok(())
612 }
613
614 async fn handle_message(&mut self, msg: RaftMessage<T>) -> Result<(), ConsensusError> {
619 match msg {
620 RaftMessage::RequestVote {
621 term,
622 candidate_id,
623 last_log_index,
624 last_log_term,
625 } => {
626 self.handle_request_vote(term, candidate_id, last_log_index, last_log_term)
627 .await?;
628 }
629
630 RaftMessage::RequestVoteResponse {
631 term,
632 vote_granted,
633 from_id,
634 } => {
635 self.handle_request_vote_response(term, vote_granted, from_id)
636 .await?;
637 }
638
639 RaftMessage::AppendEntries {
640 term,
641 leader_id,
642 prev_log_index,
643 prev_log_term,
644 entries,
645 leader_commit,
646 } => {
647 self.handle_append_entries(
648 term,
649 leader_id,
650 prev_log_index,
651 prev_log_term,
652 entries,
653 leader_commit,
654 )
655 .await?;
656 }
657
658 RaftMessage::AppendEntriesResponse {
659 term,
660 success,
661 match_index,
662 from_id,
663 } => {
664 self.handle_append_entries_response(term, success, match_index, from_id)
665 .await?;
666 }
667
668 RaftMessage::InstallSnapshot {
669 term,
670 leader_id,
671 snapshot,
672 } => {
673 self.handle_install_snapshot(term, leader_id, snapshot)
674 .await?;
675 }
676
677 RaftMessage::InstallSnapshotResponse {
678 term,
679 success: _,
680 from_id: _,
681 } => {
682 if term > self.storage.get_term().await.unwrap_or(0) {
684 self.become_follower(term).await;
685 }
686 }
687 }
688
689 Ok(())
690 }
691
692 async fn handle_request_vote(
697 &mut self,
698 term: Term,
699 candidate_id: NodeId,
700 last_log_index: LogIndex,
701 last_log_term: Term,
702 ) -> Result<(), ConsensusError> {
703 let current_term = self.storage.get_term().await.unwrap_or(0);
704
705 if term > current_term {
707 self.become_follower(term).await;
708 }
709
710 let current_term = self.storage.get_term().await.unwrap_or(0);
711 let voted_for = self.storage.get_vote().await.unwrap_or(None);
712 let our_log_info = self.storage.get_last_log_info().await.unwrap_or_default();
713
714 let vote_granted = term >= current_term
719 && (voted_for.is_none() || voted_for == Some(candidate_id))
720 && self.is_log_up_to_date(last_log_index, last_log_term, &our_log_info);
721
722 if vote_granted {
723 let _ = self.storage.set_vote(Some(candidate_id)).await;
724 tracing::debug!(
725 node_id = self.id,
726 candidate = candidate_id,
727 term = term,
728 "Granted vote"
729 );
730 }
731
732 let response = RaftMessage::RequestVoteResponse {
734 term: current_term,
735 vote_granted,
736 from_id: self.id,
737 };
738
739 self.network.respond(candidate_id, response).await?;
740
741 Ok(())
742 }
743
744 async fn handle_request_vote_response(
745 &mut self,
746 term: Term,
747 vote_granted: bool,
748 from_id: NodeId,
749 ) -> Result<(), ConsensusError> {
750 let current_term = self.storage.get_term().await.unwrap_or(0);
751
752 if term > current_term {
754 self.become_follower(term).await;
755 return Ok(());
756 }
757
758 if term < current_term || self.role != RaftRole::Candidate {
760 return Ok(());
761 }
762
763 self.votes_received.insert(from_id, vote_granted);
765
766 if vote_granted {
767 let votes = self.votes_received.values().filter(|&&v| v).count();
768
769 tracing::debug!(
770 node_id = self.id,
771 from = from_id,
772 votes = votes,
773 "Received vote"
774 );
775
776 if self.has_majority(&self.votes_received) {
778 self.become_leader().await;
779 }
780 }
781
782 Ok(())
783 }
784
785 fn is_log_up_to_date(
787 &self,
788 last_log_index: LogIndex,
789 last_log_term: Term,
790 our_log: &LogInfo,
791 ) -> bool {
792 if last_log_term > our_log.last_term {
796 return true;
797 }
798 if last_log_term == our_log.last_term && last_log_index >= our_log.last_index {
799 return true;
800 }
801 false
802 }
803
804 async fn send_append_entries_to_all(&mut self) -> Result<(), ConsensusError> {
809 let term = self.storage.get_term().await.unwrap_or(0);
810 let peers = self.peer_ids();
811
812 for peer_id in peers {
813 if let Err(e) = self.send_append_entries_to_peer(peer_id, term).await {
814 tracing::warn!(peer = peer_id, error = %e, "Failed to send AppendEntries");
815 }
816 }
817
818 Ok(())
819 }
820
821 async fn send_append_entries_to_peer(
822 &mut self,
823 peer_id: NodeId,
824 term: Term,
825 ) -> Result<(), ConsensusError> {
826 let leader_state = self
827 .leader_state
828 .as_ref()
829 .ok_or(ConsensusError::NotLeader)?;
830
831 let next_idx = *leader_state.next_index.get(&peer_id).unwrap_or(&1);
832 let prev_log_index = next_idx.saturating_sub(1);
833
834 let prev_log_term = if prev_log_index == 0 {
836 0
837 } else {
838 self.storage
839 .get_log_entry(prev_log_index)
840 .await?
841 .map(|e| e.term)
842 .unwrap_or(0)
843 };
844
845 let last_log_info = self.storage.get_last_log_info().await?;
847 let end_idx =
848 (next_idx + self.config.max_entries_per_rpc as u64).min(last_log_info.last_index + 1);
849 let entries_iter = self.storage.get_log_range(next_idx, end_idx).await?;
850 let entries: Vec<LogEntry<T>> = entries_iter.collect::<Result<_, _>>()?;
851
852 self.network
853 .send_append_entries(
854 peer_id,
855 term,
856 self.id,
857 prev_log_index,
858 prev_log_term,
859 entries,
860 self.commit_index,
861 )
862 .await?;
863
864 Ok(())
865 }
866
867 async fn handle_append_entries(
868 &mut self,
869 term: Term,
870 leader_id: NodeId,
871 prev_log_index: LogIndex,
872 prev_log_term: Term,
873 entries: Vec<LogEntry<T>>,
874 leader_commit: LogIndex,
875 ) -> Result<(), ConsensusError> {
876 let current_term = self.storage.get_term().await.unwrap_or(0);
877
878 if term < current_term {
880 let response = RaftMessage::AppendEntriesResponse {
881 term: current_term,
882 success: false,
883 match_index: 0,
884 from_id: self.id,
885 };
886 self.network.respond(leader_id, response).await?;
887 return Ok(());
888 }
889
890 if term > current_term {
892 self.become_follower(term).await;
893 }
894
895 self.current_leader = Some(leader_id);
897
898 if self.role == RaftRole::Candidate {
900 self.become_follower(term).await;
901 }
902
903 let our_log_info = self.storage.get_last_log_info().await?;
905
906 let success = if prev_log_index == 0 {
907 true
908 } else if prev_log_index > our_log_info.last_index {
909 false
910 } else {
911 self.storage
913 .get_log_entry(prev_log_index)
914 .await?
915 .map(|e| e.term == prev_log_term)
916 .unwrap_or(false)
917 };
918
919 let match_index = if success {
920 if !entries.is_empty() {
922 for entry in &entries {
928 let conflict = match self.storage.get_log_entry(entry.index).await? {
929 Some(existing) => existing.term != entry.term,
930 None => false,
931 };
932
933 if conflict {
934 self.storage.truncate_log(entry.index).await?;
936 break;
937 }
938 }
939
940 let log_info = self.storage.get_last_log_info().await?;
942 let new_entries: Vec<_> = entries
943 .into_iter()
944 .filter(|e| e.index > log_info.last_index)
945 .collect();
946
947 if !new_entries.is_empty() {
948 self.storage.append_entries(&new_entries).await?;
949
950 for entry in &new_entries {
952 if let LogCommand::Config(config) = &entry.command {
953 tracing::info!(
954 node_id = self.id,
955 index = entry.index,
956 "Node transitioning to new configuration from log"
957 );
958 self.cluster_config = config.clone();
959
960 let _ = self
962 .network
963 .update_peers(
964 config
965 .all_nodes()
966 .into_iter()
967 .map(|id| crate::network::PeerInfo {
968 id,
969 address: "".to_string(), })
971 .collect(),
972 )
973 .await;
974 }
975 }
976 }
977 }
978
979 let our_log_info = self.storage.get_last_log_info().await?;
981 if leader_commit > self.commit_index {
982 self.commit_index = leader_commit.min(our_log_info.last_index);
983 self.storage.set_commit_index(self.commit_index).await?;
984 }
985
986 self.storage.get_last_log_info().await?.last_index
987 } else {
988 0
989 };
990
991 let response = RaftMessage::AppendEntriesResponse {
993 term: self.storage.get_term().await.unwrap_or(0),
994 success,
995 match_index,
996 from_id: self.id,
997 };
998
999 self.network.respond(leader_id, response).await?;
1000
1001 Ok(())
1002 }
1003
1004 async fn handle_append_entries_response(
1005 &mut self,
1006 term: Term,
1007 success: bool,
1008 match_index: LogIndex,
1009 from_id: NodeId,
1010 ) -> Result<(), ConsensusError> {
1011 let current_term = self.storage.get_term().await.unwrap_or(0);
1012
1013 if term > current_term {
1015 self.become_follower(term).await;
1016 return Ok(());
1017 }
1018
1019 if self.role != RaftRole::Leader || term != current_term {
1021 return Ok(());
1022 }
1023
1024 let leader_state = self
1025 .leader_state
1026 .as_mut()
1027 .ok_or(ConsensusError::NotLeader)?;
1028
1029 if success {
1030 if match_index > *leader_state.match_index.get(&from_id).unwrap_or(&0) {
1032 leader_state.match_index.insert(from_id, match_index);
1033 leader_state.next_index.insert(from_id, match_index + 1);
1034 }
1035 } else {
1036 let next_idx = leader_state.next_index.get(&from_id).copied().unwrap_or(1);
1038 if next_idx > 1 {
1039 leader_state.next_index.insert(from_id, next_idx - 1);
1040 }
1041 }
1042
1043 Ok(())
1044 }
1045
1046 async fn handle_install_snapshot(
1051 &mut self,
1052 term: Term,
1053 leader_id: NodeId,
1054 snapshot: crate::raft::Snapshot<T>,
1055 ) -> Result<(), ConsensusError> {
1056 let current_term = self.storage.get_term().await.unwrap_or(0);
1057
1058 if term < current_term {
1059 let response = RaftMessage::InstallSnapshotResponse {
1060 term: current_term,
1061 success: false,
1062 from_id: self.id,
1063 };
1064 self.network.respond(leader_id, response).await?;
1065 return Ok(());
1066 }
1067
1068 if term > current_term {
1069 self.become_follower(term).await;
1070 }
1071
1072 self.current_leader = Some(leader_id);
1073
1074 self.storage.install_snapshot(snapshot.clone()).await?;
1076
1077 if snapshot.last_included_index > self.commit_index {
1079 self.commit_index = snapshot.last_included_index;
1080 self.storage.set_commit_index(self.commit_index).await?;
1081 }
1082
1083 let response = RaftMessage::InstallSnapshotResponse {
1084 term: self.storage.get_term().await.unwrap_or(0),
1085 success: true,
1086 from_id: self.id,
1087 };
1088
1089 self.network.respond(leader_id, response).await?;
1090
1091 Ok(())
1092 }
1093
1094 #[tracing::instrument(skip(self), level = "debug")]
1099 async fn advance_commit_index(&mut self) -> Result<(), ConsensusError> {
1100 if self.role != RaftRole::Leader {
1101 return Ok(());
1102 }
1103
1104 let leader_state = self
1105 .leader_state
1106 .as_ref()
1107 .ok_or(ConsensusError::NotLeader)?;
1108
1109 let current_term = self.storage.get_term().await.unwrap_or(0);
1110 let log_info = self.storage.get_last_log_info().await?;
1111
1112 let mut joint_entry_to_finalize = None;
1118
1119 for (i, n) in ((self.commit_index + 1)..=log_info.last_index).enumerate() {
1120 if i % 100 == 0 {
1122 tokio::task::yield_now().await;
1123 }
1124 let entry = self.storage.get_log_entry(n).await?;
1126 if entry.as_ref().map(|e| e.term) != Some(current_term) {
1127 continue;
1128 }
1129
1130 let mut replicated = vec![self.id];
1132 for (&node_id, &match_idx) in &leader_state.match_index {
1133 if match_idx >= n {
1134 replicated.push(node_id);
1135 }
1136 }
1137
1138 if self.cluster_config.has_majority(&replicated) {
1139 self.commit_index = n;
1140 self.storage.set_commit_index(n).await?;
1141
1142 tracing::debug!(
1143 node_id = self.id,
1144 commit_index = n,
1145 is_joint = self.cluster_config.is_joint(),
1146 "Advanced commit index"
1147 );
1148
1149 if let Some(LogEntry {
1151 command: LogCommand::Config(config @ ClusterConfig::Joint { .. }),
1152 ..
1153 }) = entry
1154 {
1155 joint_entry_to_finalize = Some(config.clone());
1157 }
1158 }
1159 }
1160
1161 if let Some(config) = joint_entry_to_finalize {
1162 self.finalize_conf_change(config).await?;
1163 }
1164
1165 Ok(())
1166 }
1167
1168 async fn finalize_conf_change(
1172 &mut self,
1173 joint_config: ClusterConfig,
1174 ) -> Result<(), ConsensusError> {
1175 let new_nodes = match joint_config {
1176 ClusterConfig::Joint { old: _, new } => new,
1177 _ => {
1178 return Err(ConsensusError::ConfigChangeError(
1179 "Not in Joint state".into(),
1180 ));
1181 }
1182 };
1183
1184 let final_config = ClusterConfig::Single(new_nodes);
1185 let term = self.storage.get_term().await.unwrap_or(0);
1186 let log_info = self.storage.get_last_log_info().await?;
1187 let next_idx = log_info.last_index + 1;
1188
1189 let entry = LogEntry::config(next_idx, term, final_config.clone())?;
1190 self.storage.append_entries(&[entry]).await?;
1191
1192 self.cluster_config = final_config.clone();
1194
1195 tracing::info!(
1196 node_id = self.id,
1197 index = next_idx,
1198 "Leader finalized configuration to Single state"
1199 );
1200
1201 Ok(())
1202 }
1203}
1204
1205#[async_trait]
1206impl<T, N, S> ConsensusEngine<T> for RaftEngine<T, N, S>
1207where
1208 T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
1209 N: RaftNetwork<T> + Send + Sync,
1210 S: RaftStorage<T> + Send,
1211{
1212 async fn run(&mut self) -> Result<(), ConsensusError> {
1213 self.run_loop().await
1214 }
1215
1216 async fn propose(&mut self, value: T) -> Result<LogIndex, ConsensusError> {
1217 if self.role != RaftRole::Leader {
1218 return Err(ConsensusError::NotLeader);
1219 }
1220
1221 let term = self.storage.get_term().await.unwrap_or(0);
1222 let log_info = self.storage.get_last_log_info().await?;
1223 let new_index = log_info.last_index + 1;
1224
1225 let entry = LogEntry::new(new_index, term, value)?;
1226
1227 self.storage.append_entries(&[entry]).await?;
1228
1229 tracing::debug!(
1230 node_id = self.id,
1231 index = new_index,
1232 term = term,
1233 "Appended new entry"
1234 );
1235
1236 Ok(new_index)
1237 }
1238
1239 async fn propose_conf_change(
1240 &mut self,
1241 change: crate::raft::ConfChange,
1242 ) -> Result<LogIndex, ConsensusError> {
1243 if self.role != RaftRole::Leader {
1244 return Err(ConsensusError::NotLeader);
1245 }
1246
1247 if self.cluster_config.is_joint() {
1248 return Err(ConsensusError::ConfigChangeInProgress);
1249 }
1250
1251 let old_nodes = self.cluster_config.all_nodes();
1252 let mut new_nodes = old_nodes.clone();
1253 match change {
1254 crate::raft::ConfChange::AddNode(id) => {
1255 if !new_nodes.contains(&id) {
1256 new_nodes.push(id);
1257 }
1258 }
1259 crate::raft::ConfChange::RemoveNode(id) => {
1260 new_nodes.retain(|&x| x != id);
1261 }
1262 }
1263
1264 let joint_config = ClusterConfig::Joint {
1265 old: old_nodes,
1266 new: new_nodes,
1267 };
1268
1269 let term = self.storage.get_term().await.unwrap_or(0);
1270 let log_info = self.storage.get_last_log_info().await?;
1271 let next_idx = log_info.last_index + 1;
1272
1273 let entry = LogEntry::config(next_idx, term, joint_config.clone())?;
1274 self.storage.append_entries(&[entry]).await?;
1275
1276 self.cluster_config = joint_config;
1278
1279 tracing::info!(
1280 node_id = self.id,
1281 index = next_idx,
1282 "Leader proposed Joint Consensus configuration change"
1283 );
1284
1285 Ok(next_idx)
1286 }
1287
1288 fn leader_id(&self) -> Option<NodeId> {
1289 self.current_leader
1290 }
1291
1292 fn is_leader(&self) -> bool {
1293 self.role == RaftRole::Leader
1294 }
1295
1296 async fn current_term(&self) -> Term {
1297 self.storage.get_term().await.unwrap_or(0)
1298 }
1299
1300 fn commit_index(&self) -> LogIndex {
1301 self.commit_index
1302 }
1303}
1304
1305pub struct ConsensusFactory;
1311
1312impl ConsensusFactory {
1313 pub fn create_engine<T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static>(
1315 strategy: ConsensusStrategy,
1316 id: NodeId,
1317 network: Box<dyn ConsensusNetwork>,
1318 storage: Box<dyn RaftStorage<T>>,
1319 ) -> Result<Box<dyn ConsensusEngine<T>>, ConsensusError> {
1320 tracing::info!(
1321 strategy = ?strategy,
1322 node_id = id,
1323 "Creating consensus engine"
1324 );
1325
1326 match strategy {
1327 ConsensusStrategy::Raft => {
1328 Ok(Box::new(LegacyRaftEngineAdapter::new(id, network, storage)))
1329 }
1330 ConsensusStrategy::Paxos => {
1331 Err(ConsensusError::NotImplemented(ConsensusStrategy::Paxos))
1332 }
1333 }
1334 }
1335}
1336
1337struct LegacyRaftEngineAdapter<T: Send + Sync> {
1339 node: crate::raft::RaftNode<T>,
1340}
1341
1342impl<T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static> LegacyRaftEngineAdapter<T> {
1343 fn new(
1344 id: NodeId,
1345 network: Box<dyn ConsensusNetwork>,
1346 storage: Box<dyn RaftStorage<T>>,
1347 ) -> Self {
1348 Self {
1349 node: crate::raft::RaftNode::new(id, network, storage, RaftConfig::default()),
1350 }
1351 }
1352}
1353
1354#[async_trait]
1355impl<T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static> ConsensusEngine<T>
1356 for LegacyRaftEngineAdapter<T>
1357{
1358 async fn run(&mut self) -> Result<(), ConsensusError> {
1359 tracing::info!(
1360 node_id = self.node.id,
1361 "Starting legacy Raft consensus loop"
1362 );
1363
1364 loop {
1365 tokio::task::yield_now().await;
1366 }
1367 }
1368
1369 async fn propose(&mut self, _value: T) -> Result<LogIndex, ConsensusError> {
1370 Err(ConsensusError::NotLeader)
1372 }
1373
1374 async fn propose_conf_change(
1375 &mut self,
1376 _change: crate::raft::ConfChange,
1377 ) -> Result<LogIndex, ConsensusError> {
1378 Err(ConsensusError::NotLeader)
1379 }
1380
1381 fn leader_id(&self) -> Option<NodeId> {
1382 if self.node.role == RaftRole::Leader {
1383 Some(self.node.id)
1384 } else {
1385 None
1386 }
1387 }
1388
1389 fn is_leader(&self) -> bool {
1390 self.node.role == RaftRole::Leader
1391 }
1392
1393 async fn current_term(&self) -> Term {
1394 self.node.storage.get_term().await.unwrap_or(0)
1395 }
1396
1397 fn commit_index(&self) -> LogIndex {
1398 self.node.commit_index
1399 }
1400}
1401
1402#[cfg(test)]
1407mod tests {
1408 use super::*;
1409 use crate::raft::InMemoryStorage;
1410
1411 struct MockNetwork;
1412
1413 #[async_trait]
1414 impl ConsensusNetwork for MockNetwork {
1415 async fn broadcast_vote_request(
1416 &self,
1417 _term: Term,
1418 _candidate_id: NodeId,
1419 ) -> Result<(), String> {
1420 Ok(())
1421 }
1422
1423 async fn send_heartbeat(&self, _leader_id: NodeId, _term: Term) -> Result<(), String> {
1424 Ok(())
1425 }
1426
1427 async fn receive(&self) -> Result<crate::network::Packet, String> {
1428 futures::future::pending().await
1429 }
1430
1431 async fn update_peers(&self, _peers: Vec<String>) -> Result<(), String> {
1432 Ok(())
1433 }
1434 }
1435
1436 #[test]
1437 fn test_create_raft_engine() {
1438 let storage: Box<dyn RaftStorage<String>> = Box::new(InMemoryStorage::new());
1439 let network: Box<dyn ConsensusNetwork> = Box::new(MockNetwork);
1440
1441 let result = ConsensusFactory::create_engine(ConsensusStrategy::Raft, 1, network, storage);
1442
1443 assert!(result.is_ok());
1444 }
1445
1446 #[test]
1447 fn test_paxos_not_implemented() {
1448 let storage: Box<dyn RaftStorage<String>> = Box::new(InMemoryStorage::new());
1449 let network: Box<dyn ConsensusNetwork> = Box::new(MockNetwork);
1450
1451 let result = ConsensusFactory::create_engine(ConsensusStrategy::Paxos, 1, network, storage);
1452
1453 assert!(matches!(result, Err(ConsensusError::NotImplemented(_))));
1454 }
1455
1456 #[test]
1457 fn test_raft_config_default() {
1458 let config = RaftConfig::default();
1459 assert!(config.heartbeat_interval < config.election_timeout_min);
1460 }
1461
1462 #[test]
1463 fn test_raft_config_random_timeout() {
1464 let config = RaftConfig::default();
1465 let t1 = config.random_election_timeout();
1466 let t2 = config.random_election_timeout();
1467
1468 assert!(t1 >= config.election_timeout_min);
1470 assert!(t1 <= config.election_timeout_max);
1471 assert!(t2 >= config.election_timeout_min);
1472 assert!(t2 <= config.election_timeout_max);
1473 }
1474}