rtc_interceptor/nack/
generator.rs

1//! NACK Generator Interceptor - Generates NACK requests for missing packets.
2
3use super::receive_log::ReceiveLog;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13/// Builder for the NackGeneratorInterceptor.
14///
15/// # Example
16///
17/// ```ignore
18/// use rtc_interceptor::{Registry, NackGeneratorBuilder};
19/// use std::time::Duration;
20///
21/// let chain = Registry::new()
22///     .with(NackGeneratorBuilder::new()
23///         .with_size(512)
24///         .with_interval(Duration::from_millis(100))
25///         .with_skip_last_n(2)
26///         .build())
27///     .build();
28/// ```
29pub struct NackGeneratorBuilder<P> {
30    /// Size of the receive log (must be power of 2: 64, 128, ..., 32768).
31    size: u16,
32    /// Interval between NACK generation cycles.
33    interval: Duration,
34    /// Number of most recent packets to skip when generating NACKs.
35    skip_last_n: u16,
36    /// Maximum number of NACKs to send per missing packet (0 = unlimited).
37    max_nacks_per_packet: u16,
38    _phantom: PhantomData<P>,
39}
40
41impl<P> Default for NackGeneratorBuilder<P> {
42    fn default() -> Self {
43        Self {
44            size: 512,
45            interval: Duration::from_millis(100),
46            skip_last_n: 0,
47            max_nacks_per_packet: 0,
48            _phantom: PhantomData,
49        }
50    }
51}
52
53impl<P> NackGeneratorBuilder<P> {
54    /// Create a new builder with default settings.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Set the size of the receive log.
60    ///
61    /// Size must be a power of 2 between 64 and 32768 (inclusive).
62    pub fn with_size(mut self, size: u16) -> Self {
63        self.size = size;
64        self
65    }
66
67    /// Set the interval between NACK generation cycles.
68    pub fn with_interval(mut self, interval: Duration) -> Self {
69        self.interval = interval;
70        self
71    }
72
73    /// Set the number of most recent packets to skip when generating NACKs.
74    ///
75    /// This helps avoid generating NACKs for packets that are simply delayed
76    /// and haven't arrived yet.
77    pub fn with_skip_last_n(mut self, skip_last_n: u16) -> Self {
78        self.skip_last_n = skip_last_n;
79        self
80    }
81
82    /// Set the maximum number of NACKs to send per missing packet.
83    ///
84    /// Set to 0 (default) for unlimited NACKs.
85    pub fn with_max_nacks_per_packet(mut self, max: u16) -> Self {
86        self.max_nacks_per_packet = max;
87        self
88    }
89
90    /// Build the interceptor factory function.
91    pub fn build(self) -> impl FnOnce(P) -> NackGeneratorInterceptor<P> {
92        move |inner| {
93            NackGeneratorInterceptor::new(
94                inner,
95                self.size,
96                self.interval,
97                self.skip_last_n,
98                self.max_nacks_per_packet,
99            )
100        }
101    }
102}
103
104/// Interceptor that generates NACK requests for missing RTP packets.
105///
106/// This interceptor monitors incoming RTP packets on remote streams,
107/// tracks which sequence numbers have been received, and periodically
108/// generates RTCP TransportLayerNack packets for missing sequences.
109pub struct NackGeneratorInterceptor<P> {
110    inner: P,
111
112    /// Configuration
113    size: u16,
114    interval: Duration,
115    skip_last_n: u16,
116    max_nacks_per_packet: u16,
117
118    /// Next timeout for NACK generation
119    eto: Instant,
120
121    /// Sender SSRC for NACK packets
122    sender_ssrc: u32,
123
124    /// Receive logs per remote stream SSRC
125    receive_logs: HashMap<u32, ReceiveLog>,
126
127    /// NACK count per (SSRC, sequence number) for max_nacks_per_packet limiting
128    nack_counts: HashMap<u32, HashMap<u16, u16>>,
129
130    /// Queue for outgoing NACK packets
131    write_queue: VecDeque<TaggedPacket>,
132}
133
134impl<P> NackGeneratorInterceptor<P> {
135    fn new(
136        inner: P,
137        size: u16,
138        interval: Duration,
139        skip_last_n: u16,
140        max_nacks_per_packet: u16,
141    ) -> Self {
142        Self {
143            inner,
144            size,
145            interval,
146            skip_last_n,
147            max_nacks_per_packet,
148            eto: Instant::now(),
149            sender_ssrc: rand::random(),
150            receive_logs: HashMap::new(),
151            nack_counts: HashMap::new(),
152            write_queue: VecDeque::new(),
153        }
154    }
155
156    /// Generate NACKs for all streams with missing packets.
157    fn generate_nacks(&mut self, now: Instant) {
158        for (&ssrc, receive_log) in &self.receive_logs {
159            let missing = receive_log.missing_seq_numbers(self.skip_last_n);
160            if missing.is_empty() {
161                // Clear nack counts for this SSRC if no missing packets
162                self.nack_counts.remove(&ssrc);
163                continue;
164            }
165
166            // Initialize nack counts for this SSRC if needed
167            let nack_count = self.nack_counts.entry(ssrc).or_default();
168
169            // Filter by max_nacks_per_packet if configured
170            let filtered: Vec<u16> = if self.max_nacks_per_packet > 0 {
171                missing
172                    .iter()
173                    .filter(|&&seq| {
174                        let count = nack_count.entry(seq).or_insert(0);
175                        if *count < self.max_nacks_per_packet {
176                            *count += 1;
177                            true
178                        } else {
179                            false
180                        }
181                    })
182                    .copied()
183                    .collect()
184            } else {
185                missing.clone()
186            };
187
188            if filtered.is_empty() {
189                continue;
190            }
191
192            // Clean up nack counts for packets no longer missing
193            nack_count.retain(|seq, _| missing.contains(seq));
194
195            // Create NACK packet
196            let nack = rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
197                sender_ssrc: self.sender_ssrc,
198                media_ssrc: ssrc,
199                nacks: rtcp::transport_feedbacks::transport_layer_nack::nack_pairs_from_sequence_numbers(
200                    &filtered,
201                ),
202            };
203
204            self.write_queue.push_back(TaggedPacket {
205                now,
206                transport: TransportContext::default(),
207                message: Packet::Rtcp(vec![Box::new(nack)]),
208            });
209        }
210    }
211}
212
213impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
214    for NackGeneratorInterceptor<P>
215{
216    type Rout = TaggedPacket;
217    type Wout = TaggedPacket;
218    type Eout = ();
219    type Error = Error;
220    type Time = Instant;
221
222    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
223        // Track incoming RTP packets
224        if let Packet::Rtp(ref rtp_packet) = msg.message
225            && let Some(receive_log) = self.receive_logs.get_mut(&rtp_packet.header.ssrc)
226        {
227            receive_log.add(rtp_packet.header.sequence_number);
228        }
229
230        self.inner.handle_read(msg)
231    }
232
233    fn poll_read(&mut self) -> Option<Self::Rout> {
234        self.inner.poll_read()
235    }
236
237    fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
238        self.inner.handle_write(msg)
239    }
240
241    fn poll_write(&mut self) -> Option<Self::Wout> {
242        // First drain generated NACK packets
243        if let Some(pkt) = self.write_queue.pop_front() {
244            return Some(pkt);
245        }
246        self.inner.poll_write()
247    }
248
249    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
250        if self.eto <= now {
251            self.eto = now + self.interval;
252            self.generate_nacks(now);
253        }
254
255        self.inner.handle_timeout(now)
256    }
257
258    fn poll_timeout(&mut self) -> Option<Self::Time> {
259        if let Some(inner_eto) = self.inner.poll_timeout()
260            && inner_eto < self.eto
261        {
262            return Some(inner_eto);
263        }
264        Some(self.eto)
265    }
266}
267
268impl<P: Interceptor> Interceptor for NackGeneratorInterceptor<P> {
269    fn bind_local_stream(&mut self, info: &StreamInfo) {
270        self.inner.bind_local_stream(info);
271    }
272
273    fn unbind_local_stream(&mut self, info: &StreamInfo) {
274        self.inner.unbind_local_stream(info);
275    }
276
277    fn bind_remote_stream(&mut self, info: &StreamInfo) {
278        if stream_supports_nack(info)
279            && let Some(receive_log) = ReceiveLog::new(self.size)
280        {
281            self.receive_logs.insert(info.ssrc, receive_log);
282        }
283        self.inner.bind_remote_stream(info);
284    }
285
286    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
287        self.receive_logs.remove(&info.ssrc);
288        self.nack_counts.remove(&info.ssrc);
289        self.inner.unbind_remote_stream(info);
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::Registry;
297    use crate::stream_info::RTCPFeedback;
298    use sansio::Protocol;
299
300    fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
301        TaggedPacket {
302            now: Instant::now(),
303            transport: Default::default(),
304            message: Packet::Rtp(rtp::Packet {
305                header: rtp::header::Header {
306                    ssrc,
307                    sequence_number: seq,
308                    ..Default::default()
309                },
310                ..Default::default()
311            }),
312        }
313    }
314
315    #[test]
316    fn test_nack_generator_builder_defaults() {
317        let chain = Registry::new()
318            .with(NackGeneratorBuilder::default().build())
319            .build();
320
321        assert_eq!(chain.size, 512);
322        assert_eq!(chain.interval, Duration::from_millis(100));
323        assert_eq!(chain.skip_last_n, 0);
324        assert_eq!(chain.max_nacks_per_packet, 0);
325    }
326
327    #[test]
328    fn test_nack_generator_builder_custom() {
329        let chain = Registry::new()
330            .with(
331                NackGeneratorBuilder::new()
332                    .with_size(1024)
333                    .with_interval(Duration::from_millis(50))
334                    .with_skip_last_n(3)
335                    .with_max_nacks_per_packet(5)
336                    .build(),
337            )
338            .build();
339
340        assert_eq!(chain.size, 1024);
341        assert_eq!(chain.interval, Duration::from_millis(50));
342        assert_eq!(chain.skip_last_n, 3);
343        assert_eq!(chain.max_nacks_per_packet, 5);
344    }
345
346    #[test]
347    fn test_nack_generator_no_nack_without_binding() {
348        let mut chain = Registry::new()
349            .with(
350                NackGeneratorBuilder::new()
351                    .with_interval(Duration::from_millis(100))
352                    .build(),
353            )
354            .build();
355
356        let now = Instant::now();
357
358        // Receive packets without binding stream (no receive log)
359        chain.handle_read(make_rtp_packet(12345, 0)).unwrap();
360        chain.handle_read(make_rtp_packet(12345, 2)).unwrap(); // Gap at 1
361
362        // Trigger timeout
363        let later = now + Duration::from_millis(200);
364        chain.handle_timeout(later).unwrap();
365
366        // No NACK should be generated (stream not bound)
367        assert!(chain.poll_write().is_none());
368    }
369
370    #[test]
371    fn test_nack_generator_generates_nack() {
372        let mut chain = Registry::new()
373            .with(
374                NackGeneratorBuilder::new()
375                    .with_size(64)
376                    .with_interval(Duration::from_millis(100))
377                    .build(),
378            )
379            .build();
380
381        // Bind remote stream with NACK support
382        let info = StreamInfo {
383            ssrc: 12345,
384            clock_rate: 90000,
385            rtcp_feedback: vec![RTCPFeedback {
386                typ: "nack".to_string(),
387                parameter: "".to_string(),
388            }],
389            ..Default::default()
390        };
391        chain.bind_remote_stream(&info);
392
393        let base_time = Instant::now();
394
395        // Receive packets with gap
396        let mut pkt = make_rtp_packet(12345, 10);
397        pkt.now = base_time;
398        chain.handle_read(pkt).unwrap();
399
400        let mut pkt = make_rtp_packet(12345, 12); // Gap at 11
401        pkt.now = base_time;
402        chain.handle_read(pkt).unwrap();
403
404        chain.poll_read();
405
406        // Trigger timeout
407        let later = base_time + Duration::from_millis(200);
408        chain.handle_timeout(later).unwrap();
409
410        // Should generate NACK for seq 11
411        let nack_pkt = chain.poll_write();
412        assert!(nack_pkt.is_some());
413
414        if let Some(tagged) = nack_pkt {
415            if let Packet::Rtcp(rtcp_packets) = tagged.message {
416                assert_eq!(rtcp_packets.len(), 1);
417                let nack = rtcp_packets[0]
418                    .as_any()
419                    .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
420                    .expect("Expected TransportLayerNack");
421                assert_eq!(nack.media_ssrc, 12345);
422                assert!(!nack.nacks.is_empty());
423            } else {
424                panic!("Expected RTCP packet");
425            }
426        }
427    }
428
429    #[test]
430    fn test_nack_generator_skip_last_n() {
431        let mut chain = Registry::new()
432            .with(
433                NackGeneratorBuilder::new()
434                    .with_size(64)
435                    .with_interval(Duration::from_millis(100))
436                    .with_skip_last_n(2)
437                    .build(),
438            )
439            .build();
440
441        let info = StreamInfo {
442            ssrc: 12345,
443            clock_rate: 90000,
444            rtcp_feedback: vec![RTCPFeedback {
445                typ: "nack".to_string(),
446                parameter: "".to_string(),
447            }],
448            ..Default::default()
449        };
450        chain.bind_remote_stream(&info);
451
452        let base_time = Instant::now();
453
454        // Receive: 10, 11, 12, 14, 16, 18 (gaps at 13, 15, 17)
455        for seq in [10u16, 11, 12, 14, 16, 18] {
456            let mut pkt = make_rtp_packet(12345, seq);
457            pkt.now = base_time;
458            chain.handle_read(pkt).unwrap();
459        }
460
461        // Trigger timeout
462        let later = base_time + Duration::from_millis(200);
463        chain.handle_timeout(later).unwrap();
464
465        // With skip_last_n=2, should only NACK for 13, 15 (not 17)
466        let nack_pkt = chain.poll_write();
467        assert!(nack_pkt.is_some());
468
469        if let Some(tagged) = nack_pkt
470            && let Packet::Rtcp(rtcp_packets) = tagged.message
471        {
472            let nack = rtcp_packets[0]
473                .as_any()
474                .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
475                .expect("Expected TransportLayerNack");
476
477            // Get all nacked sequence numbers
478            let mut nacked_seqs = Vec::new();
479            for nack_pair in &nack.nacks {
480                nacked_seqs.push(nack_pair.packet_id);
481                for i in 0..16 {
482                    if nack_pair.lost_packets & (1 << i) != 0 {
483                        nacked_seqs.push(nack_pair.packet_id.wrapping_add(i + 1));
484                    }
485                }
486            }
487
488            // Should contain 13, 15 but not 17
489            assert!(nacked_seqs.contains(&13));
490            assert!(nacked_seqs.contains(&15));
491            assert!(!nacked_seqs.contains(&17));
492        }
493    }
494
495    #[test]
496    fn test_nack_generator_unbind_removes_stream() {
497        let mut chain = Registry::new()
498            .with(
499                NackGeneratorBuilder::new()
500                    .with_size(64)
501                    .with_interval(Duration::from_millis(100))
502                    .build(),
503            )
504            .build();
505
506        let info = StreamInfo {
507            ssrc: 12345,
508            clock_rate: 90000,
509            rtcp_feedback: vec![RTCPFeedback {
510                typ: "nack".to_string(),
511                parameter: "".to_string(),
512            }],
513            ..Default::default()
514        };
515
516        chain.bind_remote_stream(&info);
517        assert!(chain.receive_logs.contains_key(&12345));
518
519        chain.unbind_remote_stream(&info);
520        assert!(!chain.receive_logs.contains_key(&12345));
521        assert!(!chain.nack_counts.contains_key(&12345));
522    }
523
524    #[test]
525    fn test_nack_generator_no_nack_support() {
526        let mut chain = Registry::new()
527            .with(
528                NackGeneratorBuilder::new()
529                    .with_size(64)
530                    .with_interval(Duration::from_millis(100))
531                    .build(),
532            )
533            .build();
534
535        // Bind stream without NACK support
536        let info = StreamInfo {
537            ssrc: 12345,
538            clock_rate: 90000,
539            rtcp_feedback: vec![], // No NACK support
540            ..Default::default()
541        };
542        chain.bind_remote_stream(&info);
543
544        // Should not create receive log
545        assert!(!chain.receive_logs.contains_key(&12345));
546    }
547}