1use super::send_buffer::SendBuffer;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::Instant;
12
13pub struct NackResponderBuilder<P> {
27 size: u16,
29 _phantom: PhantomData<P>,
30}
31
32impl<P> Default for NackResponderBuilder<P> {
33 fn default() -> Self {
34 Self {
35 size: 1024,
36 _phantom: PhantomData,
37 }
38 }
39}
40
41impl<P> NackResponderBuilder<P> {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_size(mut self, size: u16) -> Self {
52 self.size = size;
53 self
54 }
55
56 pub fn build(self) -> impl FnOnce(P) -> NackResponderInterceptor<P> {
58 move |inner| NackResponderInterceptor::new(inner, self.size)
59 }
60}
61
62struct LocalStream {
64 send_buffer: SendBuffer,
66 ssrc_rtx: Option<u32>,
68 payload_type_rtx: Option<u8>,
70 rtx_sequence_number: u16,
72}
73
74pub struct NackResponderInterceptor<P> {
79 inner: P,
80
81 size: u16,
83
84 streams: HashMap<u32, LocalStream>,
86
87 write_queue: VecDeque<TaggedPacket>,
89}
90
91impl<P> NackResponderInterceptor<P> {
92 fn new(inner: P, size: u16) -> Self {
93 Self {
94 inner,
95 size,
96 streams: HashMap::new(),
97 write_queue: VecDeque::new(),
98 }
99 }
100
101 fn handle_nack(
103 &mut self,
104 now: Instant,
105 nack: &rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack,
106 ) {
107 let mut seqs_to_retransmit = Vec::new();
109
110 for nack_pair in &nack.nacks {
111 seqs_to_retransmit.push(nack_pair.packet_id);
113
114 for i in 0..16 {
116 if nack_pair.lost_packets & (1 << i) != 0 {
117 let seq = nack_pair.packet_id.wrapping_add(i + 1);
118 seqs_to_retransmit.push(seq);
119 }
120 }
121 }
122
123 let Some(stream) = self.streams.get_mut(&nack.media_ssrc) else {
124 return;
125 };
126
127 for seq in seqs_to_retransmit {
129 let Some(original_packet) = stream.send_buffer.get(seq) else {
130 continue;
131 };
132
133 let packet = if let (Some(ssrc_rtx), Some(pt_rtx)) =
134 (stream.ssrc_rtx, stream.payload_type_rtx)
135 {
136 let original_seq = original_packet.header.sequence_number;
141 let mut rtx_payload = Vec::with_capacity(2 + original_packet.payload.len());
142 rtx_payload.extend_from_slice(&original_seq.to_be_bytes());
143 rtx_payload.extend_from_slice(&original_packet.payload);
144
145 let rtx_seq = stream.rtx_sequence_number;
146 stream.rtx_sequence_number = stream.rtx_sequence_number.wrapping_add(1);
147
148 rtp::Packet {
149 header: rtp::header::Header {
150 ssrc: ssrc_rtx,
151 payload_type: pt_rtx,
152 sequence_number: rtx_seq,
153 timestamp: original_packet.header.timestamp,
154 marker: original_packet.header.marker,
155 ..Default::default()
156 },
157 payload: rtx_payload.into(),
158 }
159 } else {
160 original_packet.clone()
162 };
163
164 self.write_queue.push_back(TaggedPacket {
165 now,
166 transport: TransportContext::default(),
167 message: Packet::Rtp(packet),
168 });
169 }
170 }
171}
172
173impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
174 for NackResponderInterceptor<P>
175{
176 type Rout = TaggedPacket;
177 type Wout = TaggedPacket;
178 type Eout = ();
179 type Error = Error;
180 type Time = Instant;
181
182 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
183 if let Packet::Rtcp(ref rtcp_packets) = msg.message {
185 for rtcp_packet in rtcp_packets {
186 if let Some(nack) = rtcp_packet
187 .as_any()
188 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
189 {
190 self.handle_nack(msg.now, nack);
191 }
192 }
193 }
194
195 self.inner.handle_read(msg)
196 }
197
198 fn poll_read(&mut self) -> Option<Self::Rout> {
199 self.inner.poll_read()
200 }
201
202 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
203 if let Packet::Rtp(ref rtp_packet) = msg.message
205 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
206 {
207 stream.send_buffer.add(rtp_packet.clone());
208 }
209
210 self.inner.handle_write(msg)
211 }
212
213 fn poll_write(&mut self) -> Option<Self::Wout> {
214 if let Some(pkt) = self.write_queue.pop_front() {
216 return Some(pkt);
217 }
218 self.inner.poll_write()
219 }
220
221 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
222 self.inner.handle_timeout(now)
223 }
224
225 fn poll_timeout(&mut self) -> Option<Self::Time> {
226 self.inner.poll_timeout()
227 }
228}
229
230impl<P: Interceptor> Interceptor for NackResponderInterceptor<P> {
231 fn bind_local_stream(&mut self, info: &StreamInfo) {
232 if stream_supports_nack(info)
233 && let Some(send_buffer) = SendBuffer::new(self.size)
234 {
235 self.streams.insert(
236 info.ssrc,
237 LocalStream {
238 send_buffer,
239 ssrc_rtx: info.ssrc_rtx,
240 payload_type_rtx: info.payload_type_rtx,
241 rtx_sequence_number: 0,
242 },
243 );
244 }
245 self.inner.bind_local_stream(info);
246 }
247
248 fn unbind_local_stream(&mut self, info: &StreamInfo) {
249 self.streams.remove(&info.ssrc);
250 self.inner.unbind_local_stream(info);
251 }
252
253 fn bind_remote_stream(&mut self, info: &StreamInfo) {
254 self.inner.bind_remote_stream(info);
255 }
256
257 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
258 self.inner.unbind_remote_stream(info);
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::Registry;
266 use crate::stream_info::RTCPFeedback;
267 use sansio::Protocol;
268
269 fn make_rtp_packet(ssrc: u32, seq: u16, payload: &[u8]) -> TaggedPacket {
270 TaggedPacket {
271 now: Instant::now(),
272 transport: Default::default(),
273 message: Packet::Rtp(rtp::Packet {
274 header: rtp::header::Header {
275 ssrc,
276 sequence_number: seq,
277 ..Default::default()
278 },
279 payload: payload.to_vec().into(),
280 }),
281 }
282 }
283
284 fn make_nack_packet(sender_ssrc: u32, media_ssrc: u32, nacks: Vec<(u16, u16)>) -> TaggedPacket {
285 let nack_pairs: Vec<rtcp::transport_feedbacks::transport_layer_nack::NackPair> = nacks
286 .into_iter()
287 .map(|(packet_id, lost_packets)| {
288 rtcp::transport_feedbacks::transport_layer_nack::NackPair {
289 packet_id,
290 lost_packets,
291 }
292 })
293 .collect();
294
295 TaggedPacket {
296 now: Instant::now(),
297 transport: Default::default(),
298 message: Packet::Rtcp(vec![Box::new(
299 rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
300 sender_ssrc,
301 media_ssrc,
302 nacks: nack_pairs,
303 },
304 )]),
305 }
306 }
307
308 #[test]
309 fn test_nack_responder_builder_defaults() {
310 let chain = Registry::new()
311 .with(NackResponderBuilder::default().build())
312 .build();
313
314 assert_eq!(chain.size, 1024);
315 }
316
317 #[test]
318 fn test_nack_responder_builder_custom() {
319 let chain = Registry::new()
320 .with(NackResponderBuilder::new().with_size(2048).build())
321 .build();
322
323 assert_eq!(chain.size, 2048);
324 }
325
326 #[test]
327 fn test_nack_responder_retransmits_packet() {
328 let mut chain = Registry::new()
329 .with(NackResponderBuilder::new().with_size(8).build())
330 .build();
331
332 let info = StreamInfo {
334 ssrc: 12345,
335 clock_rate: 90000,
336 rtcp_feedback: vec![RTCPFeedback {
337 typ: "nack".to_string(),
338 parameter: "".to_string(),
339 }],
340 ..Default::default()
341 };
342 chain.bind_local_stream(&info);
343
344 let now = Instant::now();
345
346 for seq in [10u16, 11, 12, 14, 15] {
348 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
349 pkt.now = now;
350 chain.handle_write(pkt).unwrap();
351 chain.poll_write(); }
353
354 let mut nack = make_nack_packet(999, 12345, vec![(11, 0b1011)]);
357 nack.now = now;
358 chain.handle_read(nack).unwrap();
359
360 let mut retransmitted = Vec::new();
362 while let Some(pkt) = chain.poll_write() {
363 if let Packet::Rtp(rtp) = pkt.message {
364 retransmitted.push(rtp.header.sequence_number);
365 }
366 }
367
368 assert!(retransmitted.contains(&11));
369 assert!(retransmitted.contains(&12));
370 assert!(!retransmitted.contains(&13)); assert!(retransmitted.contains(&15));
372 }
373
374 #[test]
375 fn test_nack_responder_no_retransmit_without_binding() {
376 let mut chain = Registry::new()
377 .with(NackResponderBuilder::new().with_size(8).build())
378 .build();
379
380 let now = Instant::now();
381
382 for seq in [10u16, 11, 12] {
384 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
385 pkt.now = now;
386 chain.handle_write(pkt).unwrap();
387 chain.poll_write();
388 }
389
390 let mut nack = make_nack_packet(999, 12345, vec![(11, 0)]);
392 nack.now = now;
393 chain.handle_read(nack).unwrap();
394
395 assert!(chain.poll_write().is_none());
397 }
398
399 #[test]
400 fn test_nack_responder_no_retransmit_expired_packet() {
401 let mut chain = Registry::new()
402 .with(NackResponderBuilder::new().with_size(8).build())
403 .build();
404
405 let info = StreamInfo {
406 ssrc: 12345,
407 clock_rate: 90000,
408 rtcp_feedback: vec![RTCPFeedback {
409 typ: "nack".to_string(),
410 parameter: "".to_string(),
411 }],
412 ..Default::default()
413 };
414 chain.bind_local_stream(&info);
415
416 let now = Instant::now();
417
418 for seq in 0..16u16 {
420 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
421 pkt.now = now;
422 chain.handle_write(pkt).unwrap();
423 chain.poll_write();
424 }
425
426 let mut nack = make_nack_packet(999, 12345, vec![(0, 0)]);
428 nack.now = now;
429 chain.handle_read(nack).unwrap();
430
431 assert!(chain.poll_write().is_none());
433
434 let mut nack = make_nack_packet(999, 12345, vec![(10, 0)]);
436 nack.now = now;
437 chain.handle_read(nack).unwrap();
438
439 let pkt = chain.poll_write();
440 assert!(pkt.is_some());
441 if let Some(tagged) = pkt
442 && let Packet::Rtp(rtp) = tagged.message
443 {
444 assert_eq!(rtp.header.sequence_number, 10);
445 }
446 }
447
448 #[test]
449 fn test_nack_responder_unbind_removes_stream() {
450 let mut chain = Registry::new()
451 .with(NackResponderBuilder::new().with_size(8).build())
452 .build();
453
454 let info = StreamInfo {
455 ssrc: 12345,
456 clock_rate: 90000,
457 rtcp_feedback: vec![RTCPFeedback {
458 typ: "nack".to_string(),
459 parameter: "".to_string(),
460 }],
461 ..Default::default()
462 };
463
464 chain.bind_local_stream(&info);
465 assert!(chain.streams.contains_key(&12345));
466
467 chain.unbind_local_stream(&info);
468 assert!(!chain.streams.contains_key(&12345));
469 }
470
471 #[test]
472 fn test_nack_responder_no_nack_support() {
473 let mut chain = Registry::new()
474 .with(NackResponderBuilder::new().with_size(8).build())
475 .build();
476
477 let info = StreamInfo {
479 ssrc: 12345,
480 clock_rate: 90000,
481 rtcp_feedback: vec![], ..Default::default()
483 };
484 chain.bind_local_stream(&info);
485
486 assert!(!chain.streams.contains_key(&12345));
488 }
489
490 #[test]
491 fn test_nack_responder_passthrough() {
492 let mut chain = Registry::new()
493 .with(NackResponderBuilder::new().with_size(8).build())
494 .build();
495
496 let now = Instant::now();
497
498 let mut pkt = make_rtp_packet(12345, 1, &[1]);
500 pkt.now = now;
501 chain.handle_write(pkt).unwrap();
502 let out = chain.poll_write();
503 assert!(out.is_some());
504
505 let mut nack = make_nack_packet(999, 12345, vec![(1, 0)]);
507 nack.now = now;
508 chain.handle_read(nack).unwrap();
509 let out = chain.poll_read();
510 assert!(out.is_some());
511 }
512
513 #[test]
514 fn test_nack_responder_rfc4588_rtx() {
515 let mut chain = Registry::new()
516 .with(NackResponderBuilder::new().with_size(8).build())
517 .build();
518
519 let info = StreamInfo {
521 ssrc: 1,
522 ssrc_rtx: Some(2), payload_type: 96,
524 payload_type_rtx: Some(97), clock_rate: 90000,
526 rtcp_feedback: vec![RTCPFeedback {
527 typ: "nack".to_string(),
528 parameter: "".to_string(),
529 }],
530 ..Default::default()
531 };
532 chain.bind_local_stream(&info);
533
534 let now = Instant::now();
535
536 for seq in [10u16, 11, 12, 14, 15] {
538 let mut pkt = make_rtp_packet(1, seq, &[seq as u8]);
539 pkt.now = now;
540 chain.handle_write(pkt).unwrap();
541 chain.poll_write(); }
543
544 let mut nack = make_nack_packet(999, 1, vec![(11, 0b1011)]);
547 nack.now = now;
548 chain.handle_read(nack).unwrap();
549
550 let mut rtx_seq = 0u16;
552 for expected_original_seq in [11u16, 12, 15] {
553 let pkt = chain.poll_write();
554 assert!(
555 pkt.is_some(),
556 "Expected RTX packet for seq {}",
557 expected_original_seq
558 );
559
560 if let Some(tagged) = pkt {
561 if let Packet::Rtp(rtp) = tagged.message {
562 assert_eq!(rtp.header.ssrc, 2, "RTX packet should use RTX SSRC");
564 assert_eq!(
566 rtp.header.payload_type, 97,
567 "RTX packet should use RTX payload type"
568 );
569 assert_eq!(
571 rtp.header.sequence_number, rtx_seq,
572 "RTX seq should be {}",
573 rtx_seq
574 );
575 rtx_seq += 1;
576
577 assert!(
579 rtp.payload.len() >= 2,
580 "RTX payload should have at least 2 bytes"
581 );
582 let original_seq_from_payload =
583 u16::from_be_bytes([rtp.payload[0], rtp.payload[1]]);
584 assert_eq!(
585 original_seq_from_payload, expected_original_seq,
586 "RTX payload should contain original seq"
587 );
588
589 assert_eq!(
591 rtp.payload[2..],
592 [expected_original_seq as u8],
593 "Original payload should follow seq number"
594 );
595 } else {
596 panic!("Expected RTP packet");
597 }
598 }
599 }
600
601 assert!(chain.poll_write().is_none());
603 }
604}