rtc_interceptor/nack/
responder.rs

1//! NACK Responder Interceptor - Responds to NACK requests by retransmitting packets.
2
3use super::send_buffer::SendBuffer;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::Instant;
12
13/// Builder for the NackResponderInterceptor.
14///
15/// # Example
16///
17/// ```ignore
18/// use rtc_interceptor::{Registry, NackResponderBuilder};
19///
20/// let chain = Registry::new()
21///     .with(NackResponderBuilder::new()
22///         .with_size(1024)
23///         .build())
24///     .build();
25/// ```
26pub struct NackResponderBuilder<P> {
27    /// Size of the send buffer (must be power of 2: 1, 2, 4, ..., 32768).
28    size: u16,
29    _phantom: PhantomData<P>,
30}
31
32impl<P> Default for NackResponderBuilder<P> {
33    fn default() -> Self {
34        Self {
35            size: 1024,
36            _phantom: PhantomData,
37        }
38    }
39}
40
41impl<P> NackResponderBuilder<P> {
42    /// Create a new builder with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Set the size of the send buffer.
48    ///
49    /// Size must be a power of 2 between 1 and 32768 (inclusive).
50    /// Larger buffers can retransmit older packets but use more memory.
51    pub fn with_size(mut self, size: u16) -> Self {
52        self.size = size;
53        self
54    }
55
56    /// Build the interceptor factory function.
57    pub fn build(self) -> impl FnOnce(P) -> NackResponderInterceptor<P> {
58        move |inner| NackResponderInterceptor::new(inner, self.size)
59    }
60}
61
62/// Per-stream state for the responder.
63struct LocalStream {
64    /// Buffer of sent packets for retransmission.
65    send_buffer: SendBuffer,
66    /// RTX SSRC for RFC4588 retransmission (if configured).
67    ssrc_rtx: Option<u32>,
68    /// RTX payload type for RFC4588 retransmission (if configured).
69    payload_type_rtx: Option<u8>,
70    /// Sequence number counter for RTX packets.
71    rtx_sequence_number: u16,
72}
73
74/// Interceptor that responds to NACK requests by retransmitting packets.
75///
76/// This interceptor buffers outgoing RTP packets on local streams and
77/// retransmits them when RTCP TransportLayerNack packets are received.
78#[derive(Interceptor)]
79pub struct NackResponderInterceptor<P> {
80    #[next]
81    inner: P,
82
83    /// Configuration
84    size: u16,
85
86    /// Send buffers per local stream SSRC
87    streams: HashMap<u32, LocalStream>,
88
89    /// Queue for retransmitted packets
90    write_queue: VecDeque<TaggedPacket>,
91}
92
93impl<P> NackResponderInterceptor<P> {
94    fn new(inner: P, size: u16) -> Self {
95        Self {
96            inner,
97            size,
98            streams: HashMap::new(),
99            write_queue: VecDeque::new(),
100        }
101    }
102
103    /// Handle a NACK request by queuing retransmissions.
104    fn handle_nack(
105        &mut self,
106        now: Instant,
107        nack: &rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack,
108    ) {
109        // Collect sequence numbers to retransmit
110        let mut seqs_to_retransmit = Vec::new();
111
112        for nack_pair in &nack.nacks {
113            // Check the base packet ID
114            seqs_to_retransmit.push(nack_pair.packet_id);
115
116            // Check each bit in lost_packets bitmap
117            for i in 0..16 {
118                if nack_pair.lost_packets & (1 << i) != 0 {
119                    let seq = nack_pair.packet_id.wrapping_add(i + 1);
120                    seqs_to_retransmit.push(seq);
121                }
122            }
123        }
124
125        let Some(stream) = self.streams.get_mut(&nack.media_ssrc) else {
126            return;
127        };
128
129        // Queue retransmissions
130        for seq in seqs_to_retransmit {
131            let Some(original_packet) = stream.send_buffer.get(seq) else {
132                continue;
133            };
134
135            let packet = if let (Some(ssrc_rtx), Some(pt_rtx)) =
136                (stream.ssrc_rtx, stream.payload_type_rtx)
137            {
138                // RFC4588: Create RTX packet
139                // - Use RTX SSRC and payload type
140                // - Prepend original sequence number (2 bytes big-endian) to payload
141                // - Use separate RTX sequence number counter
142                let original_seq = original_packet.header.sequence_number;
143                let mut rtx_payload = Vec::with_capacity(2 + original_packet.payload.len());
144                rtx_payload.extend_from_slice(&original_seq.to_be_bytes());
145                rtx_payload.extend_from_slice(&original_packet.payload);
146
147                let rtx_seq = stream.rtx_sequence_number;
148                stream.rtx_sequence_number = stream.rtx_sequence_number.wrapping_add(1);
149
150                rtp::Packet {
151                    header: rtp::header::Header {
152                        ssrc: ssrc_rtx,
153                        payload_type: pt_rtx,
154                        sequence_number: rtx_seq,
155                        timestamp: original_packet.header.timestamp,
156                        marker: original_packet.header.marker,
157                        ..Default::default()
158                    },
159                    payload: rtx_payload.into(),
160                }
161            } else {
162                // No RTX: retransmit original packet as-is
163                original_packet.clone()
164            };
165
166            self.write_queue.push_back(TaggedPacket {
167                now,
168                transport: TransportContext::default(),
169                message: Packet::Rtp(packet),
170            });
171        }
172    }
173}
174
175#[interceptor]
176impl<P: Interceptor> NackResponderInterceptor<P> {
177    #[overrides]
178    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
179        // Process NACK packets
180        if let Packet::Rtcp(ref rtcp_packets) = msg.message {
181            for rtcp_packet in rtcp_packets {
182                if let Some(nack) = rtcp_packet
183                    .as_any()
184                    .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
185                {
186                    self.handle_nack(msg.now, nack);
187                }
188            }
189        }
190
191        self.inner.handle_read(msg)
192    }
193
194    #[overrides]
195    fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
196        // Buffer outgoing RTP packets
197        if let Packet::Rtp(ref rtp_packet) = msg.message
198            && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
199        {
200            stream.send_buffer.add(rtp_packet.clone());
201        }
202
203        self.inner.handle_write(msg)
204    }
205
206    #[overrides]
207    fn poll_write(&mut self) -> Option<Self::Wout> {
208        // First drain retransmitted packets
209        if let Some(pkt) = self.write_queue.pop_front() {
210            return Some(pkt);
211        }
212        self.inner.poll_write()
213    }
214
215    #[overrides]
216    fn bind_local_stream(&mut self, info: &StreamInfo) {
217        if stream_supports_nack(info)
218            && let Some(send_buffer) = SendBuffer::new(self.size)
219        {
220            self.streams.insert(
221                info.ssrc,
222                LocalStream {
223                    send_buffer,
224                    ssrc_rtx: info.ssrc_rtx,
225                    payload_type_rtx: info.payload_type_rtx,
226                    rtx_sequence_number: 0,
227                },
228            );
229        }
230        self.inner.bind_local_stream(info);
231    }
232
233    #[overrides]
234    fn unbind_local_stream(&mut self, info: &StreamInfo) {
235        self.streams.remove(&info.ssrc);
236        self.inner.unbind_local_stream(info);
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::Registry;
244    use crate::stream_info::RTCPFeedback;
245    use sansio::Protocol;
246
247    fn make_rtp_packet(ssrc: u32, seq: u16, payload: &[u8]) -> TaggedPacket {
248        TaggedPacket {
249            now: Instant::now(),
250            transport: Default::default(),
251            message: Packet::Rtp(rtp::Packet {
252                header: rtp::header::Header {
253                    ssrc,
254                    sequence_number: seq,
255                    ..Default::default()
256                },
257                payload: payload.to_vec().into(),
258            }),
259        }
260    }
261
262    fn make_nack_packet(sender_ssrc: u32, media_ssrc: u32, nacks: Vec<(u16, u16)>) -> TaggedPacket {
263        let nack_pairs: Vec<rtcp::transport_feedbacks::transport_layer_nack::NackPair> = nacks
264            .into_iter()
265            .map(|(packet_id, lost_packets)| {
266                rtcp::transport_feedbacks::transport_layer_nack::NackPair {
267                    packet_id,
268                    lost_packets,
269                }
270            })
271            .collect();
272
273        TaggedPacket {
274            now: Instant::now(),
275            transport: Default::default(),
276            message: Packet::Rtcp(vec![Box::new(
277                rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
278                    sender_ssrc,
279                    media_ssrc,
280                    nacks: nack_pairs,
281                },
282            )]),
283        }
284    }
285
286    #[test]
287    fn test_nack_responder_builder_defaults() {
288        let chain = Registry::new()
289            .with(NackResponderBuilder::default().build())
290            .build();
291
292        assert_eq!(chain.size, 1024);
293    }
294
295    #[test]
296    fn test_nack_responder_builder_custom() {
297        let chain = Registry::new()
298            .with(NackResponderBuilder::new().with_size(2048).build())
299            .build();
300
301        assert_eq!(chain.size, 2048);
302    }
303
304    #[test]
305    fn test_nack_responder_retransmits_packet() {
306        let mut chain = Registry::new()
307            .with(NackResponderBuilder::new().with_size(8).build())
308            .build();
309
310        // Bind local stream with NACK support
311        let info = StreamInfo {
312            ssrc: 12345,
313            clock_rate: 90000,
314            rtcp_feedback: vec![RTCPFeedback {
315                typ: "nack".to_string(),
316                parameter: "".to_string(),
317            }],
318            ..Default::default()
319        };
320        chain.bind_local_stream(&info);
321
322        let now = Instant::now();
323
324        // Send packets 10, 11, 12, 14, 15 (missing 13)
325        for seq in [10u16, 11, 12, 14, 15] {
326            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
327            pkt.now = now;
328            chain.handle_write(pkt).unwrap();
329            chain.poll_write(); // Drain normal write
330        }
331
332        // Receive NACK for 11, 12, 13, 15
333        // nack_pair: packet_id=11, lost_packets=0b1011 means 11, 12, 13, 15
334        let mut nack = make_nack_packet(999, 12345, vec![(11, 0b1011)]);
335        nack.now = now;
336        chain.handle_read(nack).unwrap();
337
338        // Should retransmit 11, 12, 15 (13 was never sent)
339        let mut retransmitted = Vec::new();
340        while let Some(pkt) = chain.poll_write() {
341            if let Packet::Rtp(rtp) = pkt.message {
342                retransmitted.push(rtp.header.sequence_number);
343            }
344        }
345
346        assert!(retransmitted.contains(&11));
347        assert!(retransmitted.contains(&12));
348        assert!(!retransmitted.contains(&13)); // Never sent
349        assert!(retransmitted.contains(&15));
350    }
351
352    #[test]
353    fn test_nack_responder_no_retransmit_without_binding() {
354        let mut chain = Registry::new()
355            .with(NackResponderBuilder::new().with_size(8).build())
356            .build();
357
358        let now = Instant::now();
359
360        // Send packets without binding stream (no buffer)
361        for seq in [10u16, 11, 12] {
362            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
363            pkt.now = now;
364            chain.handle_write(pkt).unwrap();
365            chain.poll_write();
366        }
367
368        // Receive NACK
369        let mut nack = make_nack_packet(999, 12345, vec![(11, 0)]);
370        nack.now = now;
371        chain.handle_read(nack).unwrap();
372
373        // No retransmissions (stream not bound)
374        assert!(chain.poll_write().is_none());
375    }
376
377    #[test]
378    fn test_nack_responder_no_retransmit_expired_packet() {
379        let mut chain = Registry::new()
380            .with(NackResponderBuilder::new().with_size(8).build())
381            .build();
382
383        let info = StreamInfo {
384            ssrc: 12345,
385            clock_rate: 90000,
386            rtcp_feedback: vec![RTCPFeedback {
387                typ: "nack".to_string(),
388                parameter: "".to_string(),
389            }],
390            ..Default::default()
391        };
392        chain.bind_local_stream(&info);
393
394        let now = Instant::now();
395
396        // Send packets 0-15 (buffer size is 8, so 0-7 will be pushed out)
397        for seq in 0..16u16 {
398            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
399            pkt.now = now;
400            chain.handle_write(pkt).unwrap();
401            chain.poll_write();
402        }
403
404        // Request retransmit of seq 0 (should be expired from buffer)
405        let mut nack = make_nack_packet(999, 12345, vec![(0, 0)]);
406        nack.now = now;
407        chain.handle_read(nack).unwrap();
408
409        // No retransmission (packet too old)
410        assert!(chain.poll_write().is_none());
411
412        // But seq 10 should still be available
413        let mut nack = make_nack_packet(999, 12345, vec![(10, 0)]);
414        nack.now = now;
415        chain.handle_read(nack).unwrap();
416
417        let pkt = chain.poll_write();
418        assert!(pkt.is_some());
419        if let Some(tagged) = pkt
420            && let Packet::Rtp(rtp) = tagged.message
421        {
422            assert_eq!(rtp.header.sequence_number, 10);
423        }
424    }
425
426    #[test]
427    fn test_nack_responder_unbind_removes_stream() {
428        let mut chain = Registry::new()
429            .with(NackResponderBuilder::new().with_size(8).build())
430            .build();
431
432        let info = StreamInfo {
433            ssrc: 12345,
434            clock_rate: 90000,
435            rtcp_feedback: vec![RTCPFeedback {
436                typ: "nack".to_string(),
437                parameter: "".to_string(),
438            }],
439            ..Default::default()
440        };
441
442        chain.bind_local_stream(&info);
443        assert!(chain.streams.contains_key(&12345));
444
445        chain.unbind_local_stream(&info);
446        assert!(!chain.streams.contains_key(&12345));
447    }
448
449    #[test]
450    fn test_nack_responder_no_nack_support() {
451        let mut chain = Registry::new()
452            .with(NackResponderBuilder::new().with_size(8).build())
453            .build();
454
455        // Bind stream without NACK support
456        let info = StreamInfo {
457            ssrc: 12345,
458            clock_rate: 90000,
459            rtcp_feedback: vec![], // No NACK support
460            ..Default::default()
461        };
462        chain.bind_local_stream(&info);
463
464        // Should not create send buffer
465        assert!(!chain.streams.contains_key(&12345));
466    }
467
468    #[test]
469    fn test_nack_responder_passthrough() {
470        let mut chain = Registry::new()
471            .with(NackResponderBuilder::new().with_size(8).build())
472            .build();
473
474        let now = Instant::now();
475
476        // RTP packets should pass through
477        let mut pkt = make_rtp_packet(12345, 1, &[1]);
478        pkt.now = now;
479        chain.handle_write(pkt).unwrap();
480        let out = chain.poll_write();
481        assert!(out.is_some());
482
483        // RTCP packets should pass through to read
484        let mut nack = make_nack_packet(999, 12345, vec![(1, 0)]);
485        nack.now = now;
486        chain.handle_read(nack).unwrap();
487        let out = chain.poll_read();
488        assert!(out.is_none());
489    }
490
491    #[test]
492    fn test_nack_responder_rfc4588_rtx() {
493        let mut chain = Registry::new()
494            .with(NackResponderBuilder::new().with_size(8).build())
495            .build();
496
497        // Bind local stream with NACK support AND RTX configured
498        let info = StreamInfo {
499            ssrc: 1,
500            ssrc_rtx: Some(2), // RTX SSRC
501            payload_type: 96,
502            payload_type_rtx: Some(97), // RTX payload type
503            clock_rate: 90000,
504            rtcp_feedback: vec![RTCPFeedback {
505                typ: "nack".to_string(),
506                parameter: "".to_string(),
507            }],
508            ..Default::default()
509        };
510        chain.bind_local_stream(&info);
511
512        let now = Instant::now();
513
514        // Send packets 10, 11, 12, 14, 15 (missing 13)
515        for seq in [10u16, 11, 12, 14, 15] {
516            let mut pkt = make_rtp_packet(1, seq, &[seq as u8]);
517            pkt.now = now;
518            chain.handle_write(pkt).unwrap();
519            chain.poll_write(); // Drain normal write
520        }
521
522        // Receive NACK for 11, 12, 13, 15
523        // nack_pair: packet_id=11, lost_packets=0b1011 means 11, 12, 13, 15
524        let mut nack = make_nack_packet(999, 1, vec![(11, 0b1011)]);
525        nack.now = now;
526        chain.handle_read(nack).unwrap();
527
528        // Should retransmit 11, 12, 15 (13 was never sent) using RTX format
529        let mut rtx_seq = 0u16;
530        for expected_original_seq in [11u16, 12, 15] {
531            let pkt = chain.poll_write();
532            assert!(
533                pkt.is_some(),
534                "Expected RTX packet for seq {}",
535                expected_original_seq
536            );
537
538            if let Some(tagged) = pkt {
539                if let Packet::Rtp(rtp) = tagged.message {
540                    // Verify RTX SSRC
541                    assert_eq!(rtp.header.ssrc, 2, "RTX packet should use RTX SSRC");
542                    // Verify RTX payload type
543                    assert_eq!(
544                        rtp.header.payload_type, 97,
545                        "RTX packet should use RTX payload type"
546                    );
547                    // Verify RTX sequence number (increments separately)
548                    assert_eq!(
549                        rtp.header.sequence_number, rtx_seq,
550                        "RTX seq should be {}",
551                        rtx_seq
552                    );
553                    rtx_seq += 1;
554
555                    // Verify payload: first 2 bytes should be original sequence number (big-endian)
556                    assert!(
557                        rtp.payload.len() >= 2,
558                        "RTX payload should have at least 2 bytes"
559                    );
560                    let original_seq_from_payload =
561                        u16::from_be_bytes([rtp.payload[0], rtp.payload[1]]);
562                    assert_eq!(
563                        original_seq_from_payload, expected_original_seq,
564                        "RTX payload should contain original seq"
565                    );
566
567                    // Verify original payload follows
568                    assert_eq!(
569                        rtp.payload[2..],
570                        [expected_original_seq as u8],
571                        "Original payload should follow seq number"
572                    );
573                } else {
574                    panic!("Expected RTP packet");
575                }
576            }
577        }
578
579        // No more packets
580        assert!(chain.poll_write().is_none());
581    }
582}