rtc_interceptor/nack/
responder.rs

1//! NACK Responder Interceptor - Responds to NACK requests by retransmitting packets.
2
3use super::send_buffer::SendBuffer;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::Instant;
12
13/// Builder for the NackResponderInterceptor.
14///
15/// # Example
16///
17/// ```ignore
18/// use rtc_interceptor::{Registry, NackResponderBuilder};
19///
20/// let chain = Registry::new()
21///     .with(NackResponderBuilder::new()
22///         .with_size(1024)
23///         .build())
24///     .build();
25/// ```
26pub struct NackResponderBuilder<P> {
27    /// Size of the send buffer (must be power of 2: 1, 2, 4, ..., 32768).
28    size: u16,
29    _phantom: PhantomData<P>,
30}
31
32impl<P> Default for NackResponderBuilder<P> {
33    fn default() -> Self {
34        Self {
35            size: 1024,
36            _phantom: PhantomData,
37        }
38    }
39}
40
41impl<P> NackResponderBuilder<P> {
42    /// Create a new builder with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Set the size of the send buffer.
48    ///
49    /// Size must be a power of 2 between 1 and 32768 (inclusive).
50    /// Larger buffers can retransmit older packets but use more memory.
51    pub fn with_size(mut self, size: u16) -> Self {
52        self.size = size;
53        self
54    }
55
56    /// Build the interceptor factory function.
57    pub fn build(self) -> impl FnOnce(P) -> NackResponderInterceptor<P> {
58        move |inner| NackResponderInterceptor::new(inner, self.size)
59    }
60}
61
62/// Per-stream state for the responder.
63struct LocalStream {
64    /// Buffer of sent packets for retransmission.
65    send_buffer: SendBuffer,
66    /// RTX SSRC for RFC4588 retransmission (if configured).
67    ssrc_rtx: Option<u32>,
68    /// RTX payload type for RFC4588 retransmission (if configured).
69    payload_type_rtx: Option<u8>,
70    /// Sequence number counter for RTX packets.
71    rtx_sequence_number: u16,
72}
73
74/// Interceptor that responds to NACK requests by retransmitting packets.
75///
76/// This interceptor buffers outgoing RTP packets on local streams and
77/// retransmits them when RTCP TransportLayerNack packets are received.
78pub struct NackResponderInterceptor<P> {
79    inner: P,
80
81    /// Configuration
82    size: u16,
83
84    /// Send buffers per local stream SSRC
85    streams: HashMap<u32, LocalStream>,
86
87    /// Queue for retransmitted packets
88    write_queue: VecDeque<TaggedPacket>,
89}
90
91impl<P> NackResponderInterceptor<P> {
92    fn new(inner: P, size: u16) -> Self {
93        Self {
94            inner,
95            size,
96            streams: HashMap::new(),
97            write_queue: VecDeque::new(),
98        }
99    }
100
101    /// Handle a NACK request by queuing retransmissions.
102    fn handle_nack(
103        &mut self,
104        now: Instant,
105        nack: &rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack,
106    ) {
107        // Collect sequence numbers to retransmit
108        let mut seqs_to_retransmit = Vec::new();
109
110        for nack_pair in &nack.nacks {
111            // Check the base packet ID
112            seqs_to_retransmit.push(nack_pair.packet_id);
113
114            // Check each bit in lost_packets bitmap
115            for i in 0..16 {
116                if nack_pair.lost_packets & (1 << i) != 0 {
117                    let seq = nack_pair.packet_id.wrapping_add(i + 1);
118                    seqs_to_retransmit.push(seq);
119                }
120            }
121        }
122
123        let Some(stream) = self.streams.get_mut(&nack.media_ssrc) else {
124            return;
125        };
126
127        // Queue retransmissions
128        for seq in seqs_to_retransmit {
129            let Some(original_packet) = stream.send_buffer.get(seq) else {
130                continue;
131            };
132
133            let packet = if let (Some(ssrc_rtx), Some(pt_rtx)) =
134                (stream.ssrc_rtx, stream.payload_type_rtx)
135            {
136                // RFC4588: Create RTX packet
137                // - Use RTX SSRC and payload type
138                // - Prepend original sequence number (2 bytes big-endian) to payload
139                // - Use separate RTX sequence number counter
140                let original_seq = original_packet.header.sequence_number;
141                let mut rtx_payload = Vec::with_capacity(2 + original_packet.payload.len());
142                rtx_payload.extend_from_slice(&original_seq.to_be_bytes());
143                rtx_payload.extend_from_slice(&original_packet.payload);
144
145                let rtx_seq = stream.rtx_sequence_number;
146                stream.rtx_sequence_number = stream.rtx_sequence_number.wrapping_add(1);
147
148                rtp::Packet {
149                    header: rtp::header::Header {
150                        ssrc: ssrc_rtx,
151                        payload_type: pt_rtx,
152                        sequence_number: rtx_seq,
153                        timestamp: original_packet.header.timestamp,
154                        marker: original_packet.header.marker,
155                        ..Default::default()
156                    },
157                    payload: rtx_payload.into(),
158                }
159            } else {
160                // No RTX: retransmit original packet as-is
161                original_packet.clone()
162            };
163
164            self.write_queue.push_back(TaggedPacket {
165                now,
166                transport: TransportContext::default(),
167                message: Packet::Rtp(packet),
168            });
169        }
170    }
171}
172
173impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
174    for NackResponderInterceptor<P>
175{
176    type Rout = TaggedPacket;
177    type Wout = TaggedPacket;
178    type Eout = ();
179    type Error = Error;
180    type Time = Instant;
181
182    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
183        // Process NACK packets
184        if let Packet::Rtcp(ref rtcp_packets) = msg.message {
185            for rtcp_packet in rtcp_packets {
186                if let Some(nack) = rtcp_packet
187                    .as_any()
188                    .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
189                {
190                    self.handle_nack(msg.now, nack);
191                }
192            }
193        }
194
195        self.inner.handle_read(msg)
196    }
197
198    fn poll_read(&mut self) -> Option<Self::Rout> {
199        self.inner.poll_read()
200    }
201
202    fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
203        // Buffer outgoing RTP packets
204        if let Packet::Rtp(ref rtp_packet) = msg.message
205            && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
206        {
207            stream.send_buffer.add(rtp_packet.clone());
208        }
209
210        self.inner.handle_write(msg)
211    }
212
213    fn poll_write(&mut self) -> Option<Self::Wout> {
214        // First drain retransmitted packets
215        if let Some(pkt) = self.write_queue.pop_front() {
216            return Some(pkt);
217        }
218        self.inner.poll_write()
219    }
220
221    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
222        self.inner.handle_timeout(now)
223    }
224
225    fn poll_timeout(&mut self) -> Option<Self::Time> {
226        self.inner.poll_timeout()
227    }
228}
229
230impl<P: Interceptor> Interceptor for NackResponderInterceptor<P> {
231    fn bind_local_stream(&mut self, info: &StreamInfo) {
232        if stream_supports_nack(info)
233            && let Some(send_buffer) = SendBuffer::new(self.size)
234        {
235            self.streams.insert(
236                info.ssrc,
237                LocalStream {
238                    send_buffer,
239                    ssrc_rtx: info.ssrc_rtx,
240                    payload_type_rtx: info.payload_type_rtx,
241                    rtx_sequence_number: 0,
242                },
243            );
244        }
245        self.inner.bind_local_stream(info);
246    }
247
248    fn unbind_local_stream(&mut self, info: &StreamInfo) {
249        self.streams.remove(&info.ssrc);
250        self.inner.unbind_local_stream(info);
251    }
252
253    fn bind_remote_stream(&mut self, info: &StreamInfo) {
254        self.inner.bind_remote_stream(info);
255    }
256
257    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
258        self.inner.unbind_remote_stream(info);
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::Registry;
266    use crate::stream_info::RTCPFeedback;
267    use sansio::Protocol;
268
269    fn make_rtp_packet(ssrc: u32, seq: u16, payload: &[u8]) -> TaggedPacket {
270        TaggedPacket {
271            now: Instant::now(),
272            transport: Default::default(),
273            message: Packet::Rtp(rtp::Packet {
274                header: rtp::header::Header {
275                    ssrc,
276                    sequence_number: seq,
277                    ..Default::default()
278                },
279                payload: payload.to_vec().into(),
280            }),
281        }
282    }
283
284    fn make_nack_packet(sender_ssrc: u32, media_ssrc: u32, nacks: Vec<(u16, u16)>) -> TaggedPacket {
285        let nack_pairs: Vec<rtcp::transport_feedbacks::transport_layer_nack::NackPair> = nacks
286            .into_iter()
287            .map(|(packet_id, lost_packets)| {
288                rtcp::transport_feedbacks::transport_layer_nack::NackPair {
289                    packet_id,
290                    lost_packets,
291                }
292            })
293            .collect();
294
295        TaggedPacket {
296            now: Instant::now(),
297            transport: Default::default(),
298            message: Packet::Rtcp(vec![Box::new(
299                rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
300                    sender_ssrc,
301                    media_ssrc,
302                    nacks: nack_pairs,
303                },
304            )]),
305        }
306    }
307
308    #[test]
309    fn test_nack_responder_builder_defaults() {
310        let chain = Registry::new()
311            .with(NackResponderBuilder::default().build())
312            .build();
313
314        assert_eq!(chain.size, 1024);
315    }
316
317    #[test]
318    fn test_nack_responder_builder_custom() {
319        let chain = Registry::new()
320            .with(NackResponderBuilder::new().with_size(2048).build())
321            .build();
322
323        assert_eq!(chain.size, 2048);
324    }
325
326    #[test]
327    fn test_nack_responder_retransmits_packet() {
328        let mut chain = Registry::new()
329            .with(NackResponderBuilder::new().with_size(8).build())
330            .build();
331
332        // Bind local stream with NACK support
333        let info = StreamInfo {
334            ssrc: 12345,
335            clock_rate: 90000,
336            rtcp_feedback: vec![RTCPFeedback {
337                typ: "nack".to_string(),
338                parameter: "".to_string(),
339            }],
340            ..Default::default()
341        };
342        chain.bind_local_stream(&info);
343
344        let now = Instant::now();
345
346        // Send packets 10, 11, 12, 14, 15 (missing 13)
347        for seq in [10u16, 11, 12, 14, 15] {
348            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
349            pkt.now = now;
350            chain.handle_write(pkt).unwrap();
351            chain.poll_write(); // Drain normal write
352        }
353
354        // Receive NACK for 11, 12, 13, 15
355        // nack_pair: packet_id=11, lost_packets=0b1011 means 11, 12, 13, 15
356        let mut nack = make_nack_packet(999, 12345, vec![(11, 0b1011)]);
357        nack.now = now;
358        chain.handle_read(nack).unwrap();
359
360        // Should retransmit 11, 12, 15 (13 was never sent)
361        let mut retransmitted = Vec::new();
362        while let Some(pkt) = chain.poll_write() {
363            if let Packet::Rtp(rtp) = pkt.message {
364                retransmitted.push(rtp.header.sequence_number);
365            }
366        }
367
368        assert!(retransmitted.contains(&11));
369        assert!(retransmitted.contains(&12));
370        assert!(!retransmitted.contains(&13)); // Never sent
371        assert!(retransmitted.contains(&15));
372    }
373
374    #[test]
375    fn test_nack_responder_no_retransmit_without_binding() {
376        let mut chain = Registry::new()
377            .with(NackResponderBuilder::new().with_size(8).build())
378            .build();
379
380        let now = Instant::now();
381
382        // Send packets without binding stream (no buffer)
383        for seq in [10u16, 11, 12] {
384            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
385            pkt.now = now;
386            chain.handle_write(pkt).unwrap();
387            chain.poll_write();
388        }
389
390        // Receive NACK
391        let mut nack = make_nack_packet(999, 12345, vec![(11, 0)]);
392        nack.now = now;
393        chain.handle_read(nack).unwrap();
394
395        // No retransmissions (stream not bound)
396        assert!(chain.poll_write().is_none());
397    }
398
399    #[test]
400    fn test_nack_responder_no_retransmit_expired_packet() {
401        let mut chain = Registry::new()
402            .with(NackResponderBuilder::new().with_size(8).build())
403            .build();
404
405        let info = StreamInfo {
406            ssrc: 12345,
407            clock_rate: 90000,
408            rtcp_feedback: vec![RTCPFeedback {
409                typ: "nack".to_string(),
410                parameter: "".to_string(),
411            }],
412            ..Default::default()
413        };
414        chain.bind_local_stream(&info);
415
416        let now = Instant::now();
417
418        // Send packets 0-15 (buffer size is 8, so 0-7 will be pushed out)
419        for seq in 0..16u16 {
420            let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
421            pkt.now = now;
422            chain.handle_write(pkt).unwrap();
423            chain.poll_write();
424        }
425
426        // Request retransmit of seq 0 (should be expired from buffer)
427        let mut nack = make_nack_packet(999, 12345, vec![(0, 0)]);
428        nack.now = now;
429        chain.handle_read(nack).unwrap();
430
431        // No retransmission (packet too old)
432        assert!(chain.poll_write().is_none());
433
434        // But seq 10 should still be available
435        let mut nack = make_nack_packet(999, 12345, vec![(10, 0)]);
436        nack.now = now;
437        chain.handle_read(nack).unwrap();
438
439        let pkt = chain.poll_write();
440        assert!(pkt.is_some());
441        if let Some(tagged) = pkt
442            && let Packet::Rtp(rtp) = tagged.message
443        {
444            assert_eq!(rtp.header.sequence_number, 10);
445        }
446    }
447
448    #[test]
449    fn test_nack_responder_unbind_removes_stream() {
450        let mut chain = Registry::new()
451            .with(NackResponderBuilder::new().with_size(8).build())
452            .build();
453
454        let info = StreamInfo {
455            ssrc: 12345,
456            clock_rate: 90000,
457            rtcp_feedback: vec![RTCPFeedback {
458                typ: "nack".to_string(),
459                parameter: "".to_string(),
460            }],
461            ..Default::default()
462        };
463
464        chain.bind_local_stream(&info);
465        assert!(chain.streams.contains_key(&12345));
466
467        chain.unbind_local_stream(&info);
468        assert!(!chain.streams.contains_key(&12345));
469    }
470
471    #[test]
472    fn test_nack_responder_no_nack_support() {
473        let mut chain = Registry::new()
474            .with(NackResponderBuilder::new().with_size(8).build())
475            .build();
476
477        // Bind stream without NACK support
478        let info = StreamInfo {
479            ssrc: 12345,
480            clock_rate: 90000,
481            rtcp_feedback: vec![], // No NACK support
482            ..Default::default()
483        };
484        chain.bind_local_stream(&info);
485
486        // Should not create send buffer
487        assert!(!chain.streams.contains_key(&12345));
488    }
489
490    #[test]
491    fn test_nack_responder_passthrough() {
492        let mut chain = Registry::new()
493            .with(NackResponderBuilder::new().with_size(8).build())
494            .build();
495
496        let now = Instant::now();
497
498        // RTP packets should pass through
499        let mut pkt = make_rtp_packet(12345, 1, &[1]);
500        pkt.now = now;
501        chain.handle_write(pkt).unwrap();
502        let out = chain.poll_write();
503        assert!(out.is_some());
504
505        // RTCP packets should pass through to read
506        let mut nack = make_nack_packet(999, 12345, vec![(1, 0)]);
507        nack.now = now;
508        chain.handle_read(nack).unwrap();
509        let out = chain.poll_read();
510        assert!(out.is_some());
511    }
512
513    #[test]
514    fn test_nack_responder_rfc4588_rtx() {
515        let mut chain = Registry::new()
516            .with(NackResponderBuilder::new().with_size(8).build())
517            .build();
518
519        // Bind local stream with NACK support AND RTX configured
520        let info = StreamInfo {
521            ssrc: 1,
522            ssrc_rtx: Some(2), // RTX SSRC
523            payload_type: 96,
524            payload_type_rtx: Some(97), // RTX payload type
525            clock_rate: 90000,
526            rtcp_feedback: vec![RTCPFeedback {
527                typ: "nack".to_string(),
528                parameter: "".to_string(),
529            }],
530            ..Default::default()
531        };
532        chain.bind_local_stream(&info);
533
534        let now = Instant::now();
535
536        // Send packets 10, 11, 12, 14, 15 (missing 13)
537        for seq in [10u16, 11, 12, 14, 15] {
538            let mut pkt = make_rtp_packet(1, seq, &[seq as u8]);
539            pkt.now = now;
540            chain.handle_write(pkt).unwrap();
541            chain.poll_write(); // Drain normal write
542        }
543
544        // Receive NACK for 11, 12, 13, 15
545        // nack_pair: packet_id=11, lost_packets=0b1011 means 11, 12, 13, 15
546        let mut nack = make_nack_packet(999, 1, vec![(11, 0b1011)]);
547        nack.now = now;
548        chain.handle_read(nack).unwrap();
549
550        // Should retransmit 11, 12, 15 (13 was never sent) using RTX format
551        let mut rtx_seq = 0u16;
552        for expected_original_seq in [11u16, 12, 15] {
553            let pkt = chain.poll_write();
554            assert!(
555                pkt.is_some(),
556                "Expected RTX packet for seq {}",
557                expected_original_seq
558            );
559
560            if let Some(tagged) = pkt {
561                if let Packet::Rtp(rtp) = tagged.message {
562                    // Verify RTX SSRC
563                    assert_eq!(rtp.header.ssrc, 2, "RTX packet should use RTX SSRC");
564                    // Verify RTX payload type
565                    assert_eq!(
566                        rtp.header.payload_type, 97,
567                        "RTX packet should use RTX payload type"
568                    );
569                    // Verify RTX sequence number (increments separately)
570                    assert_eq!(
571                        rtp.header.sequence_number, rtx_seq,
572                        "RTX seq should be {}",
573                        rtx_seq
574                    );
575                    rtx_seq += 1;
576
577                    // Verify payload: first 2 bytes should be original sequence number (big-endian)
578                    assert!(
579                        rtp.payload.len() >= 2,
580                        "RTX payload should have at least 2 bytes"
581                    );
582                    let original_seq_from_payload =
583                        u16::from_be_bytes([rtp.payload[0], rtp.payload[1]]);
584                    assert_eq!(
585                        original_seq_from_payload, expected_original_seq,
586                        "RTX payload should contain original seq"
587                    );
588
589                    // Verify original payload follows
590                    assert_eq!(
591                        rtp.payload[2..],
592                        [expected_original_seq as u8],
593                        "Original payload should follow seq number"
594                    );
595                } else {
596                    panic!("Expected RTP packet");
597                }
598            }
599        }
600
601        // No more packets
602        assert!(chain.poll_write().is_none());
603    }
604}