use super::types::SequenceNumber;
use std::collections::VecDeque;
use std::time::{Duration, Instant};
pub const DEFAULT_WINDOW_SIZE: u8 = 8;
pub const DEFAULT_RTO: Duration = Duration::from_secs(2);
pub const MAX_RTO: Duration = Duration::from_secs(30);
pub const DEFAULT_MAX_RETRIES: u32 = 5;
#[derive(Debug, Clone)]
pub struct ArqConfig {
pub window_size: u8,
pub initial_rto: Duration,
pub max_rto: Duration,
pub max_retries: u32,
}
impl Default for ArqConfig {
fn default() -> Self {
Self {
window_size: DEFAULT_WINDOW_SIZE,
initial_rto: DEFAULT_RTO,
max_rto: MAX_RTO,
max_retries: DEFAULT_MAX_RETRIES,
}
}
}
impl ArqConfig {
pub fn for_ble() -> Self {
Self {
window_size: 4, initial_rto: Duration::from_millis(1500),
max_rto: Duration::from_secs(15),
max_retries: 5,
}
}
pub fn for_lora() -> Self {
Self {
window_size: 2, initial_rto: Duration::from_secs(10),
max_rto: Duration::from_secs(60),
max_retries: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct SendEntry {
pub seq: SequenceNumber,
pub data: Vec<u8>,
#[allow(dead_code)]
first_sent: Instant,
last_sent: Instant,
pub transmissions: u32,
}
impl SendEntry {
pub fn new(seq: SequenceNumber, data: Vec<u8>) -> Self {
let now = Instant::now();
Self {
seq,
data,
first_sent: now,
last_sent: now,
transmissions: 1,
}
}
pub fn time_since_sent(&self) -> Duration {
self.last_sent.elapsed()
}
#[allow(dead_code)]
pub fn total_time(&self) -> Duration {
self.first_sent.elapsed()
}
pub fn mark_retransmitted(&mut self) {
self.last_sent = Instant::now();
self.transmissions += 1;
}
}
#[derive(Debug)]
pub struct SendWindow {
config: ArqConfig,
next_seq: SequenceNumber,
base_seq: SequenceNumber,
unacked: VecDeque<SendEntry>,
current_rto: Duration,
srtt: Option<Duration>,
}
impl SendWindow {
pub fn new(config: ArqConfig) -> Self {
Self {
current_rto: config.initial_rto,
config,
next_seq: SequenceNumber::new(0),
base_seq: SequenceNumber::new(0),
unacked: VecDeque::new(),
srtt: None,
}
}
pub fn with_defaults() -> Self {
Self::new(ArqConfig::default())
}
pub fn next_seq(&self) -> SequenceNumber {
self.next_seq
}
pub fn can_send(&self) -> bool {
self.unacked.len() < self.config.window_size as usize
}
pub fn is_full(&self) -> bool {
!self.can_send()
}
pub fn in_flight(&self) -> usize {
self.unacked.len()
}
pub fn len(&self) -> usize {
self.in_flight()
}
pub fn is_empty(&self) -> bool {
self.unacked.is_empty()
}
pub fn send(&mut self, data: Vec<u8>) -> Option<SequenceNumber> {
if !self.can_send() {
return None;
}
let seq = self.next_seq;
self.next_seq = self.next_seq.next();
self.unacked.push_back(SendEntry::new(seq, data));
Some(seq)
}
pub fn add(
&mut self,
seq: SequenceNumber,
data: Vec<u8>,
) -> Result<(), super::types::ConstrainedError> {
if self.is_full() {
return Err(super::types::ConstrainedError::SendBufferFull);
}
self.unacked.push_back(SendEntry::new(seq, data));
Ok(())
}
pub fn acknowledge(&mut self, ack: SequenceNumber) -> usize {
let mut count = 0;
while let Some(entry) = self.unacked.front() {
let dist = self.base_seq.distance_to(entry.seq);
let ack_dist = self.base_seq.distance_to(ack);
if dist <= ack_dist {
if let Some(entry) = self.unacked.pop_front() {
if entry.transmissions == 1 {
self.update_rtt(entry.time_since_sent());
}
count += 1;
}
} else {
break;
}
}
if count > 0 {
self.base_seq = ack.next();
}
count
}
fn update_rtt(&mut self, sample: Duration) {
const ALPHA: f64 = 0.125;
if let Some(srtt) = self.srtt {
let srtt_secs = srtt.as_secs_f64();
let sample_secs = sample.as_secs_f64();
let new_srtt = (1.0 - ALPHA) * srtt_secs + ALPHA * sample_secs;
let new_rto = (2.0 * new_srtt).clamp(
self.config.initial_rto.as_secs_f64(),
self.config.max_rto.as_secs_f64(),
);
self.srtt = Some(Duration::from_secs_f64(new_srtt));
self.current_rto = Duration::from_secs_f64(new_rto);
} else {
self.srtt = Some(sample);
self.current_rto = sample * 2;
}
}
pub fn rto(&self) -> Duration {
self.current_rto
}
pub fn get_retransmissions(&mut self) -> Option<Vec<(SequenceNumber, Vec<u8>)>> {
let rto = self.current_rto;
let max_retries = self.config.max_retries;
let mut retransmits = Vec::new();
for entry in &mut self.unacked {
if entry.time_since_sent() > rto {
if entry.transmissions > max_retries {
return None;
}
retransmits.push((entry.seq, entry.data.clone()));
entry.mark_retransmitted();
}
}
if !retransmits.is_empty() {
self.current_rto = (self.current_rto * 2).min(self.config.max_rto);
}
Some(retransmits)
}
pub fn reset(&mut self) {
self.next_seq = SequenceNumber::new(0);
self.base_seq = SequenceNumber::new(0);
self.unacked.clear();
self.current_rto = self.config.initial_rto;
self.srtt = None;
}
}
#[derive(Debug)]
pub struct ReceiveWindow {
window_size: u8,
next_expected: SequenceNumber,
cumulative_ack: SequenceNumber,
out_of_order: VecDeque<(SequenceNumber, Vec<u8>)>,
}
impl ReceiveWindow {
pub fn new(window_size: u8) -> Self {
Self {
window_size,
next_expected: SequenceNumber::new(0),
cumulative_ack: SequenceNumber::new(0),
out_of_order: VecDeque::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(DEFAULT_WINDOW_SIZE)
}
pub fn cumulative_ack(&self) -> SequenceNumber {
self.cumulative_ack
}
pub fn is_in_window(&self, seq: SequenceNumber) -> bool {
self.next_expected.is_in_window(seq, self.window_size)
}
pub fn receive(
&mut self,
seq: SequenceNumber,
data: Vec<u8>,
) -> Option<Vec<(SequenceNumber, Vec<u8>)>> {
if !self.is_in_window(seq) {
return None;
}
if seq == self.next_expected {
let mut deliverable = vec![(seq, data)];
self.next_expected = self.next_expected.next();
self.cumulative_ack = seq;
while let Some(entry_idx) = self
.out_of_order
.iter()
.position(|(s, _)| *s == self.next_expected)
{
if let Some((s, d)) = self.out_of_order.remove(entry_idx) {
deliverable.push((s, d));
self.next_expected = self.next_expected.next();
self.cumulative_ack = s;
}
}
Some(deliverable)
} else {
if !self.out_of_order.iter().any(|(s, _)| *s == seq) {
let pos = self
.out_of_order
.iter()
.position(|(s, _)| {
self.next_expected.distance_to(*s) > self.next_expected.distance_to(seq)
})
.unwrap_or(self.out_of_order.len());
self.out_of_order.insert(pos, (seq, data));
}
None
}
}
pub fn reset(&mut self) {
self.next_expected = SequenceNumber::new(0);
self.cumulative_ack = SequenceNumber::new(0);
self.out_of_order.clear();
}
pub fn reset_with_seq(&mut self, start_seq: SequenceNumber) {
self.next_expected = start_seq;
self.cumulative_ack = start_seq;
self.out_of_order.clear();
}
pub fn buffered_count(&self) -> usize {
self.out_of_order.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arq_config_defaults() {
let config = ArqConfig::default();
assert_eq!(config.window_size, DEFAULT_WINDOW_SIZE);
assert_eq!(config.initial_rto, DEFAULT_RTO);
}
#[test]
fn test_arq_config_ble() {
let config = ArqConfig::for_ble();
assert!(config.window_size < DEFAULT_WINDOW_SIZE);
assert!(config.initial_rto < DEFAULT_RTO);
}
#[test]
fn test_send_entry() {
let entry = SendEntry::new(SequenceNumber::new(5), b"test".to_vec());
assert_eq!(entry.seq, SequenceNumber::new(5));
assert_eq!(entry.transmissions, 1);
assert!(entry.time_since_sent() < Duration::from_secs(1));
}
#[test]
fn test_send_window_basic() {
let mut window = SendWindow::with_defaults();
assert!(window.can_send());
assert_eq!(window.in_flight(), 0);
let seq = window.send(b"hello".to_vec()).unwrap();
assert_eq!(seq, SequenceNumber::new(0));
assert_eq!(window.in_flight(), 1);
let acked = window.acknowledge(SequenceNumber::new(0));
assert_eq!(acked, 1);
assert_eq!(window.in_flight(), 0);
}
#[test]
fn test_send_window_full() {
let config = ArqConfig {
window_size: 2,
..Default::default()
};
let mut window = SendWindow::new(config);
assert!(window.send(b"1".to_vec()).is_some());
assert!(window.send(b"2".to_vec()).is_some());
assert!(!window.can_send());
assert!(window.send(b"3".to_vec()).is_none());
}
#[test]
fn test_send_window_cumulative_ack() {
let mut window = SendWindow::with_defaults();
window.send(b"1".to_vec());
window.send(b"2".to_vec());
window.send(b"3".to_vec());
assert_eq!(window.in_flight(), 3);
let acked = window.acknowledge(SequenceNumber::new(1));
assert_eq!(acked, 2);
assert_eq!(window.in_flight(), 1);
}
#[test]
fn test_receive_window_in_order() {
let mut window = ReceiveWindow::with_defaults();
let result = window.receive(SequenceNumber::new(0), b"first".to_vec());
assert!(result.is_some());
let packets = result.unwrap();
assert_eq!(packets.len(), 1);
assert_eq!(packets[0].1, b"first");
assert_eq!(window.cumulative_ack(), SequenceNumber::new(0));
}
#[test]
fn test_receive_window_out_of_order() {
let mut window = ReceiveWindow::with_defaults();
let result = window.receive(SequenceNumber::new(1), b"second".to_vec());
assert!(result.is_none());
assert_eq!(window.buffered_count(), 1);
let result = window.receive(SequenceNumber::new(0), b"first".to_vec());
assert!(result.is_some());
let packets = result.unwrap();
assert_eq!(packets.len(), 2);
assert_eq!(packets[0].1, b"first");
assert_eq!(packets[1].1, b"second");
assert_eq!(window.cumulative_ack(), SequenceNumber::new(1));
assert_eq!(window.buffered_count(), 0);
}
#[test]
fn test_receive_window_duplicate() {
let mut window = ReceiveWindow::with_defaults();
window.receive(SequenceNumber::new(0), b"first".to_vec());
let result = window.receive(SequenceNumber::new(0), b"first".to_vec());
assert!(result.is_none());
}
#[test]
fn test_receive_window_out_of_window() {
let config = ArqConfig {
window_size: 4,
..Default::default()
};
let mut window = ReceiveWindow::new(config.window_size);
let result = window.receive(SequenceNumber::new(10), b"data".to_vec());
assert!(result.is_none());
assert_eq!(window.buffered_count(), 0);
}
#[test]
fn test_send_window_reset() {
let mut window = SendWindow::with_defaults();
window.send(b"data".to_vec());
assert_eq!(window.in_flight(), 1);
window.reset();
assert_eq!(window.in_flight(), 0);
assert_eq!(window.next_seq(), SequenceNumber::new(0));
}
#[test]
fn test_receive_window_reset() {
let mut window = ReceiveWindow::with_defaults();
window.receive(SequenceNumber::new(1), b"data".to_vec());
assert_eq!(window.buffered_count(), 1);
window.reset();
assert_eq!(window.buffered_count(), 0);
}
}