reddb_server/replication/
commit_waiter.rs1use std::collections::HashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Condvar, Mutex};
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Default)]
33struct State {
34 durable_lsn: HashMap<String, u64>,
38}
39
40#[derive(Debug, Default)]
44pub struct CommitWaiterMetrics {
45 pub reached_total: AtomicU64,
46 pub timed_out_total: AtomicU64,
47 pub not_required_total: AtomicU64,
48 pub last_wait_micros: AtomicU64,
51}
52
53#[derive(Debug)]
54pub struct CommitWaiter {
55 state: Mutex<State>,
56 cond: Condvar,
57 metrics: CommitWaiterMetrics,
58}
59
60impl Default for CommitWaiter {
61 fn default() -> Self {
62 Self {
63 state: Mutex::new(State::default()),
64 cond: Condvar::new(),
65 metrics: CommitWaiterMetrics::default(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum AwaitOutcome {
72 Reached(u32),
77 TimedOut { observed: u32, required: u32 },
81 NotRequired,
84}
85
86impl CommitWaiter {
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
95 let mut state = self.state.lock().expect("commit waiter mutex");
96 let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
97 if lsn > *entry {
98 *entry = lsn;
99 self.cond.notify_all();
100 }
101 }
102
103 pub fn drop_replica(&self, replica_id: &str) {
110 let mut state = self.state.lock().expect("commit waiter mutex");
111 if state.durable_lsn.remove(replica_id).is_some() {
112 self.cond.notify_all();
113 }
114 }
115
116 pub fn snapshot(&self) -> Vec<(String, u64)> {
119 let state = self.state.lock().expect("commit waiter mutex");
120 let mut v: Vec<(String, u64)> = state
121 .durable_lsn
122 .iter()
123 .map(|(k, v)| (k.clone(), *v))
124 .collect();
125 v.sort_by(|a, b| a.0.cmp(&b.0));
126 v
127 }
128
129 pub fn commit_watermark(&self, required: u32) -> u64 {
137 let state = self.state.lock().expect("commit waiter mutex");
138 commit_watermark(&state.durable_lsn, required)
139 }
140
141 pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
150 if required == 0 {
151 self.metrics
152 .not_required_total
153 .fetch_add(1, Ordering::Relaxed);
154 return AwaitOutcome::NotRequired;
155 }
156 let started = Instant::now();
157 let deadline = started + timeout;
158 let mut state = self.state.lock().expect("commit waiter mutex");
159 loop {
160 let watermark = commit_watermark(&state.durable_lsn, required);
161 if watermark >= target_lsn {
162 self.record_outcome_metrics(true, started);
163 let observed = count_at_or_past(&state.durable_lsn, target_lsn);
164 return AwaitOutcome::Reached(observed);
165 }
166 let now = Instant::now();
167 if now >= deadline {
168 self.record_outcome_metrics(false, started);
169 let observed = count_at_or_past(&state.durable_lsn, target_lsn);
170 return AwaitOutcome::TimedOut { observed, required };
171 }
172 let remaining = deadline - now;
173 let (next_state, _wait_result) = self
174 .cond
175 .wait_timeout(state, remaining)
176 .expect("commit waiter condvar");
177 state = next_state;
178 }
179 }
180
181 pub fn wait_for_change_until<F>(&self, timeout: Option<Duration>, mut is_satisfied: F) -> bool
184 where
185 F: FnMut() -> bool,
186 {
187 let started = Instant::now();
188 let mut state = self.state.lock().expect("commit waiter mutex");
189 loop {
190 if is_satisfied() {
191 return true;
192 }
193 let Some(limit) = timeout else {
194 state = self.cond.wait(state).expect("commit waiter condvar");
195 continue;
196 };
197 let elapsed = started.elapsed();
198 if elapsed >= limit {
199 return false;
200 }
201 let remaining = limit - elapsed;
202 let (next_state, _wait_result) = self
203 .cond
204 .wait_timeout(state, remaining)
205 .expect("commit waiter condvar");
206 state = next_state;
207 }
208 }
209
210 pub fn wait_for_commit_watermark(
212 &self,
213 target_lsn: u64,
214 required: u32,
215 timeout: Option<Duration>,
216 ) -> bool {
217 if required == 0 {
218 return true;
219 }
220 let started = Instant::now();
221 let mut state = self.state.lock().expect("commit waiter mutex");
222 loop {
223 if commit_watermark(&state.durable_lsn, required) >= target_lsn {
224 return true;
225 }
226 let Some(limit) = timeout else {
227 state = self.cond.wait(state).expect("commit waiter condvar");
228 continue;
229 };
230 let elapsed = started.elapsed();
231 if elapsed >= limit {
232 return false;
233 }
234 let remaining = limit - elapsed;
235 let (next_state, _wait_result) = self
236 .cond
237 .wait_timeout(state, remaining)
238 .expect("commit waiter condvar");
239 state = next_state;
240 }
241 }
242
243 fn record_outcome_metrics(&self, reached: bool, started: Instant) {
244 let elapsed = (started.elapsed().as_micros() as u64).max(1);
245 self.metrics
246 .last_wait_micros
247 .store(elapsed, Ordering::Relaxed);
248 if reached {
249 self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
250 } else {
251 self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
252 }
253 }
254
255 pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
257 (
258 self.metrics.reached_total.load(Ordering::Relaxed),
259 self.metrics.timed_out_total.load(Ordering::Relaxed),
260 self.metrics.not_required_total.load(Ordering::Relaxed),
261 self.metrics.last_wait_micros.load(Ordering::Relaxed),
262 )
263 }
264}
265
266fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
267 map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
268}
269
270fn commit_watermark(map: &HashMap<String, u64>, required: u32) -> u64 {
271 if required == 0 || map.len() < required as usize {
272 return 0;
273 }
274 let mut durable: Vec<u64> = map.values().copied().collect();
275 durable.sort_unstable_by(|a, b| b.cmp(a));
276 durable[(required as usize) - 1]
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::sync::Arc;
283 use std::thread;
284
285 #[test]
286 fn required_zero_is_immediate_no_op() {
287 let w = CommitWaiter::new();
288 let r = w.await_acks(100, 0, Duration::from_secs(60));
289 assert_eq!(r, AwaitOutcome::NotRequired);
290 }
291
292 #[test]
293 fn reaches_threshold_with_existing_acks() {
294 let w = CommitWaiter::new();
295 w.record_replica_ack("a", 200);
296 w.record_replica_ack("b", 200);
297 let r = w.await_acks(150, 2, Duration::from_millis(10));
298 assert_eq!(r, AwaitOutcome::Reached(2));
299 }
300
301 #[test]
302 fn commit_watermark_is_nth_highest_durable_lsn() {
303 let w = CommitWaiter::new();
304 w.record_replica_ack("a", 10);
305 w.record_replica_ack("b", 30);
306 w.record_replica_ack("c", 20);
307
308 assert_eq!(w.commit_watermark(1), 30);
309 assert_eq!(w.commit_watermark(2), 20);
310 assert_eq!(w.commit_watermark(3), 10);
311 assert_eq!(w.commit_watermark(4), 0);
312
313 w.record_replica_ack("b", 15);
314 assert_eq!(w.commit_watermark(2), 20);
315 }
316
317 #[test]
318 fn times_out_when_no_one_has_acked() {
319 let w = CommitWaiter::new();
320 w.record_replica_ack("a", 100);
321 let r = w.await_acks(500, 1, Duration::from_millis(20));
322 match r {
323 AwaitOutcome::TimedOut { observed, required } => {
324 assert_eq!(observed, 0);
325 assert_eq!(required, 1);
326 }
327 other => panic!("expected TimedOut, got {other:?}"),
328 }
329 }
330
331 #[test]
332 fn ack_arriving_during_wait_unblocks_caller() {
333 let w = Arc::new(CommitWaiter::new());
334 let waiter = Arc::clone(&w);
335 let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
336 thread::sleep(Duration::from_millis(50));
338 w.record_replica_ack("late", 1000);
339 let outcome = handle.join().expect("waiter thread");
340 assert_eq!(outcome, AwaitOutcome::Reached(1));
341 }
342
343 #[test]
344 fn ack_idempotent_does_not_double_count() {
345 let w = CommitWaiter::new();
346 w.record_replica_ack("a", 50);
347 w.record_replica_ack("a", 50);
348 w.record_replica_ack("a", 50);
349 let r = w.await_acks(50, 1, Duration::from_millis(5));
350 assert_eq!(r, AwaitOutcome::Reached(1));
351 let r2 = w.await_acks(50, 2, Duration::from_millis(20));
353 assert!(matches!(
354 r2,
355 AwaitOutcome::TimedOut {
356 observed: 1,
357 required: 2
358 }
359 ));
360 }
361
362 #[test]
363 fn ack_only_advances_lsn_forward() {
364 let w = CommitWaiter::new();
365 w.record_replica_ack("a", 200);
366 w.record_replica_ack("a", 100);
368 let snap = w.snapshot();
369 assert_eq!(snap, vec![("a".to_string(), 200)]);
370 }
371
372 #[test]
373 fn drop_replica_removes_from_count() {
374 let w = CommitWaiter::new();
375 w.record_replica_ack("a", 100);
376 w.record_replica_ack("b", 100);
377 w.drop_replica("a");
378 let r = w.await_acks(100, 2, Duration::from_millis(20));
379 assert!(matches!(
380 r,
381 AwaitOutcome::TimedOut {
382 observed: 1,
383 required: 2
384 }
385 ));
386 }
387
388 #[test]
389 fn metrics_count_each_outcome_kind() {
390 let w = CommitWaiter::new();
391 w.await_acks(100, 0, Duration::from_millis(5));
393 w.await_acks(100, 1, Duration::from_millis(5));
395 w.record_replica_ack("a", 100);
397 w.await_acks(100, 1, Duration::from_millis(5));
398
399 let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
400 assert_eq!(reached, 1, "one Reached call");
401 assert_eq!(timed_out, 1, "one TimedOut call");
402 assert_eq!(not_required, 1, "one NotRequired call");
403 assert!(last_micros > 0, "last_wait_micros must be set");
407 }
408}