1use crate::report::receiver_stream::ReceiverStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use shared::TransportContext;
7use shared::error::Error;
8use std::collections::{HashMap, VecDeque};
9use std::marker::PhantomData;
10use std::time::{Duration, Instant};
11
12pub struct ReceiverReportBuilder<P> {
31 interval: Duration,
33 _phantom: PhantomData<P>,
34}
35
36impl<P> Default for ReceiverReportBuilder<P> {
37 fn default() -> Self {
38 Self {
39 interval: Duration::from_secs(1),
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<P> ReceiverReportBuilder<P> {
46 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn with_interval(mut self, interval: Duration) -> Self {
65 self.interval = interval;
66 self
67 }
68
69 pub fn build(self) -> impl FnOnce(P) -> ReceiverReportInterceptor<P> {
82 move |inner| ReceiverReportInterceptor::new(inner, self.interval)
83 }
84}
85
86pub struct ReceiverReportInterceptor<P> {
105 inner: P,
106
107 interval: Duration,
108 eto: Instant,
109
110 streams: HashMap<u32, ReceiverStream>,
111
112 read_queue: VecDeque<TaggedPacket>,
113 write_queue: VecDeque<TaggedPacket>,
114}
115
116impl<P> ReceiverReportInterceptor<P> {
117 fn new(inner: P, interval: Duration) -> Self {
119 Self {
120 inner,
121
122 interval,
123 eto: Instant::now(),
124
125 streams: HashMap::new(),
126
127 read_queue: VecDeque::new(),
128 write_queue: VecDeque::new(),
129 }
130 }
131
132 fn process_rtp(&mut self, now: Instant, ssrc: u32, seq: u16, timestamp: u32) {
134 let stream = self.streams.entry(ssrc).or_insert_with(|| {
136 ReceiverStream::new(ssrc, 90000)
138 });
139
140 let pkt = rtp::packet::Packet {
142 header: rtp::header::Header {
143 ssrc,
144 sequence_number: seq,
145 timestamp,
146 ..Default::default()
147 },
148 ..Default::default()
149 };
150
151 stream.process_rtp(now, &pkt);
152 }
153
154 fn process_sender_report(&mut self, now: Instant, sr: &rtcp::sender_report::SenderReport) {
156 if let Some(stream) = self.streams.get_mut(&sr.ssrc) {
157 stream.process_sender_report(now, sr);
158 }
159 }
160
161 fn generate_reports(&mut self, now: Instant) -> Vec<rtcp::receiver_report::ReceiverReport> {
163 self.streams
164 .values_mut()
165 .map(|stream| stream.generate_report(now))
166 .collect()
167 }
168
169 fn register_stream(&mut self, ssrc: u32, clock_rate: u32) {
171 self.streams
172 .entry(ssrc)
173 .or_insert_with(|| ReceiverStream::new(ssrc, clock_rate));
174 }
175}
176
177impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
178 for ReceiverReportInterceptor<P>
179{
180 type Rout = TaggedPacket;
181 type Wout = TaggedPacket;
182 type Eout = ();
183 type Error = Error;
184 type Time = Instant;
185
186 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
187 if let Packet::Rtcp(rtcp_packets) = &msg.message {
188 for rtcp_packet in rtcp_packets {
189 if let Some(sr) = rtcp_packet
190 .as_any()
191 .downcast_ref::<rtcp::sender_report::SenderReport>()
192 && let Some(stream) = self.streams.get_mut(&sr.ssrc)
193 {
194 stream.process_sender_report(msg.now, sr);
195 }
196 }
197 } else if let Packet::Rtp(rtp_packet) = &msg.message
198 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
199 {
200 stream.process_rtp(msg.now, rtp_packet);
201 }
202
203 self.inner.handle_read(msg)
204 }
205
206 fn poll_read(&mut self) -> Option<Self::Rout> {
207 self.inner.poll_read()
208 }
209
210 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
211 self.inner.handle_write(msg)
212 }
213
214 fn poll_write(&mut self) -> Option<Self::Wout> {
215 if let Some(pkt) = self.write_queue.pop_front() {
217 return Some(pkt);
218 }
219 self.inner.poll_write()
220 }
221
222 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
223 if self.eto <= now {
224 self.eto = now + self.interval;
225
226 for stream in self.streams.values_mut() {
227 let rr = stream.generate_report(now);
228 self.write_queue.push_back(TaggedPacket {
229 now,
230 transport: TransportContext::default(),
231 message: Packet::Rtcp(vec![Box::new(rr)]),
232 });
233 }
234 }
235
236 self.inner.handle_timeout(now)
237 }
238
239 fn poll_timeout(&mut self) -> Option<Self::Time> {
240 if let Some(eto) = self.inner.poll_timeout()
241 && eto < self.eto
242 {
243 Some(eto)
244 } else {
245 Some(self.eto)
246 }
247 }
248}
249
250impl<P: Interceptor> Interceptor for ReceiverReportInterceptor<P> {
251 fn bind_local_stream(&mut self, info: &StreamInfo) {
252 self.inner.bind_local_stream(info);
253 }
254 fn unbind_local_stream(&mut self, info: &StreamInfo) {
255 self.inner.unbind_local_stream(info);
256 }
257 fn bind_remote_stream(&mut self, info: &StreamInfo) {
258 let stream = ReceiverStream::new(info.ssrc, info.clock_rate);
259 self.streams.insert(info.ssrc, stream);
260
261 self.inner.bind_remote_stream(info);
262 }
263 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
264 self.streams.remove(&info.ssrc);
265
266 self.inner.unbind_remote_stream(info);
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::Registry;
274 use sansio::Protocol;
275
276 fn dummy_rtp_packet() -> TaggedPacket {
277 TaggedPacket {
278 now: Instant::now(),
279 transport: Default::default(),
280 message: crate::Packet::Rtp(rtp::Packet::default()),
281 }
282 }
283
284 #[test]
285 fn test_receiver_report_builder_default() {
286 let chain = Registry::new()
288 .with(ReceiverReportBuilder::default().build())
289 .build();
290
291 assert_eq!(chain.interval, Duration::from_secs(1));
292 assert!(chain.streams.is_empty());
293 }
294
295 #[test]
296 fn test_receiver_report_builder_with_custom_interval() {
297 let chain = Registry::new()
299 .with(
300 ReceiverReportBuilder::default()
301 .with_interval(Duration::from_millis(500))
302 .build(),
303 )
304 .build();
305
306 assert_eq!(chain.interval, Duration::from_millis(500));
307 }
308
309 #[test]
310 fn test_receiver_report_chain_handle_read_write() {
311 let mut chain = Registry::new()
313 .with(ReceiverReportBuilder::default().build())
314 .build();
315
316 let pkt = dummy_rtp_packet();
318 chain.handle_read(pkt).unwrap();
319 assert!(chain.poll_read().is_none());
320
321 let pkt2 = dummy_rtp_packet();
323 let pkt2_message = pkt2.message.clone();
324 chain.handle_write(pkt2).unwrap();
325 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
326 }
327
328 #[test]
329 fn test_register_stream() {
330 let mut chain = Registry::new()
331 .with(ReceiverReportBuilder::default().build())
332 .build();
333
334 chain.register_stream(12345, 48000);
335 assert!(chain.streams.contains_key(&12345));
336 }
337
338 #[test]
339 fn test_process_rtp() {
340 let mut chain = Registry::new()
341 .with(ReceiverReportBuilder::default().build())
342 .build();
343
344 let now = Instant::now();
345 chain.process_rtp(now, 12345, 1, 1000);
346
347 assert!(chain.streams.contains_key(&12345));
348 }
349
350 #[test]
351 fn test_generate_reports() {
352 let mut chain = Registry::new()
353 .with(ReceiverReportBuilder::default().build())
354 .build();
355
356 let now = Instant::now();
357 chain.process_rtp(now, 12345, 1, 1000);
358 chain.process_rtp(now, 12345, 2, 2000);
359
360 let reports = chain.generate_reports(now);
361 assert_eq!(reports.len(), 1);
362 }
363
364 #[test]
365 fn test_chained_interceptors() {
366 use crate::report::sender::SenderReportBuilder;
367
368 let mut chain = Registry::new()
370 .with(ReceiverReportBuilder::default().build())
371 .with(
372 SenderReportBuilder::default()
373 .with_interval(Duration::from_millis(250))
374 .build(),
375 )
376 .build();
377
378 let pkt = dummy_rtp_packet();
380 chain.handle_read(pkt).unwrap();
381 assert!(chain.poll_read().is_none());
382
383 let pkt2 = dummy_rtp_packet();
384 let pkt2_message = pkt2.message.clone();
385 chain.handle_write(pkt2).unwrap();
386 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
387 }
388
389 #[test]
390 fn test_receiver_report_generation_on_timeout() {
391 let mut chain = Registry::new()
394 .with(
395 ReceiverReportBuilder::default()
396 .with_interval(Duration::from_secs(1))
397 .build(),
398 )
399 .build();
400
401 let info = StreamInfo {
403 ssrc: 123456,
404 clock_rate: 90000,
405 ..Default::default()
406 };
407 chain.bind_remote_stream(&info);
408
409 let base_time = Instant::now();
410
411 for i in 0..10u16 {
413 let pkt = TaggedPacket {
414 now: base_time,
415 transport: Default::default(),
416 message: Packet::Rtp(rtp::Packet {
417 header: rtp::header::Header {
418 ssrc: 123456,
419 sequence_number: i,
420 timestamp: i as u32 * 3000,
421 ..Default::default()
422 },
423 ..Default::default()
424 }),
425 };
426 chain.handle_read(pkt).unwrap();
427 chain.poll_read();
428 }
429
430 chain.handle_timeout(base_time).unwrap();
432
433 while chain.poll_write().is_some() {}
435
436 let later_time = base_time + Duration::from_secs(2);
438 chain.handle_timeout(later_time).unwrap();
439
440 let report = chain.poll_write();
442 assert!(report.is_some());
443
444 if let Some(tagged) = report {
445 if let Packet::Rtcp(rtcp_packets) = tagged.message {
446 assert_eq!(rtcp_packets.len(), 1);
447 let rr = rtcp_packets[0]
448 .as_any()
449 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
450 .expect("Expected ReceiverReport");
451 assert_eq!(rr.reports.len(), 1);
452 assert_eq!(rr.reports[0].ssrc, 123456);
453 assert_eq!(rr.reports[0].last_sequence_number, 9);
454 assert_eq!(rr.reports[0].fraction_lost, 0);
455 assert_eq!(rr.reports[0].total_lost, 0);
456 } else {
457 panic!("Expected RTCP packet");
458 }
459 }
460 }
461
462 #[test]
463 fn test_receiver_report_with_packet_loss() {
464 let mut chain = Registry::new()
466 .with(
467 ReceiverReportBuilder::default()
468 .with_interval(Duration::from_secs(1))
469 .build(),
470 )
471 .build();
472
473 let info = StreamInfo {
474 ssrc: 123456,
475 clock_rate: 90000,
476 ..Default::default()
477 };
478 chain.bind_remote_stream(&info);
479
480 let base_time = Instant::now();
481
482 let pkt = TaggedPacket {
484 now: base_time,
485 transport: Default::default(),
486 message: Packet::Rtp(rtp::Packet {
487 header: rtp::header::Header {
488 ssrc: 123456,
489 sequence_number: 1,
490 timestamp: 3000,
491 ..Default::default()
492 },
493 ..Default::default()
494 }),
495 };
496 chain.handle_read(pkt).unwrap();
497 chain.poll_read();
498
499 let pkt = TaggedPacket {
501 now: base_time,
502 transport: Default::default(),
503 message: Packet::Rtp(rtp::Packet {
504 header: rtp::header::Header {
505 ssrc: 123456,
506 sequence_number: 3,
507 timestamp: 9000,
508 ..Default::default()
509 },
510 ..Default::default()
511 }),
512 };
513 chain.handle_read(pkt).unwrap();
514 chain.poll_read();
515
516 let later_time = base_time + Duration::from_secs(2);
518 chain.handle_timeout(later_time).unwrap();
519
520 let report = chain.poll_write();
521 assert!(report.is_some());
522
523 if let Some(tagged) = report {
524 if let Packet::Rtcp(rtcp_packets) = tagged.message {
525 let rr = rtcp_packets[0]
526 .as_any()
527 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
528 .expect("Expected ReceiverReport");
529 assert_eq!(rr.reports[0].last_sequence_number, 3);
530 assert_eq!(rr.reports[0].total_lost, 1);
532 assert_eq!(rr.reports[0].fraction_lost, (256u32 * 1 / 3) as u8);
534 } else {
535 panic!("Expected RTCP packet");
536 }
537 }
538 }
539
540 #[test]
541 fn test_receiver_report_with_sender_report() {
542 let mut chain = Registry::new()
544 .with(
545 ReceiverReportBuilder::default()
546 .with_interval(Duration::from_secs(1))
547 .build(),
548 )
549 .build();
550
551 let info = StreamInfo {
552 ssrc: 123456,
553 clock_rate: 90000,
554 ..Default::default()
555 };
556 chain.bind_remote_stream(&info);
557
558 let base_time = Instant::now();
559
560 let pkt = TaggedPacket {
562 now: base_time,
563 transport: Default::default(),
564 message: Packet::Rtp(rtp::Packet {
565 header: rtp::header::Header {
566 ssrc: 123456,
567 sequence_number: 1,
568 timestamp: 3000,
569 ..Default::default()
570 },
571 ..Default::default()
572 }),
573 };
574 chain.handle_read(pkt).unwrap();
575 chain.poll_read();
576
577 let sr = rtcp::sender_report::SenderReport {
579 ssrc: 123456,
580 ntp_time: 0x1234_5678_0000_0000,
581 rtp_time: 3000,
582 packet_count: 100,
583 octet_count: 10000,
584 ..Default::default()
585 };
586 let sr_pkt = TaggedPacket {
587 now: base_time,
588 transport: Default::default(),
589 message: Packet::Rtcp(vec![Box::new(sr)]),
590 };
591 chain.handle_read(sr_pkt).unwrap();
592
593 let later_time = base_time + Duration::from_secs(1);
595 chain.handle_timeout(later_time).unwrap();
596
597 let report = chain.poll_write();
598 assert!(report.is_some());
599
600 if let Some(tagged) = report {
601 if let Packet::Rtcp(rtcp_packets) = tagged.message {
602 let rr = rtcp_packets[0]
603 .as_any()
604 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
605 .expect("Expected ReceiverReport");
606 assert_eq!(rr.reports[0].delay, 65536);
608 assert_eq!(rr.reports[0].last_sender_report, 0x5678_0000);
610 } else {
611 panic!("Expected RTCP packet");
612 }
613 }
614 }
615
616 #[test]
617 fn test_receiver_report_multiple_streams() {
618 let mut chain = Registry::new()
620 .with(
621 ReceiverReportBuilder::default()
622 .with_interval(Duration::from_secs(1))
623 .build(),
624 )
625 .build();
626
627 let info1 = StreamInfo {
628 ssrc: 111111,
629 clock_rate: 90000,
630 ..Default::default()
631 };
632 let info2 = StreamInfo {
633 ssrc: 222222,
634 clock_rate: 48000,
635 ..Default::default()
636 };
637 chain.bind_remote_stream(&info1);
638 chain.bind_remote_stream(&info2);
639
640 let base_time = Instant::now();
641
642 for i in 0..5u16 {
644 let pkt = TaggedPacket {
645 now: base_time,
646 transport: Default::default(),
647 message: Packet::Rtp(rtp::Packet {
648 header: rtp::header::Header {
649 ssrc: 111111,
650 sequence_number: i,
651 timestamp: i as u32 * 3000,
652 ..Default::default()
653 },
654 ..Default::default()
655 }),
656 };
657 chain.handle_read(pkt).unwrap();
658 chain.poll_read();
659 }
660
661 let pkt = TaggedPacket {
663 now: base_time,
664 transport: Default::default(),
665 message: Packet::Rtp(rtp::Packet {
666 header: rtp::header::Header {
667 ssrc: 222222,
668 sequence_number: 0,
669 timestamp: 0,
670 ..Default::default()
671 },
672 ..Default::default()
673 }),
674 };
675 chain.handle_read(pkt).unwrap();
676 chain.poll_read();
677
678 let pkt = TaggedPacket {
679 now: base_time,
680 transport: Default::default(),
681 message: Packet::Rtp(rtp::Packet {
682 header: rtp::header::Header {
683 ssrc: 222222,
684 sequence_number: 5, timestamp: 5 * 960,
686 ..Default::default()
687 },
688 ..Default::default()
689 }),
690 };
691 chain.handle_read(pkt).unwrap();
692 chain.poll_read();
693
694 let later_time = base_time + Duration::from_secs(2);
696 chain.handle_timeout(later_time).unwrap();
697
698 let mut ssrcs = vec![];
700 let mut total_lost = vec![];
701
702 while let Some(tagged) = chain.poll_write() {
703 if let Packet::Rtcp(rtcp_packets) = tagged.message {
704 for rtcp_pkt in rtcp_packets {
705 if let Some(rr) = rtcp_pkt
706 .as_any()
707 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
708 {
709 for report in &rr.reports {
710 ssrcs.push(report.ssrc);
711 total_lost.push(report.total_lost);
712 }
713 }
714 }
715 }
716 }
717
718 assert_eq!(ssrcs.len(), 2);
719 assert!(ssrcs.contains(&111111));
720 assert!(ssrcs.contains(&222222));
721
722 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
724 assert_eq!(total_lost[idx1], 0);
725
726 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
728 assert_eq!(total_lost[idx2], 4);
729 }
730
731 #[test]
732 fn test_receiver_report_unbind_stream() {
733 let mut chain = Registry::new()
735 .with(
736 ReceiverReportBuilder::default()
737 .with_interval(Duration::from_secs(1))
738 .build(),
739 )
740 .build();
741
742 let info = StreamInfo {
743 ssrc: 123456,
744 clock_rate: 90000,
745 ..Default::default()
746 };
747 chain.bind_remote_stream(&info);
748
749 let base_time = Instant::now();
750
751 let pkt = TaggedPacket {
753 now: base_time,
754 transport: Default::default(),
755 message: Packet::Rtp(rtp::Packet {
756 header: rtp::header::Header {
757 ssrc: 123456,
758 sequence_number: 0,
759 timestamp: 0,
760 ..Default::default()
761 },
762 ..Default::default()
763 }),
764 };
765 chain.handle_read(pkt).unwrap();
766 chain.poll_read();
767
768 chain.unbind_remote_stream(&info);
770
771 let later_time = base_time + Duration::from_secs(2);
773 chain.handle_timeout(later_time).unwrap();
774
775 assert!(chain.poll_write().is_none());
777 }
778
779 #[test]
780 fn test_receiver_report_sequence_wrap() {
781 let mut chain = Registry::new()
783 .with(
784 ReceiverReportBuilder::default()
785 .with_interval(Duration::from_secs(1))
786 .build(),
787 )
788 .build();
789
790 let info = StreamInfo {
791 ssrc: 123456,
792 clock_rate: 90000,
793 ..Default::default()
794 };
795 chain.bind_remote_stream(&info);
796
797 let base_time = Instant::now();
798
799 let pkt = TaggedPacket {
801 now: base_time,
802 transport: Default::default(),
803 message: Packet::Rtp(rtp::Packet {
804 header: rtp::header::Header {
805 ssrc: 123456,
806 sequence_number: 0xffff,
807 timestamp: 0,
808 ..Default::default()
809 },
810 ..Default::default()
811 }),
812 };
813 chain.handle_read(pkt).unwrap();
814 chain.poll_read();
815
816 let pkt = TaggedPacket {
818 now: base_time,
819 transport: Default::default(),
820 message: Packet::Rtp(rtp::Packet {
821 header: rtp::header::Header {
822 ssrc: 123456,
823 sequence_number: 0x00,
824 timestamp: 3000,
825 ..Default::default()
826 },
827 ..Default::default()
828 }),
829 };
830 chain.handle_read(pkt).unwrap();
831 chain.poll_read();
832
833 let later_time = base_time + Duration::from_secs(2);
835 chain.handle_timeout(later_time).unwrap();
836
837 let report = chain.poll_write();
838 assert!(report.is_some());
839
840 if let Some(tagged) = report {
841 if let Packet::Rtcp(rtcp_packets) = tagged.message {
842 let rr = rtcp_packets[0]
843 .as_any()
844 .downcast_ref::<rtcp::receiver_report::ReceiverReport>()
845 .expect("Expected ReceiverReport");
846 assert_eq!(rr.reports[0].last_sequence_number, 1 << 16);
848 assert_eq!(rr.reports[0].fraction_lost, 0);
849 assert_eq!(rr.reports[0].total_lost, 0);
850 } else {
851 panic!("Expected RTCP packet");
852 }
853 }
854 }
855}