1use crate::report::receiver_stream::ReceiverStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket, interceptor};
6use shared::TransportContext;
7use shared::error::Error;
8use std::collections::{HashMap, VecDeque};
9use std::marker::PhantomData;
10use std::time::{Duration, Instant};
11
12pub struct ReceiverReportBuilder<P> {
31 interval: Duration,
33 _phantom: PhantomData<P>,
34}
35
36impl<P> Default for ReceiverReportBuilder<P> {
37 fn default() -> Self {
38 Self {
39 interval: Duration::from_secs(1),
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<P> ReceiverReportBuilder<P> {
46 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn with_interval(mut self, interval: Duration) -> Self {
65 self.interval = interval;
66 self
67 }
68
69 pub fn build(self) -> impl FnOnce(P) -> ReceiverReportInterceptor<P> {
82 move |inner| ReceiverReportInterceptor::new(inner, self.interval)
83 }
84}
85
86#[derive(Interceptor)]
105pub struct ReceiverReportInterceptor<P> {
106 #[next]
107 inner: P,
108
109 interval: Duration,
110 eto: Instant,
111
112 streams: HashMap<u32, ReceiverStream>,
113
114 read_queue: VecDeque<TaggedPacket>,
115 write_queue: VecDeque<TaggedPacket>,
116}
117
118impl<P> ReceiverReportInterceptor<P> {
119 fn new(inner: P, interval: Duration) -> Self {
121 Self {
122 inner,
123
124 interval,
125 eto: Instant::now(),
126
127 streams: HashMap::new(),
128
129 read_queue: VecDeque::new(),
130 write_queue: VecDeque::new(),
131 }
132 }
133
134 fn process_rtp(&mut self, now: Instant, ssrc: u32, seq: u16, timestamp: u32) {
136 let stream = self.streams.entry(ssrc).or_insert_with(|| {
138 ReceiverStream::new(ssrc, 90000)
140 });
141
142 let pkt = rtp::packet::Packet {
144 header: rtp::header::Header {
145 ssrc,
146 sequence_number: seq,
147 timestamp,
148 ..Default::default()
149 },
150 ..Default::default()
151 };
152
153 stream.process_rtp(now, &pkt);
154 }
155
156 fn process_sender_report(&mut self, now: Instant, sr: &rtcp::sender_report::SenderReport) {
158 if let Some(stream) = self.streams.get_mut(&sr.ssrc) {
159 stream.process_sender_report(now, sr);
160 }
161 }
162
163 fn generate_reports(&mut self, now: Instant) -> Vec<rtcp::receiver_report::ReceiverReport> {
165 self.streams
166 .values_mut()
167 .map(|stream| stream.generate_report(now))
168 .collect()
169 }
170
171 fn register_stream(&mut self, ssrc: u32, clock_rate: u32) {
173 self.streams
174 .entry(ssrc)
175 .or_insert_with(|| ReceiverStream::new(ssrc, clock_rate));
176 }
177}
178
179#[interceptor]
180impl<P: Interceptor> ReceiverReportInterceptor<P> {
181 #[overrides]
182 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
183 if let Packet::Rtcp(rtcp_packets) = &msg.message {
184 for rtcp_packet in rtcp_packets {
185 if let Some(sr) = rtcp_packet
186 .as_any()
187 .downcast_ref::<rtcp::sender_report::SenderReport>()
188 && let Some(stream) = self.streams.get_mut(&sr.ssrc)
189 {
190 stream.process_sender_report(msg.now, sr);
191 }
192 }
193 } else if let Packet::Rtp(rtp_packet) = &msg.message
194 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
195 {
196 stream.process_rtp(msg.now, rtp_packet);
197 }
198
199 self.inner.handle_read(msg)
200 }
201
202 #[overrides]
203 fn poll_write(&mut self) -> Option<Self::Wout> {
204 if let Some(pkt) = self.write_queue.pop_front() {
206 return Some(pkt);
207 }
208 self.inner.poll_write()
209 }
210
211 #[overrides]
212 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
213 if self.eto <= now {
214 self.eto = now + self.interval;
215
216 for stream in self.streams.values_mut() {
217 let rr = stream.generate_report(now);
218 self.write_queue.push_back(TaggedPacket {
219 now,
220 transport: TransportContext::default(),
221 message: Packet::Rtcp(vec![Box::new(rr)]),
222 });
223 }
224 }
225
226 self.inner.handle_timeout(now)
227 }
228
229 #[overrides]
230 fn poll_timeout(&mut self) -> Option<Self::Time> {
231 if let Some(eto) = self.inner.poll_timeout()
232 && eto < self.eto
233 {
234 Some(eto)
235 } else {
236 Some(self.eto)
237 }
238 }
239
240 #[overrides]
241 fn bind_remote_stream(&mut self, info: &StreamInfo) {
242 let stream = ReceiverStream::new(info.ssrc, info.clock_rate);
243 self.streams.insert(info.ssrc, stream);
244
245 self.inner.bind_remote_stream(info);
246 }
247
248 #[overrides]
249 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
250 self.streams.remove(&info.ssrc);
251
252 self.inner.unbind_remote_stream(info);
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::Registry;
260 use sansio::Protocol;
261
262 fn dummy_rtp_packet() -> TaggedPacket {
263 TaggedPacket {
264 now: Instant::now(),
265 transport: Default::default(),
266 message: crate::Packet::Rtp(rtp::Packet::default()),
267 }
268 }
269
270 #[test]
271 fn test_receiver_report_builder_default() {
272 let chain = Registry::new()
274 .with(ReceiverReportBuilder::default().build())
275 .build();
276
277 assert_eq!(chain.interval, Duration::from_secs(1));
278 assert!(chain.streams.is_empty());
279 }
280
281 #[test]
282 fn test_receiver_report_builder_with_custom_interval() {
283 let chain = Registry::new()
285 .with(
286 ReceiverReportBuilder::default()
287 .with_interval(Duration::from_millis(500))
288 .build(),
289 )
290 .build();
291
292 assert_eq!(chain.interval, Duration::from_millis(500));
293 }
294
295 #[test]
296 fn test_receiver_report_chain_handle_read_write() {
297 let mut chain = Registry::new()
299 .with(ReceiverReportBuilder::default().build())
300 .build();
301
302 let pkt = dummy_rtp_packet();
304 chain.handle_read(pkt).unwrap();
305 assert!(chain.poll_read().is_none());
306
307 let pkt2 = dummy_rtp_packet();
309 let pkt2_message = pkt2.message.clone();
310 chain.handle_write(pkt2).unwrap();
311 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
312 }
313
314 #[test]
315 fn test_register_stream() {
316 let mut chain = Registry::new()
317 .with(ReceiverReportBuilder::default().build())
318 .build();
319
320 chain.register_stream(12345, 48000);
321 assert!(chain.streams.contains_key(&12345));
322 }
323
324 #[test]
325 fn test_process_rtp() {
326 let mut chain = Registry::new()
327 .with(ReceiverReportBuilder::default().build())
328 .build();
329
330 let now = Instant::now();
331 chain.process_rtp(now, 12345, 1, 1000);
332
333 assert!(chain.streams.contains_key(&12345));
334 }
335
336 #[test]
337 fn test_generate_reports() {
338 let mut chain = Registry::new()
339 .with(ReceiverReportBuilder::default().build())
340 .build();
341
342 let now = Instant::now();
343 chain.process_rtp(now, 12345, 1, 1000);
344 chain.process_rtp(now, 12345, 2, 2000);
345
346 let reports = chain.generate_reports(now);
347 assert_eq!(reports.len(), 1);
348 }
349
350 #[test]
351 fn test_chained_interceptors() {
352 use crate::report::sender::SenderReportBuilder;
353
354 let mut chain = Registry::new()
356 .with(ReceiverReportBuilder::default().build())
357 .with(
358 SenderReportBuilder::default()
359 .with_interval(Duration::from_millis(250))
360 .build(),
361 )
362 .build();
363
364 let pkt = dummy_rtp_packet();
366 chain.handle_read(pkt).unwrap();
367 assert!(chain.poll_read().is_none());
368
369 let pkt2 = dummy_rtp_packet();
370 let pkt2_message = pkt2.message.clone();
371 chain.handle_write(pkt2).unwrap();
372 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
373 }
374
375 #[test]
376 fn test_receiver_report_generation_on_timeout() {
377 let mut chain = Registry::new()
380 .with(
381 ReceiverReportBuilder::default()
382 .with_interval(Duration::from_secs(1))
383 .build(),
384 )
385 .build();
386
387 let info = StreamInfo {
389 ssrc: 123456,
390 clock_rate: 90000,
391 ..Default::default()
392 };
393 chain.bind_remote_stream(&info);
394
395 let base_time = Instant::now();
396
397 for i in 0..10u16 {
399 let pkt = TaggedPacket {
400 now: base_time,
401 transport: Default::default(),
402 message: Packet::Rtp(rtp::Packet {
403 header: rtp::header::Header {
404 ssrc: 123456,
405 sequence_number: i,
406 timestamp: i as u32 * 3000,
407 ..Default::default()
408 },
409 ..Default::default()
410 }),
411 };
412 chain.handle_read(pkt).unwrap();
413 chain.poll_read();
414 }
415
416 chain.handle_timeout(base_time).unwrap();
418
419 while chain.poll_write().is_some() {}
421
422 let later_time = base_time + Duration::from_secs(2);
424 chain.handle_timeout(later_time).unwrap();
425
426 let report = chain.poll_write();
428 assert!(report.is_some());
429
430 if let Some(tagged) = report {
431 if let Packet::Rtcp(rtcp_packets) = tagged.message {
432 assert_eq!(rtcp_packets.len(), 1);
433 let rr = rtcp_packets[0]
434 .as_any()
435 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
436 .expect("Expected ReceiverReport");
437 assert_eq!(rr.reports.len(), 1);
438 assert_eq!(rr.reports[0].ssrc, 123456);
439 assert_eq!(rr.reports[0].last_sequence_number, 9);
440 assert_eq!(rr.reports[0].fraction_lost, 0);
441 assert_eq!(rr.reports[0].total_lost, 0);
442 } else {
443 panic!("Expected RTCP packet");
444 }
445 }
446 }
447
448 #[test]
449 fn test_receiver_report_with_packet_loss() {
450 let mut chain = Registry::new()
452 .with(
453 ReceiverReportBuilder::default()
454 .with_interval(Duration::from_secs(1))
455 .build(),
456 )
457 .build();
458
459 let info = StreamInfo {
460 ssrc: 123456,
461 clock_rate: 90000,
462 ..Default::default()
463 };
464 chain.bind_remote_stream(&info);
465
466 let base_time = Instant::now();
467
468 let pkt = TaggedPacket {
470 now: base_time,
471 transport: Default::default(),
472 message: Packet::Rtp(rtp::Packet {
473 header: rtp::header::Header {
474 ssrc: 123456,
475 sequence_number: 1,
476 timestamp: 3000,
477 ..Default::default()
478 },
479 ..Default::default()
480 }),
481 };
482 chain.handle_read(pkt).unwrap();
483 chain.poll_read();
484
485 let pkt = TaggedPacket {
487 now: base_time,
488 transport: Default::default(),
489 message: Packet::Rtp(rtp::Packet {
490 header: rtp::header::Header {
491 ssrc: 123456,
492 sequence_number: 3,
493 timestamp: 9000,
494 ..Default::default()
495 },
496 ..Default::default()
497 }),
498 };
499 chain.handle_read(pkt).unwrap();
500 chain.poll_read();
501
502 let later_time = base_time + Duration::from_secs(2);
504 chain.handle_timeout(later_time).unwrap();
505
506 let report = chain.poll_write();
507 assert!(report.is_some());
508
509 if let Some(tagged) = report {
510 if let Packet::Rtcp(rtcp_packets) = tagged.message {
511 let rr = rtcp_packets[0]
512 .as_any()
513 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
514 .expect("Expected ReceiverReport");
515 assert_eq!(rr.reports[0].last_sequence_number, 3);
516 assert_eq!(rr.reports[0].total_lost, 1);
518 assert_eq!(rr.reports[0].fraction_lost, (256u32 * 1 / 3) as u8);
520 } else {
521 panic!("Expected RTCP packet");
522 }
523 }
524 }
525
526 #[test]
527 fn test_receiver_report_with_sender_report() {
528 let mut chain = Registry::new()
530 .with(
531 ReceiverReportBuilder::default()
532 .with_interval(Duration::from_secs(1))
533 .build(),
534 )
535 .build();
536
537 let info = StreamInfo {
538 ssrc: 123456,
539 clock_rate: 90000,
540 ..Default::default()
541 };
542 chain.bind_remote_stream(&info);
543
544 let base_time = Instant::now();
545
546 let pkt = TaggedPacket {
548 now: base_time,
549 transport: Default::default(),
550 message: Packet::Rtp(rtp::Packet {
551 header: rtp::header::Header {
552 ssrc: 123456,
553 sequence_number: 1,
554 timestamp: 3000,
555 ..Default::default()
556 },
557 ..Default::default()
558 }),
559 };
560 chain.handle_read(pkt).unwrap();
561 chain.poll_read();
562
563 let sr = rtcp::sender_report::SenderReport {
565 ssrc: 123456,
566 ntp_time: 0x1234_5678_0000_0000,
567 rtp_time: 3000,
568 packet_count: 100,
569 octet_count: 10000,
570 ..Default::default()
571 };
572 let sr_pkt = TaggedPacket {
573 now: base_time,
574 transport: Default::default(),
575 message: Packet::Rtcp(vec![Box::new(sr)]),
576 };
577 chain.handle_read(sr_pkt).unwrap();
578
579 let later_time = base_time + Duration::from_secs(1);
581 chain.handle_timeout(later_time).unwrap();
582
583 let report = chain.poll_write();
584 assert!(report.is_some());
585
586 if let Some(tagged) = report {
587 if let Packet::Rtcp(rtcp_packets) = tagged.message {
588 let rr = rtcp_packets[0]
589 .as_any()
590 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
591 .expect("Expected ReceiverReport");
592 assert_eq!(rr.reports[0].delay, 65536);
594 assert_eq!(rr.reports[0].last_sender_report, 0x5678_0000);
596 } else {
597 panic!("Expected RTCP packet");
598 }
599 }
600 }
601
602 #[test]
603 fn test_receiver_report_multiple_streams() {
604 let mut chain = Registry::new()
606 .with(
607 ReceiverReportBuilder::default()
608 .with_interval(Duration::from_secs(1))
609 .build(),
610 )
611 .build();
612
613 let info1 = StreamInfo {
614 ssrc: 111111,
615 clock_rate: 90000,
616 ..Default::default()
617 };
618 let info2 = StreamInfo {
619 ssrc: 222222,
620 clock_rate: 48000,
621 ..Default::default()
622 };
623 chain.bind_remote_stream(&info1);
624 chain.bind_remote_stream(&info2);
625
626 let base_time = Instant::now();
627
628 for i in 0..5u16 {
630 let pkt = TaggedPacket {
631 now: base_time,
632 transport: Default::default(),
633 message: Packet::Rtp(rtp::Packet {
634 header: rtp::header::Header {
635 ssrc: 111111,
636 sequence_number: i,
637 timestamp: i as u32 * 3000,
638 ..Default::default()
639 },
640 ..Default::default()
641 }),
642 };
643 chain.handle_read(pkt).unwrap();
644 chain.poll_read();
645 }
646
647 let pkt = TaggedPacket {
649 now: base_time,
650 transport: Default::default(),
651 message: Packet::Rtp(rtp::Packet {
652 header: rtp::header::Header {
653 ssrc: 222222,
654 sequence_number: 0,
655 timestamp: 0,
656 ..Default::default()
657 },
658 ..Default::default()
659 }),
660 };
661 chain.handle_read(pkt).unwrap();
662 chain.poll_read();
663
664 let pkt = TaggedPacket {
665 now: base_time,
666 transport: Default::default(),
667 message: Packet::Rtp(rtp::Packet {
668 header: rtp::header::Header {
669 ssrc: 222222,
670 sequence_number: 5, timestamp: 5 * 960,
672 ..Default::default()
673 },
674 ..Default::default()
675 }),
676 };
677 chain.handle_read(pkt).unwrap();
678 chain.poll_read();
679
680 let later_time = base_time + Duration::from_secs(2);
682 chain.handle_timeout(later_time).unwrap();
683
684 let mut ssrcs = vec![];
686 let mut total_lost = vec![];
687
688 while let Some(tagged) = chain.poll_write() {
689 if let Packet::Rtcp(rtcp_packets) = tagged.message {
690 for rtcp_pkt in rtcp_packets {
691 if let Some(rr) = rtcp_pkt
692 .as_any()
693 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
694 {
695 for report in &rr.reports {
696 ssrcs.push(report.ssrc);
697 total_lost.push(report.total_lost);
698 }
699 }
700 }
701 }
702 }
703
704 assert_eq!(ssrcs.len(), 2);
705 assert!(ssrcs.contains(&111111));
706 assert!(ssrcs.contains(&222222));
707
708 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
710 assert_eq!(total_lost[idx1], 0);
711
712 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
714 assert_eq!(total_lost[idx2], 4);
715 }
716
717 #[test]
718 fn test_receiver_report_unbind_stream() {
719 let mut chain = Registry::new()
721 .with(
722 ReceiverReportBuilder::default()
723 .with_interval(Duration::from_secs(1))
724 .build(),
725 )
726 .build();
727
728 let info = StreamInfo {
729 ssrc: 123456,
730 clock_rate: 90000,
731 ..Default::default()
732 };
733 chain.bind_remote_stream(&info);
734
735 let base_time = Instant::now();
736
737 let pkt = TaggedPacket {
739 now: base_time,
740 transport: Default::default(),
741 message: Packet::Rtp(rtp::Packet {
742 header: rtp::header::Header {
743 ssrc: 123456,
744 sequence_number: 0,
745 timestamp: 0,
746 ..Default::default()
747 },
748 ..Default::default()
749 }),
750 };
751 chain.handle_read(pkt).unwrap();
752 chain.poll_read();
753
754 chain.unbind_remote_stream(&info);
756
757 let later_time = base_time + Duration::from_secs(2);
759 chain.handle_timeout(later_time).unwrap();
760
761 assert!(chain.poll_write().is_none());
763 }
764
765 #[test]
766 fn test_receiver_report_sequence_wrap() {
767 let mut chain = Registry::new()
769 .with(
770 ReceiverReportBuilder::default()
771 .with_interval(Duration::from_secs(1))
772 .build(),
773 )
774 .build();
775
776 let info = StreamInfo {
777 ssrc: 123456,
778 clock_rate: 90000,
779 ..Default::default()
780 };
781 chain.bind_remote_stream(&info);
782
783 let base_time = Instant::now();
784
785 let pkt = TaggedPacket {
787 now: base_time,
788 transport: Default::default(),
789 message: Packet::Rtp(rtp::Packet {
790 header: rtp::header::Header {
791 ssrc: 123456,
792 sequence_number: 0xffff,
793 timestamp: 0,
794 ..Default::default()
795 },
796 ..Default::default()
797 }),
798 };
799 chain.handle_read(pkt).unwrap();
800 chain.poll_read();
801
802 let pkt = TaggedPacket {
804 now: base_time,
805 transport: Default::default(),
806 message: Packet::Rtp(rtp::Packet {
807 header: rtp::header::Header {
808 ssrc: 123456,
809 sequence_number: 0x00,
810 timestamp: 3000,
811 ..Default::default()
812 },
813 ..Default::default()
814 }),
815 };
816 chain.handle_read(pkt).unwrap();
817 chain.poll_read();
818
819 let later_time = base_time + Duration::from_secs(2);
821 chain.handle_timeout(later_time).unwrap();
822
823 let report = chain.poll_write();
824 assert!(report.is_some());
825
826 if let Some(tagged) = report {
827 if let Packet::Rtcp(rtcp_packets) = tagged.message {
828 let rr = rtcp_packets[0]
829 .as_any()
830 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
831 .expect("Expected ReceiverReport");
832 assert_eq!(rr.reports[0].last_sequence_number, 1 << 16);
834 assert_eq!(rr.reports[0].fraction_lost, 0);
835 assert_eq!(rr.reports[0].total_lost, 0);
836 } else {
837 panic!("Expected RTCP packet");
838 }
839 }
840 }
841}