use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::{Mutex, RwLock};
use tokio::time;
pub type NodeId = String;
pub type Term = u64;
pub type LogIndex = u64;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Follower,
Candidate,
Leader,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub term: Term,
pub index: LogIndex,
pub command: Command,
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Command {
Insert {
id: String,
vector: Vec<f32>,
metadata: serde_json::Value,
},
Delete { id: String },
Update {
id: String,
vector: Vec<f32>,
metadata: serde_json::Value,
},
NoOp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteRequest {
pub term: Term,
pub candidate_id: NodeId,
pub last_log_index: LogIndex,
pub last_log_term: Term,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestVoteResponse {
pub term: Term,
pub vote_granted: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesRequest {
pub term: Term,
pub leader_id: NodeId,
pub prev_log_index: LogIndex,
pub prev_log_term: Term,
pub entries: Vec<LogEntry>,
pub leader_commit: LogIndex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppendEntriesResponse {
pub term: Term,
pub success: bool,
pub conflict_index: Option<LogIndex>,
pub conflict_term: Option<Term>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaftConfig {
pub node_id: NodeId,
pub peers: Vec<NodeId>,
pub election_timeout_min_ms: u64,
pub election_timeout_max_ms: u64,
pub heartbeat_interval_ms: u64,
pub max_entries_per_batch: usize,
}
impl Default for RaftConfig {
fn default() -> Self {
Self {
node_id: "node-0".to_string(),
peers: vec![],
election_timeout_min_ms: 150,
election_timeout_max_ms: 300,
heartbeat_interval_ms: 50,
max_entries_per_batch: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersistentState {
current_term: Term,
voted_for: Option<NodeId>,
log: Vec<LogEntry>,
}
impl Default for PersistentState {
fn default() -> Self {
Self {
current_term: 0,
voted_for: None,
log: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
struct VolatileState {
commit_index: LogIndex,
last_applied: LogIndex,
}
impl Default for VolatileState {
fn default() -> Self {
Self {
commit_index: 0,
last_applied: 0,
}
}
}
#[derive(Debug, Clone)]
struct LeaderState {
next_index: HashMap<NodeId, LogIndex>,
match_index: HashMap<NodeId, LogIndex>,
}
pub struct RaftNode {
config: RaftConfig,
state: Arc<RwLock<NodeState>>,
persistent: Arc<RwLock<PersistentState>>,
volatile: Arc<Mutex<VolatileState>>,
leader_state: Arc<Mutex<Option<LeaderState>>>,
last_heartbeat: Arc<Mutex<Instant>>,
peers: Arc<RwLock<HashSet<NodeId>>>,
}
impl RaftNode {
pub fn new(config: RaftConfig) -> Self {
let peers: HashSet<NodeId> = config.peers.iter().cloned().collect();
Self {
config,
state: Arc::new(RwLock::new(NodeState::Follower)),
persistent: Arc::new(RwLock::new(PersistentState::default())),
volatile: Arc::new(Mutex::new(VolatileState::default())),
leader_state: Arc::new(Mutex::new(None)),
last_heartbeat: Arc::new(Mutex::new(Instant::now())),
peers: Arc::new(RwLock::new(peers)),
}
}
pub async fn current_term(&self) -> Term {
self.persistent.read().await.current_term
}
pub async fn state(&self) -> NodeState {
self.state.read().await.clone()
}
pub async fn is_leader(&self) -> bool {
matches!(*self.state.read().await, NodeState::Leader)
}
pub async fn leader_id(&self) -> Option<NodeId> {
if self.is_leader().await {
Some(self.config.node_id.clone())
} else {
None
}
}
pub async fn start(self: Arc<Self>) {
let node = self.clone();
tokio::spawn(async move {
node.election_timer_loop().await;
});
let node = self.clone();
tokio::spawn(async move {
node.heartbeat_loop().await;
});
}
async fn election_timer_loop(self: Arc<Self>) {
loop {
let timeout = self.random_election_timeout();
time::sleep(timeout).await;
let last_heartbeat = *self.last_heartbeat.lock().await;
if last_heartbeat.elapsed() >= timeout {
let state = self.state.read().await.clone();
if !matches!(state, NodeState::Leader) {
self.start_election().await;
}
}
}
}
async fn heartbeat_loop(self: Arc<Self>) {
loop {
time::sleep(Duration::from_millis(self.config.heartbeat_interval_ms)).await;
if self.is_leader().await {
self.send_heartbeats().await;
}
}
}
pub async fn start_election(&self) {
*self.state.write().await = NodeState::Candidate;
let mut persistent = self.persistent.write().await;
persistent.current_term += 1;
persistent.voted_for = Some(self.config.node_id.clone());
let current_term = persistent.current_term;
let last_log_index = persistent.log.last().map(|e| e.index).unwrap_or(0);
let last_log_term = persistent.log.last().map(|e| e.term).unwrap_or(0);
drop(persistent);
let request = RequestVoteRequest {
term: current_term,
candidate_id: self.config.node_id.clone(),
last_log_index,
last_log_term,
};
let peers = self.peers.read().await.clone();
let votes_needed = (peers.len() + 1) / 2 + 1;
let mut votes = 1;
if votes >= votes_needed {
self.become_leader().await;
}
}
async fn become_leader(&self) {
*self.state.write().await = NodeState::Leader;
let persistent = self.persistent.read().await;
let last_log_index = persistent.log.last().map(|e| e.index).unwrap_or(0);
drop(persistent);
let peers = self.peers.read().await.clone();
let mut next_index = HashMap::new();
let mut match_index = HashMap::new();
for peer in peers.iter() {
next_index.insert(peer.clone(), last_log_index + 1);
match_index.insert(peer.clone(), 0);
}
*self.leader_state.lock().await = Some(LeaderState {
next_index,
match_index,
});
self.append_entry(Command::NoOp).await.ok();
}
async fn send_heartbeats(&self) {
let persistent = self.persistent.read().await;
let term = persistent.current_term;
let commit_index = self.volatile.lock().await.commit_index;
drop(persistent);
let peers = self.peers.read().await.clone();
for _peer in peers.iter() {
let _request = AppendEntriesRequest {
term,
leader_id: self.config.node_id.clone(),
prev_log_index: 0,
prev_log_term: 0,
entries: vec![],
leader_commit: commit_index,
};
}
}
pub async fn append_entry(&self, command: Command) -> Result<LogIndex, String> {
if !self.is_leader().await {
return Err("Not the leader".to_string());
}
let mut persistent = self.persistent.write().await;
let index = persistent.log.last().map(|e| e.index + 1).unwrap_or(1);
let entry = LogEntry {
term: persistent.current_term,
index,
command,
timestamp: SystemTime::now(),
};
persistent.log.push(entry);
Ok(index)
}
pub async fn handle_request_vote(&self, request: RequestVoteRequest) -> RequestVoteResponse {
let mut persistent = self.persistent.write().await;
if request.term > persistent.current_term {
persistent.current_term = request.term;
persistent.voted_for = None;
*self.state.write().await = NodeState::Follower;
}
let mut vote_granted = false;
if request.term == persistent.current_term {
let can_vote = persistent.voted_for.is_none()
|| persistent.voted_for.as_ref() == Some(&request.candidate_id);
if can_vote {
let our_last_log_index = persistent.log.last().map(|e| e.index).unwrap_or(0);
let our_last_log_term = persistent.log.last().map(|e| e.term).unwrap_or(0);
let log_ok = request.last_log_term > our_last_log_term
|| (request.last_log_term == our_last_log_term
&& request.last_log_index >= our_last_log_index);
if log_ok {
persistent.voted_for = Some(request.candidate_id);
vote_granted = true;
*self.last_heartbeat.lock().await = Instant::now();
}
}
}
RequestVoteResponse {
term: persistent.current_term,
vote_granted,
}
}
pub async fn handle_append_entries(
&self,
request: AppendEntriesRequest,
) -> AppendEntriesResponse {
let mut persistent = self.persistent.write().await;
if request.term > persistent.current_term {
persistent.current_term = request.term;
persistent.voted_for = None;
*self.state.write().await = NodeState::Follower;
}
*self.last_heartbeat.lock().await = Instant::now();
if request.term < persistent.current_term {
return AppendEntriesResponse {
term: persistent.current_term,
success: false,
conflict_index: None,
conflict_term: None,
};
}
if request.prev_log_index > 0 {
if let Some(entry) = persistent.log.get((request.prev_log_index - 1) as usize) {
if entry.term != request.prev_log_term {
return AppendEntriesResponse {
term: persistent.current_term,
success: false,
conflict_index: Some(request.prev_log_index),
conflict_term: Some(entry.term),
};
}
} else {
return AppendEntriesResponse {
term: persistent.current_term,
success: false,
conflict_index: Some(persistent.log.len() as u64),
conflict_term: None,
};
}
}
let mut insert_index = request.prev_log_index as usize;
for entry in request.entries {
if insert_index < persistent.log.len() {
if persistent.log[insert_index].term != entry.term {
persistent.log.truncate(insert_index);
persistent.log.push(entry);
}
} else {
persistent.log.push(entry);
}
insert_index += 1;
}
if request.leader_commit > self.volatile.lock().await.commit_index {
let new_commit_index = request
.leader_commit
.min(persistent.log.last().map(|e| e.index).unwrap_or(0));
self.volatile.lock().await.commit_index = new_commit_index;
}
AppendEntriesResponse {
term: persistent.current_term,
success: true,
conflict_index: None,
conflict_term: None,
}
}
pub async fn get_entries_to_apply(&self) -> Vec<LogEntry> {
let mut volatile = self.volatile.lock().await;
let persistent = self.persistent.read().await;
let mut entries = Vec::new();
while volatile.last_applied < volatile.commit_index {
volatile.last_applied += 1;
if let Some(entry) = persistent.log.get((volatile.last_applied - 1) as usize) {
entries.push(entry.clone());
}
}
entries
}
fn random_election_timeout(&self) -> Duration {
use rand::Rng;
let mut rng = rand::thread_rng();
let timeout_ms = rng
.gen_range(self.config.election_timeout_min_ms..=self.config.election_timeout_max_ms);
Duration::from_millis(timeout_ms)
}
pub async fn log_stats(&self) -> LogStats {
let persistent = self.persistent.read().await;
let volatile = self.volatile.lock().await;
LogStats {
total_entries: persistent.log.len(),
committed_entries: volatile.commit_index as usize,
applied_entries: volatile.last_applied as usize,
current_term: persistent.current_term,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogStats {
pub total_entries: usize,
pub committed_entries: usize,
pub applied_entries: usize,
pub current_term: Term,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_raft_node_creation() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
assert_eq!(node.state().await, NodeState::Follower);
assert_eq!(node.current_term().await, 0);
}
#[tokio::test]
async fn test_leader_election() {
let config = RaftConfig {
node_id: "node-1".to_string(),
peers: vec![], ..Default::default()
};
let node = RaftNode::new(config);
node.start_election().await;
assert_eq!(node.state().await, NodeState::Leader);
assert!(node.is_leader().await);
}
#[tokio::test]
async fn test_append_entry() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
*node.state.write().await = NodeState::Leader;
node.persistent.write().await.current_term = 1;
let command = Command::Insert {
id: "vec1".to_string(),
vector: vec![1.0, 2.0, 3.0],
metadata: serde_json::json!({}),
};
let result = node.append_entry(command).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1);
}
#[tokio::test]
async fn test_request_vote_grant() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let request = RequestVoteRequest {
term: 1,
candidate_id: "node-2".to_string(),
last_log_index: 0,
last_log_term: 0,
};
let response = node.handle_request_vote(request).await;
assert!(response.vote_granted);
assert_eq!(response.term, 1);
}
#[tokio::test]
async fn test_request_vote_deny_old_term() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
node.persistent.write().await.current_term = 2;
let request = RequestVoteRequest {
term: 1, candidate_id: "node-2".to_string(),
last_log_index: 0,
last_log_term: 0,
};
let response = node.handle_request_vote(request).await;
assert!(!response.vote_granted);
assert_eq!(response.term, 2);
}
#[tokio::test]
async fn test_append_entries_heartbeat() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let request = AppendEntriesRequest {
term: 1,
leader_id: "node-leader".to_string(),
prev_log_index: 0,
prev_log_term: 0,
entries: vec![], leader_commit: 0,
};
let response = node.handle_append_entries(request).await;
assert!(response.success);
assert_eq!(response.term, 1);
}
#[tokio::test]
async fn test_append_entries_with_entry() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let entry = LogEntry {
term: 1,
index: 1,
command: Command::NoOp,
timestamp: SystemTime::now(),
};
let request = AppendEntriesRequest {
term: 1,
leader_id: "node-leader".to_string(),
prev_log_index: 0,
prev_log_term: 0,
entries: vec![entry],
leader_commit: 0,
};
let response = node.handle_append_entries(request).await;
assert!(response.success);
let persistent = node.persistent.read().await;
assert_eq!(persistent.log.len(), 1);
}
#[tokio::test]
async fn test_log_stats() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
*node.state.write().await = NodeState::Leader;
node.persistent.write().await.current_term = 1;
node.append_entry(Command::NoOp).await.unwrap();
node.append_entry(Command::NoOp).await.unwrap();
let stats = node.log_stats().await;
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.current_term, 1);
}
#[tokio::test]
async fn test_commit_and_apply() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
*node.state.write().await = NodeState::Leader;
node.persistent.write().await.current_term = 1;
node.append_entry(Command::NoOp).await.unwrap();
node.append_entry(Command::NoOp).await.unwrap();
node.volatile.lock().await.commit_index = 2;
let entries = node.get_entries_to_apply().await;
assert_eq!(entries.len(), 2);
let stats = node.log_stats().await;
assert_eq!(stats.applied_entries, 2);
}
#[tokio::test]
async fn test_term_update_on_higher_term() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
assert_eq!(node.current_term().await, 0);
let request = RequestVoteRequest {
term: 5,
candidate_id: "node-2".to_string(),
last_log_index: 0,
last_log_term: 0,
};
node.handle_request_vote(request).await;
assert_eq!(node.current_term().await, 5);
}
#[tokio::test]
async fn test_leader_id() {
let config = RaftConfig {
node_id: "node-1".to_string(),
..Default::default()
};
let node = RaftNode::new(config);
assert_eq!(node.leader_id().await, None);
*node.state.write().await = NodeState::Leader;
assert_eq!(node.leader_id().await, Some("node-1".to_string()));
}
#[tokio::test]
async fn test_log_replication_conflict() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let mut persistent = node.persistent.write().await;
persistent.log.push(LogEntry {
term: 1,
index: 1,
command: Command::NoOp,
timestamp: SystemTime::now(),
});
drop(persistent);
let request = AppendEntriesRequest {
term: 2,
leader_id: "node-leader".to_string(),
prev_log_index: 1,
prev_log_term: 2, entries: vec![],
leader_commit: 0,
};
let response = node.handle_append_entries(request).await;
assert!(!response.success);
assert_eq!(response.conflict_index, Some(1));
assert_eq!(response.conflict_term, Some(1));
}
#[tokio::test]
async fn test_already_voted() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let request1 = RequestVoteRequest {
term: 1,
candidate_id: "node-2".to_string(),
last_log_index: 0,
last_log_term: 0,
};
let response1 = node.handle_request_vote(request1).await;
assert!(response1.vote_granted);
let request2 = RequestVoteRequest {
term: 1,
candidate_id: "node-3".to_string(),
last_log_index: 0,
last_log_term: 0,
};
let response2 = node.handle_request_vote(request2).await;
assert!(!response2.vote_granted); }
#[tokio::test]
async fn test_candidate_log_not_up_to_date() {
let config = RaftConfig::default();
let node = RaftNode::new(config);
let mut persistent = node.persistent.write().await;
persistent.log.push(LogEntry {
term: 2,
index: 1,
command: Command::NoOp,
timestamp: SystemTime::now(),
});
drop(persistent);
let request = RequestVoteRequest {
term: 3,
candidate_id: "node-2".to_string(),
last_log_index: 0,
last_log_term: 1, };
let response = node.handle_request_vote(request).await;
assert!(!response.vote_granted);
}
}