1use super::receive_log::ReceiveLog;
4use super::stream_supports_nack;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13pub struct NackGeneratorBuilder<P> {
30 size: u16,
32 interval: Duration,
34 skip_last_n: u16,
36 max_nacks_per_packet: u16,
38 _phantom: PhantomData<P>,
39}
40
41impl<P> Default for NackGeneratorBuilder<P> {
42 fn default() -> Self {
43 Self {
44 size: 512,
45 interval: Duration::from_millis(100),
46 skip_last_n: 0,
47 max_nacks_per_packet: 0,
48 _phantom: PhantomData,
49 }
50 }
51}
52
53impl<P> NackGeneratorBuilder<P> {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn with_size(mut self, size: u16) -> Self {
63 self.size = size;
64 self
65 }
66
67 pub fn with_interval(mut self, interval: Duration) -> Self {
69 self.interval = interval;
70 self
71 }
72
73 pub fn with_skip_last_n(mut self, skip_last_n: u16) -> Self {
78 self.skip_last_n = skip_last_n;
79 self
80 }
81
82 pub fn with_max_nacks_per_packet(mut self, max: u16) -> Self {
86 self.max_nacks_per_packet = max;
87 self
88 }
89
90 pub fn build(self) -> impl FnOnce(P) -> NackGeneratorInterceptor<P> {
92 move |inner| {
93 NackGeneratorInterceptor::new(
94 inner,
95 self.size,
96 self.interval,
97 self.skip_last_n,
98 self.max_nacks_per_packet,
99 )
100 }
101 }
102}
103
104#[derive(Interceptor)]
110pub struct NackGeneratorInterceptor<P> {
111 #[next]
112 inner: P,
113
114 size: u16,
116 interval: Duration,
117 skip_last_n: u16,
118 max_nacks_per_packet: u16,
119
120 eto: Instant,
122
123 sender_ssrc: u32,
125
126 receive_logs: HashMap<u32, ReceiveLog>,
128
129 nack_counts: HashMap<u32, HashMap<u16, u16>>,
131
132 write_queue: VecDeque<TaggedPacket>,
134}
135
136impl<P> NackGeneratorInterceptor<P> {
137 fn new(
138 inner: P,
139 size: u16,
140 interval: Duration,
141 skip_last_n: u16,
142 max_nacks_per_packet: u16,
143 ) -> Self {
144 Self {
145 inner,
146 size,
147 interval,
148 skip_last_n,
149 max_nacks_per_packet,
150 eto: Instant::now(),
151 sender_ssrc: rand::random(),
152 receive_logs: HashMap::new(),
153 nack_counts: HashMap::new(),
154 write_queue: VecDeque::new(),
155 }
156 }
157
158 fn generate_nacks(&mut self, now: Instant) {
160 for (&ssrc, receive_log) in &self.receive_logs {
161 let missing = receive_log.missing_seq_numbers(self.skip_last_n);
162 if missing.is_empty() {
163 self.nack_counts.remove(&ssrc);
165 continue;
166 }
167
168 let nack_count = self.nack_counts.entry(ssrc).or_default();
170
171 let filtered: Vec<u16> = if self.max_nacks_per_packet > 0 {
173 missing
174 .iter()
175 .filter(|&&seq| {
176 let count = nack_count.entry(seq).or_insert(0);
177 if *count < self.max_nacks_per_packet {
178 *count += 1;
179 true
180 } else {
181 false
182 }
183 })
184 .copied()
185 .collect()
186 } else {
187 missing.clone()
188 };
189
190 if filtered.is_empty() {
191 continue;
192 }
193
194 nack_count.retain(|seq, _| missing.contains(seq));
196
197 let nack = rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
199 sender_ssrc: self.sender_ssrc,
200 media_ssrc: ssrc,
201 nacks: rtcp::transport_feedbacks::transport_layer_nack::nack_pairs_from_sequence_numbers(
202 &filtered,
203 ),
204 };
205
206 self.write_queue.push_back(TaggedPacket {
207 now,
208 transport: TransportContext::default(),
209 message: Packet::Rtcp(vec![Box::new(nack)]),
210 });
211 }
212 }
213}
214
215#[interceptor]
216impl<P: Interceptor> NackGeneratorInterceptor<P> {
217 #[overrides]
218 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
219 if let Packet::Rtp(ref rtp_packet) = msg.message
221 && let Some(receive_log) = self.receive_logs.get_mut(&rtp_packet.header.ssrc)
222 {
223 receive_log.add(rtp_packet.header.sequence_number);
224 }
225
226 self.inner.handle_read(msg)
227 }
228
229 #[overrides]
230 fn poll_write(&mut self) -> Option<Self::Wout> {
231 if let Some(pkt) = self.write_queue.pop_front() {
233 return Some(pkt);
234 }
235 self.inner.poll_write()
236 }
237
238 #[overrides]
239 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
240 if self.eto <= now {
241 self.eto = now + self.interval;
242 self.generate_nacks(now);
243 }
244
245 self.inner.handle_timeout(now)
246 }
247
248 #[overrides]
249 fn poll_timeout(&mut self) -> Option<Self::Time> {
250 if let Some(inner_eto) = self.inner.poll_timeout()
251 && inner_eto < self.eto
252 {
253 return Some(inner_eto);
254 }
255 Some(self.eto)
256 }
257
258 #[overrides]
259 fn bind_remote_stream(&mut self, info: &StreamInfo) {
260 if stream_supports_nack(info)
261 && let Some(receive_log) = ReceiveLog::new(self.size)
262 {
263 self.receive_logs.insert(info.ssrc, receive_log);
264 }
265 self.inner.bind_remote_stream(info);
266 }
267
268 #[overrides]
269 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
270 self.receive_logs.remove(&info.ssrc);
271 self.nack_counts.remove(&info.ssrc);
272 self.inner.unbind_remote_stream(info);
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::Registry;
280 use crate::stream_info::RTCPFeedback;
281 use sansio::Protocol;
282
283 fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
284 TaggedPacket {
285 now: Instant::now(),
286 transport: Default::default(),
287 message: Packet::Rtp(rtp::Packet {
288 header: rtp::header::Header {
289 ssrc,
290 sequence_number: seq,
291 ..Default::default()
292 },
293 ..Default::default()
294 }),
295 }
296 }
297
298 #[test]
299 fn test_nack_generator_builder_defaults() {
300 let chain = Registry::new()
301 .with(NackGeneratorBuilder::default().build())
302 .build();
303
304 assert_eq!(chain.size, 512);
305 assert_eq!(chain.interval, Duration::from_millis(100));
306 assert_eq!(chain.skip_last_n, 0);
307 assert_eq!(chain.max_nacks_per_packet, 0);
308 }
309
310 #[test]
311 fn test_nack_generator_builder_custom() {
312 let chain = Registry::new()
313 .with(
314 NackGeneratorBuilder::new()
315 .with_size(1024)
316 .with_interval(Duration::from_millis(50))
317 .with_skip_last_n(3)
318 .with_max_nacks_per_packet(5)
319 .build(),
320 )
321 .build();
322
323 assert_eq!(chain.size, 1024);
324 assert_eq!(chain.interval, Duration::from_millis(50));
325 assert_eq!(chain.skip_last_n, 3);
326 assert_eq!(chain.max_nacks_per_packet, 5);
327 }
328
329 #[test]
330 fn test_nack_generator_no_nack_without_binding() {
331 let mut chain = Registry::new()
332 .with(
333 NackGeneratorBuilder::new()
334 .with_interval(Duration::from_millis(100))
335 .build(),
336 )
337 .build();
338
339 let now = Instant::now();
340
341 chain.handle_read(make_rtp_packet(12345, 0)).unwrap();
343 chain.handle_read(make_rtp_packet(12345, 2)).unwrap(); let later = now + Duration::from_millis(200);
347 chain.handle_timeout(later).unwrap();
348
349 assert!(chain.poll_write().is_none());
351 }
352
353 #[test]
354 fn test_nack_generator_generates_nack() {
355 let mut chain = Registry::new()
356 .with(
357 NackGeneratorBuilder::new()
358 .with_size(64)
359 .with_interval(Duration::from_millis(100))
360 .build(),
361 )
362 .build();
363
364 let info = StreamInfo {
366 ssrc: 12345,
367 clock_rate: 90000,
368 rtcp_feedback: vec![RTCPFeedback {
369 typ: "nack".to_string(),
370 parameter: "".to_string(),
371 }],
372 ..Default::default()
373 };
374 chain.bind_remote_stream(&info);
375
376 let base_time = Instant::now();
377
378 let mut pkt = make_rtp_packet(12345, 10);
380 pkt.now = base_time;
381 chain.handle_read(pkt).unwrap();
382
383 let mut pkt = make_rtp_packet(12345, 12); pkt.now = base_time;
385 chain.handle_read(pkt).unwrap();
386
387 chain.poll_read();
388
389 let later = base_time + Duration::from_millis(200);
391 chain.handle_timeout(later).unwrap();
392
393 let nack_pkt = chain.poll_write();
395 assert!(nack_pkt.is_some());
396
397 if let Some(tagged) = nack_pkt {
398 if let Packet::Rtcp(rtcp_packets) = tagged.message {
399 assert_eq!(rtcp_packets.len(), 1);
400 let nack = rtcp_packets[0]
401 .as_any()
402 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
403 .expect("Expected TransportLayerNack");
404 assert_eq!(nack.media_ssrc, 12345);
405 assert!(!nack.nacks.is_empty());
406 } else {
407 panic!("Expected RTCP packet");
408 }
409 }
410 }
411
412 #[test]
413 fn test_nack_generator_skip_last_n() {
414 let mut chain = Registry::new()
415 .with(
416 NackGeneratorBuilder::new()
417 .with_size(64)
418 .with_interval(Duration::from_millis(100))
419 .with_skip_last_n(2)
420 .build(),
421 )
422 .build();
423
424 let info = StreamInfo {
425 ssrc: 12345,
426 clock_rate: 90000,
427 rtcp_feedback: vec![RTCPFeedback {
428 typ: "nack".to_string(),
429 parameter: "".to_string(),
430 }],
431 ..Default::default()
432 };
433 chain.bind_remote_stream(&info);
434
435 let base_time = Instant::now();
436
437 for seq in [10u16, 11, 12, 14, 16, 18] {
439 let mut pkt = make_rtp_packet(12345, seq);
440 pkt.now = base_time;
441 chain.handle_read(pkt).unwrap();
442 }
443
444 let later = base_time + Duration::from_millis(200);
446 chain.handle_timeout(later).unwrap();
447
448 let nack_pkt = chain.poll_write();
450 assert!(nack_pkt.is_some());
451
452 if let Some(tagged) = nack_pkt
453 && let Packet::Rtcp(rtcp_packets) = tagged.message
454 {
455 let nack = rtcp_packets[0]
456 .as_any()
457 .downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
458 .expect("Expected TransportLayerNack");
459
460 let mut nacked_seqs = Vec::new();
462 for nack_pair in &nack.nacks {
463 nacked_seqs.push(nack_pair.packet_id);
464 for i in 0..16 {
465 if nack_pair.lost_packets & (1 << i) != 0 {
466 nacked_seqs.push(nack_pair.packet_id.wrapping_add(i + 1));
467 }
468 }
469 }
470
471 assert!(nacked_seqs.contains(&13));
473 assert!(nacked_seqs.contains(&15));
474 assert!(!nacked_seqs.contains(&17));
475 }
476 }
477
478 #[test]
479 fn test_nack_generator_unbind_removes_stream() {
480 let mut chain = Registry::new()
481 .with(
482 NackGeneratorBuilder::new()
483 .with_size(64)
484 .with_interval(Duration::from_millis(100))
485 .build(),
486 )
487 .build();
488
489 let info = StreamInfo {
490 ssrc: 12345,
491 clock_rate: 90000,
492 rtcp_feedback: vec![RTCPFeedback {
493 typ: "nack".to_string(),
494 parameter: "".to_string(),
495 }],
496 ..Default::default()
497 };
498
499 chain.bind_remote_stream(&info);
500 assert!(chain.receive_logs.contains_key(&12345));
501
502 chain.unbind_remote_stream(&info);
503 assert!(!chain.receive_logs.contains_key(&12345));
504 assert!(!chain.nack_counts.contains_key(&12345));
505 }
506
507 #[test]
508 fn test_nack_generator_no_nack_support() {
509 let mut chain = Registry::new()
510 .with(
511 NackGeneratorBuilder::new()
512 .with_size(64)
513 .with_interval(Duration::from_millis(100))
514 .build(),
515 )
516 .build();
517
518 let info = StreamInfo {
520 ssrc: 12345,
521 clock_rate: 90000,
522 rtcp_feedback: vec![], ..Default::default()
524 };
525 chain.bind_remote_stream(&info);
526
527 assert!(!chain.receive_logs.contains_key(&12345));
529 }
530}