1use crate::report::receiver_stream::ReceiverStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use shared::TransportContext;
7use shared::error::Error;
8use std::collections::{HashMap, VecDeque};
9use std::marker::PhantomData;
10use std::time::{Duration, Instant};
11
12pub struct ReceiverReportBuilder<P> {
31 interval: Duration,
33 _phantom: PhantomData<P>,
34}
35
36impl<P> Default for ReceiverReportBuilder<P> {
37 fn default() -> Self {
38 Self {
39 interval: Duration::from_secs(1),
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<P> ReceiverReportBuilder<P> {
46 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn with_interval(mut self, interval: Duration) -> Self {
65 self.interval = interval;
66 self
67 }
68
69 pub fn build(self) -> impl FnOnce(P) -> ReceiverReportInterceptor<P> {
82 move |inner| ReceiverReportInterceptor::new(inner, self.interval)
83 }
84}
85
86pub struct ReceiverReportInterceptor<P> {
105 inner: P,
106
107 interval: Duration,
108 eto: Instant,
109
110 streams: HashMap<u32, ReceiverStream>,
111
112 read_queue: VecDeque<TaggedPacket>,
113 write_queue: VecDeque<TaggedPacket>,
114}
115
116impl<P> ReceiverReportInterceptor<P> {
117 fn new(inner: P, interval: Duration) -> Self {
119 Self {
120 inner,
121
122 interval,
123 eto: Instant::now(),
124
125 streams: HashMap::new(),
126
127 read_queue: VecDeque::new(),
128 write_queue: VecDeque::new(),
129 }
130 }
131
132 fn process_rtp(&mut self, now: Instant, ssrc: u32, seq: u16, timestamp: u32) {
134 let stream = self.streams.entry(ssrc).or_insert_with(|| {
136 ReceiverStream::new(ssrc, 90000)
138 });
139
140 let pkt = rtp::packet::Packet {
142 header: rtp::header::Header {
143 ssrc,
144 sequence_number: seq,
145 timestamp,
146 ..Default::default()
147 },
148 ..Default::default()
149 };
150
151 stream.process_rtp(now, &pkt);
152 }
153
154 fn process_sender_report(&mut self, now: Instant, sr: &rtcp::sender_report::SenderReport) {
156 if let Some(stream) = self.streams.get_mut(&sr.ssrc) {
157 stream.process_sender_report(now, sr);
158 }
159 }
160
161 fn generate_reports(&mut self, now: Instant) -> Vec<rtcp::receiver_report::ReceiverReport> {
163 self.streams
164 .values_mut()
165 .map(|stream| stream.generate_report(now))
166 .collect()
167 }
168
169 fn register_stream(&mut self, ssrc: u32, clock_rate: u32) {
171 self.streams
172 .entry(ssrc)
173 .or_insert_with(|| ReceiverStream::new(ssrc, clock_rate));
174 }
175}
176
177impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
178 for ReceiverReportInterceptor<P>
179{
180 type Rout = TaggedPacket;
181 type Wout = TaggedPacket;
182 type Eout = ();
183 type Error = Error;
184 type Time = Instant;
185
186 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
187 if let Packet::Rtcp(rtcp_packets) = &msg.message {
188 for rtcp_packet in rtcp_packets {
189 if let Some(sr) = rtcp_packet
190 .as_any()
191 .downcast_ref::<rtcp::sender_report::SenderReport>()
192 && let Some(stream) = self.streams.get_mut(&sr.ssrc)
193 {
194 stream.process_sender_report(msg.now, sr);
195 }
196 }
197 } else if let Packet::Rtp(rtp_packet) = &msg.message
198 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
199 {
200 stream.process_rtp(msg.now, rtp_packet);
201 }
202
203 self.inner.handle_read(msg)
204 }
205
206 fn poll_read(&mut self) -> Option<Self::Rout> {
207 self.inner.poll_read()
208 }
209
210 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
211 self.inner.handle_write(msg)
212 }
213
214 fn poll_write(&mut self) -> Option<Self::Wout> {
215 if let Some(pkt) = self.write_queue.pop_front() {
217 return Some(pkt);
218 }
219 self.inner.poll_write()
220 }
221
222 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
223 if self.eto <= now {
224 self.eto = now + self.interval;
225
226 for stream in self.streams.values_mut() {
227 let rr = stream.generate_report(now);
228 self.write_queue.push_back(TaggedPacket {
229 now,
230 transport: TransportContext::default(),
231 message: Packet::Rtcp(vec![Box::new(rr)]),
232 });
233 }
234 }
235
236 self.inner.handle_timeout(now)
237 }
238
239 fn poll_timeout(&mut self) -> Option<Self::Time> {
240 if let Some(eto) = self.inner.poll_timeout()
241 && eto < self.eto
242 {
243 Some(eto)
244 } else {
245 Some(self.eto)
246 }
247 }
248}
249
250impl<P: Interceptor> Interceptor for ReceiverReportInterceptor<P> {
251 fn bind_local_stream(&mut self, info: &StreamInfo) {
252 self.inner.bind_local_stream(info);
253 }
254 fn unbind_local_stream(&mut self, info: &StreamInfo) {
255 self.inner.unbind_local_stream(info);
256 }
257 fn bind_remote_stream(&mut self, info: &StreamInfo) {
258 let stream = ReceiverStream::new(info.ssrc, info.clock_rate);
259 self.streams.insert(info.ssrc, stream);
260
261 self.inner.bind_remote_stream(info);
262 }
263 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
264 self.streams.remove(&info.ssrc);
265
266 self.inner.unbind_remote_stream(info);
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::Registry;
274 use sansio::Protocol;
275
276 fn dummy_rtp_packet() -> TaggedPacket {
277 TaggedPacket {
278 now: Instant::now(),
279 transport: Default::default(),
280 message: crate::Packet::Rtp(rtp::Packet::default()),
281 }
282 }
283
284 #[test]
285 fn test_receiver_report_builder_default() {
286 let chain = Registry::new()
288 .with(ReceiverReportBuilder::default().build())
289 .build();
290
291 assert_eq!(chain.interval, Duration::from_secs(1));
292 assert!(chain.streams.is_empty());
293 }
294
295 #[test]
296 fn test_receiver_report_builder_with_custom_interval() {
297 let chain = Registry::new()
299 .with(
300 ReceiverReportBuilder::default()
301 .with_interval(Duration::from_millis(500))
302 .build(),
303 )
304 .build();
305
306 assert_eq!(chain.interval, Duration::from_millis(500));
307 }
308
309 #[test]
310 fn test_receiver_report_chain_handle_read_write() {
311 let mut chain = Registry::new()
313 .with(ReceiverReportBuilder::default().build())
314 .build();
315
316 let pkt = dummy_rtp_packet();
318 let pkt_message = pkt.message.clone();
319 chain.handle_read(pkt).unwrap();
320 assert_eq!(chain.poll_read().unwrap().message, pkt_message);
321
322 let pkt2 = dummy_rtp_packet();
324 let pkt2_message = pkt2.message.clone();
325 chain.handle_write(pkt2).unwrap();
326 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
327 }
328
329 #[test]
330 fn test_register_stream() {
331 let mut chain = Registry::new()
332 .with(ReceiverReportBuilder::default().build())
333 .build();
334
335 chain.register_stream(12345, 48000);
336 assert!(chain.streams.contains_key(&12345));
337 }
338
339 #[test]
340 fn test_process_rtp() {
341 let mut chain = Registry::new()
342 .with(ReceiverReportBuilder::default().build())
343 .build();
344
345 let now = Instant::now();
346 chain.process_rtp(now, 12345, 1, 1000);
347
348 assert!(chain.streams.contains_key(&12345));
349 }
350
351 #[test]
352 fn test_generate_reports() {
353 let mut chain = Registry::new()
354 .with(ReceiverReportBuilder::default().build())
355 .build();
356
357 let now = Instant::now();
358 chain.process_rtp(now, 12345, 1, 1000);
359 chain.process_rtp(now, 12345, 2, 2000);
360
361 let reports = chain.generate_reports(now);
362 assert_eq!(reports.len(), 1);
363 }
364
365 #[test]
366 fn test_chained_interceptors() {
367 use crate::report::sender::SenderReportBuilder;
368
369 let mut chain = Registry::new()
371 .with(ReceiverReportBuilder::default().build())
372 .with(
373 SenderReportBuilder::default()
374 .with_interval(Duration::from_millis(250))
375 .build(),
376 )
377 .build();
378
379 let pkt = dummy_rtp_packet();
381 let pkt_message = pkt.message.clone();
382 chain.handle_read(pkt).unwrap();
383 assert_eq!(chain.poll_read().unwrap().message, pkt_message);
384
385 let pkt2 = dummy_rtp_packet();
386 let pkt2_message = pkt2.message.clone();
387 chain.handle_write(pkt2).unwrap();
388 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
389 }
390
391 #[test]
392 fn test_receiver_report_generation_on_timeout() {
393 let mut chain = Registry::new()
396 .with(
397 ReceiverReportBuilder::default()
398 .with_interval(Duration::from_secs(1))
399 .build(),
400 )
401 .build();
402
403 let info = StreamInfo {
405 ssrc: 123456,
406 clock_rate: 90000,
407 ..Default::default()
408 };
409 chain.bind_remote_stream(&info);
410
411 let base_time = Instant::now();
412
413 for i in 0..10u16 {
415 let pkt = TaggedPacket {
416 now: base_time,
417 transport: Default::default(),
418 message: Packet::Rtp(rtp::Packet {
419 header: rtp::header::Header {
420 ssrc: 123456,
421 sequence_number: i,
422 timestamp: i as u32 * 3000,
423 ..Default::default()
424 },
425 ..Default::default()
426 }),
427 };
428 chain.handle_read(pkt).unwrap();
429 chain.poll_read();
430 }
431
432 chain.handle_timeout(base_time).unwrap();
434
435 while chain.poll_write().is_some() {}
437
438 let later_time = base_time + Duration::from_secs(2);
440 chain.handle_timeout(later_time).unwrap();
441
442 let report = chain.poll_write();
444 assert!(report.is_some());
445
446 if let Some(tagged) = report {
447 if let Packet::Rtcp(rtcp_packets) = tagged.message {
448 assert_eq!(rtcp_packets.len(), 1);
449 let rr = rtcp_packets[0]
450 .as_any()
451 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
452 .expect("Expected ReceiverReport");
453 assert_eq!(rr.reports.len(), 1);
454 assert_eq!(rr.reports[0].ssrc, 123456);
455 assert_eq!(rr.reports[0].last_sequence_number, 9);
456 assert_eq!(rr.reports[0].fraction_lost, 0);
457 assert_eq!(rr.reports[0].total_lost, 0);
458 } else {
459 panic!("Expected RTCP packet");
460 }
461 }
462 }
463
464 #[test]
465 fn test_receiver_report_with_packet_loss() {
466 let mut chain = Registry::new()
468 .with(
469 ReceiverReportBuilder::default()
470 .with_interval(Duration::from_secs(1))
471 .build(),
472 )
473 .build();
474
475 let info = StreamInfo {
476 ssrc: 123456,
477 clock_rate: 90000,
478 ..Default::default()
479 };
480 chain.bind_remote_stream(&info);
481
482 let base_time = Instant::now();
483
484 let pkt = TaggedPacket {
486 now: base_time,
487 transport: Default::default(),
488 message: Packet::Rtp(rtp::Packet {
489 header: rtp::header::Header {
490 ssrc: 123456,
491 sequence_number: 1,
492 timestamp: 3000,
493 ..Default::default()
494 },
495 ..Default::default()
496 }),
497 };
498 chain.handle_read(pkt).unwrap();
499 chain.poll_read();
500
501 let pkt = TaggedPacket {
503 now: base_time,
504 transport: Default::default(),
505 message: Packet::Rtp(rtp::Packet {
506 header: rtp::header::Header {
507 ssrc: 123456,
508 sequence_number: 3,
509 timestamp: 9000,
510 ..Default::default()
511 },
512 ..Default::default()
513 }),
514 };
515 chain.handle_read(pkt).unwrap();
516 chain.poll_read();
517
518 let later_time = base_time + Duration::from_secs(2);
520 chain.handle_timeout(later_time).unwrap();
521
522 let report = chain.poll_write();
523 assert!(report.is_some());
524
525 if let Some(tagged) = report {
526 if let Packet::Rtcp(rtcp_packets) = tagged.message {
527 let rr = rtcp_packets[0]
528 .as_any()
529 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
530 .expect("Expected ReceiverReport");
531 assert_eq!(rr.reports[0].last_sequence_number, 3);
532 assert_eq!(rr.reports[0].total_lost, 1);
534 assert_eq!(rr.reports[0].fraction_lost, (256u32 * 1 / 3) as u8);
536 } else {
537 panic!("Expected RTCP packet");
538 }
539 }
540 }
541
542 #[test]
543 fn test_receiver_report_with_sender_report() {
544 let mut chain = Registry::new()
546 .with(
547 ReceiverReportBuilder::default()
548 .with_interval(Duration::from_secs(1))
549 .build(),
550 )
551 .build();
552
553 let info = StreamInfo {
554 ssrc: 123456,
555 clock_rate: 90000,
556 ..Default::default()
557 };
558 chain.bind_remote_stream(&info);
559
560 let base_time = Instant::now();
561
562 let pkt = TaggedPacket {
564 now: base_time,
565 transport: Default::default(),
566 message: Packet::Rtp(rtp::Packet {
567 header: rtp::header::Header {
568 ssrc: 123456,
569 sequence_number: 1,
570 timestamp: 3000,
571 ..Default::default()
572 },
573 ..Default::default()
574 }),
575 };
576 chain.handle_read(pkt).unwrap();
577 chain.poll_read();
578
579 let sr = rtcp::sender_report::SenderReport {
581 ssrc: 123456,
582 ntp_time: 0x1234_5678_0000_0000,
583 rtp_time: 3000,
584 packet_count: 100,
585 octet_count: 10000,
586 ..Default::default()
587 };
588 let sr_pkt = TaggedPacket {
589 now: base_time,
590 transport: Default::default(),
591 message: Packet::Rtcp(vec![Box::new(sr)]),
592 };
593 chain.handle_read(sr_pkt).unwrap();
594
595 let later_time = base_time + Duration::from_secs(1);
597 chain.handle_timeout(later_time).unwrap();
598
599 let report = chain.poll_write();
600 assert!(report.is_some());
601
602 if let Some(tagged) = report {
603 if let Packet::Rtcp(rtcp_packets) = tagged.message {
604 let rr = rtcp_packets[0]
605 .as_any()
606 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
607 .expect("Expected ReceiverReport");
608 assert_eq!(rr.reports[0].delay, 65536);
610 assert_eq!(rr.reports[0].last_sender_report, 0x5678_0000);
612 } else {
613 panic!("Expected RTCP packet");
614 }
615 }
616 }
617
618 #[test]
619 fn test_receiver_report_multiple_streams() {
620 let mut chain = Registry::new()
622 .with(
623 ReceiverReportBuilder::default()
624 .with_interval(Duration::from_secs(1))
625 .build(),
626 )
627 .build();
628
629 let info1 = StreamInfo {
630 ssrc: 111111,
631 clock_rate: 90000,
632 ..Default::default()
633 };
634 let info2 = StreamInfo {
635 ssrc: 222222,
636 clock_rate: 48000,
637 ..Default::default()
638 };
639 chain.bind_remote_stream(&info1);
640 chain.bind_remote_stream(&info2);
641
642 let base_time = Instant::now();
643
644 for i in 0..5u16 {
646 let pkt = TaggedPacket {
647 now: base_time,
648 transport: Default::default(),
649 message: Packet::Rtp(rtp::Packet {
650 header: rtp::header::Header {
651 ssrc: 111111,
652 sequence_number: i,
653 timestamp: i as u32 * 3000,
654 ..Default::default()
655 },
656 ..Default::default()
657 }),
658 };
659 chain.handle_read(pkt).unwrap();
660 chain.poll_read();
661 }
662
663 let pkt = TaggedPacket {
665 now: base_time,
666 transport: Default::default(),
667 message: Packet::Rtp(rtp::Packet {
668 header: rtp::header::Header {
669 ssrc: 222222,
670 sequence_number: 0,
671 timestamp: 0,
672 ..Default::default()
673 },
674 ..Default::default()
675 }),
676 };
677 chain.handle_read(pkt).unwrap();
678 chain.poll_read();
679
680 let pkt = TaggedPacket {
681 now: base_time,
682 transport: Default::default(),
683 message: Packet::Rtp(rtp::Packet {
684 header: rtp::header::Header {
685 ssrc: 222222,
686 sequence_number: 5, timestamp: 5 * 960,
688 ..Default::default()
689 },
690 ..Default::default()
691 }),
692 };
693 chain.handle_read(pkt).unwrap();
694 chain.poll_read();
695
696 let later_time = base_time + Duration::from_secs(2);
698 chain.handle_timeout(later_time).unwrap();
699
700 let mut ssrcs = vec![];
702 let mut total_lost = vec![];
703
704 while let Some(tagged) = chain.poll_write() {
705 if let Packet::Rtcp(rtcp_packets) = tagged.message {
706 for rtcp_pkt in rtcp_packets {
707 if let Some(rr) = rtcp_pkt
708 .as_any()
709 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
710 {
711 for report in &rr.reports {
712 ssrcs.push(report.ssrc);
713 total_lost.push(report.total_lost);
714 }
715 }
716 }
717 }
718 }
719
720 assert_eq!(ssrcs.len(), 2);
721 assert!(ssrcs.contains(&111111));
722 assert!(ssrcs.contains(&222222));
723
724 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
726 assert_eq!(total_lost[idx1], 0);
727
728 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
730 assert_eq!(total_lost[idx2], 4);
731 }
732
733 #[test]
734 fn test_receiver_report_unbind_stream() {
735 let mut chain = Registry::new()
737 .with(
738 ReceiverReportBuilder::default()
739 .with_interval(Duration::from_secs(1))
740 .build(),
741 )
742 .build();
743
744 let info = StreamInfo {
745 ssrc: 123456,
746 clock_rate: 90000,
747 ..Default::default()
748 };
749 chain.bind_remote_stream(&info);
750
751 let base_time = Instant::now();
752
753 let pkt = TaggedPacket {
755 now: base_time,
756 transport: Default::default(),
757 message: Packet::Rtp(rtp::Packet {
758 header: rtp::header::Header {
759 ssrc: 123456,
760 sequence_number: 0,
761 timestamp: 0,
762 ..Default::default()
763 },
764 ..Default::default()
765 }),
766 };
767 chain.handle_read(pkt).unwrap();
768 chain.poll_read();
769
770 chain.unbind_remote_stream(&info);
772
773 let later_time = base_time + Duration::from_secs(2);
775 chain.handle_timeout(later_time).unwrap();
776
777 assert!(chain.poll_write().is_none());
779 }
780
781 #[test]
782 fn test_receiver_report_sequence_wrap() {
783 let mut chain = Registry::new()
785 .with(
786 ReceiverReportBuilder::default()
787 .with_interval(Duration::from_secs(1))
788 .build(),
789 )
790 .build();
791
792 let info = StreamInfo {
793 ssrc: 123456,
794 clock_rate: 90000,
795 ..Default::default()
796 };
797 chain.bind_remote_stream(&info);
798
799 let base_time = Instant::now();
800
801 let pkt = TaggedPacket {
803 now: base_time,
804 transport: Default::default(),
805 message: Packet::Rtp(rtp::Packet {
806 header: rtp::header::Header {
807 ssrc: 123456,
808 sequence_number: 0xffff,
809 timestamp: 0,
810 ..Default::default()
811 },
812 ..Default::default()
813 }),
814 };
815 chain.handle_read(pkt).unwrap();
816 chain.poll_read();
817
818 let pkt = TaggedPacket {
820 now: base_time,
821 transport: Default::default(),
822 message: Packet::Rtp(rtp::Packet {
823 header: rtp::header::Header {
824 ssrc: 123456,
825 sequence_number: 0x00,
826 timestamp: 3000,
827 ..Default::default()
828 },
829 ..Default::default()
830 }),
831 };
832 chain.handle_read(pkt).unwrap();
833 chain.poll_read();
834
835 let later_time = base_time + Duration::from_secs(2);
837 chain.handle_timeout(later_time).unwrap();
838
839 let report = chain.poll_write();
840 assert!(report.is_some());
841
842 if let Some(tagged) = report {
843 if let Packet::Rtcp(rtcp_packets) = tagged.message {
844 let rr = rtcp_packets[0]
845 .as_any()
846 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
847 .expect("Expected ReceiverReport");
848 assert_eq!(rr.reports[0].last_sequence_number, 1 << 16);
850 assert_eq!(rr.reports[0].fraction_lost, 0);
851 assert_eq!(rr.reports[0].total_lost, 0);
852 } else {
853 panic!("Expected RTCP packet");
854 }
855 }
856 }
857}