ipfrs_storage/
raft.rs

1//! RAFT Consensus Protocol Implementation
2//!
3//! This module implements the RAFT consensus algorithm for distributed storage.
4//! RAFT provides strong consistency guarantees through leader election and log replication.
5//!
6//! # Architecture
7//!
8//! - **RaftNode**: Main RAFT node that participates in consensus
9//! - **RaftLog**: Append-only log of operations
10//! - **StateMachine**: Applies committed operations to the underlying BlockStore
11//! - **RPC Protocol**: AppendEntries and RequestVote for node communication
12//!
13//! # Example
14//!
15//! ```ignore
16//! use ipfrs_storage::raft::{RaftNode, RaftConfig, NodeId};
17//! use ipfrs_storage::sled::SledBlockStore;
18//!
19//! #[tokio::main]
20//! async fn main() -> anyhow::Result<()> {
21//!     let store = SledBlockStore::new("/tmp/raft-node-1")?;
22//!     let config = RaftConfig::default();
23//!
24//!     let mut node = RaftNode::new(
25//!         NodeId(1),
26//!         vec![NodeId(2), NodeId(3)],
27//!         store,
28//!         config,
29//!     )?;
30//!
31//!     node.start().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use crate::traits::BlockStore;
37use ipfrs_core::{Block, Cid, Result};
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::Arc;
42use std::time::{Duration, Instant};
43use tokio::sync::{mpsc, oneshot};
44use tokio::time;
45use tracing::{debug, info};
46
47/// Unique identifier for a RAFT node
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct NodeId(pub u64);
50
51impl std::fmt::Display for NodeId {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "Node({})", self.0)
54    }
55}
56
57/// RAFT term number (monotonically increasing)
58#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
59pub struct Term(pub u64);
60
61impl Term {
62    pub fn increment(&mut self) {
63        self.0 += 1;
64    }
65}
66
67/// Index in the RAFT log
68#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
69pub struct LogIndex(pub u64);
70
71impl LogIndex {
72    pub fn increment(&mut self) {
73        self.0 += 1;
74    }
75}
76
77/// Node state in RAFT protocol
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum NodeState {
80    /// Follower state (default)
81    Follower,
82    /// Candidate state (during election)
83    Candidate,
84    /// Leader state (elected leader)
85    Leader,
86}
87
88/// RAFT log entry containing a command
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct LogEntry {
91    /// Term when entry was received by leader
92    pub term: Term,
93    /// Index in the log
94    pub index: LogIndex,
95    /// Command to execute on state machine
96    pub command: Command,
97}
98
99/// Command that can be applied to the state machine (BlockStore)
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub enum Command {
102    /// Put a block (stores CID and data separately)
103    Put { cid_bytes: Vec<u8>, data: Vec<u8> },
104    /// Delete a block (CID stored as bytes)
105    Delete { cid_bytes: Vec<u8> },
106    /// No-op (used for leader election)
107    NoOp,
108}
109
110/// AppendEntries RPC request
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct AppendEntriesRequest {
113    /// Leader's term
114    pub term: Term,
115    /// Leader's ID (so follower can redirect clients)
116    pub leader_id: NodeId,
117    /// Index of log entry immediately preceding new ones
118    pub prev_log_index: LogIndex,
119    /// Term of prev_log_index entry
120    pub prev_log_term: Term,
121    /// Log entries to store (empty for heartbeat)
122    pub entries: Vec<LogEntry>,
123    /// Leader's commit index
124    pub leader_commit: LogIndex,
125}
126
127/// AppendEntries RPC response
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct AppendEntriesResponse {
130    /// Current term, for leader to update itself
131    pub term: Term,
132    /// True if follower contained entry matching prev_log_index and prev_log_term
133    pub success: bool,
134    /// Hint for leader to backtrack (next index to try)
135    pub conflict_index: Option<LogIndex>,
136}
137
138/// RequestVote RPC request
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct RequestVoteRequest {
141    /// Candidate's term
142    pub term: Term,
143    /// Candidate requesting vote
144    pub candidate_id: NodeId,
145    /// Index of candidate's last log entry
146    pub last_log_index: LogIndex,
147    /// Term of candidate's last log entry
148    pub last_log_term: Term,
149}
150
151/// RequestVote RPC response
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct RequestVoteResponse {
154    /// Current term, for candidate to update itself
155    pub term: Term,
156    /// True means candidate received vote
157    pub vote_granted: bool,
158}
159
160/// Configuration for RAFT node
161#[derive(Debug, Clone)]
162pub struct RaftConfig {
163    /// Heartbeat interval (when leader)
164    pub heartbeat_interval: Duration,
165    /// Election timeout range (randomized to avoid split votes)
166    pub election_timeout_min: Duration,
167    pub election_timeout_max: Duration,
168    /// Maximum number of entries to send in one AppendEntries RPC
169    pub max_entries_per_append: usize,
170}
171
172impl Default for RaftConfig {
173    fn default() -> Self {
174        Self {
175            heartbeat_interval: Duration::from_millis(50),
176            election_timeout_min: Duration::from_millis(150),
177            election_timeout_max: Duration::from_millis(300),
178            max_entries_per_append: 100,
179        }
180    }
181}
182
183/// Persistent state on all servers (must survive restarts)
184#[derive(Debug, Clone, Serialize, Deserialize)]
185struct PersistentState {
186    /// Latest term server has seen
187    current_term: Term,
188    /// Candidate that received vote in current term
189    voted_for: Option<NodeId>,
190}
191
192impl Default for PersistentState {
193    fn default() -> Self {
194        Self {
195            current_term: Term(0),
196            voted_for: None,
197        }
198    }
199}
200
201/// Volatile state on all servers
202#[derive(Debug, Default)]
203struct VolatileState {
204    /// Index of highest log entry known to be committed
205    commit_index: LogIndex,
206    /// Index of highest log entry applied to state machine
207    last_applied: LogIndex,
208}
209
210/// Volatile state on leaders (reinitialized after election)
211#[derive(Debug)]
212#[allow(dead_code)]
213struct LeaderState {
214    /// For each server, index of next log entry to send
215    next_index: HashMap<NodeId, LogIndex>,
216    /// For each server, index of highest log entry known to be replicated
217    match_index: HashMap<NodeId, LogIndex>,
218}
219
220/// RAFT node that participates in consensus
221pub struct RaftNode<S: BlockStore> {
222    /// This node's ID
223    id: NodeId,
224    /// Other nodes in the cluster
225    peers: Vec<NodeId>,
226    /// Current state (Follower, Candidate, Leader)
227    state: Arc<RwLock<NodeState>>,
228    /// Persistent state
229    persistent: Arc<RwLock<PersistentState>>,
230    /// Volatile state
231    volatile: Arc<RwLock<VolatileState>>,
232    /// Leader state (only valid when Leader)
233    #[allow(dead_code)]
234    leader_state: Arc<RwLock<Option<LeaderState>>>,
235    /// RAFT log
236    log: Arc<RwLock<Vec<LogEntry>>>,
237    /// Underlying block store (state machine)
238    store: Arc<S>,
239    /// Configuration
240    config: RaftConfig,
241    /// Last time we heard from leader (for election timeout)
242    last_heartbeat: Arc<RwLock<Instant>>,
243    /// Current leader (if known)
244    current_leader: Arc<RwLock<Option<NodeId>>>,
245    /// Channel for RPC requests
246    rpc_tx: mpsc::UnboundedSender<RpcMessage>,
247    rpc_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<RpcMessage>>>>,
248}
249
250/// RPC message for internal communication
251#[derive(Debug)]
252#[allow(dead_code)]
253enum RpcMessage {
254    AppendEntries {
255        request: AppendEntriesRequest,
256        response_tx: oneshot::Sender<AppendEntriesResponse>,
257    },
258    RequestVote {
259        request: RequestVoteRequest,
260        response_tx: oneshot::Sender<RequestVoteResponse>,
261    },
262}
263
264impl<S: BlockStore + Send + Sync + 'static> RaftNode<S> {
265    /// Create a new RAFT node
266    pub fn new(id: NodeId, peers: Vec<NodeId>, store: S, config: RaftConfig) -> Result<Self> {
267        let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
268
269        Ok(Self {
270            id,
271            peers,
272            state: Arc::new(RwLock::new(NodeState::Follower)),
273            persistent: Arc::new(RwLock::new(PersistentState::default())),
274            volatile: Arc::new(RwLock::new(VolatileState::default())),
275            leader_state: Arc::new(RwLock::new(None)),
276            log: Arc::new(RwLock::new(Vec::new())),
277            store: Arc::new(store),
278            config,
279            last_heartbeat: Arc::new(RwLock::new(Instant::now())),
280            current_leader: Arc::new(RwLock::new(None)),
281            rpc_tx,
282            rpc_rx: Arc::new(RwLock::new(Some(rpc_rx))),
283        })
284    }
285
286    /// Start the RAFT node
287    pub async fn start(&mut self) -> Result<()> {
288        info!("Starting RAFT node {}", self.id);
289
290        // Take the receiver out of the option
291        let mut rpc_rx = self
292            .rpc_rx
293            .write()
294            .take()
295            .ok_or_else(|| ipfrs_core::Error::Internal("Node already started".to_string()))?;
296
297        // Spawn election timer
298        let _election_handle = self.spawn_election_timer();
299
300        // Main event loop
301        loop {
302            tokio::select! {
303                // Handle RPC messages
304                Some(msg) = rpc_rx.recv() => {
305                    self.handle_rpc(msg).await?;
306                }
307                // Periodic tasks (apply committed entries)
308                _ = time::sleep(Duration::from_millis(10)) => {
309                    self.apply_committed_entries().await?;
310                }
311            }
312        }
313    }
314
315    /// Spawn election timer task
316    fn spawn_election_timer(&self) -> tokio::task::JoinHandle<()> {
317        let id = self.id;
318        let state = Arc::clone(&self.state);
319        let persistent = Arc::clone(&self.persistent);
320        let last_heartbeat = Arc::clone(&self.last_heartbeat);
321        let config = self.config.clone();
322        let _peers = self.peers.clone();
323        let _log = Arc::clone(&self.log);
324        let _rpc_tx = self.rpc_tx.clone();
325
326        tokio::spawn(async move {
327            loop {
328                // Calculate randomized election timeout
329                let timeout = Self::random_election_timeout(&config);
330                time::sleep(timeout).await;
331
332                // Check if we should start an election
333                let current_state = *state.read();
334                let elapsed = last_heartbeat.read().elapsed();
335
336                if current_state != NodeState::Leader && elapsed >= timeout {
337                    info!("{}: Election timeout, starting election", id);
338                    // Start election (simplified - would need to send RequestVote RPCs)
339                    *state.write() = NodeState::Candidate;
340                    persistent.write().current_term.increment();
341                    persistent.write().voted_for = Some(id);
342                }
343            }
344        })
345    }
346
347    /// Get a random election timeout
348    fn random_election_timeout(config: &RaftConfig) -> Duration {
349        use rand::Rng;
350        let min = config.election_timeout_min.as_millis() as u64;
351        let max = config.election_timeout_max.as_millis() as u64;
352        let timeout_ms = rand::rng().random_range(min..=max);
353        Duration::from_millis(timeout_ms)
354    }
355
356    /// Handle incoming RPC message
357    async fn handle_rpc(&self, msg: RpcMessage) -> Result<()> {
358        match msg {
359            RpcMessage::AppendEntries {
360                request,
361                response_tx,
362            } => {
363                let response = self.handle_append_entries(request).await?;
364                let _ = response_tx.send(response);
365            }
366            RpcMessage::RequestVote {
367                request,
368                response_tx,
369            } => {
370                let response = self.handle_request_vote(request).await?;
371                let _ = response_tx.send(response);
372            }
373        }
374        Ok(())
375    }
376
377    /// Handle AppendEntries RPC
378    #[allow(clippy::unused_async)]
379    async fn handle_append_entries(
380        &self,
381        request: AppendEntriesRequest,
382    ) -> Result<AppendEntriesResponse> {
383        let mut persistent = self.persistent.write();
384        let current_term = persistent.current_term;
385
386        // Reply false if term < currentTerm
387        if request.term < current_term {
388            return Ok(AppendEntriesResponse {
389                term: current_term,
390                success: false,
391                conflict_index: None,
392            });
393        }
394
395        // Update term if we see a higher one
396        if request.term > current_term {
397            persistent.current_term = request.term;
398            persistent.voted_for = None;
399            *self.state.write() = NodeState::Follower;
400        }
401
402        // Reset election timer (we heard from leader)
403        *self.last_heartbeat.write() = Instant::now();
404        *self.current_leader.write() = Some(request.leader_id);
405
406        let mut log = self.log.write();
407
408        // Reply false if log doesn't contain entry at prev_log_index with prev_log_term
409        if request.prev_log_index.0 > 0 {
410            if request.prev_log_index.0 > log.len() as u64 {
411                return Ok(AppendEntriesResponse {
412                    term: persistent.current_term,
413                    success: false,
414                    conflict_index: Some(LogIndex(log.len() as u64)),
415                });
416            }
417
418            let prev_entry = &log[(request.prev_log_index.0 - 1) as usize];
419            if prev_entry.term != request.prev_log_term {
420                // Find conflicting term's first index
421                let conflict_term = prev_entry.term;
422                let mut conflict_index = request.prev_log_index.0;
423                for entry in log.iter().rev() {
424                    if entry.term != conflict_term {
425                        break;
426                    }
427                    conflict_index = entry.index.0;
428                }
429
430                return Ok(AppendEntriesResponse {
431                    term: persistent.current_term,
432                    success: false,
433                    conflict_index: Some(LogIndex(conflict_index)),
434                });
435            }
436        }
437
438        // Append new entries
439        for entry in request.entries {
440            let index = (entry.index.0 - 1) as usize;
441            if index >= log.len() {
442                log.push(entry);
443            } else if log[index].term != entry.term {
444                // Delete conflicting entry and all that follow
445                log.truncate(index);
446                log.push(entry);
447            }
448        }
449
450        // Update commit index
451        if request.leader_commit.0 > self.volatile.read().commit_index.0 {
452            let new_commit = request.leader_commit.0.min(log.len() as u64);
453            self.volatile.write().commit_index = LogIndex(new_commit);
454        }
455
456        Ok(AppendEntriesResponse {
457            term: persistent.current_term,
458            success: true,
459            conflict_index: None,
460        })
461    }
462
463    /// Handle RequestVote RPC
464    #[allow(clippy::unused_async)]
465    async fn handle_request_vote(
466        &self,
467        request: RequestVoteRequest,
468    ) -> Result<RequestVoteResponse> {
469        let mut persistent = self.persistent.write();
470        let current_term = persistent.current_term;
471
472        // Reply false if term < currentTerm
473        if request.term < current_term {
474            return Ok(RequestVoteResponse {
475                term: current_term,
476                vote_granted: false,
477            });
478        }
479
480        // Update term if we see a higher one
481        if request.term > current_term {
482            persistent.current_term = request.term;
483            persistent.voted_for = None;
484            *self.state.write() = NodeState::Follower;
485        }
486
487        // Grant vote if we haven't voted or voted for this candidate
488        let vote_granted = if persistent.voted_for.is_none()
489            || persistent.voted_for == Some(request.candidate_id)
490        {
491            // Check if candidate's log is at least as up-to-date
492            let log = self.log.read();
493            let last_log_index = log.len() as u64;
494            let last_log_term = log.last().map(|e| e.term).unwrap_or(Term(0));
495
496            let log_ok = request.last_log_term > last_log_term
497                || (request.last_log_term == last_log_term
498                    && request.last_log_index.0 >= last_log_index);
499
500            if log_ok {
501                persistent.voted_for = Some(request.candidate_id);
502                true
503            } else {
504                false
505            }
506        } else {
507            false
508        };
509
510        Ok(RequestVoteResponse {
511            term: persistent.current_term,
512            vote_granted,
513        })
514    }
515
516    /// Apply committed entries to the state machine
517    async fn apply_committed_entries(&self) -> Result<()> {
518        let commit_index = self.volatile.read().commit_index;
519
520        loop {
521            // Extract the command while holding the lock
522            let command = {
523                let mut volatile = self.volatile.write();
524
525                if volatile.last_applied.0 >= commit_index.0 {
526                    break;
527                }
528
529                volatile.last_applied.0 += 1;
530                let entry = &self.log.read()[(volatile.last_applied.0 - 1) as usize];
531                entry.command.clone()
532            }; // Lock is dropped here
533
534            // Apply command to state machine (without holding the lock)
535            match command {
536                Command::Put { cid_bytes, data } => {
537                    // Reconstruct CID and Block
538                    if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
539                        let block = Block::from_parts(cid, bytes::Bytes::from(data));
540                        self.store.put(&block).await?;
541                        debug!("Applied PUT: {}", block.cid());
542                    }
543                }
544                Command::Delete { cid_bytes } => {
545                    // Deserialize CID from bytes
546                    if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
547                        self.store.delete(&cid).await?;
548                        debug!("Applied DELETE: {}", cid);
549                    }
550                }
551                Command::NoOp => {
552                    debug!("Applied NoOp");
553                }
554            }
555        }
556
557        Ok(())
558    }
559
560    /// Append a new entry to the log (leader only)
561    #[allow(clippy::unused_async)]
562    pub async fn append_entry(&self, command: Command) -> Result<LogIndex> {
563        let state = *self.state.read();
564        if state != NodeState::Leader {
565            return Err(ipfrs_core::Error::Internal("Not the leader".to_string()));
566        }
567
568        let mut log = self.log.write();
569        let index = LogIndex((log.len() + 1) as u64);
570        let term = self.persistent.read().current_term;
571
572        let entry = LogEntry {
573            term,
574            index,
575            command,
576        };
577
578        log.push(entry);
579        Ok(index)
580    }
581
582    /// Get the current leader ID
583    pub fn current_leader(&self) -> Option<NodeId> {
584        *self.current_leader.read()
585    }
586
587    /// Check if this node is the leader
588    pub fn is_leader(&self) -> bool {
589        *self.state.read() == NodeState::Leader
590    }
591
592    /// Get current term
593    pub fn current_term(&self) -> Term {
594        self.persistent.read().current_term
595    }
596}
597
598/// Statistics about the RAFT node
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct RaftStats {
601    /// Node ID
602    pub node_id: NodeId,
603    /// Current state
604    pub state: String,
605    /// Current term
606    pub term: Term,
607    /// Current leader (if known)
608    pub leader: Option<NodeId>,
609    /// Log size
610    pub log_size: usize,
611    /// Commit index
612    pub commit_index: LogIndex,
613    /// Last applied index
614    pub last_applied: LogIndex,
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620    use crate::memory::MemoryBlockStore;
621
622    #[tokio::test]
623    async fn test_node_creation() {
624        let store = MemoryBlockStore::new();
625        let config = RaftConfig::default();
626        let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config);
627        assert!(node.is_ok());
628    }
629
630    #[tokio::test]
631    async fn test_append_entries_lower_term() {
632        let store = MemoryBlockStore::new();
633        let config = RaftConfig::default();
634        let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
635
636        // Set current term to 5
637        node.persistent.write().current_term = Term(5);
638
639        let request = AppendEntriesRequest {
640            term: Term(3),
641            leader_id: NodeId(2),
642            prev_log_index: LogIndex(0),
643            prev_log_term: Term(0),
644            entries: vec![],
645            leader_commit: LogIndex(0),
646        };
647
648        let response = node.handle_append_entries(request).await.unwrap();
649        assert!(!response.success);
650        assert_eq!(response.term, Term(5));
651    }
652
653    #[tokio::test]
654    async fn test_request_vote_grant() {
655        let store = MemoryBlockStore::new();
656        let config = RaftConfig::default();
657        let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
658
659        let request = RequestVoteRequest {
660            term: Term(1),
661            candidate_id: NodeId(2),
662            last_log_index: LogIndex(0),
663            last_log_term: Term(0),
664        };
665
666        let response = node.handle_request_vote(request).await.unwrap();
667        assert!(response.vote_granted);
668        assert_eq!(node.persistent.read().voted_for, Some(NodeId(2)));
669    }
670
671    #[tokio::test]
672    async fn test_request_vote_deny_already_voted() {
673        let store = MemoryBlockStore::new();
674        let config = RaftConfig::default();
675        let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
676
677        // Vote for node 2
678        node.persistent.write().voted_for = Some(NodeId(2));
679        node.persistent.write().current_term = Term(1);
680
681        // Node 3 requests vote
682        let request = RequestVoteRequest {
683            term: Term(1),
684            candidate_id: NodeId(3),
685            last_log_index: LogIndex(0),
686            last_log_term: Term(0),
687        };
688
689        let response = node.handle_request_vote(request).await.unwrap();
690        assert!(!response.vote_granted);
691    }
692}