dbx_core/replication/
slave.rs1use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use tokio::sync::broadcast;
6
7use crate::replication::protocol::ReplicationMessage;
8
9#[derive(Debug)]
11pub enum ReplayError {
12 ChannelError(String),
14 ApplyError(String),
16}
17
18impl std::fmt::Display for ReplayError {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 ReplayError::ChannelError(s) => write!(f, "channel error: {s}"),
22 ReplayError::ApplyError(s) => write!(f, "apply error: {s}"),
23 }
24 }
25}
26
27pub struct ReplicationSlave {
34 last_applied_lsn: Arc<AtomicU64>,
36}
37
38impl ReplicationSlave {
39 pub fn new() -> Self {
40 Self {
41 last_applied_lsn: Arc::new(AtomicU64::new(u64::MAX)), }
43 }
44
45 pub fn last_applied_lsn(&self) -> Option<u64> {
47 let v = self.last_applied_lsn.load(Ordering::SeqCst);
48 if v == u64::MAX { None } else { Some(v) }
49 }
50
51 pub async fn replay_n<F>(
56 &self,
57 rx: &mut broadcast::Receiver<ReplicationMessage>,
58 count: usize,
59 mut apply_fn: F,
60 ) -> Result<usize, ReplayError>
61 where
62 F: FnMut(u64, &[u8]) -> Result<(), String>,
63 {
64 let mut applied = 0;
65
66 loop {
67 if count > 0 && applied >= count {
68 break;
69 }
70
71 let msg = rx
72 .recv()
73 .await
74 .map_err(|e| ReplayError::ChannelError(e.to_string()))?;
75
76 if let ReplicationMessage::WalEntry { lsn, data, .. } = msg {
77 let last = self.last_applied_lsn.load(Ordering::SeqCst);
79 if last != u64::MAX && lsn <= last {
80 continue;
82 }
83
84 apply_fn(lsn, &data).map_err(ReplayError::ApplyError)?;
85 self.last_applied_lsn.store(lsn, Ordering::SeqCst);
86 applied += 1;
87 }
88 }
90
91 Ok(applied)
92 }
93}
94
95impl Default for ReplicationSlave {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::replication::master::ReplicationMaster;
105
106 #[tokio::test]
107 async fn test_slave_replays_wal_entries() {
108 let (master, mut rx) = ReplicationMaster::new(16);
109 let slave = ReplicationSlave::new();
110
111 master.replicate(b"entry_0".to_vec());
113 master.replicate(b"entry_1".to_vec());
114 master.replicate(b"entry_2".to_vec());
115
116 let mut replayed = Vec::new();
117 let count = slave
118 .replay_n(&mut rx, 3, |lsn, data| {
119 replayed.push((lsn, data.to_vec()));
120 Ok(())
121 })
122 .await
123 .unwrap();
124
125 assert_eq!(count, 3);
126 assert_eq!(slave.last_applied_lsn(), Some(2));
127 assert_eq!(replayed[0], (0, b"entry_0".to_vec()));
128 assert_eq!(replayed[2], (2, b"entry_2".to_vec()));
129 }
130
131 #[tokio::test]
132 async fn test_slave_apply_error_propagates() {
133 let (master, mut rx) = ReplicationMaster::new(16);
134 let slave = ReplicationSlave::new();
135
136 master.replicate(b"bad_data".to_vec());
137
138 let result = slave
139 .replay_n(&mut rx, 1, |_lsn, _data| {
140 Err("intentional error".to_string())
141 })
142 .await;
143
144 assert!(matches!(result, Err(ReplayError::ApplyError(_))));
145 }
146}