rtc_interceptor/nack/
generator.rs

1//! NACK Generator Interceptor - Generates NACK requests for missing packets.
2
3use super::receive_log::ReceiveLog;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13/// Builder for the NackGeneratorInterceptor.
14///
15/// # Example
16///
17/// ```ignore
18/// use rtc_interceptor::{Registry, NackGeneratorBuilder};
19/// use std::time::Duration;
20///
21/// let chain = Registry::new()
22///     .with(NackGeneratorBuilder::new()
23///         .with_size(512)
24///         .with_interval(Duration::from_millis(100))
25///         .with_skip_last_n(2)
26///         .build())
27///     .build();
28/// ```
29pub struct NackGeneratorBuilder<P> {
30    /// Size of the receive log (must be power of 2: 64, 128, ..., 32768).
31    size: u16,
32    /// Interval between NACK generation cycles.
33    interval: Duration,
34    /// Number of most recent packets to skip when generating NACKs.
35    skip_last_n: u16,
36    /// Maximum number of NACKs to send per missing packet (0 = unlimited).
37    max_nacks_per_packet: u16,
38    _phantom: PhantomData<P>,
39}
40
41impl<P> Default for NackGeneratorBuilder<P> {
42    fn default() -> Self {
43        Self {
44            size: 512,
45            interval: Duration::from_millis(100),
46            skip_last_n: 0,
47            max_nacks_per_packet: 0,
48            _phantom: PhantomData,
49        }
50    }
51}
52
53impl<P> NackGeneratorBuilder<P> {
54    /// Create a new builder with default settings.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Set the size of the receive log.
60    ///
61    /// Size must be a power of 2 between 64 and 32768 (inclusive).
62    pub fn with_size(mut self, size: u16) -> Self {
63        self.size = size;
64        self
65    }
66
67    /// Set the interval between NACK generation cycles.
68    pub fn with_interval(mut self, interval: Duration) -> Self {
69        self.interval = interval;
70        self
71    }
72
73    /// Set the number of most recent packets to skip when generating NACKs.
74    ///
75    /// This helps avoid generating NACKs for packets that are simply delayed
76    /// and haven't arrived yet.
77    pub fn with_skip_last_n(mut self, skip_last_n: u16) -> Self {
78        self.skip_last_n = skip_last_n;
79        self
80    }
81
82    /// Set the maximum number of NACKs to send per missing packet.
83    ///
84    /// Set to 0 (default) for unlimited NACKs.
85    pub fn with_max_nacks_per_packet(mut self, max: u16) -> Self {
86        self.max_nacks_per_packet = max;
87        self
88    }
89
90    /// Build the interceptor factory function.
91    pub fn build(self) -> impl FnOnce(P) -> NackGeneratorInterceptor<P> {
92        move |inner| {
93            NackGeneratorInterceptor::new(
94                inner,
95                self.size,
96                self.interval,
97                self.skip_last_n,
98                self.max_nacks_per_packet,
99            )
100        }
101    }
102}
103
104/// Interceptor that generates NACK requests for missing RTP packets.
105///
106/// This interceptor monitors incoming RTP packets on remote streams,
107/// tracks which sequence numbers have been received, and periodically
108/// generates RTCP TransportLayerNack packets for missing sequences.
109#[derive(Interceptor)]
110pub struct NackGeneratorInterceptor<P> {
111    #[next]
112    inner: P,
113
114    /// Configuration
115    size: u16,
116    interval: Duration,
117    skip_last_n: u16,
118    max_nacks_per_packet: u16,
119
120    /// Next timeout for NACK generation
121    eto: Instant,
122
123    /// Sender SSRC for NACK packets
124    sender_ssrc: u32,
125
126    /// Receive logs per remote stream SSRC
127    receive_logs: HashMap<u32, ReceiveLog>,
128
129    /// NACK count per (SSRC, sequence number) for max_nacks_per_packet limiting
130    nack_counts: HashMap<u32, HashMap<u16, u16>>,
131
132    /// Queue for outgoing NACK packets
133    write_queue: VecDeque<TaggedPacket>,
134}
135
136impl<P> NackGeneratorInterceptor<P> {
137    fn new(
138        inner: P,
139        size: u16,
140        interval: Duration,
141        skip_last_n: u16,
142        max_nacks_per_packet: u16,
143    ) -> Self {
144        Self {
145            inner,
146            size,
147            interval,
148            skip_last_n,
149            max_nacks_per_packet,
150            eto: Instant::now(),
151            sender_ssrc: rand::random(),
152            receive_logs: HashMap::new(),
153            nack_counts: HashMap::new(),
154            write_queue: VecDeque::new(),
155        }
156    }
157
158    /// Generate NACKs for all streams with missing packets.
159    fn generate_nacks(&mut self, now: Instant) {
160        for (&ssrc, receive_log) in &self.receive_logs {
161            let missing = receive_log.missing_seq_numbers(self.skip_last_n);
162            if missing.is_empty() {
163                // Clear nack counts for this SSRC if no missing packets
164                self.nack_counts.remove(&ssrc);
165                continue;
166            }
167
168            // Initialize nack counts for this SSRC if needed
169            let nack_count = self.nack_counts.entry(ssrc).or_default();
170
171            // Filter by max_nacks_per_packet if configured
172            let filtered: Vec<u16> = if self.max_nacks_per_packet > 0 {
173                missing
174                    .iter()
175                    .filter(|&&seq| {
176                        let count = nack_count.entry(seq).or_insert(0);
177                        if *count < self.max_nacks_per_packet {
178                            *count += 1;
179                            true
180                        } else {
181                            false
182                        }
183                    })
184                    .copied()
185                    .collect()
186            } else {
187                missing.clone()
188            };
189
190            if filtered.is_empty() {
191                continue;
192            }
193
194            // Clean up nack counts for packets no longer missing
195            nack_count.retain(|seq, _| missing.contains(seq));
196
197            // Create NACK packet
198            let nack = rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
199                sender_ssrc: self.sender_ssrc,
200                media_ssrc: ssrc,
201                nacks: rtcp::transport_feedbacks::transport_layer_nack::nack_pairs_from_sequence_numbers(
202                    &filtered,
203                ),
204            };
205
206            self.write_queue.push_back(TaggedPacket {
207                now,
208                transport: TransportContext::default(),
209                message: Packet::Rtcp(vec![Box::new(nack)]),
210            });
211        }
212    }
213}
214
215#[interceptor]
216impl<P: Interceptor> NackGeneratorInterceptor<P> {
217    #[overrides]
218    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
219        // Track incoming RTP packets
220        if let Packet::Rtp(ref rtp_packet) = msg.message
221            && let Some(receive_log) = self.receive_logs.get_mut(&rtp_packet.header.ssrc)
222        {
223            receive_log.add(rtp_packet.header.sequence_number);
224        }
225
226        self.inner.handle_read(msg)
227    }
228
229    #[overrides]
230    fn poll_write(&mut self) -> Option<Self::Wout> {
231        // First drain generated NACK packets
232        if let Some(pkt) = self.write_queue.pop_front() {
233            return Some(pkt);
234        }
235        self.inner.poll_write()
236    }
237
238    #[overrides]
239    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
240        if self.eto <= now {
241            self.eto = now + self.interval;
242            self.generate_nacks(now);
243        }
244
245        self.inner.handle_timeout(now)
246    }
247
248    #[overrides]
249    fn poll_timeout(&mut self) -> Option<Self::Time> {
250        if let Some(inner_eto) = self.inner.poll_timeout()
251            && inner_eto < self.eto
252        {
253            return Some(inner_eto);
254        }
255        Some(self.eto)
256    }
257
258    #[overrides]
259    fn bind_remote_stream(&mut self, info: &StreamInfo) {
260        if stream_supports_nack(info)
261            && let Some(receive_log) = ReceiveLog::new(self.size)
262        {
263            self.receive_logs.insert(info.ssrc, receive_log);
264        }
265        self.inner.bind_remote_stream(info);
266    }
267
268    #[overrides]
269    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
270        self.receive_logs.remove(&info.ssrc);
271        self.nack_counts.remove(&info.ssrc);
272        self.inner.unbind_remote_stream(info);
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::Registry;
280    use crate::stream_info::RTCPFeedback;
281    use sansio::Protocol;
282
283    fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
284        TaggedPacket {
285            now: Instant::now(),
286            transport: Default::default(),
287            message: Packet::Rtp(rtp::Packet {
288                header: rtp::header::Header {
289                    ssrc,
290                    sequence_number: seq,
291                    ..Default::default()
292                },
293                ..Default::default()
294            }),
295        }
296    }
297
298    #[test]
299    fn test_nack_generator_builder_defaults() {
300        let chain = Registry::new()
301            .with(NackGeneratorBuilder::default().build())
302            .build();
303
304        assert_eq!(chain.size, 512);
305        assert_eq!(chain.interval, Duration::from_millis(100));
306        assert_eq!(chain.skip_last_n, 0);
307        assert_eq!(chain.max_nacks_per_packet, 0);
308    }
309
310    #[test]
311    fn test_nack_generator_builder_custom() {
312        let chain = Registry::new()
313            .with(
314                NackGeneratorBuilder::new()
315                    .with_size(1024)
316                    .with_interval(Duration::from_millis(50))
317                    .with_skip_last_n(3)
318                    .with_max_nacks_per_packet(5)
319                    .build(),
320            )
321            .build();
322
323        assert_eq!(chain.size, 1024);
324        assert_eq!(chain.interval, Duration::from_millis(50));
325        assert_eq!(chain.skip_last_n, 3);
326        assert_eq!(chain.max_nacks_per_packet, 5);
327    }
328
329    #[test]
330    fn test_nack_generator_no_nack_without_binding() {
331        let mut chain = Registry::new()
332            .with(
333                NackGeneratorBuilder::new()
334                    .with_interval(Duration::from_millis(100))
335                    .build(),
336            )
337            .build();
338
339        let now = Instant::now();
340
341        // Receive packets without binding stream (no receive log)
342        chain.handle_read(make_rtp_packet(12345, 0)).unwrap();
343        chain.handle_read(make_rtp_packet(12345, 2)).unwrap(); // Gap at 1
344
345        // Trigger timeout
346        let later = now + Duration::from_millis(200);
347        chain.handle_timeout(later).unwrap();
348
349        // No NACK should be generated (stream not bound)
350        assert!(chain.poll_write().is_none());
351    }
352
353    #[test]
354    fn test_nack_generator_generates_nack() {
355        let mut chain = Registry::new()
356            .with(
357                NackGeneratorBuilder::new()
358                    .with_size(64)
359                    .with_interval(Duration::from_millis(100))
360                    .build(),
361            )
362            .build();
363
364        // Bind remote stream with NACK support
365        let info = StreamInfo {
366            ssrc: 12345,
367            clock_rate: 90000,
368            rtcp_feedback: vec![RTCPFeedback {
369                typ: "nack".to_string(),
370                parameter: "".to_string(),
371            }],
372            ..Default::default()
373        };
374        chain.bind_remote_stream(&info);
375
376        let base_time = Instant::now();
377
378        // Receive packets with gap
379        let mut pkt = make_rtp_packet(12345, 10);
380        pkt.now = base_time;
381        chain.handle_read(pkt).unwrap();
382
383        let mut pkt = make_rtp_packet(12345, 12); // Gap at 11
384        pkt.now = base_time;
385        chain.handle_read(pkt).unwrap();
386
387        chain.poll_read();
388
389        // Trigger timeout
390        let later = base_time + Duration::from_millis(200);
391        chain.handle_timeout(later).unwrap();
392
393        // Should generate NACK for seq 11
394        let nack_pkt = chain.poll_write();
395        assert!(nack_pkt.is_some());
396
397        if let Some(tagged) = nack_pkt {
398            if let Packet::Rtcp(rtcp_packets) = tagged.message {
399                assert_eq!(rtcp_packets.len(), 1);
400                let nack = rtcp_packets[0]
401                    .as_any()
402                    .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
403                    .expect("Expected TransportLayerNack");
404                assert_eq!(nack.media_ssrc, 12345);
405                assert!(!nack.nacks.is_empty());
406            } else {
407                panic!("Expected RTCP packet");
408            }
409        }
410    }
411
412    #[test]
413    fn test_nack_generator_skip_last_n() {
414        let mut chain = Registry::new()
415            .with(
416                NackGeneratorBuilder::new()
417                    .with_size(64)
418                    .with_interval(Duration::from_millis(100))
419                    .with_skip_last_n(2)
420                    .build(),
421            )
422            .build();
423
424        let info = StreamInfo {
425            ssrc: 12345,
426            clock_rate: 90000,
427            rtcp_feedback: vec![RTCPFeedback {
428                typ: "nack".to_string(),
429                parameter: "".to_string(),
430            }],
431            ..Default::default()
432        };
433        chain.bind_remote_stream(&info);
434
435        let base_time = Instant::now();
436
437        // Receive: 10, 11, 12, 14, 16, 18 (gaps at 13, 15, 17)
438        for seq in [10u16, 11, 12, 14, 16, 18] {
439            let mut pkt = make_rtp_packet(12345, seq);
440            pkt.now = base_time;
441            chain.handle_read(pkt).unwrap();
442        }
443
444        // Trigger timeout
445        let later = base_time + Duration::from_millis(200);
446        chain.handle_timeout(later).unwrap();
447
448        // With skip_last_n=2, should only NACK for 13, 15 (not 17)
449        let nack_pkt = chain.poll_write();
450        assert!(nack_pkt.is_some());
451
452        if let Some(tagged) = nack_pkt
453            && let Packet::Rtcp(rtcp_packets) = tagged.message
454        {
455            let nack = rtcp_packets[0]
456                .as_any()
457                .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
458                .expect("Expected TransportLayerNack");
459
460            // Get all nacked sequence numbers
461            let mut nacked_seqs = Vec::new();
462            for nack_pair in &nack.nacks {
463                nacked_seqs.push(nack_pair.packet_id);
464                for i in 0..16 {
465                    if nack_pair.lost_packets & (1 << i) != 0 {
466                        nacked_seqs.push(nack_pair.packet_id.wrapping_add(i + 1));
467                    }
468                }
469            }
470
471            // Should contain 13, 15 but not 17
472            assert!(nacked_seqs.contains(&13));
473            assert!(nacked_seqs.contains(&15));
474            assert!(!nacked_seqs.contains(&17));
475        }
476    }
477
478    #[test]
479    fn test_nack_generator_unbind_removes_stream() {
480        let mut chain = Registry::new()
481            .with(
482                NackGeneratorBuilder::new()
483                    .with_size(64)
484                    .with_interval(Duration::from_millis(100))
485                    .build(),
486            )
487            .build();
488
489        let info = StreamInfo {
490            ssrc: 12345,
491            clock_rate: 90000,
492            rtcp_feedback: vec![RTCPFeedback {
493                typ: "nack".to_string(),
494                parameter: "".to_string(),
495            }],
496            ..Default::default()
497        };
498
499        chain.bind_remote_stream(&info);
500        assert!(chain.receive_logs.contains_key(&12345));
501
502        chain.unbind_remote_stream(&info);
503        assert!(!chain.receive_logs.contains_key(&12345));
504        assert!(!chain.nack_counts.contains_key(&12345));
505    }
506
507    #[test]
508    fn test_nack_generator_no_nack_support() {
509        let mut chain = Registry::new()
510            .with(
511                NackGeneratorBuilder::new()
512                    .with_size(64)
513                    .with_interval(Duration::from_millis(100))
514                    .build(),
515            )
516            .build();
517
518        // Bind stream without NACK support
519        let info = StreamInfo {
520            ssrc: 12345,
521            clock_rate: 90000,
522            rtcp_feedback: vec![], // No NACK support
523            ..Default::default()
524        };
525        chain.bind_remote_stream(&info);
526
527        // Should not create receive log
528        assert!(!chain.receive_logs.contains_key(&12345));
529    }
530}