1use super::send_buffer::SendBuffer;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::Instant;
12
13pub struct NackResponderBuilder<P> {
27 size: u16,
29 _phantom: PhantomData<P>,
30}
31
32impl<P> Default for NackResponderBuilder<P> {
33 fn default() -> Self {
34 Self {
35 size: 1024,
36 _phantom: PhantomData,
37 }
38 }
39}
40
41impl<P> NackResponderBuilder<P> {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_size(mut self, size: u16) -> Self {
52 self.size = size;
53 self
54 }
55
56 pub fn build(self) -> impl FnOnce(P) -> NackResponderInterceptor<P> {
58 move |inner| NackResponderInterceptor::new(inner, self.size)
59 }
60}
61
62struct LocalStream {
64 send_buffer: SendBuffer,
66 ssrc_rtx: Option<u32>,
68 payload_type_rtx: Option<u8>,
70 rtx_sequence_number: u16,
72}
73
74#[derive(Interceptor)]
79pub struct NackResponderInterceptor<P> {
80 #[next]
81 inner: P,
82
83 size: u16,
85
86 streams: HashMap<u32, LocalStream>,
88
89 write_queue: VecDeque<TaggedPacket>,
91}
92
93impl<P> NackResponderInterceptor<P> {
94 fn new(inner: P, size: u16) -> Self {
95 Self {
96 inner,
97 size,
98 streams: HashMap::new(),
99 write_queue: VecDeque::new(),
100 }
101 }
102
103 fn handle_nack(
105 &mut self,
106 now: Instant,
107 nack: &rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack,
108 ) {
109 let mut seqs_to_retransmit = Vec::new();
111
112 for nack_pair in &nack.nacks {
113 seqs_to_retransmit.push(nack_pair.packet_id);
115
116 for i in 0..16 {
118 if nack_pair.lost_packets & (1 << i) != 0 {
119 let seq = nack_pair.packet_id.wrapping_add(i + 1);
120 seqs_to_retransmit.push(seq);
121 }
122 }
123 }
124
125 let Some(stream) = self.streams.get_mut(&nack.media_ssrc) else {
126 return;
127 };
128
129 for seq in seqs_to_retransmit {
131 let Some(original_packet) = stream.send_buffer.get(seq) else {
132 continue;
133 };
134
135 let packet = if let (Some(ssrc_rtx), Some(pt_rtx)) =
136 (stream.ssrc_rtx, stream.payload_type_rtx)
137 {
138 let original_seq = original_packet.header.sequence_number;
143 let mut rtx_payload = Vec::with_capacity(2 + original_packet.payload.len());
144 rtx_payload.extend_from_slice(&original_seq.to_be_bytes());
145 rtx_payload.extend_from_slice(&original_packet.payload);
146
147 let rtx_seq = stream.rtx_sequence_number;
148 stream.rtx_sequence_number = stream.rtx_sequence_number.wrapping_add(1);
149
150 rtp::Packet {
151 header: rtp::header::Header {
152 ssrc: ssrc_rtx,
153 payload_type: pt_rtx,
154 sequence_number: rtx_seq,
155 timestamp: original_packet.header.timestamp,
156 marker: original_packet.header.marker,
157 ..Default::default()
158 },
159 payload: rtx_payload.into(),
160 }
161 } else {
162 original_packet.clone()
164 };
165
166 self.write_queue.push_back(TaggedPacket {
167 now,
168 transport: TransportContext::default(),
169 message: Packet::Rtp(packet),
170 });
171 }
172 }
173}
174
175#[interceptor]
176impl<P: Interceptor> NackResponderInterceptor<P> {
177 #[overrides]
178 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
179 if let Packet::Rtcp(ref rtcp_packets) = msg.message {
181 for rtcp_packet in rtcp_packets {
182 if let Some(nack) = rtcp_packet
183 .as_any()
184 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
185 {
186 self.handle_nack(msg.now, nack);
187 }
188 }
189 }
190
191 self.inner.handle_read(msg)
192 }
193
194 #[overrides]
195 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
196 if let Packet::Rtp(ref rtp_packet) = msg.message
198 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
199 {
200 stream.send_buffer.add(rtp_packet.clone());
201 }
202
203 self.inner.handle_write(msg)
204 }
205
206 #[overrides]
207 fn poll_write(&mut self) -> Option<Self::Wout> {
208 if let Some(pkt) = self.write_queue.pop_front() {
210 return Some(pkt);
211 }
212 self.inner.poll_write()
213 }
214
215 #[overrides]
216 fn bind_local_stream(&mut self, info: &StreamInfo) {
217 if stream_supports_nack(info)
218 && let Some(send_buffer) = SendBuffer::new(self.size)
219 {
220 self.streams.insert(
221 info.ssrc,
222 LocalStream {
223 send_buffer,
224 ssrc_rtx: info.ssrc_rtx,
225 payload_type_rtx: info.payload_type_rtx,
226 rtx_sequence_number: 0,
227 },
228 );
229 }
230 self.inner.bind_local_stream(info);
231 }
232
233 #[overrides]
234 fn unbind_local_stream(&mut self, info: &StreamInfo) {
235 self.streams.remove(&info.ssrc);
236 self.inner.unbind_local_stream(info);
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::Registry;
244 use crate::stream_info::RTCPFeedback;
245 use sansio::Protocol;
246
247 fn make_rtp_packet(ssrc: u32, seq: u16, payload: &[u8]) -> TaggedPacket {
248 TaggedPacket {
249 now: Instant::now(),
250 transport: Default::default(),
251 message: Packet::Rtp(rtp::Packet {
252 header: rtp::header::Header {
253 ssrc,
254 sequence_number: seq,
255 ..Default::default()
256 },
257 payload: payload.to_vec().into(),
258 }),
259 }
260 }
261
262 fn make_nack_packet(sender_ssrc: u32, media_ssrc: u32, nacks: Vec<(u16, u16)>) -> TaggedPacket {
263 let nack_pairs: Vec<rtcp::transport_feedbacks::transport_layer_nack::NackPair> = nacks
264 .into_iter()
265 .map(|(packet_id, lost_packets)| {
266 rtcp::transport_feedbacks::transport_layer_nack::NackPair {
267 packet_id,
268 lost_packets,
269 }
270 })
271 .collect();
272
273 TaggedPacket {
274 now: Instant::now(),
275 transport: Default::default(),
276 message: Packet::Rtcp(vec![Box::new(
277 rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
278 sender_ssrc,
279 media_ssrc,
280 nacks: nack_pairs,
281 },
282 )]),
283 }
284 }
285
286 #[test]
287 fn test_nack_responder_builder_defaults() {
288 let chain = Registry::new()
289 .with(NackResponderBuilder::default().build())
290 .build();
291
292 assert_eq!(chain.size, 1024);
293 }
294
295 #[test]
296 fn test_nack_responder_builder_custom() {
297 let chain = Registry::new()
298 .with(NackResponderBuilder::new().with_size(2048).build())
299 .build();
300
301 assert_eq!(chain.size, 2048);
302 }
303
304 #[test]
305 fn test_nack_responder_retransmits_packet() {
306 let mut chain = Registry::new()
307 .with(NackResponderBuilder::new().with_size(8).build())
308 .build();
309
310 let info = StreamInfo {
312 ssrc: 12345,
313 clock_rate: 90000,
314 rtcp_feedback: vec![RTCPFeedback {
315 typ: "nack".to_string(),
316 parameter: "".to_string(),
317 }],
318 ..Default::default()
319 };
320 chain.bind_local_stream(&info);
321
322 let now = Instant::now();
323
324 for seq in [10u16, 11, 12, 14, 15] {
326 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
327 pkt.now = now;
328 chain.handle_write(pkt).unwrap();
329 chain.poll_write(); }
331
332 let mut nack = make_nack_packet(999, 12345, vec![(11, 0b1011)]);
335 nack.now = now;
336 chain.handle_read(nack).unwrap();
337
338 let mut retransmitted = Vec::new();
340 while let Some(pkt) = chain.poll_write() {
341 if let Packet::Rtp(rtp) = pkt.message {
342 retransmitted.push(rtp.header.sequence_number);
343 }
344 }
345
346 assert!(retransmitted.contains(&11));
347 assert!(retransmitted.contains(&12));
348 assert!(!retransmitted.contains(&13)); assert!(retransmitted.contains(&15));
350 }
351
352 #[test]
353 fn test_nack_responder_no_retransmit_without_binding() {
354 let mut chain = Registry::new()
355 .with(NackResponderBuilder::new().with_size(8).build())
356 .build();
357
358 let now = Instant::now();
359
360 for seq in [10u16, 11, 12] {
362 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
363 pkt.now = now;
364 chain.handle_write(pkt).unwrap();
365 chain.poll_write();
366 }
367
368 let mut nack = make_nack_packet(999, 12345, vec![(11, 0)]);
370 nack.now = now;
371 chain.handle_read(nack).unwrap();
372
373 assert!(chain.poll_write().is_none());
375 }
376
377 #[test]
378 fn test_nack_responder_no_retransmit_expired_packet() {
379 let mut chain = Registry::new()
380 .with(NackResponderBuilder::new().with_size(8).build())
381 .build();
382
383 let info = StreamInfo {
384 ssrc: 12345,
385 clock_rate: 90000,
386 rtcp_feedback: vec![RTCPFeedback {
387 typ: "nack".to_string(),
388 parameter: "".to_string(),
389 }],
390 ..Default::default()
391 };
392 chain.bind_local_stream(&info);
393
394 let now = Instant::now();
395
396 for seq in 0..16u16 {
398 let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
399 pkt.now = now;
400 chain.handle_write(pkt).unwrap();
401 chain.poll_write();
402 }
403
404 let mut nack = make_nack_packet(999, 12345, vec![(0, 0)]);
406 nack.now = now;
407 chain.handle_read(nack).unwrap();
408
409 assert!(chain.poll_write().is_none());
411
412 let mut nack = make_nack_packet(999, 12345, vec![(10, 0)]);
414 nack.now = now;
415 chain.handle_read(nack).unwrap();
416
417 let pkt = chain.poll_write();
418 assert!(pkt.is_some());
419 if let Some(tagged) = pkt
420 && let Packet::Rtp(rtp) = tagged.message
421 {
422 assert_eq!(rtp.header.sequence_number, 10);
423 }
424 }
425
426 #[test]
427 fn test_nack_responder_unbind_removes_stream() {
428 let mut chain = Registry::new()
429 .with(NackResponderBuilder::new().with_size(8).build())
430 .build();
431
432 let info = StreamInfo {
433 ssrc: 12345,
434 clock_rate: 90000,
435 rtcp_feedback: vec![RTCPFeedback {
436 typ: "nack".to_string(),
437 parameter: "".to_string(),
438 }],
439 ..Default::default()
440 };
441
442 chain.bind_local_stream(&info);
443 assert!(chain.streams.contains_key(&12345));
444
445 chain.unbind_local_stream(&info);
446 assert!(!chain.streams.contains_key(&12345));
447 }
448
449 #[test]
450 fn test_nack_responder_no_nack_support() {
451 let mut chain = Registry::new()
452 .with(NackResponderBuilder::new().with_size(8).build())
453 .build();
454
455 let info = StreamInfo {
457 ssrc: 12345,
458 clock_rate: 90000,
459 rtcp_feedback: vec![], ..Default::default()
461 };
462 chain.bind_local_stream(&info);
463
464 assert!(!chain.streams.contains_key(&12345));
466 }
467
468 #[test]
469 fn test_nack_responder_passthrough() {
470 let mut chain = Registry::new()
471 .with(NackResponderBuilder::new().with_size(8).build())
472 .build();
473
474 let now = Instant::now();
475
476 let mut pkt = make_rtp_packet(12345, 1, &[1]);
478 pkt.now = now;
479 chain.handle_write(pkt).unwrap();
480 let out = chain.poll_write();
481 assert!(out.is_some());
482
483 let mut nack = make_nack_packet(999, 12345, vec![(1, 0)]);
485 nack.now = now;
486 chain.handle_read(nack).unwrap();
487 let out = chain.poll_read();
488 assert!(out.is_none());
489 }
490
491 #[test]
492 fn test_nack_responder_rfc4588_rtx() {
493 let mut chain = Registry::new()
494 .with(NackResponderBuilder::new().with_size(8).build())
495 .build();
496
497 let info = StreamInfo {
499 ssrc: 1,
500 ssrc_rtx: Some(2), payload_type: 96,
502 payload_type_rtx: Some(97), clock_rate: 90000,
504 rtcp_feedback: vec![RTCPFeedback {
505 typ: "nack".to_string(),
506 parameter: "".to_string(),
507 }],
508 ..Default::default()
509 };
510 chain.bind_local_stream(&info);
511
512 let now = Instant::now();
513
514 for seq in [10u16, 11, 12, 14, 15] {
516 let mut pkt = make_rtp_packet(1, seq, &[seq as u8]);
517 pkt.now = now;
518 chain.handle_write(pkt).unwrap();
519 chain.poll_write(); }
521
522 let mut nack = make_nack_packet(999, 1, vec![(11, 0b1011)]);
525 nack.now = now;
526 chain.handle_read(nack).unwrap();
527
528 let mut rtx_seq = 0u16;
530 for expected_original_seq in [11u16, 12, 15] {
531 let pkt = chain.poll_write();
532 assert!(
533 pkt.is_some(),
534 "Expected RTX packet for seq {}",
535 expected_original_seq
536 );
537
538 if let Some(tagged) = pkt {
539 if let Packet::Rtp(rtp) = tagged.message {
540 assert_eq!(rtp.header.ssrc, 2, "RTX packet should use RTX SSRC");
542 assert_eq!(
544 rtp.header.payload_type, 97,
545 "RTX packet should use RTX payload type"
546 );
547 assert_eq!(
549 rtp.header.sequence_number, rtx_seq,
550 "RTX seq should be {}",
551 rtx_seq
552 );
553 rtx_seq += 1;
554
555 assert!(
557 rtp.payload.len() >= 2,
558 "RTX payload should have at least 2 bytes"
559 );
560 let original_seq_from_payload =
561 u16::from_be_bytes([rtp.payload[0], rtp.payload[1]]);
562 assert_eq!(
563 original_seq_from_payload, expected_original_seq,
564 "RTX payload should contain original seq"
565 );
566
567 assert_eq!(
569 rtp.payload[2..],
570 [expected_original_seq as u8],
571 "Original payload should follow seq number"
572 );
573 } else {
574 panic!("Expected RTP packet");
575 }
576 }
577 }
578
579 assert!(chain.poll_write().is_none());
581 }
582}