use alloc::collections::BTreeMap;
use alloc::vec::Vec;
pub const ACK_TIMEOUT_MS: u64 = 2000;
pub const ACK_RANDOM_FACTOR_NUM: u64 = 3;
pub const ACK_RANDOM_FACTOR_DEN: u64 = 2;
pub const MAX_RETRANSMIT: u32 = 4;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PendingConfirmable {
pub message_id: u16,
pub token: Vec<u8>,
pub bytes: Vec<u8>,
pub retransmits_left: u32,
pub next_timeout_ms: u64,
pub current_interval_ms: u64,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ReliabilityTracker {
pending: BTreeMap<u16, PendingConfirmable>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct TickOutput {
pub retransmit: Vec<PendingConfirmable>,
pub timed_out: Vec<u16>,
}
impl ReliabilityTracker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn send_confirmable(
&mut self,
message_id: u16,
token: Vec<u8>,
bytes: Vec<u8>,
now_ms: u64,
) {
self.pending.insert(
message_id,
PendingConfirmable {
message_id,
token,
bytes,
retransmits_left: MAX_RETRANSMIT,
next_timeout_ms: now_ms + initial_interval(),
current_interval_ms: initial_interval(),
},
);
}
pub fn receive_ack(&mut self, message_id: u16) -> bool {
self.pending.remove(&message_id).is_some()
}
pub fn receive_rst(&mut self, message_id: u16) -> bool {
self.pending.remove(&message_id).is_some()
}
pub fn tick(&mut self, now_ms: u64) -> TickOutput {
let mut out = TickOutput::default();
let mut to_drop = Vec::new();
for (mid, entry) in self.pending.iter_mut() {
if now_ms < entry.next_timeout_ms {
continue;
}
if entry.retransmits_left == 0 {
to_drop.push(*mid);
continue;
}
entry.retransmits_left -= 1;
entry.current_interval_ms *= 2;
entry.next_timeout_ms = now_ms + entry.current_interval_ms;
out.retransmit.push(entry.clone());
}
for mid in to_drop {
self.pending.remove(&mid);
out.timed_out.push(mid);
}
out
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
}
fn initial_interval() -> u64 {
ACK_TIMEOUT_MS * ACK_RANDOM_FACTOR_NUM / ACK_RANDOM_FACTOR_DEN
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn fresh_tracker_is_empty() {
let t = ReliabilityTracker::new();
assert_eq!(t.pending_count(), 0);
}
#[test]
fn send_then_ack_clears_pending() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![1, 2], alloc::vec![0; 10], 0);
assert_eq!(t.pending_count(), 1);
assert!(t.receive_ack(42));
assert_eq!(t.pending_count(), 0);
}
#[test]
fn unknown_ack_returns_false() {
let mut t = ReliabilityTracker::new();
assert!(!t.receive_ack(99));
}
#[test]
fn rst_clears_pending() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![], alloc::vec![], 0);
assert!(t.receive_rst(42));
assert_eq!(t.pending_count(), 0);
}
#[test]
fn tick_before_timeout_does_nothing() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![], alloc::vec![0; 10], 0);
let out = t.tick(100);
assert!(out.retransmit.is_empty());
assert!(out.timed_out.is_empty());
assert_eq!(t.pending_count(), 1);
}
#[test]
fn tick_after_timeout_retransmits() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![], alloc::vec![0; 5], 0);
let out = t.tick(initial_interval() + 1);
assert_eq!(out.retransmit.len(), 1);
assert_eq!(out.retransmit[0].message_id, 42);
assert!(out.timed_out.is_empty());
}
#[test]
fn exhausting_retransmits_times_out() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![], alloc::vec![0; 5], 0);
let mut now = 0u64;
let mut interval = initial_interval();
for _ in 0..MAX_RETRANSMIT {
now += interval + 1;
interval *= 2;
let _ = t.tick(now);
}
now += interval + 1;
let out = t.tick(now);
assert!(!out.timed_out.is_empty(), "should be timed out");
assert!(t.pending_count() == 0);
}
#[test]
fn interval_doubles_per_retransmit() {
let mut t = ReliabilityTracker::new();
t.send_confirmable(42, alloc::vec![], alloc::vec![0; 5], 0);
let _ = t.tick(initial_interval() + 1);
let after_first = t.pending.get(&42).unwrap().current_interval_ms;
assert_eq!(after_first, initial_interval() * 2);
}
}