use std::collections::{BTreeSet, VecDeque};
#[derive(Debug, Default)]
pub struct LossList {
lost_packets: BTreeSet<u32>,
max_entries: usize,
}
impl LossList {
#[must_use]
pub fn new(max_entries: usize) -> Self {
Self {
lost_packets: BTreeSet::new(),
max_entries,
}
}
pub fn add(&mut self, seq: u32) {
if self.lost_packets.len() < self.max_entries {
self.lost_packets.insert(seq);
}
}
pub fn add_range(&mut self, start: u32, end: u32) {
let mut seq = start;
while seq != end && self.lost_packets.len() < self.max_entries {
self.lost_packets.insert(seq);
seq = seq.wrapping_add(1);
}
}
pub fn remove(&mut self, seq: u32) -> bool {
self.lost_packets.remove(&seq)
}
#[must_use]
pub fn contains(&self, seq: u32) -> bool {
self.lost_packets.contains(&seq)
}
#[must_use]
pub fn len(&self) -> usize {
self.lost_packets.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.lost_packets.is_empty()
}
#[must_use]
pub fn lost_sequences(&self) -> Vec<u32> {
self.lost_packets.iter().copied().collect()
}
#[must_use]
pub fn compressed_ranges(&self) -> Vec<LossRange> {
let mut ranges = Vec::new();
let mut current_start: Option<u32> = None;
let mut current_end: Option<u32> = None;
for &seq in &self.lost_packets {
match (current_start, current_end) {
(None, None) => {
current_start = Some(seq);
current_end = Some(seq);
}
(Some(start), Some(end)) => {
if seq == end.wrapping_add(1) {
current_end = Some(seq);
} else {
ranges.push(LossRange { start, end });
current_start = Some(seq);
current_end = Some(seq);
}
}
_ => unreachable!(),
}
}
if let (Some(start), Some(end)) = (current_start, current_end) {
ranges.push(LossRange { start, end });
}
ranges
}
pub fn clear(&mut self) {
self.lost_packets.clear();
}
#[must_use]
pub fn oldest(&self) -> Option<u32> {
self.lost_packets.iter().next().copied()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LossRange {
pub start: u32,
pub end: u32,
}
impl LossRange {
#[must_use]
pub const fn single(seq: u32) -> Self {
Self {
start: seq,
end: seq,
}
}
#[must_use]
pub fn count(&self) -> u32 {
if self.end >= self.start {
self.end - self.start + 1
} else {
(u32::MAX - self.start) + self.end + 2
}
}
#[must_use]
pub const fn is_single(&self) -> bool {
self.start == self.end
}
}
#[derive(Debug)]
pub struct ReceiveBuffer {
expected_seq: u32,
buffer: VecDeque<BufferedPacket>,
max_size: usize,
}
#[derive(Debug, Clone)]
pub struct BufferedPacket {
pub seq: u32,
pub data: bytes::Bytes,
}
impl ReceiveBuffer {
#[must_use]
pub const fn new(initial_seq: u32, max_size: usize) -> Self {
Self {
expected_seq: initial_seq,
buffer: VecDeque::new(),
max_size,
}
}
#[must_use]
pub const fn expected_seq(&self) -> u32 {
self.expected_seq
}
pub fn process(&mut self, seq: u32, data: bytes::Bytes) -> Option<u32> {
if seq == self.expected_seq {
self.expected_seq = self.expected_seq.wrapping_add(1);
Some(seq)
} else if seq_after(seq, self.expected_seq) {
if self.buffer.len() < self.max_size {
self.buffer.push_back(BufferedPacket { seq, data });
self.buffer.make_contiguous().sort_by_key(|p| p.seq);
}
None
} else {
None
}
}
pub fn try_deliver(&mut self) -> Vec<u32> {
let mut delivered = Vec::new();
while let Some(front) = self.buffer.front() {
if front.seq == self.expected_seq {
delivered.push(front.seq);
self.expected_seq = self.expected_seq.wrapping_add(1);
self.buffer.pop_front();
} else {
break;
}
}
delivered
}
#[must_use]
pub fn detect_gaps(&self) -> Vec<u32> {
let mut gaps = Vec::new();
if self.buffer.is_empty() {
return gaps;
}
let first_buffered = self.buffer.front().map(|p| p.seq).unwrap_or(0);
let mut seq = self.expected_seq;
while seq != first_buffered && gaps.len() < 1000 {
gaps.push(seq);
seq = seq.wrapping_add(1);
}
for window in self.buffer.iter().collect::<Vec<_>>().windows(2) {
let curr = window[0].seq;
let next = window[1].seq;
let mut s = curr.wrapping_add(1);
while s != next && gaps.len() < 1000 {
gaps.push(s);
s = s.wrapping_add(1);
}
}
gaps
}
#[must_use]
pub fn buffered_count(&self) -> usize {
self.buffer.len()
}
pub fn clear(&mut self) {
self.buffer.clear();
}
}
const fn seq_after(a: u32, b: u32) -> bool {
let diff = a.wrapping_sub(b);
diff > 0 && diff < 0x8000_0000
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_loss_list_add_remove() {
let mut list = LossList::new(100);
list.add(10);
list.add(20);
list.add(30);
assert_eq!(list.len(), 3);
assert!(list.contains(10));
assert!(list.contains(20));
assert!(!list.contains(15));
assert!(list.remove(20));
assert_eq!(list.len(), 2);
assert!(!list.contains(20));
}
#[test]
fn test_loss_list_range() {
let mut list = LossList::new(100);
list.add_range(10, 15);
assert_eq!(list.len(), 5);
assert!(list.contains(10));
assert!(list.contains(14));
assert!(!list.contains(15));
}
#[test]
fn test_loss_list_compressed_ranges() {
let mut list = LossList::new(100);
list.add(10);
list.add(11);
list.add(12);
list.add(20);
list.add(21);
let ranges = list.compressed_ranges();
assert_eq!(ranges.len(), 2);
assert_eq!(ranges[0], LossRange { start: 10, end: 12 });
assert_eq!(ranges[1], LossRange { start: 20, end: 21 });
}
#[test]
fn test_loss_range_count() {
let range = LossRange { start: 10, end: 14 };
assert_eq!(range.count(), 5);
let single = LossRange::single(42);
assert_eq!(single.count(), 1);
assert!(single.is_single());
}
#[test]
fn test_receive_buffer_in_order() {
let mut buf = ReceiveBuffer::new(100, 50);
assert_eq!(buf.process(100, Bytes::from(vec![1])), Some(100));
assert_eq!(buf.expected_seq(), 101);
assert_eq!(buf.process(101, Bytes::from(vec![2])), Some(101));
assert_eq!(buf.expected_seq(), 102);
}
#[test]
fn test_receive_buffer_out_of_order() {
let mut buf = ReceiveBuffer::new(100, 50);
assert_eq!(buf.process(102, Bytes::from(vec![3])), None);
assert_eq!(buf.buffered_count(), 1);
assert_eq!(buf.process(100, Bytes::from(vec![1])), Some(100));
assert_eq!(buf.process(101, Bytes::from(vec![2])), Some(101));
let delivered = buf.try_deliver();
assert_eq!(delivered, vec![102]);
assert_eq!(buf.expected_seq(), 103);
}
#[test]
fn test_receive_buffer_detect_gaps() {
let mut buf = ReceiveBuffer::new(100, 50);
buf.process(105, Bytes::from(vec![1]));
buf.process(106, Bytes::from(vec![2]));
let gaps = buf.detect_gaps();
assert_eq!(gaps, vec![100, 101, 102, 103, 104]);
}
#[test]
fn test_seq_after() {
assert!(seq_after(10, 5));
assert!(!seq_after(5, 10));
assert!(!seq_after(5, 5));
assert!(seq_after(0, 0xFFFF_FFFF));
assert!(seq_after(10, 0xFFFF_FFF0));
}
#[test]
fn test_loss_list_oldest() {
let mut list = LossList::new(100);
list.add(30);
list.add(10);
list.add(20);
assert_eq!(list.oldest(), Some(10));
}
}