use crate::traits::BlockStore;
use ipfrs_core::{Block, Cid, Result};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, oneshot};
use tokio::time;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeId(pub u64);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Node({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Term(pub u64);
impl Term {
pub fn increment(&mut self) {
self.0 += 1;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub struct LogIndex(pub u64);
impl LogIndex {
pub fn increment(&mut self) {
self.0 += 1;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Follower,
Candidate,
Leader,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub term: Term,
pub index: LogIndex,
pub command: Command,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Command {
Put { cid_bytes: Vec<u8>, data: Vec<u8> },
Delete { cid_bytes: Vec<u8> },
NoOp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesRequest {
pub term: Term,
pub leader_id: NodeId,
pub prev_log_index: LogIndex,
pub prev_log_term: Term,
pub entries: Vec<LogEntry>,
pub leader_commit: LogIndex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesResponse {
pub term: Term,
pub success: bool,
pub conflict_index: Option<LogIndex>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteRequest {
pub term: Term,
pub candidate_id: NodeId,
pub last_log_index: LogIndex,
pub last_log_term: Term,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteResponse {
pub term: Term,
pub vote_granted: bool,
}
#[derive(Debug, Clone)]
pub struct RaftConfig {
pub heartbeat_interval: Duration,
pub election_timeout_min: Duration,
pub election_timeout_max: Duration,
pub max_entries_per_append: usize,
}
impl Default for RaftConfig {
fn default() -> Self {
Self {
heartbeat_interval: Duration::from_millis(50),
election_timeout_min: Duration::from_millis(150),
election_timeout_max: Duration::from_millis(300),
max_entries_per_append: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersistentState {
current_term: Term,
voted_for: Option<NodeId>,
}
impl Default for PersistentState {
fn default() -> Self {
Self {
current_term: Term(0),
voted_for: None,
}
}
}
#[derive(Debug, Default)]
struct VolatileState {
commit_index: LogIndex,
last_applied: LogIndex,
}
#[derive(Debug)]
#[allow(dead_code)]
struct LeaderState {
next_index: HashMap<NodeId, LogIndex>,
match_index: HashMap<NodeId, LogIndex>,
}
pub struct RaftNode<S: BlockStore> {
id: NodeId,
peers: Vec<NodeId>,
state: Arc<RwLock<NodeState>>,
persistent: Arc<RwLock<PersistentState>>,
volatile: Arc<RwLock<VolatileState>>,
#[allow(dead_code)]
leader_state: Arc<RwLock<Option<LeaderState>>>,
log: Arc<RwLock<Vec<LogEntry>>>,
store: Arc<S>,
config: RaftConfig,
last_heartbeat: Arc<RwLock<Instant>>,
current_leader: Arc<RwLock<Option<NodeId>>>,
rpc_tx: mpsc::UnboundedSender<RpcMessage>,
rpc_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<RpcMessage>>>>,
}
#[derive(Debug)]
#[allow(dead_code)]
enum RpcMessage {
AppendEntries {
request: AppendEntriesRequest,
response_tx: oneshot::Sender<AppendEntriesResponse>,
},
RequestVote {
request: RequestVoteRequest,
response_tx: oneshot::Sender<RequestVoteResponse>,
},
}
impl<S: BlockStore + Send + Sync + 'static> RaftNode<S> {
pub fn new(id: NodeId, peers: Vec<NodeId>, store: S, config: RaftConfig) -> Result<Self> {
let (rpc_tx, rpc_rx) = mpsc::unbounded_channel();
Ok(Self {
id,
peers,
state: Arc::new(RwLock::new(NodeState::Follower)),
persistent: Arc::new(RwLock::new(PersistentState::default())),
volatile: Arc::new(RwLock::new(VolatileState::default())),
leader_state: Arc::new(RwLock::new(None)),
log: Arc::new(RwLock::new(Vec::new())),
store: Arc::new(store),
config,
last_heartbeat: Arc::new(RwLock::new(Instant::now())),
current_leader: Arc::new(RwLock::new(None)),
rpc_tx,
rpc_rx: Arc::new(RwLock::new(Some(rpc_rx))),
})
}
pub async fn start(&mut self) -> Result<()> {
info!("Starting RAFT node {}", self.id);
let mut rpc_rx = self
.rpc_rx
.write()
.take()
.ok_or_else(|| ipfrs_core::Error::Internal("Node already started".to_string()))?;
let _election_handle = self.spawn_election_timer();
loop {
tokio::select! {
Some(msg) = rpc_rx.recv() => {
self.handle_rpc(msg).await?;
}
_ = time::sleep(Duration::from_millis(10)) => {
self.apply_committed_entries().await?;
}
}
}
}
fn spawn_election_timer(&self) -> tokio::task::JoinHandle<()> {
let id = self.id;
let state = Arc::clone(&self.state);
let persistent = Arc::clone(&self.persistent);
let last_heartbeat = Arc::clone(&self.last_heartbeat);
let config = self.config.clone();
let _peers = self.peers.clone();
let _log = Arc::clone(&self.log);
let _rpc_tx = self.rpc_tx.clone();
tokio::spawn(async move {
loop {
let timeout = Self::random_election_timeout(&config);
time::sleep(timeout).await;
let current_state = *state.read();
let elapsed = last_heartbeat.read().elapsed();
if current_state != NodeState::Leader && elapsed >= timeout {
info!("{}: Election timeout, starting election", id);
*state.write() = NodeState::Candidate;
persistent.write().current_term.increment();
persistent.write().voted_for = Some(id);
}
}
})
}
fn random_election_timeout(config: &RaftConfig) -> Duration {
use rand::Rng;
let min = config.election_timeout_min.as_millis() as u64;
let max = config.election_timeout_max.as_millis() as u64;
let timeout_ms = rand::rng().random_range(min..=max);
Duration::from_millis(timeout_ms)
}
async fn handle_rpc(&self, msg: RpcMessage) -> Result<()> {
match msg {
RpcMessage::AppendEntries {
request,
response_tx,
} => {
let response = self.handle_append_entries(request).await?;
let _ = response_tx.send(response);
}
RpcMessage::RequestVote {
request,
response_tx,
} => {
let response = self.handle_request_vote(request).await?;
let _ = response_tx.send(response);
}
}
Ok(())
}
#[allow(clippy::unused_async)]
async fn handle_append_entries(
&self,
request: AppendEntriesRequest,
) -> Result<AppendEntriesResponse> {
let mut persistent = self.persistent.write();
let current_term = persistent.current_term;
if request.term < current_term {
return Ok(AppendEntriesResponse {
term: current_term,
success: false,
conflict_index: None,
});
}
if request.term > current_term {
persistent.current_term = request.term;
persistent.voted_for = None;
*self.state.write() = NodeState::Follower;
}
*self.last_heartbeat.write() = Instant::now();
*self.current_leader.write() = Some(request.leader_id);
let mut log = self.log.write();
if request.prev_log_index.0 > 0 {
if request.prev_log_index.0 > log.len() as u64 {
return Ok(AppendEntriesResponse {
term: persistent.current_term,
success: false,
conflict_index: Some(LogIndex(log.len() as u64)),
});
}
let prev_entry = &log[(request.prev_log_index.0 - 1) as usize];
if prev_entry.term != request.prev_log_term {
let conflict_term = prev_entry.term;
let mut conflict_index = request.prev_log_index.0;
for entry in log.iter().rev() {
if entry.term != conflict_term {
break;
}
conflict_index = entry.index.0;
}
return Ok(AppendEntriesResponse {
term: persistent.current_term,
success: false,
conflict_index: Some(LogIndex(conflict_index)),
});
}
}
for entry in request.entries {
let index = (entry.index.0 - 1) as usize;
if index >= log.len() {
log.push(entry);
} else if log[index].term != entry.term {
log.truncate(index);
log.push(entry);
}
}
if request.leader_commit.0 > self.volatile.read().commit_index.0 {
let new_commit = request.leader_commit.0.min(log.len() as u64);
self.volatile.write().commit_index = LogIndex(new_commit);
}
Ok(AppendEntriesResponse {
term: persistent.current_term,
success: true,
conflict_index: None,
})
}
#[allow(clippy::unused_async)]
async fn handle_request_vote(
&self,
request: RequestVoteRequest,
) -> Result<RequestVoteResponse> {
let mut persistent = self.persistent.write();
let current_term = persistent.current_term;
if request.term < current_term {
return Ok(RequestVoteResponse {
term: current_term,
vote_granted: false,
});
}
if request.term > current_term {
persistent.current_term = request.term;
persistent.voted_for = None;
*self.state.write() = NodeState::Follower;
}
let vote_granted = if persistent.voted_for.is_none()
|| persistent.voted_for == Some(request.candidate_id)
{
let log = self.log.read();
let last_log_index = log.len() as u64;
let last_log_term = log.last().map(|e| e.term).unwrap_or(Term(0));
let log_ok = request.last_log_term > last_log_term
|| (request.last_log_term == last_log_term
&& request.last_log_index.0 >= last_log_index);
if log_ok {
persistent.voted_for = Some(request.candidate_id);
true
} else {
false
}
} else {
false
};
Ok(RequestVoteResponse {
term: persistent.current_term,
vote_granted,
})
}
async fn apply_committed_entries(&self) -> Result<()> {
let commit_index = self.volatile.read().commit_index;
loop {
let command = {
let mut volatile = self.volatile.write();
if volatile.last_applied.0 >= commit_index.0 {
break;
}
volatile.last_applied.0 += 1;
let entry = &self.log.read()[(volatile.last_applied.0 - 1) as usize];
entry.command.clone()
};
match command {
Command::Put { cid_bytes, data } => {
if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
let block = Block::from_parts(cid, bytes::Bytes::from(data));
self.store.put(&block).await?;
debug!("Applied PUT: {}", block.cid());
}
}
Command::Delete { cid_bytes } => {
if let Ok(cid) = Cid::try_from(cid_bytes.as_slice()) {
self.store.delete(&cid).await?;
debug!("Applied DELETE: {}", cid);
}
}
Command::NoOp => {
debug!("Applied NoOp");
}
}
}
Ok(())
}
#[allow(clippy::unused_async)]
pub async fn append_entry(&self, command: Command) -> Result<LogIndex> {
let state = *self.state.read();
if state != NodeState::Leader {
return Err(ipfrs_core::Error::Internal("Not the leader".to_string()));
}
let mut log = self.log.write();
let index = LogIndex((log.len() + 1) as u64);
let term = self.persistent.read().current_term;
let entry = LogEntry {
term,
index,
command,
};
log.push(entry);
Ok(index)
}
pub fn current_leader(&self) -> Option<NodeId> {
*self.current_leader.read()
}
pub fn is_leader(&self) -> bool {
*self.state.read() == NodeState::Leader
}
pub fn current_term(&self) -> Term {
self.persistent.read().current_term
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaftStats {
pub node_id: NodeId,
pub state: String,
pub term: Term,
pub leader: Option<NodeId>,
pub log_size: usize,
pub commit_index: LogIndex,
pub last_applied: LogIndex,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::MemoryBlockStore;
#[tokio::test]
async fn test_node_creation() {
let store = MemoryBlockStore::new();
let config = RaftConfig::default();
let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config);
assert!(node.is_ok());
}
#[tokio::test]
async fn test_append_entries_lower_term() {
let store = MemoryBlockStore::new();
let config = RaftConfig::default();
let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
node.persistent.write().current_term = Term(5);
let request = AppendEntriesRequest {
term: Term(3),
leader_id: NodeId(2),
prev_log_index: LogIndex(0),
prev_log_term: Term(0),
entries: vec![],
leader_commit: LogIndex(0),
};
let response = node.handle_append_entries(request).await.unwrap();
assert!(!response.success);
assert_eq!(response.term, Term(5));
}
#[tokio::test]
async fn test_request_vote_grant() {
let store = MemoryBlockStore::new();
let config = RaftConfig::default();
let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
let request = RequestVoteRequest {
term: Term(1),
candidate_id: NodeId(2),
last_log_index: LogIndex(0),
last_log_term: Term(0),
};
let response = node.handle_request_vote(request).await.unwrap();
assert!(response.vote_granted);
assert_eq!(node.persistent.read().voted_for, Some(NodeId(2)));
}
#[tokio::test]
async fn test_request_vote_deny_already_voted() {
let store = MemoryBlockStore::new();
let config = RaftConfig::default();
let node = RaftNode::new(NodeId(1), vec![NodeId(2), NodeId(3)], store, config).unwrap();
node.persistent.write().voted_for = Some(NodeId(2));
node.persistent.write().current_term = Term(1);
let request = RequestVoteRequest {
term: Term(1),
candidate_id: NodeId(3),
last_log_index: LogIndex(0),
last_log_term: Term(0),
};
let response = node.handle_request_vote(request).await.unwrap();
assert!(!response.vote_granted);
}
}