use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
use tokio::sync::{Mutex, RwLock, oneshot};
use crate::replication::protocol::ReplicationMessage;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeRole {
Master,
Slave,
Candidate,
}
#[derive(Debug)]
pub enum NodeError {
ChannelError(String),
ApplyError(String),
}
pub struct ReplicationNode {
pub node_id: u32,
pub cluster_size: usize,
role: Arc<RwLock<NodeRole>>,
tx: broadcast::Sender<ReplicationMessage>,
last_lsn: Arc<AtomicU64>,
current_term: Arc<AtomicU64>,
voted_for: Arc<Mutex<Option<u32>>>,
votes_received: Arc<Mutex<u32>>,
last_heartbeat: Arc<RwLock<Instant>>,
quorum_tracker: Arc<QuorumAckTracker>,
pub quorum_write_timeout: Duration,
}
struct QuorumAckTracker {
pending: DashMap<u64, (u32, Vec<oneshot::Sender<()>>)>,
}
impl QuorumAckTracker {
fn new() -> Self {
Self {
pending: DashMap::new(),
}
}
fn register(&self, lsn: u64) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
self.pending
.entry(lsn)
.or_insert_with(|| (0, Vec::new()))
.1
.push(tx);
rx
}
fn ack(&self, lsn: u64, quorum: u32) {
if let Some(mut entry) = self.pending.get_mut(&lsn) {
entry.0 += 1;
if entry.0 >= quorum {
let senders: Vec<_> = entry.1.drain(..).collect();
drop(entry);
self.pending.remove(&lsn);
for s in senders {
let _ = s.send(());
}
}
}
}
}
impl ReplicationNode {
pub fn new(
node_id: u32,
cluster_size: usize,
initial_role: NodeRole,
tx: broadcast::Sender<ReplicationMessage>,
) -> Self {
Self {
node_id,
cluster_size,
role: Arc::new(RwLock::new(initial_role)),
tx,
last_lsn: Arc::new(AtomicU64::new(0)),
current_term: Arc::new(AtomicU64::new(0)),
voted_for: Arc::new(Mutex::new(None)),
votes_received: Arc::new(Mutex::new(0)),
last_heartbeat: Arc::new(RwLock::new(Instant::now())),
quorum_tracker: Arc::new(QuorumAckTracker::new()),
quorum_write_timeout: Duration::from_secs(5),
}
}
pub fn new_from_config(
node_id: u32,
initial_role: NodeRole,
tx: broadcast::Sender<ReplicationMessage>,
config: &crate::replication::transport::ReplicationConfig,
) -> Self {
let mut node = Self::new(node_id, config.cluster_size, initial_role, tx);
node.quorum_write_timeout = config.quorum_write_timeout;
node
}
fn quorum(&self) -> u32 {
(self.cluster_size / 2 + 1) as u32
}
pub async fn role(&self) -> NodeRole {
*self.role.read().await
}
pub fn term(&self) -> u64 {
self.current_term.load(Ordering::SeqCst)
}
pub async fn replicate(&self, data: Vec<u8>) -> Result<u64, String> {
let role = self.role().await;
if role != NodeRole::Master {
return Err("Only Master can replicate".to_string());
}
let lsn = self.last_lsn.fetch_add(1, Ordering::SeqCst) + 1;
let msg = ReplicationMessage::WalEntry {
node_id: self.node_id,
lsn,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_micros() as u64,
data,
};
if self.quorum() <= 1 {
let _ = self.tx.send(msg);
return Ok(lsn);
}
let ack_rx = self.quorum_tracker.register(lsn);
let _ = self.tx.send(msg);
tokio::time::timeout(self.quorum_write_timeout, ack_rx)
.await
.map_err(|_| format!("quorum write timeout for LSN {lsn}"))?
.map_err(|_| format!("quorum tracker channel dropped for LSN {lsn}"))?;
Ok(lsn)
}
pub async fn send_heartbeat(&self) {
if self.role().await != NodeRole::Master {
return;
}
let lsn = self.last_lsn.load(Ordering::SeqCst);
let _ = self.tx.send(ReplicationMessage::Heartbeat {
node_id: self.node_id,
lsn,
});
}
pub async fn run_receiver_loop<F>(
&self,
mut rx: broadcast::Receiver<ReplicationMessage>,
mut apply_fn: F,
) -> Result<(), NodeError>
where
F: FnMut(u64, u64, &[u8]) -> Result<(), String>,
{
loop {
match rx.recv().await {
Ok(msg) => self.handle_message(msg, &mut apply_fn).await?,
Err(broadcast::error::RecvError::Closed) => {
return Err(NodeError::ChannelError("Channel closed".into()));
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
}
async fn handle_message<F>(
&self,
msg: ReplicationMessage,
apply_fn: &mut F,
) -> Result<(), NodeError>
where
F: FnMut(u64, u64, &[u8]) -> Result<(), String>,
{
match msg {
ReplicationMessage::Heartbeat { node_id, .. } => {
if node_id != self.node_id {
*self.last_heartbeat.write().await = Instant::now();
}
}
ReplicationMessage::WalEntry {
node_id,
lsn,
timestamp,
data,
} => {
if node_id == self.node_id {
return Ok(());
}
let local_lsn = self.last_lsn.load(Ordering::SeqCst);
if lsn > local_lsn {
apply_fn(lsn, timestamp, &data).map_err(NodeError::ApplyError)?;
self.last_lsn.store(lsn, Ordering::SeqCst);
let _ = self.tx.send(ReplicationMessage::Acknowledge {
node_id: self.node_id,
lsn,
});
}
}
ReplicationMessage::VoteRequest {
node_id: candidate_id,
term,
last_lsn,
} => {
if candidate_id == self.node_id {
return Ok(());
}
let my_term = self.current_term.load(Ordering::SeqCst);
let my_lsn = self.last_lsn.load(Ordering::SeqCst);
let mut voted_for = self.voted_for.lock().await;
let grant = term >= my_term
&& (voted_for.is_none() || *voted_for == Some(candidate_id))
&& last_lsn >= my_lsn;
if grant {
*voted_for = Some(candidate_id);
self.current_term.store(term, Ordering::SeqCst);
}
let _ = self.tx.send(ReplicationMessage::VoteResponse {
node_id: candidate_id,
voter_id: self.node_id,
term,
granted: grant,
});
}
ReplicationMessage::VoteResponse {
node_id,
voter_id: _,
term,
granted,
} => {
if node_id != self.node_id {
return Ok(());
}
let my_term = self.current_term.load(Ordering::SeqCst);
if term != my_term {
return Ok(());
}
if granted && self.role().await == NodeRole::Candidate {
let mut votes = self.votes_received.lock().await;
*votes += 1;
if *votes >= self.quorum() {
let mut role = self.role.write().await;
*role = NodeRole::Master;
drop(role);
let _ = self.tx.send(ReplicationMessage::Promotion {
node_id: self.node_id,
term: my_term,
});
}
}
}
ReplicationMessage::Promotion { node_id, term } => {
if node_id == self.node_id {
return Ok(());
}
let mut role = self.role.write().await;
if term >= self.current_term.load(Ordering::SeqCst) {
self.current_term.store(term, Ordering::SeqCst);
*role = NodeRole::Slave;
*self.last_heartbeat.write().await = Instant::now();
}
}
ReplicationMessage::Acknowledge { lsn, .. } => {
self.quorum_tracker.ack(lsn, self.quorum());
}
_ => {}
}
Ok(())
}
pub async fn start_election(&self) -> bool {
if self.role().await == NodeRole::Master {
return false;
}
let elapsed = self.last_heartbeat.read().await.elapsed();
if elapsed < Duration::from_millis(200) {
return false; }
*self.role.write().await = NodeRole::Candidate;
let new_term = self.current_term.fetch_add(1, Ordering::SeqCst) + 1;
*self.voted_for.lock().await = Some(self.node_id); *self.votes_received.lock().await = 1;
let last_lsn = self.last_lsn.load(Ordering::SeqCst);
let _ = self.tx.send(ReplicationMessage::VoteRequest {
node_id: self.node_id,
term: new_term,
last_lsn,
});
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_single_node_wins_election() {
let (tx, mut rx) = broadcast::channel(32);
let node = ReplicationNode::new(1, 1, NodeRole::Slave, tx.clone());
tokio::time::sleep(Duration::from_millis(250)).await;
let started = node.start_election().await;
assert!(started, "선거 시작되어야 함");
let msg = rx.recv().await.unwrap();
match msg {
ReplicationMessage::VoteRequest {
node_id,
term,
last_lsn,
} => {
assert_eq!(node_id, 1);
let _ = node
.handle_message(
ReplicationMessage::VoteResponse {
node_id: 1,
voter_id: 1,
term,
granted: true,
},
&mut |_, _, _| Ok(()),
)
.await;
let _ = last_lsn;
}
_ => panic!("VoteRequest 예상"),
}
assert_eq!(node.role().await, NodeRole::Master);
}
#[tokio::test]
async fn test_quorum_requires_majority() {
let (tx, _rx) = broadcast::channel(32);
let node = ReplicationNode::new(1, 3, NodeRole::Candidate, tx.clone());
node.current_term.store(1, Ordering::SeqCst);
*node.votes_received.lock().await = 1;
node.handle_message(
ReplicationMessage::VoteResponse {
node_id: 1,
voter_id: 2,
term: 1,
granted: false,
},
&mut |_, _, _| Ok(()),
)
.await
.unwrap();
assert_eq!(
node.role().await,
NodeRole::Candidate,
"아직 Candidate여야 함"
);
node.handle_message(
ReplicationMessage::VoteResponse {
node_id: 1,
voter_id: 3,
term: 1,
granted: true,
},
&mut |_, _, _| Ok(()),
)
.await
.unwrap();
assert_eq!(
node.role().await,
NodeRole::Master,
"과반 획득 후 Master여야 함"
);
}
#[tokio::test]
async fn test_higher_term_promotion_demotes_master() {
let (tx, _rx) = broadcast::channel(32);
let node = ReplicationNode::new(1, 3, NodeRole::Master, tx.clone());
node.current_term.store(1, Ordering::SeqCst);
node.handle_message(
ReplicationMessage::Promotion {
node_id: 2,
term: 2,
},
&mut |_, _, _| Ok(()),
)
.await
.unwrap();
assert_eq!(
node.role().await,
NodeRole::Slave,
"더 높은 term의 Promotion → Slave 강등"
);
assert_eq!(node.term(), 2);
}
#[tokio::test]
async fn test_replicate_only_as_master() {
let (tx, _rx) = broadcast::channel(16);
let node = ReplicationNode::new(1, 1, NodeRole::Slave, tx.clone());
let result = node.replicate(b"data".to_vec()).await;
assert!(result.is_err(), "Slave는 복제 불가");
}
#[tokio::test]
async fn test_quorum_write_single_node() {
let (tx, _rx) = broadcast::channel(16);
let node = ReplicationNode::new(1, 1, NodeRole::Master, tx.clone());
let lsn = node.replicate(b"data".to_vec()).await;
assert_eq!(lsn, Ok(1), "단일 노드: quorum = 1 → 즉시 Ok(1)");
}
#[tokio::test]
async fn test_quorum_write_three_nodes() {
let (tx, _rx) = broadcast::channel(32);
let master = Arc::new(ReplicationNode::new(1, 3, NodeRole::Master, tx.clone()));
let slave2 = Arc::new(ReplicationNode::new(2, 3, NodeRole::Slave, tx.clone()));
let slave3 = Arc::new(ReplicationNode::new(3, 3, NodeRole::Slave, tx.clone()));
let master_rx_loop = Arc::clone(&master);
let rx_master = tx.subscribe();
tokio::spawn(async move {
master_rx_loop
.run_receiver_loop(rx_master, |_, _, _| Ok(()))
.await
.ok();
});
let slave2_clone = Arc::clone(&slave2);
let rx2 = tx.subscribe();
tokio::spawn(async move {
slave2_clone
.run_receiver_loop(rx2, |_, _, _| Ok(()))
.await
.ok();
});
let slave3_clone = Arc::clone(&slave3);
let rx3 = tx.subscribe();
tokio::spawn(async move {
slave3_clone
.run_receiver_loop(rx3, |_, _, _| Ok(()))
.await
.ok();
});
let lsn = master.replicate(b"quorum_data".to_vec()).await;
assert_eq!(lsn, Ok(1), "quorum 달성 후 LSN 1 반환");
}
#[tokio::test]
async fn test_quorum_write_timeout() {
let (tx, _rx) = broadcast::channel(16);
let mut node = ReplicationNode::new(1, 3, NodeRole::Master, tx.clone());
node.quorum_write_timeout = Duration::from_millis(50);
let result = node.replicate(b"data".to_vec()).await;
assert!(result.is_err(), "timeout → Err 반환 필요");
assert!(
result.unwrap_err().contains("quorum write timeout"),
"에러 메시지에 'quorum write timeout' 포함되어야 함"
);
}
#[tokio::test]
async fn test_new_from_config_injects_timeout() {
use crate::replication::transport::ReplicationConfig;
let config = ReplicationConfig {
quorum_write_timeout: Duration::from_millis(123),
..ReplicationConfig::default()
};
let (tx, _rx) = broadcast::channel(16);
let node = ReplicationNode::new_from_config(1, NodeRole::Master, tx, &config);
assert_eq!(node.quorum_write_timeout, Duration::from_millis(123));
}
}