interceptor/nack/generator/
generator_stream.rs

1use util::sync::Mutex;
2
3use super::*;
4use crate::nack::UINT16SIZE_HALF;
5
6struct GeneratorStreamInternal {
7    packets: Vec<u64>,
8    size: u16,
9    end: u16,
10    started: bool,
11    last_consecutive: u16,
12}
13
14impl GeneratorStreamInternal {
15    fn new(log2_size_minus_6: u8) -> Self {
16        GeneratorStreamInternal {
17            packets: vec![0u64; 1 << log2_size_minus_6],
18            size: 1 << (log2_size_minus_6 + 6),
19            end: 0,
20            started: false,
21            last_consecutive: 0,
22        }
23    }
24
25    fn add(&mut self, seq: u16) {
26        if !self.started {
27            self.set_received(seq);
28            self.end = seq;
29            self.started = true;
30            self.last_consecutive = seq;
31            return;
32        }
33
34        let last_consecutive_plus1 = self.last_consecutive.wrapping_add(1);
35        let diff = seq.wrapping_sub(self.end);
36        if diff == 0 {
37            return;
38        } else if diff < UINT16SIZE_HALF {
39            // this means a positive diff, in other words seq > end (with counting for rollovers)
40            let mut i = self.end.wrapping_add(1);
41            while i != seq {
42                // clear packets between end and seq (these may contain packets from a "size" ago)
43                self.del_received(i);
44                i = i.wrapping_add(1);
45            }
46            self.end = seq;
47
48            let seq_sub_last_consecutive = seq.wrapping_sub(self.last_consecutive);
49            if last_consecutive_plus1 == seq {
50                self.last_consecutive = seq;
51            } else if seq_sub_last_consecutive > self.size {
52                let diff = seq.wrapping_sub(self.size);
53                self.last_consecutive = diff;
54                self.fix_last_consecutive(); // there might be valid packets at the beginning of the buffer now
55            }
56        } else if last_consecutive_plus1 == seq {
57            // negative diff, seq < end (with counting for rollovers)
58            self.last_consecutive = seq;
59            self.fix_last_consecutive(); // there might be other valid packets after seq
60        }
61
62        self.set_received(seq);
63    }
64
65    fn get(&self, seq: u16) -> bool {
66        let diff = self.end.wrapping_sub(seq);
67        if diff >= UINT16SIZE_HALF {
68            return false;
69        }
70
71        if diff >= self.size {
72            return false;
73        }
74
75        self.get_received(seq)
76    }
77
78    fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16> {
79        let until = self.end.wrapping_sub(skip_last_n);
80        let diff = until.wrapping_sub(self.last_consecutive);
81        if diff >= UINT16SIZE_HALF {
82            // until < s.last_consecutive (counting for rollover)
83            return vec![];
84        }
85
86        let mut missing_packet_seq_nums = vec![];
87        let mut i = self.last_consecutive.wrapping_add(1);
88        let util_plus1 = until.wrapping_add(1);
89        while i != util_plus1 {
90            if !self.get_received(i) {
91                missing_packet_seq_nums.push(i);
92            }
93            i = i.wrapping_add(1);
94        }
95
96        missing_packet_seq_nums
97    }
98
99    fn set_received(&mut self, seq: u16) {
100        let pos = (seq % self.size) as usize;
101        self.packets[pos / 64] |= 1u64 << (pos % 64);
102    }
103
104    fn del_received(&mut self, seq: u16) {
105        let pos = (seq % self.size) as usize;
106        self.packets[pos / 64] &= u64::MAX ^ (1u64 << (pos % 64));
107    }
108
109    fn get_received(&self, seq: u16) -> bool {
110        let pos = (seq % self.size) as usize;
111        (self.packets[pos / 64] & (1u64 << (pos % 64))) != 0
112    }
113
114    fn fix_last_consecutive(&mut self) {
115        let mut i = self.last_consecutive.wrapping_add(1);
116        while i != self.end.wrapping_add(1) && self.get_received(i) {
117            // find all consecutive packets
118            i = i.wrapping_add(1);
119        }
120        self.last_consecutive = i.wrapping_sub(1);
121    }
122}
123
124pub(super) struct GeneratorStream {
125    parent_rtp_reader: Arc<dyn RTPReader + Send + Sync>,
126
127    internal: Mutex<GeneratorStreamInternal>,
128}
129
130impl GeneratorStream {
131    pub(super) fn new(log2_size_minus_6: u8, reader: Arc<dyn RTPReader + Send + Sync>) -> Self {
132        GeneratorStream {
133            parent_rtp_reader: reader,
134            internal: Mutex::new(GeneratorStreamInternal::new(log2_size_minus_6)),
135        }
136    }
137
138    pub(super) fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16> {
139        let internal = self.internal.lock();
140        internal.missing_seq_numbers(skip_last_n)
141    }
142
143    pub(super) fn add(&self, seq: u16) {
144        let mut internal = self.internal.lock();
145        internal.add(seq);
146    }
147}
148
149/// RTPReader is used by Interceptor.bind_remote_stream.
150#[async_trait]
151impl RTPReader for GeneratorStream {
152    /// read a rtp packet
153    async fn read(
154        &self,
155        buf: &mut [u8],
156        a: &Attributes,
157    ) -> Result<(rtp::packet::Packet, Attributes)> {
158        let (pkt, attr) = self.parent_rtp_reader.read(buf, a).await?;
159
160        self.add(pkt.header.sequence_number);
161
162        Ok((pkt, attr))
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use super::*;
169
170    #[test]
171    fn test_generator_stream() -> Result<()> {
172        let tests: Vec<u16> = vec![
173            0, 1, 127, 128, 129, 511, 512, 513, 32767, 32768, 32769, 65407, 65408, 65409, 65534,
174            65535,
175        ];
176        for start in tests {
177            let mut rl = GeneratorStreamInternal::new(1);
178
179            let all = |min: u16, max: u16| -> Vec<u16> {
180                let mut result = vec![];
181                let mut i = min;
182                let max_plus_1 = max.wrapping_add(1);
183                while i != max_plus_1 {
184                    result.push(i);
185                    i = i.wrapping_add(1);
186                }
187                result
188            };
189
190            let join = |parts: &[&[u16]]| -> Vec<u16> {
191                let mut result = vec![];
192                for p in parts {
193                    result.extend_from_slice(p);
194                }
195                result
196            };
197
198            let add = |rl: &mut GeneratorStreamInternal, nums: &[u16]| {
199                for n in nums {
200                    let seq = start.wrapping_add(*n);
201                    rl.add(seq);
202                }
203            };
204
205            let assert_get = |rl: &GeneratorStreamInternal, nums: &[u16]| {
206                for n in nums {
207                    let seq = start.wrapping_add(*n);
208                    assert!(rl.get(seq), "not found: {seq}");
209                }
210            };
211
212            let assert_not_get = |rl: &GeneratorStreamInternal, nums: &[u16]| {
213                for n in nums {
214                    let seq = start.wrapping_add(*n);
215                    assert!(
216                        !rl.get(seq),
217                        "packet found: start {}, n {}, seq {}",
218                        start,
219                        *n,
220                        seq
221                    );
222                }
223            };
224
225            let assert_missing = |rl: &GeneratorStreamInternal, skip_last_n: u16, nums: &[u16]| {
226                let missing = rl.missing_seq_numbers(skip_last_n);
227                let mut want = vec![];
228                for n in nums {
229                    let seq = start.wrapping_add(*n);
230                    want.push(seq);
231                }
232                assert_eq!(want, missing, "missing want/got, ");
233            };
234
235            let assert_last_consecutive = |rl: &GeneratorStreamInternal, last_consecutive: u16| {
236                let want = last_consecutive.wrapping_add(start);
237                assert_eq!(rl.last_consecutive, want, "invalid last_consecutive want");
238            };
239
240            add(&mut rl, &[0]);
241            assert_get(&rl, &[0]);
242            assert_missing(&rl, 0, &[]);
243            assert_last_consecutive(&rl, 0); // first element added
244
245            add(&mut rl, &all(1, 127));
246            assert_get(&rl, &all(1, 127));
247            assert_missing(&rl, 0, &[]);
248            assert_last_consecutive(&rl, 127);
249
250            add(&mut rl, &[128]);
251            assert_get(&rl, &[128]);
252            assert_not_get(&rl, &[0]);
253            assert_missing(&rl, 0, &[]);
254            assert_last_consecutive(&rl, 128);
255
256            add(&mut rl, &[130]);
257            assert_get(&rl, &[130]);
258            assert_not_get(&rl, &[1, 2, 129]);
259            assert_missing(&rl, 0, &[129]);
260            assert_last_consecutive(&rl, 128);
261
262            add(&mut rl, &[333]);
263            assert_get(&rl, &[333]);
264            assert_not_get(&rl, &all(0, 332));
265            assert_missing(&rl, 0, &all(206, 332)); // all 127 elements missing before 333
266            assert_missing(&rl, 10, &all(206, 323)); // skip last 10 packets (324-333) from check
267            assert_last_consecutive(&rl, 205); // lastConsecutive is still out of the buffer
268
269            add(&mut rl, &[329]);
270            assert_get(&rl, &[329]);
271            assert_missing(&rl, 0, &join(&[&all(206, 328), &all(330, 332)]));
272            assert_missing(&rl, 5, &join(&[&all(206, 328)])); // skip last 5 packets (329-333) from check
273            assert_last_consecutive(&rl, 205);
274
275            add(&mut rl, &all(207, 320));
276            assert_get(&rl, &all(207, 320));
277            assert_missing(&rl, 0, &join(&[&[206], &all(321, 328), &all(330, 332)]));
278            assert_last_consecutive(&rl, 205);
279
280            add(&mut rl, &[334]);
281            assert_get(&rl, &[334]);
282            assert_not_get(&rl, &[206]);
283            assert_missing(&rl, 0, &join(&[&all(321, 328), &all(330, 332)]));
284            assert_last_consecutive(&rl, 320); // head of buffer is full of consecutive packages
285
286            add(&mut rl, &all(322, 328));
287            assert_get(&rl, &all(322, 328));
288            assert_missing(&rl, 0, &join(&[&[321], &all(330, 332)]));
289            assert_last_consecutive(&rl, 320);
290
291            add(&mut rl, &[321]);
292            assert_get(&rl, &[321]);
293            assert_missing(&rl, 0, &all(330, 332));
294            assert_last_consecutive(&rl, 329); // after adding a single missing packet, lastConsecutive should jump forward
295        }
296
297        Ok(())
298    }
299
300    #[test]
301    fn test_generator_stream_rollover() {
302        let mut rl = GeneratorStreamInternal::new(1);
303        // Make sure it doesn't panic.
304        rl.add(65533);
305        rl.add(65535);
306        rl.add(65534);
307
308        let mut rl = GeneratorStreamInternal::new(1);
309        // Make sure it doesn't panic.
310        rl.add(65534);
311        rl.add(0);
312        rl.add(65535);
313    }
314}