bitfold_protocol/
sequence_buffer.rs

1use std::clone::Clone;
2
3use super::packet::SequenceNumber;
4
5/// Circular buffer for tracking sequenced items with wrapping sequence numbers.
6/// Used by AcknowledgmentHandler to track received packets for ACK bitfield generation.
7#[derive(Debug)]
8pub struct SequenceBuffer<T: Clone + Default> {
9    sequence_num: SequenceNumber,
10    entry_sequences: Box<[Option<SequenceNumber>]>,
11    entries: Box<[T]>,
12}
13
14impl<T: Clone + Default> SequenceBuffer<T> {
15    /// Creates a new sequence buffer with the specified capacity.
16    pub fn with_capacity(size: u16) -> Self {
17        Self {
18            sequence_num: 0,
19            entry_sequences: vec![None; size as usize].into_boxed_slice(),
20            entries: vec![T::default(); size as usize].into_boxed_slice(),
21        }
22    }
23
24    /// Returns the current sequence number.
25    pub fn sequence_num(&self) -> SequenceNumber {
26        self.sequence_num
27    }
28
29    /// Gets a mutable reference to an entry by sequence number.
30    pub fn get_mut(&mut self, sequence_num: SequenceNumber) -> Option<&mut T> {
31        if self.exists(sequence_num) {
32            let index = self.index(sequence_num);
33            return Some(&mut self.entries[index]);
34        }
35        None
36    }
37
38    /// Inserts an entry at the specified sequence number.
39    pub fn insert(&mut self, sequence_num: SequenceNumber, entry: T) -> Option<&mut T> {
40        if sequence_less_than(
41            sequence_num,
42            self.sequence_num.wrapping_sub(self.entry_sequences.len() as u16),
43        ) {
44            return None;
45        }
46        self.advance_sequence(sequence_num);
47        let index = self.index(sequence_num);
48        self.entry_sequences[index] = Some(sequence_num);
49        self.entries[index] = entry;
50        Some(&mut self.entries[index])
51    }
52
53    /// Checks if an entry exists at the given sequence number.
54    pub fn exists(&self, sequence_num: SequenceNumber) -> bool {
55        let index = self.index(sequence_num);
56        if let Some(s) = self.entry_sequences[index] {
57            return s == sequence_num;
58        }
59        false
60    }
61
62    /// Removes and returns the entry at the specified sequence number.
63    pub fn remove(&mut self, sequence_num: SequenceNumber) -> Option<T> {
64        if self.exists(sequence_num) {
65            let index = self.index(sequence_num);
66            let value = std::mem::take(&mut self.entries[index]);
67            self.entry_sequences[index] = None;
68            return Some(value);
69        }
70        None
71    }
72
73    fn advance_sequence(&mut self, sequence_num: SequenceNumber) {
74        if sequence_greater_than(sequence_num.wrapping_add(1), self.sequence_num) {
75            self.remove_entries(u32::from(sequence_num));
76            self.sequence_num = sequence_num.wrapping_add(1);
77        }
78    }
79
80    fn remove_entries(&mut self, mut finish_sequence: u32) {
81        let start_sequence = u32::from(self.sequence_num);
82        if finish_sequence < start_sequence {
83            finish_sequence += 65536;
84        }
85        if finish_sequence - start_sequence < self.entry_sequences.len() as u32 {
86            for sequence in start_sequence..=finish_sequence {
87                self.remove(sequence as u16);
88            }
89        } else {
90            for index in 0..self.entry_sequences.len() {
91                self.entries[index] = T::default();
92                self.entry_sequences[index] = None;
93            }
94        }
95    }
96
97    fn index(&self, sequence: SequenceNumber) -> usize {
98        sequence as usize % self.entry_sequences.len()
99    }
100}
101
102/// Compares sequence numbers with wrapping arithmetic.
103pub fn sequence_greater_than(s1: u16, s2: u16) -> bool {
104    ((s1 > s2) && (s1 - s2 <= 32768)) || ((s1 < s2) && (s2 - s1 > 32768))
105}
106
107/// Compares sequence numbers with wrapping arithmetic.
108pub fn sequence_less_than(s1: u16, s2: u16) -> bool {
109    sequence_greater_than(s2, s1)
110}