Skip to main content

reddb_server/replication/
commit_waiter.rs

1//! Synchronous commit waiter (PLAN.md Phase 11.4 — `ack_n`).
2//!
3//! Bridges the primary's commit path with replica ACKs. The commit
4//! caller picks a `target_lsn` (the LSN it just made durable
5//! locally) and asks the waiter "block until at least N replicas
6//! have ack'd this LSN, or the timeout expires." Replica ACK RPCs
7//! call `record_replica_ack` which signals every waiter whose
8//! threshold is now met.
9//!
10//! ## Thread safety
11//!
12//! The waiter uses a `Mutex<State>` + `Condvar` so the `await_acks`
13//! call blocks the caller's thread without spinning. Acks bump a
14//! per-replica `last_durable_lsn` map and broadcast on the condvar.
15//! Waiters wake, recompute the count of replicas at or past their
16//! target, and either return `Ok(count)` or re-wait.
17//!
18//! ## Runtime integration
19//!
20//! Public mutation surfaces call `RedDBRuntime::enforce_commit_policy`
21//! after successful writes. That runtime method maps `ack_n` to this
22//! waiter and decides whether a timeout is soft telemetry or a hard
23//! client-visible failure based on `RED_COMMIT_FAIL_ON_TIMEOUT`.
24
25use std::collections::HashMap;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::sync::{Condvar, Mutex};
28use std::time::{Duration, Instant};
29
30#[derive(Debug, Default)]
31struct State {
32    /// Per-replica durable LSN. Updated by `record_replica_ack`.
33    /// Replicas absent from this map are treated as having durable
34    /// LSN 0 (haven't acked anything yet).
35    durable_lsn: HashMap<String, u64>,
36}
37
38/// Outcome counters for /metrics. PLAN.md Phase 11.4 — operators
39/// alert on `timed_out` rising (commit policy is too tight or
40/// replicas are stalled) and watch `last_wait_micros` for the p95.
41#[derive(Debug, Default)]
42pub struct CommitWaiterMetrics {
43    pub reached_total: AtomicU64,
44    pub timed_out_total: AtomicU64,
45    pub not_required_total: AtomicU64,
46    /// Wall-clock micros of the most recent `Reached` or `TimedOut`
47    /// wait. Gauge, not histogram — keeps the no-extra-deps line.
48    pub last_wait_micros: AtomicU64,
49}
50
51#[derive(Debug)]
52pub struct CommitWaiter {
53    state: Mutex<State>,
54    cond: Condvar,
55    metrics: CommitWaiterMetrics,
56}
57
58impl Default for CommitWaiter {
59    fn default() -> Self {
60        Self {
61            state: Mutex::new(State::default()),
62            cond: Condvar::new(),
63            metrics: CommitWaiterMetrics::default(),
64        }
65    }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum AwaitOutcome {
70    /// At least `required` replicas reached `target_lsn` before the
71    /// deadline. The returned count is the number observed at the
72    /// moment we unblocked (may exceed `required` if many replicas
73    /// were already ahead).
74    Reached(u32),
75    /// The deadline expired with fewer than `required` replicas at
76    /// or past `target_lsn`. The count we observed is included so
77    /// the caller can log how close we got.
78    TimedOut { observed: u32, required: u32 },
79    /// `required == 0` — degenerate case, returns immediately. No
80    /// replica state is consulted.
81    NotRequired,
82}
83
84impl CommitWaiter {
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Replica reports it has durably persisted up to `lsn`.
90    /// Idempotent: only advances forward. Wakes every waiter so they
91    /// can recheck their threshold.
92    pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
93        let mut state = self.state.lock().expect("commit waiter mutex");
94        let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
95        if lsn > *entry {
96            *entry = lsn;
97            self.cond.notify_all();
98        }
99    }
100
101    /// Best-effort cleanup when a replica disconnects. Removes its
102    /// durable LSN from the map so it doesn't artificially inflate
103    /// `ack_n` counts. Wakes waiters because the count of replicas
104    /// at the target may have decreased — they need to re-evaluate
105    /// against the new reality (some will start failing if their
106    /// margin was thin).
107    pub fn drop_replica(&self, replica_id: &str) {
108        let mut state = self.state.lock().expect("commit waiter mutex");
109        if state.durable_lsn.remove(replica_id).is_some() {
110            self.cond.notify_all();
111        }
112    }
113
114    /// Snapshot of the current durable-LSN map. Useful for
115    /// observability and tests; doesn't unblock waiters.
116    pub fn snapshot(&self) -> Vec<(String, u64)> {
117        let state = self.state.lock().expect("commit waiter mutex");
118        let mut v: Vec<(String, u64)> = state
119            .durable_lsn
120            .iter()
121            .map(|(k, v)| (k.clone(), *v))
122            .collect();
123        v.sort_by(|a, b| a.0.cmp(&b.0));
124        v
125    }
126
127    /// Highest LSN durable on `required` replicas.
128    ///
129    /// For `ack_n=2`, this is the second-highest durable LSN in the
130    /// ack table. If fewer than `required` replicas have acked, the
131    /// watermark is 0. `required == 0` is not a quorum requirement, so
132    /// observability reports 0 rather than fabricating an infinite
133    /// watermark.
134    pub fn commit_watermark(&self, required: u32) -> u64 {
135        let state = self.state.lock().expect("commit waiter mutex");
136        commit_watermark(&state.durable_lsn, required)
137    }
138
139    /// Block until at least `required` replicas have durable LSN
140    /// `>= target_lsn`, or `timeout` expires. `required == 0` is a
141    /// no-op (returns `NotRequired` instantly).
142    ///
143    /// Uses `Condvar::wait_timeout` to avoid spinning. On every wake
144    /// (whether from an ack or a spurious wakeup), we recompute the
145    /// count and either return or wait again with the remaining
146    /// budget.
147    pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
148        if required == 0 {
149            self.metrics
150                .not_required_total
151                .fetch_add(1, Ordering::Relaxed);
152            return AwaitOutcome::NotRequired;
153        }
154        let started = Instant::now();
155        let deadline = started + timeout;
156        let mut state = self.state.lock().expect("commit waiter mutex");
157        loop {
158            let watermark = commit_watermark(&state.durable_lsn, required);
159            if watermark >= target_lsn {
160                self.record_outcome_metrics(true, started);
161                let observed = count_at_or_past(&state.durable_lsn, target_lsn);
162                return AwaitOutcome::Reached(observed);
163            }
164            let now = Instant::now();
165            if now >= deadline {
166                self.record_outcome_metrics(false, started);
167                let observed = count_at_or_past(&state.durable_lsn, target_lsn);
168                return AwaitOutcome::TimedOut { observed, required };
169            }
170            let remaining = deadline - now;
171            let (next_state, _wait_result) = self
172                .cond
173                .wait_timeout(state, remaining)
174                .expect("commit waiter condvar");
175            state = next_state;
176        }
177    }
178
179    /// Wait until `is_satisfied` returns true, waking on replica ack
180    /// notifications instead of a caller-side polling sleep.
181    pub fn wait_for_change_until<F>(&self, timeout: Option<Duration>, mut is_satisfied: F) -> bool
182    where
183        F: FnMut() -> bool,
184    {
185        let started = Instant::now();
186        let mut state = self.state.lock().expect("commit waiter mutex");
187        loop {
188            if is_satisfied() {
189                return true;
190            }
191            let Some(limit) = timeout else {
192                state = self.cond.wait(state).expect("commit waiter condvar");
193                continue;
194            };
195            let elapsed = started.elapsed();
196            if elapsed >= limit {
197                return false;
198            }
199            let remaining = limit - elapsed;
200            let (next_state, _wait_result) = self
201                .cond
202                .wait_timeout(state, remaining)
203                .expect("commit waiter condvar");
204            state = next_state;
205        }
206    }
207
208    /// Wait until the named commit watermark reaches `target_lsn`.
209    pub fn wait_for_commit_watermark(
210        &self,
211        target_lsn: u64,
212        required: u32,
213        timeout: Option<Duration>,
214    ) -> bool {
215        if required == 0 {
216            return true;
217        }
218        let started = Instant::now();
219        let mut state = self.state.lock().expect("commit waiter mutex");
220        loop {
221            if commit_watermark(&state.durable_lsn, required) >= target_lsn {
222                return true;
223            }
224            let Some(limit) = timeout else {
225                state = self.cond.wait(state).expect("commit waiter condvar");
226                continue;
227            };
228            let elapsed = started.elapsed();
229            if elapsed >= limit {
230                return false;
231            }
232            let remaining = limit - elapsed;
233            let (next_state, _wait_result) = self
234                .cond
235                .wait_timeout(state, remaining)
236                .expect("commit waiter condvar");
237            state = next_state;
238        }
239    }
240
241    fn record_outcome_metrics(&self, reached: bool, started: Instant) {
242        let elapsed = (started.elapsed().as_micros() as u64).max(1);
243        self.metrics
244            .last_wait_micros
245            .store(elapsed, Ordering::Relaxed);
246        if reached {
247            self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
248        } else {
249            self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
250        }
251    }
252
253    /// Snapshot of outcome counters for /metrics + tests.
254    pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
255        (
256            self.metrics.reached_total.load(Ordering::Relaxed),
257            self.metrics.timed_out_total.load(Ordering::Relaxed),
258            self.metrics.not_required_total.load(Ordering::Relaxed),
259            self.metrics.last_wait_micros.load(Ordering::Relaxed),
260        )
261    }
262}
263
264fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
265    map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
266}
267
268fn commit_watermark(map: &HashMap<String, u64>, required: u32) -> u64 {
269    if required == 0 || map.len() < required as usize {
270        return 0;
271    }
272    let mut durable: Vec<u64> = map.values().copied().collect();
273    durable.sort_unstable_by(|a, b| b.cmp(a));
274    durable[(required as usize) - 1]
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use std::sync::Arc;
281    use std::thread;
282
283    #[test]
284    fn required_zero_is_immediate_no_op() {
285        let w = CommitWaiter::new();
286        let r = w.await_acks(100, 0, Duration::from_secs(60));
287        assert_eq!(r, AwaitOutcome::NotRequired);
288    }
289
290    #[test]
291    fn reaches_threshold_with_existing_acks() {
292        let w = CommitWaiter::new();
293        w.record_replica_ack("a", 200);
294        w.record_replica_ack("b", 200);
295        let r = w.await_acks(150, 2, Duration::from_millis(10));
296        assert_eq!(r, AwaitOutcome::Reached(2));
297    }
298
299    #[test]
300    fn commit_watermark_is_nth_highest_durable_lsn() {
301        let w = CommitWaiter::new();
302        w.record_replica_ack("a", 10);
303        w.record_replica_ack("b", 30);
304        w.record_replica_ack("c", 20);
305
306        assert_eq!(w.commit_watermark(1), 30);
307        assert_eq!(w.commit_watermark(2), 20);
308        assert_eq!(w.commit_watermark(3), 10);
309        assert_eq!(w.commit_watermark(4), 0);
310
311        w.record_replica_ack("b", 15);
312        assert_eq!(w.commit_watermark(2), 20);
313    }
314
315    #[test]
316    fn times_out_when_no_one_has_acked() {
317        let w = CommitWaiter::new();
318        w.record_replica_ack("a", 100);
319        let r = w.await_acks(500, 1, Duration::from_millis(20));
320        match r {
321            AwaitOutcome::TimedOut { observed, required } => {
322                assert_eq!(observed, 0);
323                assert_eq!(required, 1);
324            }
325            other => panic!("expected TimedOut, got {other:?}"),
326        }
327    }
328
329    #[test]
330    fn ack_arriving_during_wait_unblocks_caller() {
331        let w = Arc::new(CommitWaiter::new());
332        let waiter = Arc::clone(&w);
333        let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
334        // Give the waiter a moment to enter the condvar wait.
335        thread::sleep(Duration::from_millis(50));
336        w.record_replica_ack("late", 1000);
337        let outcome = handle.join().expect("waiter thread");
338        assert_eq!(outcome, AwaitOutcome::Reached(1));
339    }
340
341    #[test]
342    fn ack_idempotent_does_not_double_count() {
343        let w = CommitWaiter::new();
344        w.record_replica_ack("a", 50);
345        w.record_replica_ack("a", 50);
346        w.record_replica_ack("a", 50);
347        let r = w.await_acks(50, 1, Duration::from_millis(5));
348        assert_eq!(r, AwaitOutcome::Reached(1));
349        // Threshold of 2 still fails — only one replica is registered.
350        let r2 = w.await_acks(50, 2, Duration::from_millis(20));
351        assert!(matches!(
352            r2,
353            AwaitOutcome::TimedOut {
354                observed: 1,
355                required: 2
356            }
357        ));
358    }
359
360    #[test]
361    fn ack_only_advances_lsn_forward() {
362        let w = CommitWaiter::new();
363        w.record_replica_ack("a", 200);
364        // Older ack must not regress the recorded LSN.
365        w.record_replica_ack("a", 100);
366        let snap = w.snapshot();
367        assert_eq!(snap, vec![("a".to_string(), 200)]);
368    }
369
370    #[test]
371    fn drop_replica_removes_from_count() {
372        let w = CommitWaiter::new();
373        w.record_replica_ack("a", 100);
374        w.record_replica_ack("b", 100);
375        w.drop_replica("a");
376        let r = w.await_acks(100, 2, Duration::from_millis(20));
377        assert!(matches!(
378            r,
379            AwaitOutcome::TimedOut {
380                observed: 1,
381                required: 2
382            }
383        ));
384    }
385
386    #[test]
387    fn metrics_count_each_outcome_kind() {
388        let w = CommitWaiter::new();
389        // not_required
390        w.await_acks(100, 0, Duration::from_millis(5));
391        // timed_out
392        w.await_acks(100, 1, Duration::from_millis(5));
393        // reached
394        w.record_replica_ack("a", 100);
395        w.await_acks(100, 1, Duration::from_millis(5));
396
397        let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
398        assert_eq!(reached, 1, "one Reached call");
399        assert_eq!(timed_out, 1, "one TimedOut call");
400        assert_eq!(not_required, 1, "one NotRequired call");
401        // last_wait_micros is set on Reached/TimedOut, NotRequired
402        // skips the gauge so the most recent measurement reflects
403        // an actual wait.
404        assert!(last_micros > 0, "last_wait_micros must be set");
405    }
406}