use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Default)]
struct State {
durable_lsn: HashMap<String, u64>,
}
#[derive(Debug, Default)]
pub struct CommitWaiterMetrics {
pub reached_total: AtomicU64,
pub timed_out_total: AtomicU64,
pub not_required_total: AtomicU64,
pub last_wait_micros: AtomicU64,
}
#[derive(Debug)]
pub struct CommitWaiter {
state: Mutex<State>,
cond: Condvar,
metrics: CommitWaiterMetrics,
}
impl Default for CommitWaiter {
fn default() -> Self {
Self {
state: Mutex::new(State::default()),
cond: Condvar::new(),
metrics: CommitWaiterMetrics::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AwaitOutcome {
Reached(u32),
TimedOut { observed: u32, required: u32 },
NotRequired,
}
impl CommitWaiter {
pub fn new() -> Self {
Self::default()
}
pub fn record_replica_ack(&self, replica_id: &str, lsn: u64) {
let mut state = self.state.lock().expect("commit waiter mutex");
let entry = state.durable_lsn.entry(replica_id.to_string()).or_insert(0);
if lsn > *entry {
*entry = lsn;
self.cond.notify_all();
}
}
pub fn drop_replica(&self, replica_id: &str) {
let mut state = self.state.lock().expect("commit waiter mutex");
if state.durable_lsn.remove(replica_id).is_some() {
self.cond.notify_all();
}
}
pub fn snapshot(&self) -> Vec<(String, u64)> {
let state = self.state.lock().expect("commit waiter mutex");
let mut v: Vec<(String, u64)> = state
.durable_lsn
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
v.sort_by(|a, b| a.0.cmp(&b.0));
v
}
pub fn await_acks(&self, target_lsn: u64, required: u32, timeout: Duration) -> AwaitOutcome {
if required == 0 {
self.metrics
.not_required_total
.fetch_add(1, Ordering::Relaxed);
return AwaitOutcome::NotRequired;
}
let started = Instant::now();
let deadline = started + timeout;
let mut state = self.state.lock().expect("commit waiter mutex");
loop {
let observed = count_at_or_past(&state.durable_lsn, target_lsn);
if observed >= required {
self.record_outcome_metrics(true, started);
return AwaitOutcome::Reached(observed);
}
let now = Instant::now();
if now >= deadline {
self.record_outcome_metrics(false, started);
return AwaitOutcome::TimedOut { observed, required };
}
let remaining = deadline - now;
let (next_state, _wait_result) = self
.cond
.wait_timeout(state, remaining)
.expect("commit waiter condvar");
state = next_state;
}
}
fn record_outcome_metrics(&self, reached: bool, started: Instant) {
let elapsed = (started.elapsed().as_micros() as u64).max(1);
self.metrics
.last_wait_micros
.store(elapsed, Ordering::Relaxed);
if reached {
self.metrics.reached_total.fetch_add(1, Ordering::Relaxed);
} else {
self.metrics.timed_out_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn metrics_snapshot(&self) -> (u64, u64, u64, u64) {
(
self.metrics.reached_total.load(Ordering::Relaxed),
self.metrics.timed_out_total.load(Ordering::Relaxed),
self.metrics.not_required_total.load(Ordering::Relaxed),
self.metrics.last_wait_micros.load(Ordering::Relaxed),
)
}
}
fn count_at_or_past(map: &HashMap<String, u64>, target_lsn: u64) -> u32 {
map.values().filter(|lsn| **lsn >= target_lsn).count() as u32
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn required_zero_is_immediate_no_op() {
let w = CommitWaiter::new();
let r = w.await_acks(100, 0, Duration::from_secs(60));
assert_eq!(r, AwaitOutcome::NotRequired);
}
#[test]
fn reaches_threshold_with_existing_acks() {
let w = CommitWaiter::new();
w.record_replica_ack("a", 200);
w.record_replica_ack("b", 200);
let r = w.await_acks(150, 2, Duration::from_millis(10));
assert_eq!(r, AwaitOutcome::Reached(2));
}
#[test]
fn times_out_when_no_one_has_acked() {
let w = CommitWaiter::new();
w.record_replica_ack("a", 100);
let r = w.await_acks(500, 1, Duration::from_millis(20));
match r {
AwaitOutcome::TimedOut { observed, required } => {
assert_eq!(observed, 0);
assert_eq!(required, 1);
}
other => panic!("expected TimedOut, got {other:?}"),
}
}
#[test]
fn ack_arriving_during_wait_unblocks_caller() {
let w = Arc::new(CommitWaiter::new());
let waiter = Arc::clone(&w);
let handle = thread::spawn(move || waiter.await_acks(1000, 1, Duration::from_secs(2)));
thread::sleep(Duration::from_millis(50));
w.record_replica_ack("late", 1000);
let outcome = handle.join().expect("waiter thread");
assert_eq!(outcome, AwaitOutcome::Reached(1));
}
#[test]
fn ack_idempotent_does_not_double_count() {
let w = CommitWaiter::new();
w.record_replica_ack("a", 50);
w.record_replica_ack("a", 50);
w.record_replica_ack("a", 50);
let r = w.await_acks(50, 1, Duration::from_millis(5));
assert_eq!(r, AwaitOutcome::Reached(1));
let r2 = w.await_acks(50, 2, Duration::from_millis(20));
assert!(matches!(
r2,
AwaitOutcome::TimedOut {
observed: 1,
required: 2
}
));
}
#[test]
fn ack_only_advances_lsn_forward() {
let w = CommitWaiter::new();
w.record_replica_ack("a", 200);
w.record_replica_ack("a", 100);
let snap = w.snapshot();
assert_eq!(snap, vec![("a".to_string(), 200)]);
}
#[test]
fn drop_replica_removes_from_count() {
let w = CommitWaiter::new();
w.record_replica_ack("a", 100);
w.record_replica_ack("b", 100);
w.drop_replica("a");
let r = w.await_acks(100, 2, Duration::from_millis(20));
assert!(matches!(
r,
AwaitOutcome::TimedOut {
observed: 1,
required: 2
}
));
}
#[test]
fn metrics_count_each_outcome_kind() {
let w = CommitWaiter::new();
w.await_acks(100, 0, Duration::from_millis(5));
w.await_acks(100, 1, Duration::from_millis(5));
w.record_replica_ack("a", 100);
w.await_acks(100, 1, Duration::from_millis(5));
let (reached, timed_out, not_required, last_micros) = w.metrics_snapshot();
assert_eq!(reached, 1, "one Reached call");
assert_eq!(timed_out, 1, "one TimedOut call");
assert_eq!(not_required, 1, "one NotRequired call");
assert!(last_micros > 0, "last_wait_micros must be set");
}
}