reddb_server/replication/
commit_waiter.rs1use std::collections::HashMap;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::sync::{Condvar, Mutex};
28use std::time::{Duration, Instant};
29
30#[derive(Debug, Default)]
31struct State {
32 durable_lsn: HashMap<String, u64>,
36}
37
38#[derive(Debug, Default)]
42pub struct CommitWaiterMetrics {
43 pub reached_total: AtomicU64,
44 pub timed_out_total: AtomicU64,
45 pub not_required_total: AtomicU64,
46 pub last_wait_micros: AtomicU64,
49}
50
51#[derive(Debug)]
52pub struct CommitWaiter {
53 state: Mutex<State>,
54 cond: Condvar,
55 metrics: CommitWaiterMetrics,
56}
57
58impl Default for CommitWaiter {
59 fn default() -> Self {
60 Self {
61 state: Mutex::new(State::default()),
62 cond: Condvar::new(),
63 metrics: CommitWaiterMetrics::default(),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum AwaitOutcome {
70 Reached(u32),
75 TimedOut { observed: u32, required: u32 },
79 NotRequired,
82}
83
84impl CommitWaiter {
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
93 let mut state = self.state.lock().expect("commit waiter mutex");
94 let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
95 if lsn > *entry {
96 *entry = lsn;
97 self.cond.notify_all();
98 }
99 }
100
101 pub fn drop_replica(&self, replica_id: &str) {
108 let mut state = self.state.lock().expect("commit waiter mutex");
109 if state.durable_lsn.remove(replica_id).is_some() {
110 self.cond.notify_all();
111 }
112 }
113
114 pub fn snapshot(&self) -> Vec<(String, u64)> {
117 let state = self.state.lock().expect("commit waiter mutex");
118 let mut v: Vec<(String, u64)> = state
119 .durable_lsn
120 .iter()
121 .map(|(k, v)| (k.clone(), *v))
122 .collect();
123 v.sort_by(|a, b| a.0.cmp(&b.0));
124 v
125 }
126
127 pub fn commit_watermark(&self, required: u32) -> u64 {
135 let state = self.state.lock().expect("commit waiter mutex");
136 commit_watermark(&state.durable_lsn, required)
137 }
138
139 pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
148 if required == 0 {
149 self.metrics
150 .not_required_total
151 .fetch_add(1, Ordering::Relaxed);
152 return AwaitOutcome::NotRequired;
153 }
154 let started = Instant::now();
155 let deadline = started + timeout;
156 let mut state = self.state.lock().expect("commit waiter mutex");
157 loop {
158 let watermark = commit_watermark(&state.durable_lsn, required);
159 if watermark >= target_lsn {
160 self.record_outcome_metrics(true, started);
161 let observed = count_at_or_past(&state.durable_lsn, target_lsn);
162 return AwaitOutcome::Reached(observed);
163 }
164 let now = Instant::now();
165 if now >= deadline {
166 self.record_outcome_metrics(false, started);
167 let observed = count_at_or_past(&state.durable_lsn, target_lsn);
168 return AwaitOutcome::TimedOut { observed, required };
169 }
170 let remaining = deadline - now;
171 let (next_state, _wait_result) = self
172 .cond
173 .wait_timeout(state, remaining)
174 .expect("commit waiter condvar");
175 state = next_state;
176 }
177 }
178
179 pub fn wait_for_change_until<F>(&self, timeout: Option<Duration>, mut is_satisfied: F) -> bool
182 where
183 F: FnMut() -> bool,
184 {
185 let started = Instant::now();
186 let mut state = self.state.lock().expect("commit waiter mutex");
187 loop {
188 if is_satisfied() {
189 return true;
190 }
191 let Some(limit) = timeout else {
192 state = self.cond.wait(state).expect("commit waiter condvar");
193 continue;
194 };
195 let elapsed = started.elapsed();
196 if elapsed >= limit {
197 return false;
198 }
199 let remaining = limit - elapsed;
200 let (next_state, _wait_result) = self
201 .cond
202 .wait_timeout(state, remaining)
203 .expect("commit waiter condvar");
204 state = next_state;
205 }
206 }
207
208 pub fn wait_for_commit_watermark(
210 &self,
211 target_lsn: u64,
212 required: u32,
213 timeout: Option<Duration>,
214 ) -> bool {
215 if required == 0 {
216 return true;
217 }
218 let started = Instant::now();
219 let mut state = self.state.lock().expect("commit waiter mutex");
220 loop {
221 if commit_watermark(&state.durable_lsn, required) >= target_lsn {
222 return true;
223 }
224 let Some(limit) = timeout else {
225 state = self.cond.wait(state).expect("commit waiter condvar");
226 continue;
227 };
228 let elapsed = started.elapsed();
229 if elapsed >= limit {
230 return false;
231 }
232 let remaining = limit - elapsed;
233 let (next_state, _wait_result) = self
234 .cond
235 .wait_timeout(state, remaining)
236 .expect("commit waiter condvar");
237 state = next_state;
238 }
239 }
240
241 fn record_outcome_metrics(&self, reached: bool, started: Instant) {
242 let elapsed = (started.elapsed().as_micros() as u64).max(1);
243 self.metrics
244 .last_wait_micros
245 .store(elapsed, Ordering::Relaxed);
246 if reached {
247 self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
248 } else {
249 self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
250 }
251 }
252
253 pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
255 (
256 self.metrics.reached_total.load(Ordering::Relaxed),
257 self.metrics.timed_out_total.load(Ordering::Relaxed),
258 self.metrics.not_required_total.load(Ordering::Relaxed),
259 self.metrics.last_wait_micros.load(Ordering::Relaxed),
260 )
261 }
262}
263
264fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
265 map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
266}
267
268fn commit_watermark(map: &HashMap<String, u64>, required: u32) -> u64 {
269 if required == 0 || map.len() < required as usize {
270 return 0;
271 }
272 let mut durable: Vec<u64> = map.values().copied().collect();
273 durable.sort_unstable_by(|a, b| b.cmp(a));
274 durable[(required as usize) - 1]
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use std::sync::Arc;
281 use std::thread;
282
283 #[test]
284 fn required_zero_is_immediate_no_op() {
285 let w = CommitWaiter::new();
286 let r = w.await_acks(100, 0, Duration::from_secs(60));
287 assert_eq!(r, AwaitOutcome::NotRequired);
288 }
289
290 #[test]
291 fn reaches_threshold_with_existing_acks() {
292 let w = CommitWaiter::new();
293 w.record_replica_ack("a", 200);
294 w.record_replica_ack("b", 200);
295 let r = w.await_acks(150, 2, Duration::from_millis(10));
296 assert_eq!(r, AwaitOutcome::Reached(2));
297 }
298
299 #[test]
300 fn commit_watermark_is_nth_highest_durable_lsn() {
301 let w = CommitWaiter::new();
302 w.record_replica_ack("a", 10);
303 w.record_replica_ack("b", 30);
304 w.record_replica_ack("c", 20);
305
306 assert_eq!(w.commit_watermark(1), 30);
307 assert_eq!(w.commit_watermark(2), 20);
308 assert_eq!(w.commit_watermark(3), 10);
309 assert_eq!(w.commit_watermark(4), 0);
310
311 w.record_replica_ack("b", 15);
312 assert_eq!(w.commit_watermark(2), 20);
313 }
314
315 #[test]
316 fn times_out_when_no_one_has_acked() {
317 let w = CommitWaiter::new();
318 w.record_replica_ack("a", 100);
319 let r = w.await_acks(500, 1, Duration::from_millis(20));
320 match r {
321 AwaitOutcome::TimedOut { observed, required } => {
322 assert_eq!(observed, 0);
323 assert_eq!(required, 1);
324 }
325 other => panic!("expected TimedOut, got {other:?}"),
326 }
327 }
328
329 #[test]
330 fn ack_arriving_during_wait_unblocks_caller() {
331 let w = Arc::new(CommitWaiter::new());
332 let waiter = Arc::clone(&w);
333 let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
334 thread::sleep(Duration::from_millis(50));
336 w.record_replica_ack("late", 1000);
337 let outcome = handle.join().expect("waiter thread");
338 assert_eq!(outcome, AwaitOutcome::Reached(1));
339 }
340
341 #[test]
342 fn ack_idempotent_does_not_double_count() {
343 let w = CommitWaiter::new();
344 w.record_replica_ack("a", 50);
345 w.record_replica_ack("a", 50);
346 w.record_replica_ack("a", 50);
347 let r = w.await_acks(50, 1, Duration::from_millis(5));
348 assert_eq!(r, AwaitOutcome::Reached(1));
349 let r2 = w.await_acks(50, 2, Duration::from_millis(20));
351 assert!(matches!(
352 r2,
353 AwaitOutcome::TimedOut {
354 observed: 1,
355 required: 2
356 }
357 ));
358 }
359
360 #[test]
361 fn ack_only_advances_lsn_forward() {
362 let w = CommitWaiter::new();
363 w.record_replica_ack("a", 200);
364 w.record_replica_ack("a", 100);
366 let snap = w.snapshot();
367 assert_eq!(snap, vec![("a".to_string(), 200)]);
368 }
369
370 #[test]
371 fn drop_replica_removes_from_count() {
372 let w = CommitWaiter::new();
373 w.record_replica_ack("a", 100);
374 w.record_replica_ack("b", 100);
375 w.drop_replica("a");
376 let r = w.await_acks(100, 2, Duration::from_millis(20));
377 assert!(matches!(
378 r,
379 AwaitOutcome::TimedOut {
380 observed: 1,
381 required: 2
382 }
383 ));
384 }
385
386 #[test]
387 fn metrics_count_each_outcome_kind() {
388 let w = CommitWaiter::new();
389 w.await_acks(100, 0, Duration::from_millis(5));
391 w.await_acks(100, 1, Duration::from_millis(5));
393 w.record_replica_ack("a", 100);
395 w.await_acks(100, 1, Duration::from_millis(5));
396
397 let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
398 assert_eq!(reached, 1, "one Reached call");
399 assert_eq!(timed_out, 1, "one TimedOut call");
400 assert_eq!(not_required, 1, "one NotRequired call");
401 assert!(last_micros > 0, "last_wait_micros must be set");
405 }
406}