1use super::receive_log::ReceiveLog;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13pub struct NackGeneratorBuilder<P> {
30 size: u16,
32 interval: Duration,
34 skip_last_n: u16,
36 max_nacks_per_packet: u16,
38 _phantom: PhantomData<P>,
39}
40
41impl<P> Default for NackGeneratorBuilder<P> {
42 fn default() -> Self {
43 Self {
44 size: 512,
45 interval: Duration::from_millis(100),
46 skip_last_n: 0,
47 max_nacks_per_packet: 0,
48 _phantom: PhantomData,
49 }
50 }
51}
52
53impl<P> NackGeneratorBuilder<P> {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn with_size(mut self, size: u16) -> Self {
63 self.size = size;
64 self
65 }
66
67 pub fn with_interval(mut self, interval: Duration) -> Self {
69 self.interval = interval;
70 self
71 }
72
73 pub fn with_skip_last_n(mut self, skip_last_n: u16) -> Self {
78 self.skip_last_n = skip_last_n;
79 self
80 }
81
82 pub fn with_max_nacks_per_packet(mut self, max: u16) -> Self {
86 self.max_nacks_per_packet = max;
87 self
88 }
89
90 pub fn build(self) -> impl FnOnce(P) -> NackGeneratorInterceptor<P> {
92 move |inner| {
93 NackGeneratorInterceptor::new(
94 inner,
95 self.size,
96 self.interval,
97 self.skip_last_n,
98 self.max_nacks_per_packet,
99 )
100 }
101 }
102}
103
104pub struct NackGeneratorInterceptor<P> {
110 inner: P,
111
112 size: u16,
114 interval: Duration,
115 skip_last_n: u16,
116 max_nacks_per_packet: u16,
117
118 eto: Instant,
120
121 sender_ssrc: u32,
123
124 receive_logs: HashMap<u32, ReceiveLog>,
126
127 nack_counts: HashMap<u32, HashMap<u16, u16>>,
129
130 write_queue: VecDeque<TaggedPacket>,
132}
133
134impl<P> NackGeneratorInterceptor<P> {
135 fn new(
136 inner: P,
137 size: u16,
138 interval: Duration,
139 skip_last_n: u16,
140 max_nacks_per_packet: u16,
141 ) -> Self {
142 Self {
143 inner,
144 size,
145 interval,
146 skip_last_n,
147 max_nacks_per_packet,
148 eto: Instant::now(),
149 sender_ssrc: rand::random(),
150 receive_logs: HashMap::new(),
151 nack_counts: HashMap::new(),
152 write_queue: VecDeque::new(),
153 }
154 }
155
156 fn generate_nacks(&mut self, now: Instant) {
158 for (&ssrc, receive_log) in &self.receive_logs {
159 let missing = receive_log.missing_seq_numbers(self.skip_last_n);
160 if missing.is_empty() {
161 self.nack_counts.remove(&ssrc);
163 continue;
164 }
165
166 let nack_count = self.nack_counts.entry(ssrc).or_default();
168
169 let filtered: Vec<u16> = if self.max_nacks_per_packet > 0 {
171 missing
172 .iter()
173 .filter(|&&seq| {
174 let count = nack_count.entry(seq).or_insert(0);
175 if *count < self.max_nacks_per_packet {
176 *count += 1;
177 true
178 } else {
179 false
180 }
181 })
182 .copied()
183 .collect()
184 } else {
185 missing.clone()
186 };
187
188 if filtered.is_empty() {
189 continue;
190 }
191
192 nack_count.retain(|seq, _| missing.contains(seq));
194
195 let nack = rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
197 sender_ssrc: self.sender_ssrc,
198 media_ssrc: ssrc,
199 nacks: rtcp::transport_feedbacks::transport_layer_nack::nack_pairs_from_sequence_numbers(
200 &filtered,
201 ),
202 };
203
204 self.write_queue.push_back(TaggedPacket {
205 now,
206 transport: TransportContext::default(),
207 message: Packet::Rtcp(vec![Box::new(nack)]),
208 });
209 }
210 }
211}
212
213impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
214 for NackGeneratorInterceptor<P>
215{
216 type Rout = TaggedPacket;
217 type Wout = TaggedPacket;
218 type Eout = ();
219 type Error = Error;
220 type Time = Instant;
221
222 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
223 if let Packet::Rtp(ref rtp_packet) = msg.message
225 && let Some(receive_log) = self.receive_logs.get_mut(&rtp_packet.header.ssrc)
226 {
227 receive_log.add(rtp_packet.header.sequence_number);
228 }
229
230 self.inner.handle_read(msg)
231 }
232
233 fn poll_read(&mut self) -> Option<Self::Rout> {
234 self.inner.poll_read()
235 }
236
237 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
238 self.inner.handle_write(msg)
239 }
240
241 fn poll_write(&mut self) -> Option<Self::Wout> {
242 if let Some(pkt) = self.write_queue.pop_front() {
244 return Some(pkt);
245 }
246 self.inner.poll_write()
247 }
248
249 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
250 if self.eto <= now {
251 self.eto = now + self.interval;
252 self.generate_nacks(now);
253 }
254
255 self.inner.handle_timeout(now)
256 }
257
258 fn poll_timeout(&mut self) -> Option<Self::Time> {
259 if let Some(inner_eto) = self.inner.poll_timeout()
260 && inner_eto < self.eto
261 {
262 return Some(inner_eto);
263 }
264 Some(self.eto)
265 }
266}
267
268impl<P: Interceptor> Interceptor for NackGeneratorInterceptor<P> {
269 fn bind_local_stream(&mut self, info: &StreamInfo) {
270 self.inner.bind_local_stream(info);
271 }
272
273 fn unbind_local_stream(&mut self, info: &StreamInfo) {
274 self.inner.unbind_local_stream(info);
275 }
276
277 fn bind_remote_stream(&mut self, info: &StreamInfo) {
278 if stream_supports_nack(info)
279 && let Some(receive_log) = ReceiveLog::new(self.size)
280 {
281 self.receive_logs.insert(info.ssrc, receive_log);
282 }
283 self.inner.bind_remote_stream(info);
284 }
285
286 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
287 self.receive_logs.remove(&info.ssrc);
288 self.nack_counts.remove(&info.ssrc);
289 self.inner.unbind_remote_stream(info);
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::Registry;
297 use crate::stream_info::RTCPFeedback;
298 use sansio::Protocol;
299
300 fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
301 TaggedPacket {
302 now: Instant::now(),
303 transport: Default::default(),
304 message: Packet::Rtp(rtp::Packet {
305 header: rtp::header::Header {
306 ssrc,
307 sequence_number: seq,
308 ..Default::default()
309 },
310 ..Default::default()
311 }),
312 }
313 }
314
315 #[test]
316 fn test_nack_generator_builder_defaults() {
317 let chain = Registry::new()
318 .with(NackGeneratorBuilder::default().build())
319 .build();
320
321 assert_eq!(chain.size, 512);
322 assert_eq!(chain.interval, Duration::from_millis(100));
323 assert_eq!(chain.skip_last_n, 0);
324 assert_eq!(chain.max_nacks_per_packet, 0);
325 }
326
327 #[test]
328 fn test_nack_generator_builder_custom() {
329 let chain = Registry::new()
330 .with(
331 NackGeneratorBuilder::new()
332 .with_size(1024)
333 .with_interval(Duration::from_millis(50))
334 .with_skip_last_n(3)
335 .with_max_nacks_per_packet(5)
336 .build(),
337 )
338 .build();
339
340 assert_eq!(chain.size, 1024);
341 assert_eq!(chain.interval, Duration::from_millis(50));
342 assert_eq!(chain.skip_last_n, 3);
343 assert_eq!(chain.max_nacks_per_packet, 5);
344 }
345
346 #[test]
347 fn test_nack_generator_no_nack_without_binding() {
348 let mut chain = Registry::new()
349 .with(
350 NackGeneratorBuilder::new()
351 .with_interval(Duration::from_millis(100))
352 .build(),
353 )
354 .build();
355
356 let now = Instant::now();
357
358 chain.handle_read(make_rtp_packet(12345, 0)).unwrap();
360 chain.handle_read(make_rtp_packet(12345, 2)).unwrap(); let later = now + Duration::from_millis(200);
364 chain.handle_timeout(later).unwrap();
365
366 assert!(chain.poll_write().is_none());
368 }
369
370 #[test]
371 fn test_nack_generator_generates_nack() {
372 let mut chain = Registry::new()
373 .with(
374 NackGeneratorBuilder::new()
375 .with_size(64)
376 .with_interval(Duration::from_millis(100))
377 .build(),
378 )
379 .build();
380
381 let info = StreamInfo {
383 ssrc: 12345,
384 clock_rate: 90000,
385 rtcp_feedback: vec![RTCPFeedback {
386 typ: "nack".to_string(),
387 parameter: "".to_string(),
388 }],
389 ..Default::default()
390 };
391 chain.bind_remote_stream(&info);
392
393 let base_time = Instant::now();
394
395 let mut pkt = make_rtp_packet(12345, 10);
397 pkt.now = base_time;
398 chain.handle_read(pkt).unwrap();
399
400 let mut pkt = make_rtp_packet(12345, 12); pkt.now = base_time;
402 chain.handle_read(pkt).unwrap();
403
404 chain.poll_read();
405
406 let later = base_time + Duration::from_millis(200);
408 chain.handle_timeout(later).unwrap();
409
410 let nack_pkt = chain.poll_write();
412 assert!(nack_pkt.is_some());
413
414 if let Some(tagged) = nack_pkt {
415 if let Packet::Rtcp(rtcp_packets) = tagged.message {
416 assert_eq!(rtcp_packets.len(), 1);
417 let nack = rtcp_packets[0]
418 .as_any()
419 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
420 .expect("Expected TransportLayerNack");
421 assert_eq!(nack.media_ssrc, 12345);
422 assert!(!nack.nacks.is_empty());
423 } else {
424 panic!("Expected RTCP packet");
425 }
426 }
427 }
428
429 #[test]
430 fn test_nack_generator_skip_last_n() {
431 let mut chain = Registry::new()
432 .with(
433 NackGeneratorBuilder::new()
434 .with_size(64)
435 .with_interval(Duration::from_millis(100))
436 .with_skip_last_n(2)
437 .build(),
438 )
439 .build();
440
441 let info = StreamInfo {
442 ssrc: 12345,
443 clock_rate: 90000,
444 rtcp_feedback: vec![RTCPFeedback {
445 typ: "nack".to_string(),
446 parameter: "".to_string(),
447 }],
448 ..Default::default()
449 };
450 chain.bind_remote_stream(&info);
451
452 let base_time = Instant::now();
453
454 for seq in [10u16, 11, 12, 14, 16, 18] {
456 let mut pkt = make_rtp_packet(12345, seq);
457 pkt.now = base_time;
458 chain.handle_read(pkt).unwrap();
459 }
460
461 let later = base_time + Duration::from_millis(200);
463 chain.handle_timeout(later).unwrap();
464
465 let nack_pkt = chain.poll_write();
467 assert!(nack_pkt.is_some());
468
469 if let Some(tagged) = nack_pkt
470 && let Packet::Rtcp(rtcp_packets) = tagged.message
471 {
472 let nack = rtcp_packets[0]
473 .as_any()
474 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
475 .expect("Expected TransportLayerNack");
476
477 let mut nacked_seqs = Vec::new();
479 for nack_pair in &nack.nacks {
480 nacked_seqs.push(nack_pair.packet_id);
481 for i in 0..16 {
482 if nack_pair.lost_packets & (1 << i) != 0 {
483 nacked_seqs.push(nack_pair.packet_id.wrapping_add(i + 1));
484 }
485 }
486 }
487
488 assert!(nacked_seqs.contains(&13));
490 assert!(nacked_seqs.contains(&15));
491 assert!(!nacked_seqs.contains(&17));
492 }
493 }
494
495 #[test]
496 fn test_nack_generator_unbind_removes_stream() {
497 let mut chain = Registry::new()
498 .with(
499 NackGeneratorBuilder::new()
500 .with_size(64)
501 .with_interval(Duration::from_millis(100))
502 .build(),
503 )
504 .build();
505
506 let info = StreamInfo {
507 ssrc: 12345,
508 clock_rate: 90000,
509 rtcp_feedback: vec![RTCPFeedback {
510 typ: "nack".to_string(),
511 parameter: "".to_string(),
512 }],
513 ..Default::default()
514 };
515
516 chain.bind_remote_stream(&info);
517 assert!(chain.receive_logs.contains_key(&12345));
518
519 chain.unbind_remote_stream(&info);
520 assert!(!chain.receive_logs.contains_key(&12345));
521 assert!(!chain.nack_counts.contains_key(&12345));
522 }
523
524 #[test]
525 fn test_nack_generator_no_nack_support() {
526 let mut chain = Registry::new()
527 .with(
528 NackGeneratorBuilder::new()
529 .with_size(64)
530 .with_interval(Duration::from_millis(100))
531 .build(),
532 )
533 .build();
534
535 let info = StreamInfo {
537 ssrc: 12345,
538 clock_rate: 90000,
539 rtcp_feedback: vec![], ..Default::default()
541 };
542 chain.bind_remote_stream(&info);
543
544 assert!(!chain.receive_logs.contains_key(&12345));
546 }
547}