Skip to main content

reddb_server/replication/
commit_waiter.rs

1//! Synchronous commit waiter (PLAN.md Phase 11.4 — `ack_n`).
2//!
3//! Bridges the primary's commit path with replica ACKs. The commit
4//! caller picks a `target_lsn` (the LSN it just made durable
5//! locally) and asks the waiter "block until at least N replicas
6//! have ack'd this LSN, or the timeout expires." Replica ACK RPCs
7//! call `record_replica_ack` which signals every waiter whose
8//! threshold is now met.
9//!
10//! ## Thread safety
11//!
12//! The waiter uses a `Mutex<State>` + `Condvar` so the `await_acks`
13//! call blocks the caller's thread without spinning. Acks bump a
14//! per-replica `last_durable_lsn` map and broadcast on the condvar.
15//! Waiters wake, recompute the count of replicas at or past their
16//! target, and either return `Ok(count)` or re-wait.
17//!
18//! ## Why this is just the foundation
19//!
20//! The actual write commit path doesn't yet call `await_acks` —
21//! wiring it in touches every public mutation surface and changes
22//! latency characteristics across the board. This module ships the
23//! primitive + the ack registry so the wiring change can land as
24//! one focused PR per surface (HTTP, gRPC, wire protocol) rather
25//! than a single massive diff.
26
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Condvar, Mutex};
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Default)]
33struct State {
34    /// Per-replica durable LSN. Updated by `record_replica_ack`.
35    /// Replicas absent from this map are treated as having durable
36    /// LSN 0 (haven't acked anything yet).
37    durable_lsn: HashMap<String, u64>,
38}
39
40/// Outcome counters for /metrics. PLAN.md Phase 11.4 — operators
41/// alert on `timed_out` rising (commit policy is too tight or
42/// replicas are stalled) and watch `last_wait_micros` for the p95.
43#[derive(Debug, Default)]
44pub struct CommitWaiterMetrics {
45    pub reached_total: AtomicU64,
46    pub timed_out_total: AtomicU64,
47    pub not_required_total: AtomicU64,
48    /// Wall-clock micros of the most recent `Reached` or `TimedOut`
49    /// wait. Gauge, not histogram — keeps the no-extra-deps line.
50    pub last_wait_micros: AtomicU64,
51}
52
53#[derive(Debug)]
54pub struct CommitWaiter {
55    state: Mutex<State>,
56    cond: Condvar,
57    metrics: CommitWaiterMetrics,
58}
59
60impl Default for CommitWaiter {
61    fn default() -> Self {
62        Self {
63            state: Mutex::new(State::default()),
64            cond: Condvar::new(),
65            metrics: CommitWaiterMetrics::default(),
66        }
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum AwaitOutcome {
72    /// At least `required` replicas reached `target_lsn` before the
73    /// deadline. The returned count is the number observed at the
74    /// moment we unblocked (may exceed `required` if many replicas
75    /// were already ahead).
76    Reached(u32),
77    /// The deadline expired with fewer than `required` replicas at
78    /// or past `target_lsn`. The count we observed is included so
79    /// the caller can log how close we got.
80    TimedOut { observed: u32, required: u32 },
81    /// `required == 0` — degenerate case, returns immediately. No
82    /// replica state is consulted.
83    NotRequired,
84}
85
86impl CommitWaiter {
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Replica reports it has durably persisted up to `lsn`.
92    /// Idempotent: only advances forward. Wakes every waiter so they
93    /// can recheck their threshold.
94    pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
95        let mut state = self.state.lock().expect("commit waiter mutex");
96        let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
97        if lsn > *entry {
98            *entry = lsn;
99            self.cond.notify_all();
100        }
101    }
102
103    /// Best-effort cleanup when a replica disconnects. Removes its
104    /// durable LSN from the map so it doesn't artificially inflate
105    /// `ack_n` counts. Wakes waiters because the count of replicas
106    /// at the target may have decreased — they need to re-evaluate
107    /// against the new reality (some will start failing if their
108    /// margin was thin).
109    pub fn drop_replica(&self, replica_id: &str) {
110        let mut state = self.state.lock().expect("commit waiter mutex");
111        if state.durable_lsn.remove(replica_id).is_some() {
112            self.cond.notify_all();
113        }
114    }
115
116    /// Snapshot of the current durable-LSN map. Useful for
117    /// observability and tests; doesn't unblock waiters.
118    pub fn snapshot(&self) -> Vec<(String, u64)> {
119        let state = self.state.lock().expect("commit waiter mutex");
120        let mut v: Vec<(String, u64)> = state
121            .durable_lsn
122            .iter()
123            .map(|(k, v)| (k.clone(), *v))
124            .collect();
125        v.sort_by(|a, b| a.0.cmp(&b.0));
126        v
127    }
128
129    /// Highest LSN durable on `required` replicas.
130    ///
131    /// For `ack_n=2`, this is the second-highest durable LSN in the
132    /// ack table. If fewer than `required` replicas have acked, the
133    /// watermark is 0. `required == 0` is not a quorum requirement, so
134    /// observability reports 0 rather than fabricating an infinite
135    /// watermark.
136    pub fn commit_watermark(&self, required: u32) -> u64 {
137        let state = self.state.lock().expect("commit waiter mutex");
138        commit_watermark(&state.durable_lsn, required)
139    }
140
141    /// Block until at least `required` replicas have durable LSN
142    /// `>= target_lsn`, or `timeout` expires. `required == 0` is a
143    /// no-op (returns `NotRequired` instantly).
144    ///
145    /// Uses `Condvar::wait_timeout` to avoid spinning. On every wake
146    /// (whether from an ack or a spurious wakeup), we recompute the
147    /// count and either return or wait again with the remaining
148    /// budget.
149    pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
150        if required == 0 {
151            self.metrics
152                .not_required_total
153                .fetch_add(1, Ordering::Relaxed);
154            return AwaitOutcome::NotRequired;
155        }
156        let started = Instant::now();
157        let deadline = started + timeout;
158        let mut state = self.state.lock().expect("commit waiter mutex");
159        loop {
160            let watermark = commit_watermark(&state.durable_lsn, required);
161            if watermark >= target_lsn {
162                self.record_outcome_metrics(true, started);
163                let observed = count_at_or_past(&state.durable_lsn, target_lsn);
164                return AwaitOutcome::Reached(observed);
165            }
166            let now = Instant::now();
167            if now >= deadline {
168                self.record_outcome_metrics(false, started);
169                let observed = count_at_or_past(&state.durable_lsn, target_lsn);
170                return AwaitOutcome::TimedOut { observed, required };
171            }
172            let remaining = deadline - now;
173            let (next_state, _wait_result) = self
174                .cond
175                .wait_timeout(state, remaining)
176                .expect("commit waiter condvar");
177            state = next_state;
178        }
179    }
180
181    /// Wait until `is_satisfied` returns true, waking on replica ack
182    /// notifications instead of a caller-side polling sleep.
183    pub fn wait_for_change_until<F>(&self, timeout: Option<Duration>, mut is_satisfied: F) -> bool
184    where
185        F: FnMut() -> bool,
186    {
187        let started = Instant::now();
188        let mut state = self.state.lock().expect("commit waiter mutex");
189        loop {
190            if is_satisfied() {
191                return true;
192            }
193            let Some(limit) = timeout else {
194                state = self.cond.wait(state).expect("commit waiter condvar");
195                continue;
196            };
197            let elapsed = started.elapsed();
198            if elapsed >= limit {
199                return false;
200            }
201            let remaining = limit - elapsed;
202            let (next_state, _wait_result) = self
203                .cond
204                .wait_timeout(state, remaining)
205                .expect("commit waiter condvar");
206            state = next_state;
207        }
208    }
209
210    /// Wait until the named commit watermark reaches `target_lsn`.
211    pub fn wait_for_commit_watermark(
212        &self,
213        target_lsn: u64,
214        required: u32,
215        timeout: Option<Duration>,
216    ) -> bool {
217        if required == 0 {
218            return true;
219        }
220        let started = Instant::now();
221        let mut state = self.state.lock().expect("commit waiter mutex");
222        loop {
223            if commit_watermark(&state.durable_lsn, required) >= target_lsn {
224                return true;
225            }
226            let Some(limit) = timeout else {
227                state = self.cond.wait(state).expect("commit waiter condvar");
228                continue;
229            };
230            let elapsed = started.elapsed();
231            if elapsed >= limit {
232                return false;
233            }
234            let remaining = limit - elapsed;
235            let (next_state, _wait_result) = self
236                .cond
237                .wait_timeout(state, remaining)
238                .expect("commit waiter condvar");
239            state = next_state;
240        }
241    }
242
243    fn record_outcome_metrics(&self, reached: bool, started: Instant) {
244        let elapsed = (started.elapsed().as_micros() as u64).max(1);
245        self.metrics
246            .last_wait_micros
247            .store(elapsed, Ordering::Relaxed);
248        if reached {
249            self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
250        } else {
251            self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
252        }
253    }
254
255    /// Snapshot of outcome counters for /metrics + tests.
256    pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
257        (
258            self.metrics.reached_total.load(Ordering::Relaxed),
259            self.metrics.timed_out_total.load(Ordering::Relaxed),
260            self.metrics.not_required_total.load(Ordering::Relaxed),
261            self.metrics.last_wait_micros.load(Ordering::Relaxed),
262        )
263    }
264}
265
266fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
267    map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
268}
269
270fn commit_watermark(map: &HashMap<String, u64>, required: u32) -> u64 {
271    if required == 0 || map.len() < required as usize {
272        return 0;
273    }
274    let mut durable: Vec<u64> = map.values().copied().collect();
275    durable.sort_unstable_by(|a, b| b.cmp(a));
276    durable[(required as usize) - 1]
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::sync::Arc;
283    use std::thread;
284
285    #[test]
286    fn required_zero_is_immediate_no_op() {
287        let w = CommitWaiter::new();
288        let r = w.await_acks(100, 0, Duration::from_secs(60));
289        assert_eq!(r, AwaitOutcome::NotRequired);
290    }
291
292    #[test]
293    fn reaches_threshold_with_existing_acks() {
294        let w = CommitWaiter::new();
295        w.record_replica_ack("a", 200);
296        w.record_replica_ack("b", 200);
297        let r = w.await_acks(150, 2, Duration::from_millis(10));
298        assert_eq!(r, AwaitOutcome::Reached(2));
299    }
300
301    #[test]
302    fn commit_watermark_is_nth_highest_durable_lsn() {
303        let w = CommitWaiter::new();
304        w.record_replica_ack("a", 10);
305        w.record_replica_ack("b", 30);
306        w.record_replica_ack("c", 20);
307
308        assert_eq!(w.commit_watermark(1), 30);
309        assert_eq!(w.commit_watermark(2), 20);
310        assert_eq!(w.commit_watermark(3), 10);
311        assert_eq!(w.commit_watermark(4), 0);
312
313        w.record_replica_ack("b", 15);
314        assert_eq!(w.commit_watermark(2), 20);
315    }
316
317    #[test]
318    fn times_out_when_no_one_has_acked() {
319        let w = CommitWaiter::new();
320        w.record_replica_ack("a", 100);
321        let r = w.await_acks(500, 1, Duration::from_millis(20));
322        match r {
323            AwaitOutcome::TimedOut { observed, required } => {
324                assert_eq!(observed, 0);
325                assert_eq!(required, 1);
326            }
327            other => panic!("expected TimedOut, got {other:?}"),
328        }
329    }
330
331    #[test]
332    fn ack_arriving_during_wait_unblocks_caller() {
333        let w = Arc::new(CommitWaiter::new());
334        let waiter = Arc::clone(&w);
335        let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
336        // Give the waiter a moment to enter the condvar wait.
337        thread::sleep(Duration::from_millis(50));
338        w.record_replica_ack("late", 1000);
339        let outcome = handle.join().expect("waiter thread");
340        assert_eq!(outcome, AwaitOutcome::Reached(1));
341    }
342
343    #[test]
344    fn ack_idempotent_does_not_double_count() {
345        let w = CommitWaiter::new();
346        w.record_replica_ack("a", 50);
347        w.record_replica_ack("a", 50);
348        w.record_replica_ack("a", 50);
349        let r = w.await_acks(50, 1, Duration::from_millis(5));
350        assert_eq!(r, AwaitOutcome::Reached(1));
351        // Threshold of 2 still fails — only one replica is registered.
352        let r2 = w.await_acks(50, 2, Duration::from_millis(20));
353        assert!(matches!(
354            r2,
355            AwaitOutcome::TimedOut {
356                observed: 1,
357                required: 2
358            }
359        ));
360    }
361
362    #[test]
363    fn ack_only_advances_lsn_forward() {
364        let w = CommitWaiter::new();
365        w.record_replica_ack("a", 200);
366        // Older ack must not regress the recorded LSN.
367        w.record_replica_ack("a", 100);
368        let snap = w.snapshot();
369        assert_eq!(snap, vec![("a".to_string(), 200)]);
370    }
371
372    #[test]
373    fn drop_replica_removes_from_count() {
374        let w = CommitWaiter::new();
375        w.record_replica_ack("a", 100);
376        w.record_replica_ack("b", 100);
377        w.drop_replica("a");
378        let r = w.await_acks(100, 2, Duration::from_millis(20));
379        assert!(matches!(
380            r,
381            AwaitOutcome::TimedOut {
382                observed: 1,
383                required: 2
384            }
385        ));
386    }
387
388    #[test]
389    fn metrics_count_each_outcome_kind() {
390        let w = CommitWaiter::new();
391        // not_required
392        w.await_acks(100, 0, Duration::from_millis(5));
393        // timed_out
394        w.await_acks(100, 1, Duration::from_millis(5));
395        // reached
396        w.record_replica_ack("a", 100);
397        w.await_acks(100, 1, Duration::from_millis(5));
398
399        let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
400        assert_eq!(reached, 1, "one Reached call");
401        assert_eq!(timed_out, 1, "one TimedOut call");
402        assert_eq!(not_required, 1, "one NotRequired call");
403        // last_wait_micros is set on Reached/TimedOut, NotRequired
404        // skips the gauge so the most recent measurement reflects
405        // an actual wait.
406        assert!(last_micros > 0, "last_wait_micros must be set");
407    }
408}