use anyhow::Result;
#[allow(unused_imports)]
use scirs2_core::random::{Random, Rng};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::{interval, sleep, timeout};
use tracing::{debug, error, info};
use crate::raft::{OxirsNodeId, RdfCommand};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Follower,
Candidate,
Leader,
}
pub type Term = u64;
pub type LogIndex = u64;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub term: Term,
pub index: LogIndex,
pub command: RdfCommand,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoteRequest {
pub term: Term,
pub candidate_id: OxirsNodeId,
pub last_log_index: LogIndex,
pub last_log_term: Term,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoteResponse {
pub term: Term,
pub vote_granted: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesRequest {
pub term: Term,
pub leader_id: OxirsNodeId,
pub prev_log_index: LogIndex,
pub prev_log_term: Term,
pub entries: Vec<LogEntry>,
pub leader_commit: LogIndex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesResponse {
pub term: Term,
pub success: bool,
pub conflict_index: Option<LogIndex>,
pub conflict_term: Option<Term>,
}
#[derive(Debug, Clone)]
pub struct RaftConfig {
pub election_timeout_min: u64,
pub election_timeout_max: u64,
pub heartbeat_interval: u64,
pub max_batch_size: usize,
}
impl Default for RaftConfig {
fn default() -> Self {
Self {
election_timeout_min: 150,
election_timeout_max: 300,
heartbeat_interval: 50,
max_batch_size: 100,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RaftEvent {
ElectionTimeout,
HigherTermDiscovered(Term),
ElectionWon,
ElectionLost,
LeaderDiscovered(OxirsNodeId),
LeadershipLost,
}
pub struct RaftStateMachine {
node_id: OxirsNodeId,
state: Arc<RwLock<NodeState>>,
current_term: Arc<RwLock<Term>>,
voted_for: Arc<RwLock<Option<OxirsNodeId>>>,
log: Arc<RwLock<Vec<LogEntry>>>,
commit_index: Arc<RwLock<LogIndex>>,
last_applied: Arc<RwLock<LogIndex>>,
next_index: Arc<RwLock<HashMap<OxirsNodeId, LogIndex>>>,
match_index: Arc<RwLock<HashMap<OxirsNodeId, LogIndex>>>,
peers: Arc<RwLock<HashSet<OxirsNodeId>>>,
current_leader: Arc<RwLock<Option<OxirsNodeId>>>,
last_heartbeat: Arc<Mutex<Instant>>,
config: Arc<RaftConfig>,
state_tx: mpsc::UnboundedSender<NodeState>,
shutdown: Arc<RwLock<bool>>,
}
impl RaftStateMachine {
pub fn new(
node_id: OxirsNodeId,
peers: HashSet<OxirsNodeId>,
config: RaftConfig,
) -> (Self, mpsc::UnboundedReceiver<NodeState>) {
let (state_tx, state_rx) = mpsc::unbounded_channel();
let state_machine = Self {
node_id,
state: Arc::new(RwLock::new(NodeState::Follower)),
current_term: Arc::new(RwLock::new(0)),
voted_for: Arc::new(RwLock::new(None)),
log: Arc::new(RwLock::new(Vec::new())),
commit_index: Arc::new(RwLock::new(0)),
last_applied: Arc::new(RwLock::new(0)),
next_index: Arc::new(RwLock::new(HashMap::new())),
match_index: Arc::new(RwLock::new(HashMap::new())),
peers: Arc::new(RwLock::new(peers)),
current_leader: Arc::new(RwLock::new(None)),
last_heartbeat: Arc::new(Mutex::new(Instant::now())),
config: Arc::new(config),
state_tx,
shutdown: Arc::new(RwLock::new(false)),
};
(state_machine, state_rx)
}
pub async fn start(&self) {
info!("Starting Raft state machine for node {}", self.node_id);
self.transition_to_follower().await;
let state_machine = self.clone();
tokio::spawn(async move {
state_machine.run_state_machine().await;
});
}
async fn run_state_machine(&self) {
loop {
if *self.shutdown.read().await {
info!("Shutting down Raft state machine for node {}", self.node_id);
break;
}
let current_state = *self.state.read().await;
match current_state {
NodeState::Follower => self.run_follower().await,
NodeState::Candidate => self.run_candidate().await,
NodeState::Leader => self.run_leader().await,
}
}
}
async fn run_follower(&self) {
debug!("Node {} running as follower", self.node_id);
let election_timeout = self.random_election_timeout();
let timeout_duration = Duration::from_millis(election_timeout);
loop {
if *self.state.read().await != NodeState::Follower {
break;
}
let last_heartbeat = *self.last_heartbeat.lock().await;
let elapsed = last_heartbeat.elapsed();
if elapsed >= timeout_duration {
info!(
"Node {} election timeout expired, becoming candidate",
self.node_id
);
self.handle_event(RaftEvent::ElectionTimeout).await;
break;
}
sleep(Duration::from_millis(10)).await;
}
}
async fn run_candidate(&self) {
debug!("Node {} running as candidate", self.node_id);
if let Err(e) = self.start_election().await {
error!("Election failed for node {}: {}", self.node_id, e);
self.handle_event(RaftEvent::ElectionLost).await;
return;
}
let election_timeout = self.random_election_timeout();
let timeout_duration = Duration::from_millis(election_timeout);
let start_time = Instant::now();
loop {
if *self.state.read().await != NodeState::Candidate {
break;
}
if start_time.elapsed() >= timeout_duration {
info!(
"Node {} election timeout expired, restarting election",
self.node_id
);
self.handle_event(RaftEvent::ElectionTimeout).await;
break;
}
sleep(Duration::from_millis(10)).await;
}
}
async fn run_leader(&self) {
info!("Node {} running as leader", self.node_id);
self.initialize_leader_state().await;
let mut heartbeat_interval =
interval(Duration::from_millis(self.config.heartbeat_interval));
loop {
if *self.state.read().await != NodeState::Leader {
break;
}
heartbeat_interval.tick().await;
self.send_heartbeats().await;
self.replicate_log_entries().await;
self.update_commit_index().await;
}
}
pub async fn handle_event(&self, event: RaftEvent) {
let current_state = *self.state.read().await;
match (current_state, event) {
(NodeState::Follower, RaftEvent::ElectionTimeout) => {
self.transition_to_candidate().await;
}
(NodeState::Follower, RaftEvent::HigherTermDiscovered(term)) => {
self.update_term(term).await;
}
(NodeState::Candidate, RaftEvent::ElectionWon) => {
self.transition_to_leader().await;
}
(NodeState::Candidate, RaftEvent::ElectionLost) => {
self.transition_to_follower().await;
}
(NodeState::Candidate, RaftEvent::LeaderDiscovered(leader_id)) => {
*self.current_leader.write().await = Some(leader_id);
self.transition_to_follower().await;
}
(NodeState::Candidate, RaftEvent::HigherTermDiscovered(term)) => {
self.update_term(term).await;
self.transition_to_follower().await;
}
(NodeState::Leader, RaftEvent::HigherTermDiscovered(term)) => {
self.update_term(term).await;
self.transition_to_follower().await;
}
(NodeState::Leader, RaftEvent::LeadershipLost) => {
self.transition_to_follower().await;
}
_ => {
debug!("Ignoring event {:?} in state {:?}", event, current_state);
}
}
}
async fn transition_to_follower(&self) {
info!("Node {} transitioning to follower", self.node_id);
*self.state.write().await = NodeState::Follower;
*self.last_heartbeat.lock().await = Instant::now();
let _ = self.state_tx.send(NodeState::Follower);
}
async fn transition_to_candidate(&self) {
info!("Node {} transitioning to candidate", self.node_id);
*self.state.write().await = NodeState::Candidate;
let mut term = self.current_term.write().await;
*term += 1;
let _current_term = *term;
drop(term);
*self.voted_for.write().await = Some(self.node_id);
*self.last_heartbeat.lock().await = Instant::now();
*self.current_leader.write().await = None;
let _ = self.state_tx.send(NodeState::Candidate);
}
async fn transition_to_leader(&self) {
info!("Node {} transitioning to leader", self.node_id);
*self.state.write().await = NodeState::Leader;
*self.current_leader.write().await = Some(self.node_id);
let _ = self.state_tx.send(NodeState::Leader);
}
async fn initialize_leader_state(&self) {
let peers = self.peers.read().await;
let last_log_index = self.get_last_log_index().await;
let mut next_index = self.next_index.write().await;
let mut match_index = self.match_index.write().await;
for peer in peers.iter() {
next_index.insert(*peer, last_log_index + 1);
match_index.insert(*peer, 0);
}
}
async fn start_election(&self) -> Result<()> {
let current_term = *self.current_term.read().await;
let last_log_index = self.get_last_log_index().await;
let last_log_term = self.get_last_log_term().await;
let request = VoteRequest {
term: current_term,
candidate_id: self.node_id,
last_log_index,
last_log_term,
};
let peers = self.peers.read().await.clone();
let mut votes_received = 1; let votes_needed = (peers.len() + 1) / 2 + 1;
let (vote_tx, mut vote_rx) = mpsc::unbounded_channel();
for peer in peers {
let request = request.clone();
let vote_tx = vote_tx.clone();
let delay_ms = {
let mut random = Random::default();
random.random_range(10..50)
};
let vote_granted = {
let mut random = Random::default();
random.random_bool_with_chance(0.5)
};
tokio::spawn(async move {
sleep(Duration::from_millis(delay_ms)).await;
let response = VoteResponse {
term: request.term,
vote_granted,
};
let _ = vote_tx.send((peer, response));
});
}
let timeout_duration = Duration::from_millis(self.config.election_timeout_min);
let deadline = Instant::now() + timeout_duration;
while votes_received < votes_needed && Instant::now() < deadline {
match timeout(
deadline.saturating_duration_since(Instant::now()),
vote_rx.recv(),
)
.await
{
Ok(Some((peer_id, response))) => {
if response.term > current_term {
self.handle_event(RaftEvent::HigherTermDiscovered(response.term))
.await;
return Err(anyhow::anyhow!("Higher term discovered"));
}
if response.vote_granted && response.term == current_term {
votes_received += 1;
debug!("Node {} received vote from {}", self.node_id, peer_id);
}
}
Ok(None) => break,
Err(_) => break,
}
}
if votes_received >= votes_needed {
info!(
"Node {} won election with {} votes",
self.node_id, votes_received
);
self.handle_event(RaftEvent::ElectionWon).await;
Ok(())
} else {
info!(
"Node {} lost election with {} votes",
self.node_id, votes_received
);
self.handle_event(RaftEvent::ElectionLost).await;
Err(anyhow::anyhow!("Insufficient votes"))
}
}
async fn send_heartbeats(&self) {
let peers = self.peers.read().await.clone();
let current_term = *self.current_term.read().await;
let commit_index = *self.commit_index.read().await;
for peer in peers {
let prev_log_index = self.get_prev_log_index_for_peer(peer).await;
let prev_log_term = self.get_prev_log_term_for_peer(peer).await;
let _request = AppendEntriesRequest {
term: current_term,
leader_id: self.node_id,
prev_log_index,
prev_log_term,
entries: Vec::new(), leader_commit: commit_index,
};
debug!("Leader {} sending heartbeat to {}", self.node_id, peer);
}
}
async fn replicate_log_entries(&self) {
debug!("Leader {} replicating log entries", self.node_id);
}
async fn update_commit_index(&self) {
let match_index = self.match_index.read().await;
let current_term = *self.current_term.read().await;
let log = self.log.read().await;
let mut indices: Vec<LogIndex> = match_index.values().cloned().collect();
indices.push(log.len() as LogIndex); indices.sort_unstable();
let majority_index = indices.len() / 2;
let new_commit_index = indices[majority_index];
if new_commit_index > *self.commit_index.read().await {
if let Some(entry) = log.get((new_commit_index - 1) as usize) {
if entry.term == current_term {
*self.commit_index.write().await = new_commit_index;
debug!(
"Leader {} updated commit index to {}",
self.node_id, new_commit_index
);
}
}
}
}
pub async fn handle_vote_request(&self, request: VoteRequest) -> VoteResponse {
let mut current_term = self.current_term.write().await;
let mut voted_for = self.voted_for.write().await;
if request.term > *current_term {
*current_term = request.term;
*voted_for = None;
self.transition_to_follower().await;
}
let vote_granted = if request.term < *current_term
|| (voted_for.is_some() && *voted_for != Some(request.candidate_id))
|| !self.is_candidate_log_up_to_date(&request).await
{
false
} else {
*voted_for = Some(request.candidate_id);
*self.last_heartbeat.lock().await = Instant::now();
true
};
VoteResponse {
term: *current_term,
vote_granted,
}
}
pub async fn handle_append_entries(
&self,
request: AppendEntriesRequest,
) -> AppendEntriesResponse {
let mut current_term = self.current_term.write().await;
if request.term > *current_term {
*current_term = request.term;
*self.voted_for.write().await = None;
self.transition_to_follower().await;
}
if request.term < *current_term {
return AppendEntriesResponse {
term: *current_term,
success: false,
conflict_index: None,
conflict_term: None,
};
}
*self.last_heartbeat.lock().await = Instant::now();
if *self.state.read().await == NodeState::Candidate {
self.handle_event(RaftEvent::LeaderDiscovered(request.leader_id))
.await;
}
*self.current_leader.write().await = Some(request.leader_id);
let mut log = self.log.write().await;
if request.prev_log_index > 0 {
if let Some(entry) = log.get((request.prev_log_index - 1) as usize) {
if entry.term != request.prev_log_term {
return AppendEntriesResponse {
term: *current_term,
success: false,
conflict_index: Some(request.prev_log_index),
conflict_term: Some(entry.term),
};
}
} else {
return AppendEntriesResponse {
term: *current_term,
success: false,
conflict_index: Some(log.len() as LogIndex + 1),
conflict_term: None,
};
}
}
for (i, entry) in request.entries.iter().enumerate() {
let index = request.prev_log_index + i as LogIndex + 1;
if let Some(existing) = log.get_mut((index - 1) as usize) {
if existing.term != entry.term {
log.truncate((index - 1) as usize);
log.push(entry.clone());
}
} else {
log.push(entry.clone());
}
}
if request.leader_commit > *self.commit_index.read().await {
let new_commit = std::cmp::min(request.leader_commit, log.len() as LogIndex);
*self.commit_index.write().await = new_commit;
}
AppendEntriesResponse {
term: *current_term,
success: true,
conflict_index: None,
conflict_term: None,
}
}
async fn update_term(&self, new_term: Term) {
*self.current_term.write().await = new_term;
*self.voted_for.write().await = None;
}
fn random_election_timeout(&self) -> u64 {
let mut random = Random::default();
random.gen_range(self.config.election_timeout_min..=self.config.election_timeout_max)
}
async fn get_last_log_index(&self) -> LogIndex {
self.log.read().await.len() as LogIndex
}
async fn get_last_log_term(&self) -> Term {
self.log.read().await.last().map(|e| e.term).unwrap_or(0)
}
async fn get_prev_log_index_for_peer(&self, peer: OxirsNodeId) -> LogIndex {
self.next_index
.read()
.await
.get(&peer)
.cloned()
.unwrap_or(1)
.saturating_sub(1)
}
async fn get_prev_log_term_for_peer(&self, peer: OxirsNodeId) -> Term {
let prev_index = self.get_prev_log_index_for_peer(peer).await;
if prev_index == 0 {
0
} else {
self.log
.read()
.await
.get((prev_index - 1) as usize)
.map(|e| e.term)
.unwrap_or(0)
}
}
async fn is_candidate_log_up_to_date(&self, request: &VoteRequest) -> bool {
let last_log_index = self.get_last_log_index().await;
let last_log_term = self.get_last_log_term().await;
request.last_log_term > last_log_term
|| (request.last_log_term == last_log_term && request.last_log_index >= last_log_index)
}
pub async fn shutdown(&self) {
*self.shutdown.write().await = true;
}
pub async fn get_state(&self) -> NodeState {
*self.state.read().await
}
pub async fn get_current_term(&self) -> Term {
*self.current_term.read().await
}
pub async fn get_current_leader(&self) -> Option<OxirsNodeId> {
*self.current_leader.read().await
}
pub async fn is_leader(&self) -> bool {
*self.state.read().await == NodeState::Leader
}
pub async fn propose_command(&self, command: RdfCommand) -> Result<()> {
if !self.is_leader().await {
return Err(anyhow::anyhow!("Not the leader"));
}
let term = *self.current_term.read().await;
let mut log = self.log.write().await;
let index = log.len() as LogIndex + 1;
let entry = LogEntry {
term,
index,
command,
};
log.push(entry);
Ok(())
}
}
impl Clone for RaftStateMachine {
fn clone(&self) -> Self {
Self {
node_id: self.node_id,
state: Arc::clone(&self.state),
current_term: Arc::clone(&self.current_term),
voted_for: Arc::clone(&self.voted_for),
log: Arc::clone(&self.log),
commit_index: Arc::clone(&self.commit_index),
last_applied: Arc::clone(&self.last_applied),
next_index: Arc::clone(&self.next_index),
match_index: Arc::clone(&self.match_index),
peers: Arc::clone(&self.peers),
current_leader: Arc::clone(&self.current_leader),
last_heartbeat: Arc::clone(&self.last_heartbeat),
config: Arc::clone(&self.config),
state_tx: self.state_tx.clone(),
shutdown: Arc::clone(&self.shutdown),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_node_state_transitions() {
let peers = HashSet::new();
let config = RaftConfig::default();
let (state_machine, mut state_rx) = RaftStateMachine::new(1, peers, config);
assert_eq!(state_machine.get_state().await, NodeState::Follower);
state_machine.handle_event(RaftEvent::ElectionTimeout).await;
assert_eq!(state_machine.get_state().await, NodeState::Candidate);
if let Some(new_state) = state_rx.recv().await {
assert_eq!(new_state, NodeState::Candidate);
}
state_machine.handle_event(RaftEvent::ElectionWon).await;
assert_eq!(state_machine.get_state().await, NodeState::Leader);
assert!(state_machine.is_leader().await);
state_machine
.handle_event(RaftEvent::HigherTermDiscovered(10))
.await;
assert_eq!(state_machine.get_state().await, NodeState::Follower);
assert!(!state_machine.is_leader().await);
}
#[tokio::test]
async fn test_term_updates() {
let peers = HashSet::new();
let config = RaftConfig::default();
let (state_machine, _) = RaftStateMachine::new(1, peers, config);
assert_eq!(state_machine.get_current_term().await, 0);
state_machine.handle_event(RaftEvent::ElectionTimeout).await;
assert_eq!(state_machine.get_current_term().await, 1);
state_machine
.handle_event(RaftEvent::HigherTermDiscovered(5))
.await;
assert_eq!(state_machine.get_current_term().await, 5);
}
#[tokio::test]
async fn test_vote_request_handling() {
let peers = HashSet::new();
let config = RaftConfig::default();
let (state_machine, _) = RaftStateMachine::new(1, peers, config);
let request = VoteRequest {
term: 1,
candidate_id: 2,
last_log_index: 0,
last_log_term: 0,
};
let response = state_machine.handle_vote_request(request).await;
assert!(response.vote_granted);
assert_eq!(response.term, 1);
let request2 = VoteRequest {
term: 1,
candidate_id: 3,
last_log_index: 0,
last_log_term: 0,
};
let response2 = state_machine.handle_vote_request(request2).await;
assert!(!response2.vote_granted);
}
#[tokio::test]
async fn test_append_entries_handling() {
let peers = HashSet::new();
let config = RaftConfig::default();
let (state_machine, _) = RaftStateMachine::new(1, peers, config);
let request = AppendEntriesRequest {
term: 1,
leader_id: 2,
prev_log_index: 0,
prev_log_term: 0,
entries: vec![LogEntry {
term: 1,
index: 1,
command: RdfCommand::Insert {
subject: "s".to_string(),
predicate: "p".to_string(),
object: "o".to_string(),
},
}],
leader_commit: 0,
};
let response = state_machine.handle_append_entries(request).await;
assert!(response.success);
assert_eq!(response.term, 1);
assert_eq!(state_machine.get_last_log_index().await, 1);
assert_eq!(state_machine.get_last_log_term().await, 1);
}
#[tokio::test]
async fn test_leader_election_timeout() {
let peers = HashSet::new();
let config = RaftConfig {
election_timeout_min: 50,
election_timeout_max: 100,
..Default::default()
};
let (state_machine, _) = RaftStateMachine::new(1, peers, config);
state_machine.start().await;
sleep(Duration::from_millis(150)).await;
let state = state_machine.get_state().await;
assert!(state == NodeState::Candidate || state == NodeState::Leader);
state_machine.shutdown().await;
}
}