reddb_server/replication/
commit_waiter.rs1use std::collections::HashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Condvar, Mutex};
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Default)]
33struct State {
34 durable_lsn: HashMap<String, u64>,
38}
39
40#[derive(Debug, Default)]
44pub struct CommitWaiterMetrics {
45 pub reached_total: AtomicU64,
46 pub timed_out_total: AtomicU64,
47 pub not_required_total: AtomicU64,
48 pub last_wait_micros: AtomicU64,
51}
52
53#[derive(Debug)]
54pub struct CommitWaiter {
55 state: Mutex<State>,
56 cond: Condvar,
57 metrics: CommitWaiterMetrics,
58}
59
60impl Default for CommitWaiter {
61 fn default() -> Self {
62 Self {
63 state: Mutex::new(State::default()),
64 cond: Condvar::new(),
65 metrics: CommitWaiterMetrics::default(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum AwaitOutcome {
72 Reached(u32),
77 TimedOut { observed: u32, required: u32 },
81 NotRequired,
84}
85
86impl CommitWaiter {
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
95 let mut state = self.state.lock().expect("commit waiter mutex");
96 let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
97 if lsn > *entry {
98 *entry = lsn;
99 self.cond.notify_all();
100 }
101 }
102
103 pub fn drop_replica(&self, replica_id: &str) {
110 let mut state = self.state.lock().expect("commit waiter mutex");
111 if state.durable_lsn.remove(replica_id).is_some() {
112 self.cond.notify_all();
113 }
114 }
115
116 pub fn snapshot(&self) -> Vec<(String, u64)> {
119 let state = self.state.lock().expect("commit waiter mutex");
120 let mut v: Vec<(String, u64)> = state
121 .durable_lsn
122 .iter()
123 .map(|(k, v)| (k.clone(), *v))
124 .collect();
125 v.sort_by(|a, b| a.0.cmp(&b.0));
126 v
127 }
128
129 pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
138 if required == 0 {
139 self.metrics
140 .not_required_total
141 .fetch_add(1, Ordering::Relaxed);
142 return AwaitOutcome::NotRequired;
143 }
144 let started = Instant::now();
145 let deadline = started + timeout;
146 let mut state = self.state.lock().expect("commit waiter mutex");
147 loop {
148 let observed = count_at_or_past(&state.durable_lsn, target_lsn);
149 if observed >= required {
150 self.record_outcome_metrics(true, started);
151 return AwaitOutcome::Reached(observed);
152 }
153 let now = Instant::now();
154 if now >= deadline {
155 self.record_outcome_metrics(false, started);
156 return AwaitOutcome::TimedOut { observed, required };
157 }
158 let remaining = deadline - now;
159 let (next_state, _wait_result) = self
160 .cond
161 .wait_timeout(state, remaining)
162 .expect("commit waiter condvar");
163 state = next_state;
164 }
165 }
166
167 fn record_outcome_metrics(&self, reached: bool, started: Instant) {
168 let elapsed = (started.elapsed().as_micros() as u64).max(1);
169 self.metrics
170 .last_wait_micros
171 .store(elapsed, Ordering::Relaxed);
172 if reached {
173 self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
174 } else {
175 self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
176 }
177 }
178
179 pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
181 (
182 self.metrics.reached_total.load(Ordering::Relaxed),
183 self.metrics.timed_out_total.load(Ordering::Relaxed),
184 self.metrics.not_required_total.load(Ordering::Relaxed),
185 self.metrics.last_wait_micros.load(Ordering::Relaxed),
186 )
187 }
188}
189
190fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
191 map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use std::sync::Arc;
198 use std::thread;
199
200 #[test]
201 fn required_zero_is_immediate_no_op() {
202 let w = CommitWaiter::new();
203 let r = w.await_acks(100, 0, Duration::from_secs(60));
204 assert_eq!(r, AwaitOutcome::NotRequired);
205 }
206
207 #[test]
208 fn reaches_threshold_with_existing_acks() {
209 let w = CommitWaiter::new();
210 w.record_replica_ack("a", 200);
211 w.record_replica_ack("b", 200);
212 let r = w.await_acks(150, 2, Duration::from_millis(10));
213 assert_eq!(r, AwaitOutcome::Reached(2));
214 }
215
216 #[test]
217 fn times_out_when_no_one_has_acked() {
218 let w = CommitWaiter::new();
219 w.record_replica_ack("a", 100);
220 let r = w.await_acks(500, 1, Duration::from_millis(20));
221 match r {
222 AwaitOutcome::TimedOut { observed, required } => {
223 assert_eq!(observed, 0);
224 assert_eq!(required, 1);
225 }
226 other => panic!("expected TimedOut, got {other:?}"),
227 }
228 }
229
230 #[test]
231 fn ack_arriving_during_wait_unblocks_caller() {
232 let w = Arc::new(CommitWaiter::new());
233 let waiter = Arc::clone(&w);
234 let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
235 thread::sleep(Duration::from_millis(50));
237 w.record_replica_ack("late", 1000);
238 let outcome = handle.join().expect("waiter thread");
239 assert_eq!(outcome, AwaitOutcome::Reached(1));
240 }
241
242 #[test]
243 fn ack_idempotent_does_not_double_count() {
244 let w = CommitWaiter::new();
245 w.record_replica_ack("a", 50);
246 w.record_replica_ack("a", 50);
247 w.record_replica_ack("a", 50);
248 let r = w.await_acks(50, 1, Duration::from_millis(5));
249 assert_eq!(r, AwaitOutcome::Reached(1));
250 let r2 = w.await_acks(50, 2, Duration::from_millis(20));
252 assert!(matches!(
253 r2,
254 AwaitOutcome::TimedOut {
255 observed: 1,
256 required: 2
257 }
258 ));
259 }
260
261 #[test]
262 fn ack_only_advances_lsn_forward() {
263 let w = CommitWaiter::new();
264 w.record_replica_ack("a", 200);
265 w.record_replica_ack("a", 100);
267 let snap = w.snapshot();
268 assert_eq!(snap, vec![("a".to_string(), 200)]);
269 }
270
271 #[test]
272 fn drop_replica_removes_from_count() {
273 let w = CommitWaiter::new();
274 w.record_replica_ack("a", 100);
275 w.record_replica_ack("b", 100);
276 w.drop_replica("a");
277 let r = w.await_acks(100, 2, Duration::from_millis(20));
278 assert!(matches!(
279 r,
280 AwaitOutcome::TimedOut {
281 observed: 1,
282 required: 2
283 }
284 ));
285 }
286
287 #[test]
288 fn metrics_count_each_outcome_kind() {
289 let w = CommitWaiter::new();
290 w.await_acks(100, 0, Duration::from_millis(5));
292 w.await_acks(100, 1, Duration::from_millis(5));
294 w.record_replica_ack("a", 100);
296 w.await_acks(100, 1, Duration::from_millis(5));
297
298 let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
299 assert_eq!(reached, 1, "one Reached call");
300 assert_eq!(timed_out, 1, "one TimedOut call");
301 assert_eq!(not_required, 1, "one NotRequired call");
302 assert!(last_micros > 0, "last_wait_micros must be set");
306 }
307}