use hashbrown::HashMap;
use noxu_sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
pub struct AckTracker {
pending_acks: Mutex<HashMap<u64, PendingAck>>,
ack_signal: Condvar,
total_acks: Mutex<u64>,
total_timeouts: Mutex<u64>,
}
#[derive(Debug)]
struct PendingAck {
vlsn: u64,
needed: u32,
received: HashMap<String, Instant>,
created: Instant,
}
impl PendingAck {
fn new(vlsn: u64, needed: u32) -> Self {
Self { vlsn, needed, received: HashMap::new(), created: Instant::now() }
}
fn is_satisfied(&self) -> bool {
self.received.len() as u32 >= self.needed
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AckResult {
Pending,
Satisfied,
Unknown,
Duplicate,
}
impl AckTracker {
pub fn new() -> Self {
Self {
pending_acks: Mutex::new(HashMap::new()),
ack_signal: Condvar::new(),
total_acks: Mutex::new(0),
total_timeouts: Mutex::new(0),
}
}
pub fn register(&self, vlsn: u64, needed_acks: u32) {
let mut pending = self.pending_acks.lock();
pending
.entry(vlsn)
.or_insert_with(|| PendingAck::new(vlsn, needed_acks));
}
pub fn record_ack(&self, vlsn: u64, replica_name: &str) -> AckResult {
let mut pending = self.pending_acks.lock();
let ack = match pending.get_mut(&vlsn) {
Some(a) => a,
None => return AckResult::Unknown,
};
if ack.received.contains_key(replica_name) {
return AckResult::Duplicate;
}
ack.received.insert(replica_name.to_string(), Instant::now());
let satisfied = ack.is_satisfied();
drop(pending);
*self.total_acks.lock() += 1;
self.ack_signal.notify_all();
if satisfied { AckResult::Satisfied } else { AckResult::Pending }
}
pub fn wait_until_satisfied<F: Fn() -> bool>(
&self,
vlsn: u64,
timeout: Duration,
should_abort: F,
) -> bool {
let deadline = Instant::now() + timeout;
let mut guard = self.pending_acks.lock();
loop {
match guard.get(&vlsn) {
None => return true,
Some(ack) if ack.is_satisfied() => return true,
_ => {}
}
if should_abort() {
return false;
}
let now = Instant::now();
if now >= deadline {
return false;
}
let res = self.ack_signal.wait_for(&mut guard, deadline - now);
if res.timed_out() && Instant::now() >= deadline {
match guard.get(&vlsn) {
None => return true,
Some(ack) if ack.is_satisfied() => return true,
_ => return false,
}
}
}
}
pub fn is_satisfied(&self, vlsn: u64) -> bool {
let pending = self.pending_acks.lock();
match pending.get(&vlsn) {
Some(ack) => ack.is_satisfied(),
None => false,
}
}
pub fn received_count(&self, vlsn: u64) -> Option<u32> {
let pending = self.pending_acks.lock();
pending.get(&vlsn).map(|ack| ack.received.len() as u32)
}
pub fn cleanup_through(&self, vlsn: u64) {
let mut pending = self.pending_acks.lock();
pending.retain(|&v, _| v > vlsn);
}
pub fn pending_count(&self) -> usize {
self.pending_acks.lock().len()
}
pub fn check_timeouts(&self, timeout: Duration) -> Vec<u64> {
let pending = self.pending_acks.lock();
let now = Instant::now();
let mut timed_out = Vec::new();
for ack in pending.values() {
if !ack.is_satisfied()
&& let Some(elapsed) = now.checked_duration_since(ack.created)
&& elapsed > timeout
{
timed_out.push(ack.vlsn);
*self.total_timeouts.lock() += 1;
}
}
timed_out
}
pub fn get_total_acks(&self) -> u64 {
*self.total_acks.lock()
}
pub fn get_total_timeouts(&self) -> u64 {
*self.total_timeouts.lock()
}
}
impl Default for AckTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_tracker() {
let tracker = AckTracker::new();
assert_eq!(tracker.pending_count(), 0);
assert_eq!(tracker.get_total_acks(), 0);
assert_eq!(tracker.get_total_timeouts(), 0);
}
#[test]
fn test_default_impl() {
let tracker = AckTracker::default();
assert_eq!(tracker.pending_count(), 0);
}
#[test]
fn test_register() {
let tracker = AckTracker::new();
tracker.register(100, 2);
assert_eq!(tracker.pending_count(), 1);
assert!(!tracker.is_satisfied(100));
}
#[test]
fn test_register_idempotent() {
let tracker = AckTracker::new();
tracker.register(100, 2);
tracker.register(100, 5); assert_eq!(tracker.pending_count(), 1);
tracker.record_ack(100, "replica1");
tracker.record_ack(100, "replica2");
assert!(tracker.is_satisfied(100));
}
#[test]
fn test_record_ack_pending() {
let tracker = AckTracker::new();
tracker.register(100, 2);
let result = tracker.record_ack(100, "replica1");
assert_eq!(result, AckResult::Pending);
assert!(!tracker.is_satisfied(100));
assert_eq!(tracker.get_total_acks(), 1);
}
#[test]
fn test_record_ack_satisfied() {
let tracker = AckTracker::new();
tracker.register(100, 2);
tracker.record_ack(100, "replica1");
let result = tracker.record_ack(100, "replica2");
assert_eq!(result, AckResult::Satisfied);
assert!(tracker.is_satisfied(100));
assert_eq!(tracker.get_total_acks(), 2);
}
#[test]
fn test_single_ack_needed() {
let tracker = AckTracker::new();
tracker.register(100, 1);
let result = tracker.record_ack(100, "replica1");
assert_eq!(result, AckResult::Satisfied);
assert!(tracker.is_satisfied(100));
}
#[test]
fn test_record_ack_unknown_vlsn() {
let tracker = AckTracker::new();
let result = tracker.record_ack(999, "replica1");
assert_eq!(result, AckResult::Unknown);
assert_eq!(tracker.get_total_acks(), 0);
}
#[test]
fn test_record_ack_duplicate() {
let tracker = AckTracker::new();
tracker.register(100, 2);
tracker.record_ack(100, "replica1");
let result = tracker.record_ack(100, "replica1");
assert_eq!(result, AckResult::Duplicate);
assert!(!tracker.is_satisfied(100));
assert_eq!(tracker.get_total_acks(), 1);
}
#[test]
fn test_is_satisfied_unknown_vlsn() {
let tracker = AckTracker::new();
assert!(!tracker.is_satisfied(999));
}
#[test]
fn test_multiple_vlsns() {
let tracker = AckTracker::new();
tracker.register(100, 1);
tracker.register(101, 2);
tracker.register(102, 1);
assert_eq!(tracker.pending_count(), 3);
tracker.record_ack(100, "r1");
assert!(tracker.is_satisfied(100));
assert!(!tracker.is_satisfied(101));
tracker.record_ack(101, "r1");
assert!(!tracker.is_satisfied(101));
tracker.record_ack(101, "r2");
assert!(tracker.is_satisfied(101));
}
#[test]
fn test_cleanup_through() {
let tracker = AckTracker::new();
tracker.register(100, 1);
tracker.register(101, 1);
tracker.register(102, 1);
tracker.register(200, 1);
assert_eq!(tracker.pending_count(), 4);
tracker.cleanup_through(102);
assert_eq!(tracker.pending_count(), 1);
assert_eq!(tracker.record_ack(100, "r1"), AckResult::Unknown);
assert_eq!(tracker.record_ack(200, "r1"), AckResult::Satisfied);
}
#[test]
fn test_cleanup_through_zero() {
let tracker = AckTracker::new();
tracker.register(100, 1);
tracker.cleanup_through(0);
assert_eq!(tracker.pending_count(), 1);
}
#[test]
fn test_cleanup_through_all() {
let tracker = AckTracker::new();
tracker.register(1, 1);
tracker.register(2, 1);
tracker.cleanup_through(100);
assert_eq!(tracker.pending_count(), 0);
}
#[test]
fn test_check_timeouts_none() {
let tracker = AckTracker::new();
tracker.register(100, 1);
let timed_out = tracker.check_timeouts(Duration::from_secs(60));
assert!(timed_out.is_empty());
assert_eq!(tracker.get_total_timeouts(), 0);
}
#[test]
fn test_check_timeouts_with_expired() {
let tracker = AckTracker::new();
{
let mut pending = tracker.pending_acks.lock();
let mut ack = PendingAck::new(50, 1);
ack.created = Instant::now() - Duration::from_secs(120);
pending.insert(50, ack);
}
let timed_out = tracker.check_timeouts(Duration::from_secs(60));
assert_eq!(timed_out.len(), 1);
assert_eq!(timed_out[0], 50);
assert_eq!(tracker.get_total_timeouts(), 1);
}
#[test]
fn test_check_timeouts_skips_satisfied() {
let tracker = AckTracker::new();
{
let mut pending = tracker.pending_acks.lock();
let mut ack = PendingAck::new(50, 1);
ack.created = Instant::now() - Duration::from_secs(120);
ack.received.insert("r1".to_string(), Instant::now());
pending.insert(50, ack);
}
let timed_out = tracker.check_timeouts(Duration::from_secs(60));
assert!(timed_out.is_empty());
}
#[test]
fn test_extra_acks_beyond_needed() {
let tracker = AckTracker::new();
tracker.register(100, 1);
assert_eq!(tracker.record_ack(100, "r1"), AckResult::Satisfied);
assert_eq!(tracker.record_ack(100, "r2"), AckResult::Satisfied);
assert_eq!(tracker.get_total_acks(), 2);
}
#[test]
fn test_zero_acks_needed() {
let tracker = AckTracker::new();
tracker.register(100, 0);
assert!(tracker.is_satisfied(100));
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AckTracker>();
}
#[test]
fn wait_until_satisfied_wakes_on_ack() {
use std::sync::Arc;
use std::thread;
let tracker = Arc::new(AckTracker::new());
tracker.register(42, 2);
let t2 = Arc::clone(&tracker);
let waiter = thread::spawn(move || {
t2.wait_until_satisfied(42, Duration::from_secs(5), || false)
});
thread::sleep(Duration::from_millis(20));
assert_eq!(tracker.record_ack(42, "r1"), AckResult::Pending);
assert_eq!(tracker.record_ack(42, "r2"), AckResult::Satisfied);
let start = Instant::now();
let ok = waiter.join().unwrap();
assert!(ok, "wait_until_satisfied must return true once satisfied");
assert!(
start.elapsed() < Duration::from_secs(2),
"must wake on ack, not spin to timeout"
);
}
#[test]
fn wait_until_satisfied_times_out_without_enough_acks() {
let tracker = AckTracker::new();
tracker.register(7, 3);
tracker.record_ack(7, "only-one");
let ok =
tracker
.wait_until_satisfied(7, Duration::from_millis(50), || false);
assert!(!ok, "must time out when acks are insufficient");
}
}