1use crate::traits::BlockStore;
37use ipfrs_core::{Block, Cid, Result};
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::Arc;
42use std::time::{Duration, Instant};
43use tokio::sync::{mpsc, oneshot};
44use tokio::time;
45use tracing::{debug, info};
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct NodeId(pub u64);
50
51impl std::fmt::Display for NodeId {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "Node({})", self.0)
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
59pub struct Term(pub u64);
60
61impl Term {
62 pub fn increment(&mut self) {
63 self.0 += 1;
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
69pub struct LogIndex(pub u64);
70
71impl LogIndex {
72 pub fn increment(&mut self) {
73 self.0 += 1;
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum NodeState {
80 Follower,
82 Candidate,
84 Leader,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct LogEntry {
91 pub term: Term,
93 pub index: LogIndex,
95 pub command: Command,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub enum Command {
102 Put { cid_bytes: Vec<u8>, data: Vec<u8> },
104 Delete { cid_bytes: Vec<u8> },
106 NoOp,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct AppendEntriesRequest {
113 pub term: Term,
115 pub leader_id: NodeId,
117 pub prev_log_index: LogIndex,
119 pub prev_log_term: Term,
121 pub entries: Vec<LogEntry>,
123 pub leader_commit: LogIndex,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct AppendEntriesResponse {
130 pub term: Term,
132 pub success: bool,
134 pub conflict_index: Option<LogIndex>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct RequestVoteRequest {
141 pub term: Term,
143 pub candidate_id: NodeId,
145 pub last_log_index: LogIndex,
147 pub last_log_term: Term,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct RequestVoteResponse {
154 pub term: Term,
156 pub vote_granted: bool,
158}
159
160#[derive(Debug, Clone)]
162pub struct RaftConfig {
163 pub heartbeat_interval: Duration,
165 pub election_timeout_min: Duration,
167 pub election_timeout_max: Duration,
168 pub max_entries_per_append: usize,
170}
171
172impl Default for RaftConfig {
173 fn default() -> Self {
174 Self {
175 heartbeat_interval: Duration::from_millis(50),
176 election_timeout_min: Duration::from_millis(150),
177 election_timeout_max: Duration::from_millis(300),
178 max_entries_per_append: 100,
179 }
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185struct PersistentState {
186 current_term: Term,
188 voted_for: Option<NodeId>,
190}
191
192impl Default for PersistentState {
193 fn default() -> Self {
194 Self {
195 current_term: Term(0),
196 voted_for: None,
197 }
198 }
199}
200
201#[derive(Debug, Default)]
203struct VolatileState {
204 commit_index: LogIndex,
206 last_applied: LogIndex,
208}
209
210#[derive(Debug)]
212#[allow(dead_code)]
213struct LeaderState {
214 next_index: HashMap<NodeId, LogIndex>,
216 match_index: HashMap<NodeId, LogIndex>,
218}
219
220pub struct RaftNode<S: BlockStore> {
222 id: NodeId,
224 peers: Vec<NodeId>,
226 state: Arc<RwLock<NodeState>>,
228 persistent: Arc<RwLock<PersistentState>>,
230 volatile: Arc<RwLock<VolatileState>>,
232 #[allow(dead_code)]
234 leader_state: Arc<RwLock<Option<LeaderState>>>,
235 log: Arc<RwLock<Vec<LogEntry>>>,
237 store: Arc<S>,
239 config: RaftConfig,
241 last_heartbeat: Arc<RwLock<Instant>>,
243 current_leader: Arc<RwLock<Option<NodeId>>>,
245 rpc_tx: mpsc::UnboundedSender<RpcMessage>,
247 rpc_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<RpcMessage>>>>,
248}
249
250#[derive(Debug)]
252#[allow(dead_code)]
253enum RpcMessage {
254 AppendEntries {
255 request: AppendEntriesRequest,
256 response_tx: oneshot::Sender<AppendEntriesResponse>,
257 },
258 RequestVote {
259 request: RequestVoteRequest,
260 response_tx: oneshot::Sender<RequestVoteResponse>,
261 },
262}
263
264impl<S: BlockStore + Send + Sync + 'static> RaftNode<S> {
265 pub fn new(id: NodeId, peers: Vec<NodeId>, store: S, config: RaftConfig) -> Result<Self> {
267 let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
268
269 Ok(Self {
270 id,
271 peers,
272 state: Arc::new(RwLock::new(NodeState::Follower)),
273 persistent: Arc::new(RwLock::new(PersistentState::default())),
274 volatile: Arc::new(RwLock::new(VolatileState::default())),
275 leader_state: Arc::new(RwLock::new(None)),
276 log: Arc::new(RwLock::new(Vec::new())),
277 store: Arc::new(store),
278 config,
279 last_heartbeat: Arc::new(RwLock::new(Instant::now())),
280 current_leader: Arc::new(RwLock::new(None)),
281 rpc_tx,
282 rpc_rx: Arc::new(RwLock::new(Some(rpc_rx))),
283 })
284 }
285
286 pub async fn start(&mut self) -> Result<()> {
288 info!("Starting RAFT node {}", self.id);
289
290 let mut rpc_rx = self
292 .rpc_rx
293 .write()
294 .take()
295 .ok_or_else(|| ipfrs_core::Error::Internal("Node already started".to_string()))?;
296
297 let _election_handle = self.spawn_election_timer();
299
300 loop {
302 tokio::select! {
303 Some(msg) = rpc_rx.recv() => {
305 self.handle_rpc(msg).await?;
306 }
307 _ = time::sleep(Duration::from_millis(10)) => {
309 self.apply_committed_entries().await?;
310 }
311 }
312 }
313 }
314
315 fn spawn_election_timer(&self) -> tokio::task::JoinHandle<()> {
317 let id = self.id;
318 let state = Arc::clone(&self.state);
319 let persistent = Arc::clone(&self.persistent);
320 let last_heartbeat = Arc::clone(&self.last_heartbeat);
321 let config = self.config.clone();
322 let _peers = self.peers.clone();
323 let _log = Arc::clone(&self.log);
324 let _rpc_tx = self.rpc_tx.clone();
325
326 tokio::spawn(async move {
327 loop {
328 let timeout = Self::random_election_timeout(&config);
330 time::sleep(timeout).await;
331
332 let current_state = *state.read();
334 let elapsed = last_heartbeat.read().elapsed();
335
336 if current_state != NodeState::Leader && elapsed >= timeout {
337 info!("{}: Election timeout, starting election", id);
338 *state.write() = NodeState::Candidate;
340 persistent.write().current_term.increment();
341 persistent.write().voted_for = Some(id);
342 }
343 }
344 })
345 }
346
347 fn random_election_timeout(config: &RaftConfig) -> Duration {
349 use rand::Rng;
350 let min = config.election_timeout_min.as_millis() as u64;
351 let max = config.election_timeout_max.as_millis() as u64;
352 let timeout_ms = rand::rng().random_range(min..=max);
353 Duration::from_millis(timeout_ms)
354 }
355
356 async fn handle_rpc(&self, msg: RpcMessage) -> Result<()> {
358 match msg {
359 RpcMessage::AppendEntries {
360 request,
361 response_tx,
362 } => {
363 let response = self.handle_append_entries(request).await?;
364 let _ = response_tx.send(response);
365 }
366 RpcMessage::RequestVote {
367 request,
368 response_tx,
369 } => {
370 let response = self.handle_request_vote(request).await?;
371 let _ = response_tx.send(response);
372 }
373 }
374 Ok(())
375 }
376
377 #[allow(clippy::unused_async)]
379 async fn handle_append_entries(
380 &self,
381 request: AppendEntriesRequest,
382 ) -> Result<AppendEntriesResponse> {
383 let mut persistent = self.persistent.write();
384 let current_term = persistent.current_term;
385
386 if request.term < current_term {
388 return Ok(AppendEntriesResponse {
389 term: current_term,
390 success: false,
391 conflict_index: None,
392 });
393 }
394
395 if request.term > current_term {
397 persistent.current_term = request.term;
398 persistent.voted_for = None;
399 *self.state.write() = NodeState::Follower;
400 }
401
402 *self.last_heartbeat.write() = Instant::now();
404 *self.current_leader.write() = Some(request.leader_id);
405
406 let mut log = self.log.write();
407
408 if request.prev_log_index.0 > 0 {
410 if request.prev_log_index.0 > log.len() as u64 {
411 return Ok(AppendEntriesResponse {
412 term: persistent.current_term,
413 success: false,
414 conflict_index: Some(LogIndex(log.len() as u64)),
415 });
416 }
417
418 let prev_entry = &log[(request.prev_log_index.0 - 1) as usize];
419 if prev_entry.term != request.prev_log_term {
420 let conflict_term = prev_entry.term;
422 let mut conflict_index = request.prev_log_index.0;
423 for entry in log.iter().rev() {
424 if entry.term != conflict_term {
425 break;
426 }
427 conflict_index = entry.index.0;
428 }
429
430 return Ok(AppendEntriesResponse {
431 term: persistent.current_term,
432 success: false,
433 conflict_index: Some(LogIndex(conflict_index)),
434 });
435 }
436 }
437
438 for entry in request.entries {
440 let index = (entry.index.0 - 1) as usize;
441 if index >= log.len() {
442 log.push(entry);
443 } else if log[index].term != entry.term {
444 log.truncate(index);
446 log.push(entry);
447 }
448 }
449
450 if request.leader_commit.0 > self.volatile.read().commit_index.0 {
452 let new_commit = request.leader_commit.0.min(log.len() as u64);
453 self.volatile.write().commit_index = LogIndex(new_commit);
454 }
455
456 Ok(AppendEntriesResponse {
457 term: persistent.current_term,
458 success: true,
459 conflict_index: None,
460 })
461 }
462
463 #[allow(clippy::unused_async)]
465 async fn handle_request_vote(
466 &self,
467 request: RequestVoteRequest,
468 ) -> Result<RequestVoteResponse> {
469 let mut persistent = self.persistent.write();
470 let current_term = persistent.current_term;
471
472 if request.term < current_term {
474 return Ok(RequestVoteResponse {
475 term: current_term,
476 vote_granted: false,
477 });
478 }
479
480 if request.term > current_term {
482 persistent.current_term = request.term;
483 persistent.voted_for = None;
484 *self.state.write() = NodeState::Follower;
485 }
486
487 let vote_granted = if persistent.voted_for.is_none()
489 || persistent.voted_for == Some(request.candidate_id)
490 {
491 let log = self.log.read();
493 let last_log_index = log.len() as u64;
494 let last_log_term = log.last().map(|e| e.term).unwrap_or(Term(0));
495
496 let log_ok = request.last_log_term > last_log_term
497 || (request.last_log_term == last_log_term
498 && request.last_log_index.0 >= last_log_index);
499
500 if log_ok {
501 persistent.voted_for = Some(request.candidate_id);
502 true
503 } else {
504 false
505 }
506 } else {
507 false
508 };
509
510 Ok(RequestVoteResponse {
511 term: persistent.current_term,
512 vote_granted,
513 })
514 }
515
516 async fn apply_committed_entries(&self) -> Result<()> {
518 let commit_index = self.volatile.read().commit_index;
519
520 loop {
521 let command = {
523 let mut volatile = self.volatile.write();
524
525 if volatile.last_applied.0 >= commit_index.0 {
526 break;
527 }
528
529 volatile.last_applied.0 += 1;
530 let entry = &self.log.read()[(volatile.last_applied.0 - 1) as usize];
531 entry.command.clone()
532 }; match command {
536 Command::Put { cid_bytes, data } => {
537 if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
539 let block = Block::from_parts(cid, bytes::Bytes::from(data));
540 self.store.put(&block).await?;
541 debug!("Applied PUT: {}", block.cid());
542 }
543 }
544 Command::Delete { cid_bytes } => {
545 if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
547 self.store.delete(&cid).await?;
548 debug!("Applied DELETE: {}", cid);
549 }
550 }
551 Command::NoOp => {
552 debug!("Applied NoOp");
553 }
554 }
555 }
556
557 Ok(())
558 }
559
560 #[allow(clippy::unused_async)]
562 pub async fn append_entry(&self, command: Command) -> Result<LogIndex> {
563 let state = *self.state.read();
564 if state != NodeState::Leader {
565 return Err(ipfrs_core::Error::Internal("Not the leader".to_string()));
566 }
567
568 let mut log = self.log.write();
569 let index = LogIndex((log.len() + 1) as u64);
570 let term = self.persistent.read().current_term;
571
572 let entry = LogEntry {
573 term,
574 index,
575 command,
576 };
577
578 log.push(entry);
579 Ok(index)
580 }
581
582 pub fn current_leader(&self) -> Option<NodeId> {
584 *self.current_leader.read()
585 }
586
587 pub fn is_leader(&self) -> bool {
589 *self.state.read() == NodeState::Leader
590 }
591
592 pub fn current_term(&self) -> Term {
594 self.persistent.read().current_term
595 }
596}
597
598#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct RaftStats {
601 pub node_id: NodeId,
603 pub state: String,
605 pub term: Term,
607 pub leader: Option<NodeId>,
609 pub log_size: usize,
611 pub commit_index: LogIndex,
613 pub last_applied: LogIndex,
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::memory::MemoryBlockStore;
621
622 #[tokio::test]
623 async fn test_node_creation() {
624 let store = MemoryBlockStore::new();
625 let config = RaftConfig::default();
626 let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config);
627 assert!(node.is_ok());
628 }
629
630 #[tokio::test]
631 async fn test_append_entries_lower_term() {
632 let store = MemoryBlockStore::new();
633 let config = RaftConfig::default();
634 let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
635
636 node.persistent.write().current_term = Term(5);
638
639 let request = AppendEntriesRequest {
640 term: Term(3),
641 leader_id: NodeId(2),
642 prev_log_index: LogIndex(0),
643 prev_log_term: Term(0),
644 entries: vec![],
645 leader_commit: LogIndex(0),
646 };
647
648 let response = node.handle_append_entries(request).await.unwrap();
649 assert!(!response.success);
650 assert_eq!(response.term, Term(5));
651 }
652
653 #[tokio::test]
654 async fn test_request_vote_grant() {
655 let store = MemoryBlockStore::new();
656 let config = RaftConfig::default();
657 let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
658
659 let request = RequestVoteRequest {
660 term: Term(1),
661 candidate_id: NodeId(2),
662 last_log_index: LogIndex(0),
663 last_log_term: Term(0),
664 };
665
666 let response = node.handle_request_vote(request).await.unwrap();
667 assert!(response.vote_granted);
668 assert_eq!(node.persistent.read().voted_for, Some(NodeId(2)));
669 }
670
671 #[tokio::test]
672 async fn test_request_vote_deny_already_voted() {
673 let store = MemoryBlockStore::new();
674 let config = RaftConfig::default();
675 let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
676
677 node.persistent.write().voted_for = Some(NodeId(2));
679 node.persistent.write().current_term = Term(1);
680
681 let request = RequestVoteRequest {
683 term: Term(1),
684 candidate_id: NodeId(3),
685 last_log_index: LogIndex(0),
686 last_log_term: Term(0),
687 };
688
689 let response = node.handle_request_vote(request).await.unwrap();
690 assert!(!response.vote_granted);
691 }
692}