Skip to main content

reddb_server/replication/
commit_waiter.rs

1//! Synchronous commit waiter (PLAN.md Phase 11.4 — `ack_n`).
2//!
3//! Bridges the primary's commit path with replica ACKs. The commit
4//! caller picks a `target_lsn` (the LSN it just made durable
5//! locally) and asks the waiter "block until at least N replicas
6//! have ack'd this LSN, or the timeout expires." Replica ACK RPCs
7//! call `record_replica_ack` which signals every waiter whose
8//! threshold is now met.
9//!
10//! ## Thread safety
11//!
12//! The waiter uses a `Mutex<State>` + `Condvar` so the `await_acks`
13//! call blocks the caller's thread without spinning. Acks bump a
14//! per-replica `last_durable_lsn` map and broadcast on the condvar.
15//! Waiters wake, recompute the count of replicas at or past their
16//! target, and either return `Ok(count)` or re-wait.
17//!
18//! ## Why this is just the foundation
19//!
20//! The actual write commit path doesn't yet call `await_acks` —
21//! wiring it in touches every public mutation surface and changes
22//! latency characteristics across the board. This module ships the
23//! primitive + the ack registry so the wiring change can land as
24//! one focused PR per surface (HTTP, gRPC, wire protocol) rather
25//! than a single massive diff.
26
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Condvar, Mutex};
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Default)]
33struct State {
34    /// Per-replica durable LSN. Updated by `record_replica_ack`.
35    /// Replicas absent from this map are treated as having durable
36    /// LSN 0 (haven't acked anything yet).
37    durable_lsn: HashMap<String, u64>,
38}
39
40/// Outcome counters for /metrics. PLAN.md Phase 11.4 — operators
41/// alert on `timed_out` rising (commit policy is too tight or
42/// replicas are stalled) and watch `last_wait_micros` for the p95.
43#[derive(Debug, Default)]
44pub struct CommitWaiterMetrics {
45    pub reached_total: AtomicU64,
46    pub timed_out_total: AtomicU64,
47    pub not_required_total: AtomicU64,
48    /// Wall-clock micros of the most recent `Reached` or `TimedOut`
49    /// wait. Gauge, not histogram — keeps the no-extra-deps line.
50    pub last_wait_micros: AtomicU64,
51}
52
53#[derive(Debug)]
54pub struct CommitWaiter {
55    state: Mutex<State>,
56    cond: Condvar,
57    metrics: CommitWaiterMetrics,
58}
59
60impl Default for CommitWaiter {
61    fn default() -> Self {
62        Self {
63            state: Mutex::new(State::default()),
64            cond: Condvar::new(),
65            metrics: CommitWaiterMetrics::default(),
66        }
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum AwaitOutcome {
72    /// At least `required` replicas reached `target_lsn` before the
73    /// deadline. The returned count is the number observed at the
74    /// moment we unblocked (may exceed `required` if many replicas
75    /// were already ahead).
76    Reached(u32),
77    /// The deadline expired with fewer than `required` replicas at
78    /// or past `target_lsn`. The count we observed is included so
79    /// the caller can log how close we got.
80    TimedOut { observed: u32, required: u32 },
81    /// `required == 0` — degenerate case, returns immediately. No
82    /// replica state is consulted.
83    NotRequired,
84}
85
86impl CommitWaiter {
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Replica reports it has durably persisted up to `lsn`.
92    /// Idempotent: only advances forward. Wakes every waiter so they
93    /// can recheck their threshold.
94    pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
95        let mut state = self.state.lock().expect("commit waiter mutex");
96        let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
97        if lsn > *entry {
98            *entry = lsn;
99            self.cond.notify_all();
100        }
101    }
102
103    /// Best-effort cleanup when a replica disconnects. Removes its
104    /// durable LSN from the map so it doesn't artificially inflate
105    /// `ack_n` counts. Wakes waiters because the count of replicas
106    /// at the target may have decreased — they need to re-evaluate
107    /// against the new reality (some will start failing if their
108    /// margin was thin).
109    pub fn drop_replica(&self, replica_id: &str) {
110        let mut state = self.state.lock().expect("commit waiter mutex");
111        if state.durable_lsn.remove(replica_id).is_some() {
112            self.cond.notify_all();
113        }
114    }
115
116    /// Snapshot of the current durable-LSN map. Useful for
117    /// observability and tests; doesn't unblock waiters.
118    pub fn snapshot(&self) -> Vec<(String, u64)> {
119        let state = self.state.lock().expect("commit waiter mutex");
120        let mut v: Vec<(String, u64)> = state
121            .durable_lsn
122            .iter()
123            .map(|(k, v)| (k.clone(), *v))
124            .collect();
125        v.sort_by(|a, b| a.0.cmp(&b.0));
126        v
127    }
128
129    /// Block until at least `required` replicas have durable LSN
130    /// `>= target_lsn`, or `timeout` expires. `required == 0` is a
131    /// no-op (returns `NotRequired` instantly).
132    ///
133    /// Uses `Condvar::wait_timeout` to avoid spinning. On every wake
134    /// (whether from an ack or a spurious wakeup), we recompute the
135    /// count and either return or wait again with the remaining
136    /// budget.
137    pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
138        if required == 0 {
139            self.metrics
140                .not_required_total
141                .fetch_add(1, Ordering::Relaxed);
142            return AwaitOutcome::NotRequired;
143        }
144        let started = Instant::now();
145        let deadline = started + timeout;
146        let mut state = self.state.lock().expect("commit waiter mutex");
147        loop {
148            let observed = count_at_or_past(&state.durable_lsn, target_lsn);
149            if observed >= required {
150                self.record_outcome_metrics(true, started);
151                return AwaitOutcome::Reached(observed);
152            }
153            let now = Instant::now();
154            if now >= deadline {
155                self.record_outcome_metrics(false, started);
156                return AwaitOutcome::TimedOut { observed, required };
157            }
158            let remaining = deadline - now;
159            let (next_state, _wait_result) = self
160                .cond
161                .wait_timeout(state, remaining)
162                .expect("commit waiter condvar");
163            state = next_state;
164        }
165    }
166
167    fn record_outcome_metrics(&self, reached: bool, started: Instant) {
168        let elapsed = (started.elapsed().as_micros() as u64).max(1);
169        self.metrics
170            .last_wait_micros
171            .store(elapsed, Ordering::Relaxed);
172        if reached {
173            self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
174        } else {
175            self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
176        }
177    }
178
179    /// Snapshot of outcome counters for /metrics + tests.
180    pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
181        (
182            self.metrics.reached_total.load(Ordering::Relaxed),
183            self.metrics.timed_out_total.load(Ordering::Relaxed),
184            self.metrics.not_required_total.load(Ordering::Relaxed),
185            self.metrics.last_wait_micros.load(Ordering::Relaxed),
186        )
187    }
188}
189
190fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
191    map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use std::sync::Arc;
198    use std::thread;
199
200    #[test]
201    fn required_zero_is_immediate_no_op() {
202        let w = CommitWaiter::new();
203        let r = w.await_acks(100, 0, Duration::from_secs(60));
204        assert_eq!(r, AwaitOutcome::NotRequired);
205    }
206
207    #[test]
208    fn reaches_threshold_with_existing_acks() {
209        let w = CommitWaiter::new();
210        w.record_replica_ack("a", 200);
211        w.record_replica_ack("b", 200);
212        let r = w.await_acks(150, 2, Duration::from_millis(10));
213        assert_eq!(r, AwaitOutcome::Reached(2));
214    }
215
216    #[test]
217    fn times_out_when_no_one_has_acked() {
218        let w = CommitWaiter::new();
219        w.record_replica_ack("a", 100);
220        let r = w.await_acks(500, 1, Duration::from_millis(20));
221        match r {
222            AwaitOutcome::TimedOut { observed, required } => {
223                assert_eq!(observed, 0);
224                assert_eq!(required, 1);
225            }
226            other => panic!("expected TimedOut, got {other:?}"),
227        }
228    }
229
230    #[test]
231    fn ack_arriving_during_wait_unblocks_caller() {
232        let w = Arc::new(CommitWaiter::new());
233        let waiter = Arc::clone(&w);
234        let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
235        // Give the waiter a moment to enter the condvar wait.
236        thread::sleep(Duration::from_millis(50));
237        w.record_replica_ack("late", 1000);
238        let outcome = handle.join().expect("waiter thread");
239        assert_eq!(outcome, AwaitOutcome::Reached(1));
240    }
241
242    #[test]
243    fn ack_idempotent_does_not_double_count() {
244        let w = CommitWaiter::new();
245        w.record_replica_ack("a", 50);
246        w.record_replica_ack("a", 50);
247        w.record_replica_ack("a", 50);
248        let r = w.await_acks(50, 1, Duration::from_millis(5));
249        assert_eq!(r, AwaitOutcome::Reached(1));
250        // Threshold of 2 still fails — only one replica is registered.
251        let r2 = w.await_acks(50, 2, Duration::from_millis(20));
252        assert!(matches!(
253            r2,
254            AwaitOutcome::TimedOut {
255                observed: 1,
256                required: 2
257            }
258        ));
259    }
260
261    #[test]
262    fn ack_only_advances_lsn_forward() {
263        let w = CommitWaiter::new();
264        w.record_replica_ack("a", 200);
265        // Older ack must not regress the recorded LSN.
266        w.record_replica_ack("a", 100);
267        let snap = w.snapshot();
268        assert_eq!(snap, vec![("a".to_string(), 200)]);
269    }
270
271    #[test]
272    fn drop_replica_removes_from_count() {
273        let w = CommitWaiter::new();
274        w.record_replica_ack("a", 100);
275        w.record_replica_ack("b", 100);
276        w.drop_replica("a");
277        let r = w.await_acks(100, 2, Duration::from_millis(20));
278        assert!(matches!(
279            r,
280            AwaitOutcome::TimedOut {
281                observed: 1,
282                required: 2
283            }
284        ));
285    }
286
287    #[test]
288    fn metrics_count_each_outcome_kind() {
289        let w = CommitWaiter::new();
290        // not_required
291        w.await_acks(100, 0, Duration::from_millis(5));
292        // timed_out
293        w.await_acks(100, 1, Duration::from_millis(5));
294        // reached
295        w.record_replica_ack("a", 100);
296        w.await_acks(100, 1, Duration::from_millis(5));
297
298        let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
299        assert_eq!(reached, 1, "one Reached call");
300        assert_eq!(timed_out, 1, "one TimedOut call");
301        assert_eq!(not_required, 1, "one NotRequired call");
302        // last_wait_micros is set on Reached/TimedOut, NotRequired
303        // skips the gauge so the most recent measurement reflects
304        // an actual wait.
305        assert!(last_micros > 0, "last_wait_micros must be set");
306    }
307}