extern crate alloc;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use core::time::Duration;
use crate::error::XrceError;
use crate::header::StreamId;
use crate::serial_number::SerialNumber16;
use crate::submessages::{AckNackPayload, HeartbeatPayload};
pub const DEFAULT_HEARTBEAT_PERIOD: Duration = Duration::from_millis(500);
pub const SENDER_WINDOW_CAP: usize = 16;
pub const RECEIVER_BUFFER_CAP: usize = 64;
pub const RELIABLE_MAX_PAYLOAD: usize = 65_535;
#[derive(Debug, Clone, Copy)]
pub struct ReliableConfig {
pub heartbeat_period: Duration,
pub sender_window: usize,
pub receiver_buffer: usize,
}
impl Default for ReliableConfig {
fn default() -> Self {
Self {
heartbeat_period: DEFAULT_HEARTBEAT_PERIOD,
sender_window: SENDER_WINDOW_CAP,
receiver_buffer: RECEIVER_BUFFER_CAP,
}
}
}
#[derive(Debug, Clone)]
pub struct ReliableStreamState {
stream_id: StreamId,
config: ReliableConfig,
next_seq: SerialNumber16,
in_flight: BTreeMap<u16, Vec<u8>>,
last_heartbeat: Option<Duration>,
expected_seq: SerialNumber16,
received: BTreeMap<u16, Vec<u8>>,
}
impl ReliableStreamState {
#[must_use]
pub fn new(stream_id: StreamId, config: ReliableConfig) -> Self {
assert!(
stream_id.is_reliable(),
"ReliableStreamState requires reliable stream id (>=128)"
);
Self {
stream_id,
config,
next_seq: SerialNumber16::new(0),
in_flight: BTreeMap::new(),
last_heartbeat: None,
expected_seq: SerialNumber16::new(0),
received: BTreeMap::new(),
}
}
#[must_use]
pub fn stream_id(&self) -> StreamId {
self.stream_id
}
#[must_use]
pub fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
#[must_use]
pub fn out_of_order_count(&self) -> usize {
self.received.len()
}
#[must_use]
pub fn expected(&self) -> SerialNumber16 {
self.expected_seq
}
pub fn submit(&mut self, payload: Vec<u8>) -> Result<SerialNumber16, XrceError> {
if payload.len() > RELIABLE_MAX_PAYLOAD {
return Err(XrceError::PayloadTooLarge {
limit: RELIABLE_MAX_PAYLOAD,
actual: payload.len(),
});
}
if self.in_flight.len() >= self.config.sender_window {
return Err(XrceError::ValueOutOfRange {
message: "reliable sender window full",
});
}
let seq = self.next_seq;
self.in_flight.insert(seq.raw(), payload);
self.next_seq = self.next_seq.next();
Ok(seq)
}
#[must_use]
pub fn get_in_flight(&self, seq: SerialNumber16) -> Option<&[u8]> {
self.in_flight.get(&seq.raw()).map(Vec::as_slice)
}
pub fn pending_heartbeat(&mut self, now: Duration) -> Option<HeartbeatPayload> {
if self.in_flight.is_empty() {
return None;
}
let due = match self.last_heartbeat {
None => true,
Some(t) => now.saturating_sub(t) >= self.config.heartbeat_period,
};
if !due {
return None;
}
self.last_heartbeat = Some(now);
let first = *self.in_flight.keys().next()?;
let last = *self.in_flight.keys().next_back()?;
Some(HeartbeatPayload {
first_unacked_seq_nr: first as i16,
last_unacked_seq_nr: last as i16,
stream_id: self.stream_id.0,
})
}
pub fn recv_acknack(&mut self, payload: AckNackPayload) {
let base = payload.first_unacked_seq_num as u16;
let bitmap = u16::from_le_bytes(payload.nack_bitmap);
let to_remove: Vec<u16> = self
.in_flight
.keys()
.copied()
.filter(|&k| {
let diff = base.wrapping_sub(k);
diff > 0 && diff < SerialNumber16::HALF_WINDOW
})
.collect();
for k in to_remove {
self.in_flight.remove(&k);
}
for i in 0u16..16 {
let seq = base.wrapping_add(i);
let bit = (bitmap >> i) & 1;
if bit == 0 {
self.in_flight.remove(&seq);
}
}
}
pub fn recv_data(&mut self, seq: SerialNumber16, payload: Vec<u8>) -> Result<(), XrceError> {
if seq.wrapping_lt(self.expected_seq) {
return Ok(()); }
if self.received.contains_key(&seq.raw()) {
return Ok(()); }
if self.received.len() >= self.config.receiver_buffer {
return Err(XrceError::ValueOutOfRange {
message: "reliable receiver buffer full",
});
}
self.received.insert(seq.raw(), payload);
Ok(())
}
pub fn drain_in_order(&mut self) -> Vec<(SerialNumber16, Vec<u8>)> {
let mut out = Vec::new();
loop {
let key = self.expected_seq.raw();
if let Some(payload) = self.received.remove(&key) {
out.push((self.expected_seq, payload));
self.expected_seq = self.expected_seq.next();
} else {
break;
}
}
out
}
#[must_use]
pub fn pending_acknack(&self, hint_last_seen: Option<SerialNumber16>) -> AckNackPayload {
let base = self.expected_seq;
let mut bitmap: u16 = 0;
for i in 0u16..16 {
let seq = base.next().0.wrapping_sub(1).wrapping_add(i);
let s = SerialNumber16::new(seq);
if let Some(h) = hint_last_seen {
if s.wrapping_gt(h) {
continue;
}
}
if !self.received.contains_key(&seq) {
bitmap |= 1u16 << i;
}
}
AckNackPayload {
first_unacked_seq_num: base.raw() as i16,
nack_bitmap: bitmap.to_le_bytes(),
stream_id: self.stream_id.0,
}
}
pub fn reset(&mut self) {
self.next_seq = SerialNumber16::new(0);
self.in_flight.clear();
self.last_heartbeat = None;
self.expected_seq = SerialNumber16::new(0);
self.received.clear();
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
fn rs() -> ReliableStreamState {
ReliableStreamState::new(StreamId::BUILTIN_RELIABLE, ReliableConfig::default())
}
#[test]
fn submit_assigns_monotonic_seqnrs() {
let mut s = rs();
let s0 = s.submit(alloc::vec![1, 2]).unwrap();
let s1 = s.submit(alloc::vec![3, 4]).unwrap();
assert_eq!(s0.raw(), 0);
assert_eq!(s1.raw(), 1);
assert_eq!(s.in_flight_count(), 2);
}
#[test]
fn submit_rejects_payload_too_large() {
let mut s = rs();
let huge = alloc::vec![0u8; RELIABLE_MAX_PAYLOAD + 1];
assert!(matches!(
s.submit(huge),
Err(XrceError::PayloadTooLarge { .. })
));
}
#[test]
fn submit_rejects_when_window_full() {
let mut s = rs();
for _ in 0..SENDER_WINDOW_CAP {
s.submit(alloc::vec![0]).unwrap();
}
assert!(s.submit(alloc::vec![0]).is_err());
}
#[test]
fn pending_heartbeat_fires_first_time() {
let mut s = rs();
s.submit(alloc::vec![1]).unwrap();
let hb = s.pending_heartbeat(Duration::from_secs(0));
assert!(hb.is_some());
let h = hb.unwrap();
assert_eq!(h.first_unacked_seq_nr, 0);
assert_eq!(h.last_unacked_seq_nr, 0);
assert_eq!(h.stream_id, StreamId::BUILTIN_RELIABLE.0);
}
#[test]
fn pending_heartbeat_silenced_until_period_elapsed() {
let mut s = rs();
s.submit(alloc::vec![1]).unwrap();
assert!(s.pending_heartbeat(Duration::from_millis(0)).is_some());
assert!(s.pending_heartbeat(Duration::from_millis(100)).is_none());
assert!(s.pending_heartbeat(Duration::from_millis(600)).is_some());
}
#[test]
fn pending_heartbeat_none_when_window_empty() {
let mut s = rs();
assert!(s.pending_heartbeat(Duration::from_secs(0)).is_none());
}
#[test]
fn recv_acknack_clears_acked_seqnrs() {
let mut s = rs();
s.submit(alloc::vec![0xA0]).unwrap(); s.submit(alloc::vec![0xA1]).unwrap(); s.submit(alloc::vec![0xA2]).unwrap(); assert_eq!(s.in_flight_count(), 3);
let ack = AckNackPayload {
first_unacked_seq_num: 2,
nack_bitmap: [0x01, 0x00],
stream_id: StreamId::BUILTIN_RELIABLE.0,
};
s.recv_acknack(ack);
assert_eq!(s.in_flight_count(), 1);
assert!(s.get_in_flight(SerialNumber16::new(2)).is_some());
}
#[test]
fn recv_acknack_full_clear_when_no_bits_set() {
let mut s = rs();
for _ in 0..5 {
s.submit(alloc::vec![0]).unwrap();
}
let ack = AckNackPayload {
first_unacked_seq_num: 5,
nack_bitmap: [0, 0],
stream_id: 0x80,
};
s.recv_acknack(ack);
assert_eq!(s.in_flight_count(), 0);
}
#[test]
fn recv_data_buffers_in_order() {
let mut s = rs();
s.recv_data(SerialNumber16::new(0), alloc::vec![10])
.unwrap();
s.recv_data(SerialNumber16::new(1), alloc::vec![11])
.unwrap();
let drained = s.drain_in_order();
assert_eq!(drained.len(), 2);
assert_eq!(drained[0].0.raw(), 0);
assert_eq!(drained[1].0.raw(), 1);
assert_eq!(s.expected().raw(), 2);
}
#[test]
fn recv_data_reorders_out_of_order() {
let mut s = rs();
s.recv_data(SerialNumber16::new(2), alloc::vec![22])
.unwrap();
s.recv_data(SerialNumber16::new(0), alloc::vec![20])
.unwrap();
let d1 = s.drain_in_order();
assert_eq!(d1.len(), 1);
assert_eq!(d1[0].0.raw(), 0);
s.recv_data(SerialNumber16::new(1), alloc::vec![21])
.unwrap();
let d2 = s.drain_in_order();
assert_eq!(d2.len(), 2);
assert_eq!(d2[0].0.raw(), 1);
assert_eq!(d2[1].0.raw(), 2);
}
#[test]
fn recv_data_drops_duplicates() {
let mut s = rs();
s.recv_data(SerialNumber16::new(0), alloc::vec![1]).unwrap();
s.drain_in_order();
s.recv_data(SerialNumber16::new(0), alloc::vec![99])
.unwrap();
assert_eq!(s.out_of_order_count(), 0);
}
#[test]
fn recv_data_rejects_when_buffer_full() {
let mut s = rs();
for i in 1..=RECEIVER_BUFFER_CAP as u16 {
s.recv_data(SerialNumber16::new(i), alloc::vec![1]).unwrap();
}
let res = s.recv_data(
SerialNumber16::new(RECEIVER_BUFFER_CAP as u16 + 1),
alloc::vec![1],
);
assert!(res.is_err());
}
#[test]
fn pending_acknack_marks_missing_slots() {
let mut s = rs();
s.recv_data(SerialNumber16::new(1), alloc::vec![1]).unwrap();
s.recv_data(SerialNumber16::new(3), alloc::vec![3]).unwrap();
let ack = s.pending_acknack(Some(SerialNumber16::new(3)));
let bitmap = u16::from_le_bytes(ack.nack_bitmap);
assert!(bitmap & (1 << 0) != 0);
assert!(bitmap & (1 << 2) != 0);
assert!(bitmap & (1 << 1) == 0); assert!(bitmap & (1 << 3) == 0); }
#[test]
fn reset_clears_state_completely() {
let mut s = rs();
s.submit(alloc::vec![1, 2]).unwrap();
s.recv_data(SerialNumber16::new(0), alloc::vec![3]).unwrap();
s.reset();
assert_eq!(s.in_flight_count(), 0);
assert_eq!(s.out_of_order_count(), 0);
assert_eq!(s.expected().raw(), 0);
}
#[test]
#[should_panic(expected = "reliable stream id")]
fn constructor_panics_on_best_effort_stream() {
let _ = ReliableStreamState::new(StreamId(1), ReliableConfig::default());
}
#[test]
fn end_to_end_sender_receiver_with_loss_recovery() {
let mut sender = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
let mut receiver = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
let s0 = sender.submit(alloc::vec![10]).expect("submit 0");
let s1 = sender.submit(alloc::vec![11]).expect("submit 1");
let s2 = sender.submit(alloc::vec![12]).expect("submit 2");
assert_eq!(sender.in_flight_count(), 3);
receiver.recv_data(s0, alloc::vec![10]).expect("recv s0");
receiver.recv_data(s2, alloc::vec![12]).expect("recv s2");
let drained = receiver.drain_in_order();
assert_eq!(drained.len(), 1);
assert_eq!(drained[0].1, alloc::vec![10]);
let acknack = receiver.pending_acknack(Some(s2));
sender.recv_acknack(acknack);
assert!(
sender.get_in_flight(s1).is_some(),
"s1 muss retransmittable sein"
);
let s1_payload = sender.get_in_flight(s1).expect("s1 retx").to_vec();
receiver.recv_data(s1, s1_payload).expect("recv retx s1");
let drained2 = receiver.drain_in_order();
assert_eq!(drained2.len(), 2);
assert_eq!(drained2[0].1, alloc::vec![11]);
assert_eq!(drained2[1].1, alloc::vec![12]);
}
#[test]
fn config_submessages_delivered_in_order_via_reliable_stream() {
let mut sender = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
let mut receiver = ReliableStreamState::new(StreamId(0x80), ReliableConfig::default());
let mut seqs = Vec::new();
for i in 0..5u8 {
let seq = sender.submit(alloc::vec![i]).expect("submit");
seqs.push(seq);
}
let order = [2usize, 0, 4, 1, 3];
for idx in order {
receiver
.recv_data(seqs[idx], alloc::vec![idx as u8])
.expect("recv");
}
let drained = receiver.drain_in_order();
assert_eq!(drained.len(), 5);
for (i, (_, payload)) in drained.iter().enumerate() {
assert_eq!(payload, &alloc::vec![i as u8]);
}
}
}