use scirs2_core::random::{Random, Rng};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
sync::{mpsc, Mutex, RwLock},
time::interval,
};
use crate::{clustering::RaftConfig, error::FusekiResult, store::Store};
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(dead_code)]
enum RaftState {
Follower,
Candidate,
Leader,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct LogEntry {
pub index: u64,
pub term: u64,
pub command: Command,
pub client_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub enum Command {
Set { key: String, value: Vec<u8> },
Delete { key: String },
ConfigChange { config: ClusterConfig },
NoOp,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct ClusterConfig {
pub members: Vec<String>,
pub new_members: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct PersistentState {
current_term: u64,
voted_for: Option<String>,
log: Vec<LogEntry>,
}
#[derive(Debug, Clone)]
struct VolatileState {
commit_index: u64,
last_applied: u64,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct LeaderState {
next_index: HashMap<String, u64>,
match_index: HashMap<String, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub enum RpcMessage {
AppendEntries(AppendEntriesRequest),
AppendEntriesResponse(AppendEntriesResponse),
RequestVote(RequestVoteRequest),
RequestVoteResponse(RequestVoteResponse),
InstallSnapshot(InstallSnapshotRequest),
InstallSnapshotResponse(InstallSnapshotResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct AppendEntriesRequest {
pub term: u64,
pub leader_id: String,
pub prev_log_index: u64,
pub prev_log_term: u64,
pub entries: Vec<LogEntry>,
pub leader_commit: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct AppendEntriesResponse {
pub term: u64,
pub success: bool,
pub last_log_index: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct RequestVoteRequest {
pub term: u64,
pub candidate_id: String,
pub last_log_index: u64,
pub last_log_term: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct RequestVoteResponse {
pub term: u64,
pub vote_granted: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct InstallSnapshotRequest {
pub term: u64,
pub leader_id: String,
pub last_included_index: u64,
pub last_included_term: u64,
pub offset: u64,
pub data: Vec<u8>,
pub done: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct InstallSnapshotResponse {
pub term: u64,
}
#[allow(dead_code)]
pub struct RaftNode {
id: String,
config: RaftConfig,
state: Arc<RwLock<RaftState>>,
persistent: Arc<RwLock<PersistentState>>,
volatile: Arc<RwLock<VolatileState>>,
leader_state: Arc<RwLock<Option<LeaderState>>>,
current_leader: Arc<RwLock<Option<String>>>,
cluster_config: Arc<RwLock<ClusterConfig>>,
rpc_tx: mpsc::Sender<(String, RpcMessage)>,
rpc_rx: Arc<Mutex<mpsc::Receiver<(String, RpcMessage)>>>,
election_timer: Arc<RwLock<Instant>>,
store: Arc<Store>,
}
#[allow(dead_code)]
impl RaftNode {
pub async fn new(id: String, config: RaftConfig, store: Arc<Store>) -> FusekiResult<Self> {
let (rpc_tx, rpc_rx) = mpsc::channel(1000);
Ok(Self {
id,
config,
state: Arc::new(RwLock::new(RaftState::Follower)),
persistent: Arc::new(RwLock::new(PersistentState {
current_term: 0,
voted_for: None,
log: vec![],
})),
volatile: Arc::new(RwLock::new(VolatileState {
commit_index: 0,
last_applied: 0,
})),
leader_state: Arc::new(RwLock::new(None)),
current_leader: Arc::new(RwLock::new(None)),
cluster_config: Arc::new(RwLock::new(ClusterConfig {
members: vec![],
new_members: None,
})),
rpc_tx,
rpc_rx: Arc::new(Mutex::new(rpc_rx)),
election_timer: Arc::new(RwLock::new(Instant::now())),
store,
})
}
pub async fn start(&self) -> FusekiResult<()> {
self.start_rpc_handler().await;
self.start_election_timer().await;
self.start_heartbeat_timer().await;
self.start_log_applier().await;
Ok(())
}
pub async fn bootstrap(&self) -> FusekiResult<()> {
let mut config = self.cluster_config.write().await;
config.members = vec![self.id.clone()];
*self.state.write().await = RaftState::Leader;
*self.current_leader.write().await = Some(self.id.clone());
*self.leader_state.write().await = Some(LeaderState {
next_index: HashMap::new(),
match_index: HashMap::new(),
});
self.append_log_entry(Command::NoOp).await?;
Ok(())
}
async fn start_rpc_handler(&self) {
let rpc_rx = self.rpc_rx.clone();
let node = self.clone_refs();
tokio::spawn(async move {
let mut rx = rpc_rx.lock().await;
while let Some((from, msg)) = rx.recv().await {
match msg {
RpcMessage::AppendEntries(req) => {
let resp = node.handle_append_entries(req).await;
let _ = node
.send_rpc(&from, RpcMessage::AppendEntriesResponse(resp))
.await;
}
RpcMessage::RequestVote(req) => {
let resp = node.handle_request_vote(req).await;
let _ = node
.send_rpc(&from, RpcMessage::RequestVoteResponse(resp))
.await;
}
RpcMessage::InstallSnapshot(req) => {
let resp = node.handle_install_snapshot(req).await;
let _ = node
.send_rpc(&from, RpcMessage::InstallSnapshotResponse(resp))
.await;
}
_ => {}
}
}
});
}
async fn start_election_timer(&self) {
let node = self.clone_refs();
let config = self.config.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(50));
let mut rng = Random::seed(42);
loop {
interval.tick().await;
let state = *node.state.read().await;
if state != RaftState::Leader {
let last_heartbeat = *node.election_timer.read().await;
let timeout =
rng.gen_range(config.election_timeout.0..config.election_timeout.1);
if last_heartbeat.elapsed() > timeout {
node.start_election().await;
}
}
}
});
}
async fn start_heartbeat_timer(&self) {
let node = self.clone_refs();
let interval_duration = self.config.heartbeat_interval;
tokio::spawn(async move {
let mut interval = interval(interval_duration);
loop {
interval.tick().await;
let state = *node.state.read().await;
if state == RaftState::Leader {
node.send_heartbeats().await;
}
}
});
}
async fn start_log_applier(&self) {
let node = self.clone_refs();
tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
loop {
interval.tick().await;
let volatile = node.volatile.read().await;
let last_applied = volatile.last_applied;
let commit_index = volatile.commit_index;
drop(volatile);
if commit_index > last_applied {
node.apply_committed_entries(last_applied + 1, commit_index)
.await;
}
}
});
}
async fn start_election(&self) {
tracing::info!("Node {} starting election", self.id);
let mut persistent = self.persistent.write().await;
persistent.current_term += 1;
persistent.voted_for = Some(self.id.clone());
let current_term = persistent.current_term;
let last_log_index = persistent.log.len() as u64;
let last_log_term = persistent.log.last().map(|e| e.term).unwrap_or(0);
drop(persistent);
*self.state.write().await = RaftState::Candidate;
*self.election_timer.write().await = Instant::now();
let config = self.cluster_config.read().await;
let mut votes = 1; let majority = (config.members.len() / 2) + 1;
for member in &config.members {
if member != &self.id {
let req = RequestVoteRequest {
term: current_term,
candidate_id: self.id.clone(),
last_log_index,
last_log_term,
};
if let Ok(()) = self.send_rpc(member, RpcMessage::RequestVote(req)).await {
votes += 1;
}
}
}
if votes >= majority {
self.become_leader().await;
}
}
async fn become_leader(&self) {
tracing::info!("Node {} became leader", self.id);
*self.state.write().await = RaftState::Leader;
*self.current_leader.write().await = Some(self.id.clone());
let config = self.cluster_config.read().await;
let log_length = self.persistent.read().await.log.len() as u64;
let mut next_index = HashMap::new();
let mut match_index = HashMap::new();
for member in &config.members {
if member != &self.id {
next_index.insert(member.clone(), log_length + 1);
match_index.insert(member.clone(), 0);
}
}
*self.leader_state.write().await = Some(LeaderState {
next_index,
match_index,
});
self.append_log_entry(Command::NoOp).await.ok();
}
async fn send_heartbeats(&self) {
let leader_state = self.leader_state.read().await;
if let Some(state) = leader_state.as_ref() {
let config = self.cluster_config.read().await;
let persistent = self.persistent.read().await;
let volatile = self.volatile.read().await;
for member in &config.members {
if member != &self.id {
let next_idx = state.next_index.get(member).copied().unwrap_or(1);
let prev_idx = next_idx.saturating_sub(1);
let prev_term = if prev_idx > 0 {
persistent
.log
.get(prev_idx as usize - 1)
.map(|e| e.term)
.unwrap_or(0)
} else {
0
};
let req = AppendEntriesRequest {
term: persistent.current_term,
leader_id: self.id.clone(),
prev_log_index: prev_idx,
prev_log_term: prev_term,
entries: vec![],
leader_commit: volatile.commit_index,
};
let _ = self.send_rpc(member, RpcMessage::AppendEntries(req)).await;
}
}
}
}
async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
let mut persistent = self.persistent.write().await;
let current_term = persistent.current_term;
if req.term < current_term {
return AppendEntriesResponse {
term: current_term,
success: false,
last_log_index: persistent.log.len() as u64,
};
}
if req.term > current_term {
persistent.current_term = req.term;
persistent.voted_for = None;
}
*self.election_timer.write().await = Instant::now();
*self.state.write().await = RaftState::Follower;
*self.current_leader.write().await = Some(req.leader_id.clone());
if req.prev_log_index > 0 {
if let Some(entry) = persistent.log.get(req.prev_log_index as usize - 1) {
if entry.term != req.prev_log_term {
return AppendEntriesResponse {
term: req.term,
success: false,
last_log_index: persistent.log.len() as u64,
};
}
} else {
return AppendEntriesResponse {
term: req.term,
success: false,
last_log_index: persistent.log.len() as u64,
};
}
}
if !req.entries.is_empty() {
persistent.log.truncate(req.prev_log_index as usize);
persistent.log.extend(req.entries);
}
if req.leader_commit > self.volatile.read().await.commit_index {
let mut volatile = self.volatile.write().await;
volatile.commit_index = req.leader_commit.min(persistent.log.len() as u64);
}
AppendEntriesResponse {
term: req.term,
success: true,
last_log_index: persistent.log.len() as u64,
}
}
async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
let mut persistent = self.persistent.write().await;
let current_term = persistent.current_term;
if req.term < current_term {
return RequestVoteResponse {
term: current_term,
vote_granted: false,
};
}
if req.term > current_term {
persistent.current_term = req.term;
persistent.voted_for = None;
*self.state.write().await = RaftState::Follower;
}
let can_vote = persistent.voted_for.is_none()
|| persistent.voted_for.as_ref() == Some(&req.candidate_id);
let log_ok = self.is_log_up_to_date(&persistent, req.last_log_index, req.last_log_term);
let vote_granted = can_vote && log_ok;
if vote_granted {
persistent.voted_for = Some(req.candidate_id);
*self.election_timer.write().await = Instant::now();
}
RequestVoteResponse {
term: req.term,
vote_granted,
}
}
async fn handle_install_snapshot(
&self,
req: InstallSnapshotRequest,
) -> InstallSnapshotResponse {
let current_term = self.persistent.read().await.current_term;
if req.term < current_term {
return InstallSnapshotResponse { term: current_term };
}
InstallSnapshotResponse { term: req.term }
}
fn is_log_up_to_date(
&self,
persistent: &PersistentState,
last_log_index: u64,
last_log_term: u64,
) -> bool {
let my_last_index = persistent.log.len() as u64;
let my_last_term = persistent.log.last().map(|e| e.term).unwrap_or(0);
last_log_term > my_last_term
|| (last_log_term == my_last_term && last_log_index >= my_last_index)
}
async fn append_log_entry(&self, command: Command) -> FusekiResult<u64> {
let mut persistent = self.persistent.write().await;
let index = persistent.log.len() as u64 + 1;
let term = persistent.current_term;
persistent.log.push(LogEntry {
index,
term,
command,
client_id: None,
});
Ok(index)
}
async fn apply_committed_entries(&self, start: u64, end: u64) {
let persistent = self.persistent.read().await;
for i in start..=end {
if let Some(entry) = persistent.log.get(i as usize - 1) {
match &entry.command {
Command::Set { key, value: _ } => {
tracing::debug!("Applied Set({}, ...)", key);
}
Command::Delete { key } => {
tracing::debug!("Applied Delete({})", key);
}
Command::ConfigChange { config: _ } => {
tracing::debug!("Applied ConfigChange");
}
Command::NoOp => {
}
}
}
}
let mut volatile = self.volatile.write().await;
volatile.last_applied = end;
}
async fn send_rpc(&self, target: &str, message: RpcMessage) -> FusekiResult<()> {
tracing::debug!("Sending {:?} to {}", message, target);
Ok(())
}
fn clone_refs(&self) -> RaftNodeRefs {
RaftNodeRefs {
id: self.id.clone(),
config: self.config.clone(),
state: self.state.clone(),
persistent: self.persistent.clone(),
volatile: self.volatile.clone(),
leader_state: self.leader_state.clone(),
current_leader: self.current_leader.clone(),
cluster_config: self.cluster_config.clone(),
election_timer: self.election_timer.clone(),
}
}
}
struct RaftNodeRefs {
id: String,
config: RaftConfig,
state: Arc<RwLock<RaftState>>,
persistent: Arc<RwLock<PersistentState>>,
volatile: Arc<RwLock<VolatileState>>,
leader_state: Arc<RwLock<Option<LeaderState>>>,
current_leader: Arc<RwLock<Option<String>>>,
cluster_config: Arc<RwLock<ClusterConfig>>,
election_timer: Arc<RwLock<Instant>>,
}
impl RaftNodeRefs {
async fn start_election(&self) {
}
async fn become_leader(&self) {
}
async fn send_heartbeats(&self) {
}
async fn handle_append_entries(&self, _req: AppendEntriesRequest) -> AppendEntriesResponse {
AppendEntriesResponse {
term: 0,
success: false,
last_log_index: 0,
}
}
async fn handle_request_vote(&self, _req: RequestVoteRequest) -> RequestVoteResponse {
RequestVoteResponse {
term: 0,
vote_granted: false,
}
}
async fn handle_install_snapshot(
&self,
_req: InstallSnapshotRequest,
) -> InstallSnapshotResponse {
InstallSnapshotResponse { term: 0 }
}
async fn apply_committed_entries(&self, _start: u64, _end: u64) {
}
async fn append_log_entry(&self, _command: Command) -> FusekiResult<u64> {
Ok(0)
}
async fn send_rpc(&self, _target: &str, _message: RpcMessage) -> FusekiResult<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_entry_serialization() {
let entry = LogEntry {
index: 1,
term: 1,
command: Command::Set {
key: "test".to_string(),
value: vec![1, 2, 3],
},
client_id: Some("client1".to_string()),
};
let json = serde_json::to_string(&entry).unwrap();
let decoded: LogEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.index, entry.index);
assert_eq!(decoded.term, entry.term);
}
#[test]
fn test_raft_state_transitions() {
assert_ne!(RaftState::Follower, RaftState::Candidate);
assert_ne!(RaftState::Candidate, RaftState::Leader);
assert_ne!(RaftState::Leader, RaftState::Follower);
}
}