1use dashmap::DashMap;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::{Duration, Instant};
14use tokio::sync::broadcast;
15use tokio::sync::{Mutex, RwLock, oneshot};
16
17use crate::replication::protocol::ReplicationMessage;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum NodeRole {
22 Master,
23 Slave,
24 Candidate,
26}
27
28#[derive(Debug)]
30pub enum NodeError {
31 ChannelError(String),
32 ApplyError(String),
33}
34
35pub struct ReplicationNode {
37 pub node_id: u32,
39 pub cluster_size: usize,
41 role: Arc<RwLock<NodeRole>>,
43 tx: broadcast::Sender<ReplicationMessage>,
45 last_lsn: Arc<AtomicU64>,
47 current_term: Arc<AtomicU64>,
49 voted_for: Arc<Mutex<Option<u32>>>,
51 votes_received: Arc<Mutex<u32>>,
53 last_heartbeat: Arc<RwLock<Instant>>,
55 quorum_tracker: Arc<QuorumAckTracker>,
57 pub quorum_write_timeout: Duration,
59}
60
61struct QuorumAckTracker {
70 pending: DashMap<u64, (u32, Vec<oneshot::Sender<()>>)>,
72}
73
74impl QuorumAckTracker {
75 fn new() -> Self {
76 Self {
77 pending: DashMap::new(),
78 }
79 }
80
81 fn register(&self, lsn: u64) -> oneshot::Receiver<()> {
84 let (tx, rx) = oneshot::channel();
85 self.pending
86 .entry(lsn)
87 .or_insert_with(|| (0, Vec::new()))
88 .1
89 .push(tx);
90 rx
91 }
92
93 fn ack(&self, lsn: u64, quorum: u32) {
95 if let Some(mut entry) = self.pending.get_mut(&lsn) {
96 entry.0 += 1;
97 if entry.0 >= quorum {
98 let senders: Vec<_> = entry.1.drain(..).collect();
99 drop(entry);
100 self.pending.remove(&lsn);
101 for s in senders {
102 let _ = s.send(());
103 }
104 }
105 }
106 }
107}
108
109impl ReplicationNode {
114 pub fn new(
116 node_id: u32,
117 cluster_size: usize,
118 initial_role: NodeRole,
119 tx: broadcast::Sender<ReplicationMessage>,
120 ) -> Self {
121 Self {
122 node_id,
123 cluster_size,
124 role: Arc::new(RwLock::new(initial_role)),
125 tx,
126 last_lsn: Arc::new(AtomicU64::new(0)),
127 current_term: Arc::new(AtomicU64::new(0)),
128 voted_for: Arc::new(Mutex::new(None)),
129 votes_received: Arc::new(Mutex::new(0)),
130 last_heartbeat: Arc::new(RwLock::new(Instant::now())),
131 quorum_tracker: Arc::new(QuorumAckTracker::new()),
132 quorum_write_timeout: Duration::from_secs(5),
133 }
134 }
135
136 pub fn new_from_config(
140 node_id: u32,
141 initial_role: NodeRole,
142 tx: broadcast::Sender<ReplicationMessage>,
143 config: &crate::replication::transport::ReplicationConfig,
144 ) -> Self {
145 let mut node = Self::new(node_id, config.cluster_size, initial_role, tx);
146 node.quorum_write_timeout = config.quorum_write_timeout;
147 node
148 }
149
150 fn quorum(&self) -> u32 {
152 (self.cluster_size / 2 + 1) as u32
153 }
154
155 pub async fn role(&self) -> NodeRole {
157 *self.role.read().await
158 }
159
160 pub fn term(&self) -> u64 {
162 self.current_term.load(Ordering::SeqCst)
163 }
164
165 pub async fn replicate(&self, data: Vec<u8>) -> Result<u64, String> {
171 let role = self.role().await;
172 if role != NodeRole::Master {
173 return Err("Only Master can replicate".to_string());
174 }
175
176 let lsn = self.last_lsn.fetch_add(1, Ordering::SeqCst) + 1;
177 let msg = ReplicationMessage::WalEntry {
178 node_id: self.node_id,
179 lsn,
180 timestamp: std::time::SystemTime::now()
181 .duration_since(std::time::UNIX_EPOCH)
182 .unwrap()
183 .as_micros() as u64,
184 data,
185 };
186
187 if self.quorum() <= 1 {
189 let _ = self.tx.send(msg);
190 return Ok(lsn);
191 }
192
193 let ack_rx = self.quorum_tracker.register(lsn);
196 let _ = self.tx.send(msg);
197
198 tokio::time::timeout(self.quorum_write_timeout, ack_rx)
200 .await
201 .map_err(|_| format!("quorum write timeout for LSN {lsn}"))?
202 .map_err(|_| format!("quorum tracker channel dropped for LSN {lsn}"))?;
203
204 Ok(lsn)
205 }
206
207 pub async fn send_heartbeat(&self) {
209 if self.role().await != NodeRole::Master {
210 return;
211 }
212 let lsn = self.last_lsn.load(Ordering::SeqCst);
213 let _ = self.tx.send(ReplicationMessage::Heartbeat {
214 node_id: self.node_id,
215 lsn,
216 });
217 }
218
219 pub async fn run_receiver_loop<F>(
221 &self,
222 mut rx: broadcast::Receiver<ReplicationMessage>,
223 mut apply_fn: F,
224 ) -> Result<(), NodeError>
225 where
226 F: FnMut(u64, u64, &[u8]) -> Result<(), String>,
227 {
228 loop {
229 match rx.recv().await {
230 Ok(msg) => self.handle_message(msg, &mut apply_fn).await?,
231 Err(broadcast::error::RecvError::Closed) => {
232 return Err(NodeError::ChannelError("Channel closed".into()));
233 }
234 Err(broadcast::error::RecvError::Lagged(_)) => continue,
235 }
236 }
237 }
238
239 async fn handle_message<F>(
241 &self,
242 msg: ReplicationMessage,
243 apply_fn: &mut F,
244 ) -> Result<(), NodeError>
245 where
246 F: FnMut(u64, u64, &[u8]) -> Result<(), String>,
247 {
248 match msg {
249 ReplicationMessage::Heartbeat { node_id, .. } => {
251 if node_id != self.node_id {
252 *self.last_heartbeat.write().await = Instant::now();
253 }
254 }
255
256 ReplicationMessage::WalEntry {
258 node_id,
259 lsn,
260 timestamp,
261 data,
262 } => {
263 if node_id == self.node_id {
264 return Ok(());
265 }
266 let local_lsn = self.last_lsn.load(Ordering::SeqCst);
267 if lsn > local_lsn {
268 apply_fn(lsn, timestamp, &data).map_err(NodeError::ApplyError)?;
269 self.last_lsn.store(lsn, Ordering::SeqCst);
270 let _ = self.tx.send(ReplicationMessage::Acknowledge {
272 node_id: self.node_id,
273 lsn,
274 });
275 }
276 }
277
278 ReplicationMessage::VoteRequest {
280 node_id: candidate_id,
281 term,
282 last_lsn,
283 } => {
284 if candidate_id == self.node_id {
285 return Ok(());
286 }
287 let my_term = self.current_term.load(Ordering::SeqCst);
288 let my_lsn = self.last_lsn.load(Ordering::SeqCst);
289
290 let mut voted_for = self.voted_for.lock().await;
291 let grant = term >= my_term
294 && (voted_for.is_none() || *voted_for == Some(candidate_id))
295 && last_lsn >= my_lsn;
296
297 if grant {
298 *voted_for = Some(candidate_id);
299 self.current_term.store(term, Ordering::SeqCst);
301 }
302
303 let _ = self.tx.send(ReplicationMessage::VoteResponse {
304 node_id: candidate_id,
305 voter_id: self.node_id,
306 term,
307 granted: grant,
308 });
309 }
310
311 ReplicationMessage::VoteResponse {
313 node_id,
314 voter_id: _,
315 term,
316 granted,
317 } => {
318 if node_id != self.node_id {
319 return Ok(());
320 }
321 let my_term = self.current_term.load(Ordering::SeqCst);
322 if term != my_term {
324 return Ok(());
325 }
326 if granted && self.role().await == NodeRole::Candidate {
327 let mut votes = self.votes_received.lock().await;
328 *votes += 1;
329 if *votes >= self.quorum() {
330 let mut role = self.role.write().await;
332 *role = NodeRole::Master;
333 drop(role);
334 let _ = self.tx.send(ReplicationMessage::Promotion {
335 node_id: self.node_id,
336 term: my_term,
337 });
338 }
339 }
340 }
341
342 ReplicationMessage::Promotion { node_id, term } => {
344 if node_id == self.node_id {
345 return Ok(());
346 }
347 let mut role = self.role.write().await;
348 if term >= self.current_term.load(Ordering::SeqCst) {
350 self.current_term.store(term, Ordering::SeqCst);
351 *role = NodeRole::Slave;
352 *self.last_heartbeat.write().await = Instant::now();
353 }
354 }
355
356 ReplicationMessage::Acknowledge { lsn, .. } => {
358 self.quorum_tracker.ack(lsn, self.quorum());
359 }
360
361 _ => {}
362 }
363 Ok(())
364 }
365
366 pub async fn start_election(&self) -> bool {
373 if self.role().await == NodeRole::Master {
374 return false;
375 }
376 let elapsed = self.last_heartbeat.read().await.elapsed();
377 if elapsed < Duration::from_millis(200) {
378 return false; }
380
381 *self.role.write().await = NodeRole::Candidate;
383
384 let new_term = self.current_term.fetch_add(1, Ordering::SeqCst) + 1;
386 *self.voted_for.lock().await = Some(self.node_id); *self.votes_received.lock().await = 1; let last_lsn = self.last_lsn.load(Ordering::SeqCst);
390 let _ = self.tx.send(ReplicationMessage::VoteRequest {
391 node_id: self.node_id,
392 term: new_term,
393 last_lsn,
394 });
395
396 true
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[tokio::test]
405 async fn test_single_node_wins_election() {
406 let (tx, mut rx) = broadcast::channel(32);
408 let node = ReplicationNode::new(1, 1, NodeRole::Slave, tx.clone());
409
410 tokio::time::sleep(Duration::from_millis(250)).await;
412 let started = node.start_election().await;
413 assert!(started, "선거 시작되어야 함");
414
415 let msg = rx.recv().await.unwrap();
417 match msg {
418 ReplicationMessage::VoteRequest {
419 node_id,
420 term,
421 last_lsn,
422 } => {
423 assert_eq!(node_id, 1);
424 let _ = node
426 .handle_message(
427 ReplicationMessage::VoteResponse {
428 node_id: 1,
429 voter_id: 1,
430 term,
431 granted: true,
432 },
433 &mut |_, _, _| Ok(()),
434 )
435 .await;
436 let _ = last_lsn;
437 }
438 _ => panic!("VoteRequest 예상"),
439 }
440
441 assert_eq!(node.role().await, NodeRole::Master);
442 }
443
444 #[tokio::test]
445 async fn test_quorum_requires_majority() {
446 let (tx, _rx) = broadcast::channel(32);
448 let node = ReplicationNode::new(1, 3, NodeRole::Candidate, tx.clone());
449 node.current_term.store(1, Ordering::SeqCst);
450 *node.votes_received.lock().await = 1; node.handle_message(
454 ReplicationMessage::VoteResponse {
455 node_id: 1,
456 voter_id: 2,
457 term: 1,
458 granted: false,
459 },
460 &mut |_, _, _| Ok(()),
461 )
462 .await
463 .unwrap();
464 assert_eq!(
465 node.role().await,
466 NodeRole::Candidate,
467 "아직 Candidate여야 함"
468 );
469
470 node.handle_message(
472 ReplicationMessage::VoteResponse {
473 node_id: 1,
474 voter_id: 3,
475 term: 1,
476 granted: true,
477 },
478 &mut |_, _, _| Ok(()),
479 )
480 .await
481 .unwrap();
482 assert_eq!(
483 node.role().await,
484 NodeRole::Master,
485 "과반 획득 후 Master여야 함"
486 );
487 }
488
489 #[tokio::test]
490 async fn test_higher_term_promotion_demotes_master() {
491 let (tx, _rx) = broadcast::channel(32);
493 let node = ReplicationNode::new(1, 3, NodeRole::Master, tx.clone());
494 node.current_term.store(1, Ordering::SeqCst);
495
496 node.handle_message(
497 ReplicationMessage::Promotion {
498 node_id: 2,
499 term: 2,
500 },
501 &mut |_, _, _| Ok(()),
502 )
503 .await
504 .unwrap();
505
506 assert_eq!(
507 node.role().await,
508 NodeRole::Slave,
509 "더 높은 term의 Promotion → Slave 강등"
510 );
511 assert_eq!(node.term(), 2);
512 }
513
514 #[tokio::test]
515 async fn test_replicate_only_as_master() {
516 let (tx, _rx) = broadcast::channel(16);
517 let node = ReplicationNode::new(1, 1, NodeRole::Slave, tx.clone());
518 let result = node.replicate(b"data".to_vec()).await;
519 assert!(result.is_err(), "Slave는 복제 불가");
520 }
521
522 #[tokio::test]
526 async fn test_quorum_write_single_node() {
527 let (tx, _rx) = broadcast::channel(16);
528 let node = ReplicationNode::new(1, 1, NodeRole::Master, tx.clone());
529
530 let lsn = node.replicate(b"data".to_vec()).await;
531 assert_eq!(lsn, Ok(1), "단일 노드: quorum = 1 → 즉시 Ok(1)");
532 }
533
534 #[tokio::test]
536 async fn test_quorum_write_three_nodes() {
537 let (tx, _rx) = broadcast::channel(32);
538
539 let master = Arc::new(ReplicationNode::new(1, 3, NodeRole::Master, tx.clone()));
541 let slave2 = Arc::new(ReplicationNode::new(2, 3, NodeRole::Slave, tx.clone()));
543 let slave3 = Arc::new(ReplicationNode::new(3, 3, NodeRole::Slave, tx.clone()));
544
545 let master_rx_loop = Arc::clone(&master);
547 let rx_master = tx.subscribe();
548 tokio::spawn(async move {
549 master_rx_loop
550 .run_receiver_loop(rx_master, |_, _, _| Ok(()))
551 .await
552 .ok();
553 });
554
555 let slave2_clone = Arc::clone(&slave2);
557 let rx2 = tx.subscribe();
558 tokio::spawn(async move {
559 slave2_clone
560 .run_receiver_loop(rx2, |_, _, _| Ok(()))
561 .await
562 .ok();
563 });
564
565 let slave3_clone = Arc::clone(&slave3);
567 let rx3 = tx.subscribe();
568 tokio::spawn(async move {
569 slave3_clone
570 .run_receiver_loop(rx3, |_, _, _| Ok(()))
571 .await
572 .ok();
573 });
574
575 let lsn = master.replicate(b"quorum_data".to_vec()).await;
577 assert_eq!(lsn, Ok(1), "quorum 달성 후 LSN 1 반환");
578 }
579
580 #[tokio::test]
582 async fn test_quorum_write_timeout() {
583 let (tx, _rx) = broadcast::channel(16);
584 let mut node = ReplicationNode::new(1, 3, NodeRole::Master, tx.clone());
585 node.quorum_write_timeout = Duration::from_millis(50);
587
588 let result = node.replicate(b"data".to_vec()).await;
589 assert!(result.is_err(), "timeout → Err 반환 필요");
590 assert!(
591 result.unwrap_err().contains("quorum write timeout"),
592 "에러 메시지에 'quorum write timeout' 포함되어야 함"
593 );
594 }
595
596 #[tokio::test]
598 async fn test_new_from_config_injects_timeout() {
599 use crate::replication::transport::ReplicationConfig;
600
601 let config = ReplicationConfig {
602 quorum_write_timeout: Duration::from_millis(123),
603 ..ReplicationConfig::default()
604 };
605 let (tx, _rx) = broadcast::channel(16);
606 let node = ReplicationNode::new_from_config(1, NodeRole::Master, tx, &config);
607 assert_eq!(node.quorum_write_timeout, Duration::from_millis(123));
608 }
609}