rechannel/
sequence_buffer.rs

1use crate::packet::AckData;
2
3#[derive(Debug)]
4pub(crate) struct SequenceBuffer<T> {
5    sequence: u16,
6    entry_sequences: Box<[Option<u16>]>,
7    entries: Box<[Option<T>]>,
8}
9
10impl<T: Clone> SequenceBuffer<T> {
11    pub fn with_capacity(size: usize) -> Self {
12        assert!(size > 0, "tried to initialize SequenceBuffer with 0 size");
13
14        Self {
15            sequence: 0,
16            entry_sequences: vec![None; size].into_boxed_slice(),
17            entries: vec![None; size].into_boxed_slice(),
18        }
19    }
20
21    pub fn size(&self) -> usize {
22        self.entries.len()
23    }
24
25    pub fn get_mut(&mut self, sequence: u16) -> Option<&mut T> {
26        if self.exists(sequence) {
27            let index = self.index(sequence);
28            return self.entries[index].as_mut();
29        }
30        None
31    }
32
33    #[allow(dead_code)]
34    pub fn get(&self, sequence: u16) -> Option<&T> {
35        if self.exists(sequence) {
36            let index = self.index(sequence);
37            return self.entries[index].as_ref();
38        }
39        None
40    }
41
42    pub fn get_or_insert_with<F: FnOnce() -> T>(&mut self, sequence: u16, f: F) -> Option<&mut T> {
43        if self.exists(sequence) {
44            let index = self.index(sequence);
45            self.entries[index].as_mut()
46        } else {
47            self.insert(sequence, f())
48        }
49    }
50
51    #[inline]
52    pub fn index(&self, sequence: u16) -> usize {
53        sequence as usize % self.entries.len()
54    }
55
56    pub fn available(&self, sequence: u16) -> bool {
57        let index = self.index(sequence);
58        self.entry_sequences[index].is_none()
59    }
60
61    /// Returns whether or not we have previously inserted an entry for the given sequence number.
62    pub fn exists(&self, sequence: u16) -> bool {
63        let index = self.index(sequence);
64        if let Some(s) = self.entry_sequences[index] {
65            return s == sequence;
66        }
67        false
68    }
69
70    pub fn insert(&mut self, sequence: u16, data: T) -> Option<&mut T> {
71        if sequence_less_than(sequence, self.sequence.wrapping_sub(self.entry_sequences.len() as u16)) {
72            return None;
73        }
74
75        if sequence_greater_than(sequence.wrapping_add(1), self.sequence) {
76            self.remove_entries(u32::from(sequence));
77            self.sequence = sequence.wrapping_add(1);
78        }
79
80        let index = self.index(sequence);
81        self.entry_sequences[index] = Some(sequence);
82        self.entries[index] = Some(data);
83        self.entries[index].as_mut()
84    }
85
86    fn remove_entries(&mut self, mut finish_sequence: u32) {
87        let start_sequence = u32::from(self.sequence);
88        if finish_sequence < start_sequence {
89            finish_sequence += 65536;
90        }
91
92        if finish_sequence - start_sequence < self.entry_sequences.len() as u32 {
93            for sequence in start_sequence..=finish_sequence {
94                self.remove(sequence as u16);
95            }
96        } else {
97            for index in 0..self.entry_sequences.len() {
98                self.entries[index] = None;
99                self.entry_sequences[index] = None;
100            }
101        }
102    }
103
104    pub fn remove(&mut self, sequence: u16) -> Option<T> {
105        if self.exists(sequence) {
106            let index = self.index(sequence);
107            self.entry_sequences[index] = None;
108            let value = self.entries[index].take();
109            return value;
110        }
111        None
112    }
113
114    #[inline]
115    pub fn sequence(&self) -> u16 {
116        self.sequence
117    }
118
119    pub fn ack_data(&self) -> AckData {
120        let ack = self.sequence().wrapping_sub(1);
121        let mut ack_bits = 0;
122        let mut mask = 1;
123
124        for i in 0..32 {
125            let sequence = ack.wrapping_sub(i);
126            if self.exists(sequence) {
127                ack_bits |= mask;
128            }
129            mask <<= 1;
130        }
131
132        AckData { ack, ack_bits }
133    }
134}
135
136// Since sequences can wrap we need to check when this when checking greater
137// Ocurring the cutover in the middle of u16
138#[inline]
139pub fn sequence_greater_than(s1: u16, s2: u16) -> bool {
140    ((s1 > s2) && (s1 - s2 <= 32768)) || ((s1 < s2) && (s2 - s1 > 32768))
141}
142
143#[inline]
144pub fn sequence_less_than(s1: u16, s2: u16) -> bool {
145    sequence_greater_than(s2, s1)
146}
147
148#[cfg(test)]
149mod tests {
150    use super::SequenceBuffer;
151
152    #[derive(Clone, Default)]
153    struct DataStub;
154
155    #[test]
156    fn max_sequence_not_exists_by_default() {
157        let buffer: SequenceBuffer<DataStub> = SequenceBuffer::with_capacity(8);
158        assert!(!buffer.exists(u16::max_value()));
159    }
160
161    #[test]
162    fn insert() {
163        let mut buffer = SequenceBuffer::with_capacity(2);
164        buffer.insert(0, DataStub).unwrap();
165        assert!(buffer.exists(0));
166    }
167
168    #[test]
169    fn remove() {
170        let mut buffer = SequenceBuffer::with_capacity(2);
171        buffer.insert(0, DataStub).unwrap();
172        let removed = buffer.remove(0);
173        assert!(removed.is_some());
174        assert!(!buffer.exists(0));
175    }
176
177    fn count_entries(buffer: &SequenceBuffer<DataStub>) -> usize {
178        buffer.entry_sequences.iter().flatten().count()
179    }
180
181    #[test]
182    fn insert_over_older_entries() {
183        let mut buffer = SequenceBuffer::with_capacity(8);
184        buffer.insert(8, DataStub).unwrap();
185        buffer.insert(0, DataStub);
186        assert!(!buffer.exists(0));
187
188        buffer.insert(16, DataStub);
189        assert!(buffer.exists(16));
190
191        assert_eq!(count_entries(&buffer), 1);
192    }
193
194    #[test]
195    fn insert_old_entries() {
196        let mut buffer = SequenceBuffer::with_capacity(8);
197        buffer.insert(11, DataStub);
198        buffer.insert(2, DataStub);
199        assert!(!buffer.exists(2));
200
201        buffer.insert(u16::max_value(), DataStub);
202        assert!(!buffer.exists(u16::max_value()));
203
204        assert_eq!(count_entries(&buffer), 1);
205    }
206
207    #[test]
208    fn ack_bits() {
209        let mut buffer = SequenceBuffer::with_capacity(64);
210        buffer.insert(0, DataStub).unwrap();
211        buffer.insert(1, DataStub).unwrap();
212        buffer.insert(3, DataStub).unwrap();
213        buffer.insert(4, DataStub).unwrap();
214        buffer.insert(5, DataStub).unwrap();
215        buffer.insert(7, DataStub).unwrap();
216        buffer.insert(30, DataStub).unwrap();
217        buffer.insert(31, DataStub).unwrap();
218        let ack_data = buffer.ack_data();
219
220        assert_eq!(ack_data.ack, 31);
221        assert_eq!(ack_data.ack_bits, 0b11011101000000000000000000000011u32);
222    }
223
224    #[test]
225    fn available() {
226        let mut buffer = SequenceBuffer::with_capacity(2);
227        buffer.insert(0, DataStub).unwrap();
228        buffer.insert(1, DataStub).unwrap();
229        assert!(!buffer.available(2));
230    }
231}