1use super::sender_stream::SenderStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use rtcp::header::PacketType;
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13pub struct SenderReportBuilder<P> {
37 interval: Duration,
39 use_latest_packet: bool,
41 _phantom: PhantomData<P>,
42}
43
44impl<P> Default for SenderReportBuilder<P> {
45 fn default() -> Self {
46 Self {
47 interval: Duration::from_secs(1),
48 use_latest_packet: false,
49 _phantom: PhantomData,
50 }
51 }
52}
53
54impl<P> SenderReportBuilder<P> {
55 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn with_interval(mut self, interval: Duration) -> Self {
74 self.interval = interval;
75 self
76 }
77
78 pub fn with_use_latest_packet(mut self) -> Self {
99 self.use_latest_packet = true;
100 self
101 }
102
103 pub fn build(self) -> impl FnOnce(P) -> SenderReportInterceptor<P> {
116 move |inner| SenderReportInterceptor::new(inner, self.interval, self.use_latest_packet)
117 }
118}
119
120pub struct SenderReportInterceptor<P> {
140 inner: P,
141
142 interval: Duration,
143 eto: Instant,
144
145 use_latest_packet: bool,
147
148 streams: HashMap<u32, SenderStream>,
149
150 read_queue: VecDeque<TaggedPacket>,
151 write_queue: VecDeque<TaggedPacket>,
152}
153
154impl<P> SenderReportInterceptor<P> {
155 fn new(inner: P, interval: Duration, use_latest_packet: bool) -> Self {
157 Self {
158 inner,
159
160 interval,
161 eto: Instant::now(),
162
163 use_latest_packet,
164
165 streams: HashMap::new(),
166
167 read_queue: VecDeque::new(),
168 write_queue: VecDeque::new(),
169 }
170 }
171
172 fn should_filter(packet_type: PacketType) -> bool {
178 packet_type == PacketType::ReceiverReport
179 || (packet_type == PacketType::TransportSpecificFeedback)
180 }
181
182 fn inner(&self) -> &P {
184 &self.inner
185 }
186
187 fn inner_mut(&mut self) -> &mut P {
189 &mut self.inner
190 }
191}
192
193impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
194 for SenderReportInterceptor<P>
195{
196 type Rout = TaggedPacket;
197 type Wout = TaggedPacket;
198 type Eout = ();
199 type Error = Error;
200 type Time = Instant;
201
202 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
203 self.inner.handle_read(msg)
204 }
205
206 fn poll_read(&mut self) -> Option<Self::Rout> {
207 self.inner.poll_read()
208 }
209
210 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
211 if let Packet::Rtp(rtp_packet) = &msg.message
212 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
213 {
214 stream.process_rtp(msg.now, rtp_packet);
215 }
216
217 self.inner.handle_write(msg)
218 }
219
220 fn poll_write(&mut self) -> Option<Self::Wout> {
221 if let Some(pkt) = self.write_queue.pop_front() {
223 return Some(pkt);
224 }
225 self.inner.poll_write()
226 }
227
228 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
229 if self.eto <= now {
230 self.eto = now + self.interval;
231
232 for stream in self.streams.values_mut() {
233 let rr = stream.generate_report(now);
234 self.write_queue.push_back(TaggedPacket {
235 now,
236 transport: TransportContext::default(),
237 message: Packet::Rtcp(vec![Box::new(rr)]),
238 });
239 }
240 }
241
242 self.inner.handle_timeout(now)
243 }
244
245 fn poll_timeout(&mut self) -> Option<Self::Time> {
246 if let Some(eto) = self.inner.poll_timeout()
247 && eto < self.eto
248 {
249 Some(eto)
250 } else {
251 Some(self.eto)
252 }
253 }
254}
255
256impl<P: Interceptor> Interceptor for SenderReportInterceptor<P> {
257 fn bind_local_stream(&mut self, info: &StreamInfo) {
258 let stream = SenderStream::new(info.ssrc, info.clock_rate, self.use_latest_packet);
259 self.streams.insert(info.ssrc, stream);
260
261 self.inner.bind_local_stream(info);
262 }
263 fn unbind_local_stream(&mut self, info: &StreamInfo) {
264 self.streams.remove(&info.ssrc);
265
266 self.inner.unbind_local_stream(info);
267 }
268 fn bind_remote_stream(&mut self, info: &StreamInfo) {
269 self.inner.bind_remote_stream(info);
270 }
271 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
272 self.inner.unbind_remote_stream(info);
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::{NoopInterceptor, Registry};
280 use sansio::Protocol;
281
282 fn dummy_rtp_packet() -> TaggedPacket {
283 TaggedPacket {
284 now: Instant::now(),
285 transport: Default::default(),
286 message: crate::Packet::Rtp(rtp::Packet::default()),
287 }
288 }
289
290 #[test]
291 fn test_sender_report_builder_default() {
292 let chain = Registry::new()
294 .with(SenderReportBuilder::default().build())
295 .build();
296
297 assert_eq!(chain.interval, Duration::from_secs(1));
298 }
299
300 #[test]
301 fn test_sender_report_builder_with_custom_interval() {
302 let chain = Registry::new()
304 .with(
305 SenderReportBuilder::default()
306 .with_interval(Duration::from_millis(500))
307 .build(),
308 )
309 .build();
310
311 assert_eq!(chain.interval, Duration::from_millis(500));
312 }
313
314 #[test]
315 fn test_sender_report_chain_handle_read_write() {
316 let mut chain = Registry::new()
318 .with(SenderReportBuilder::default().build())
319 .build();
320
321 let pkt = dummy_rtp_packet();
323 chain.handle_read(pkt).unwrap();
324 assert!(chain.poll_read().is_none());
325
326 let pkt2 = dummy_rtp_packet();
328 let pkt2_message = pkt2.message.clone();
329 chain.handle_write(pkt2).unwrap();
330 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
331 }
332
333 #[test]
334 fn test_should_filter() {
335 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
337 PacketType::ReceiverReport
338 ));
339
340 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
342 PacketType::TransportSpecificFeedback
343 ));
344
345 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
347 PacketType::SenderReport
348 ));
349
350 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
352 PacketType::SourceDescription
353 ));
354
355 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
357 PacketType::Goodbye
358 ));
359 }
360
361 #[test]
362 fn test_inner_access() {
363 let mut chain = Registry::new()
364 .with(SenderReportBuilder::default().build())
365 .build();
366
367 let _ = chain.inner();
369
370 let pkt = dummy_rtp_packet();
372 let pkt_message = pkt.message.clone();
373 chain.inner_mut().handle_write(pkt).unwrap();
374 assert_eq!(chain.inner_mut().poll_write().unwrap().message, pkt_message);
375 }
376
377 #[test]
378 fn test_use_latest_packet_option() {
379 let chain = Registry::new()
381 .with(
382 SenderReportBuilder::default()
383 .with_use_latest_packet()
384 .build(),
385 )
386 .build();
387
388 assert!(chain.use_latest_packet);
389
390 let chain_default = Registry::new()
392 .with(SenderReportBuilder::default().build())
393 .build();
394
395 assert!(!chain_default.use_latest_packet);
396 }
397
398 #[test]
399 fn test_use_latest_packet_combined_options() {
400 let chain = Registry::new()
402 .with(
403 SenderReportBuilder::default()
404 .with_interval(Duration::from_millis(250))
405 .with_use_latest_packet()
406 .build(),
407 )
408 .build();
409
410 assert_eq!(chain.interval, Duration::from_millis(250));
411 assert!(chain.use_latest_packet);
412 }
413
414 #[test]
415 fn test_sender_report_generation_on_timeout() {
416 let mut chain = Registry::new()
419 .with(
420 SenderReportBuilder::default()
421 .with_interval(Duration::from_secs(1))
422 .build(),
423 )
424 .build();
425
426 let info = StreamInfo {
428 ssrc: 123456,
429 clock_rate: 90000,
430 ..Default::default()
431 };
432 chain.bind_local_stream(&info);
433
434 let base_time = Instant::now();
435
436 for i in 0..5u16 {
438 let pkt = TaggedPacket {
439 now: base_time,
440 transport: Default::default(),
441 message: Packet::Rtp(rtp::Packet {
442 header: rtp::header::Header {
443 ssrc: 123456,
444 sequence_number: i,
445 timestamp: i as u32 * 3000,
446 ..Default::default()
447 },
448 payload: vec![0u8; 100].into(),
449 ..Default::default()
450 }),
451 };
452 chain.handle_write(pkt).unwrap();
453 chain.poll_write();
455 }
456
457 chain.handle_timeout(base_time).unwrap();
459
460 while chain.poll_write().is_some() {}
462
463 let later_time = base_time + Duration::from_secs(2);
465 chain.handle_timeout(later_time).unwrap();
466
467 let report = chain.poll_write();
469 assert!(report.is_some());
470
471 if let Some(tagged) = report {
472 if let Packet::Rtcp(rtcp_packets) = tagged.message {
473 assert_eq!(rtcp_packets.len(), 1);
474 let sr = rtcp_packets[0]
475 .as_any()
476 .downcast_ref::<rtcp::sender_report::SenderReport>()
477 .expect("Expected SenderReport");
478 assert_eq!(sr.ssrc, 123456);
479 assert_eq!(sr.packet_count, 5);
480 assert_eq!(sr.octet_count, 500);
481 } else {
482 panic!("Expected RTCP packet");
483 }
484 }
485 }
486
487 #[test]
488 fn test_sender_report_multiple_streams() {
489 let mut chain = Registry::new()
491 .with(
492 SenderReportBuilder::default()
493 .with_interval(Duration::from_secs(1))
494 .build(),
495 )
496 .build();
497
498 let info1 = StreamInfo {
500 ssrc: 111111,
501 clock_rate: 90000,
502 ..Default::default()
503 };
504 let info2 = StreamInfo {
505 ssrc: 222222,
506 clock_rate: 48000,
507 ..Default::default()
508 };
509 chain.bind_local_stream(&info1);
510 chain.bind_local_stream(&info2);
511
512 let base_time = Instant::now();
513
514 for i in 0..3u16 {
516 let pkt = TaggedPacket {
517 now: base_time,
518 transport: Default::default(),
519 message: Packet::Rtp(rtp::Packet {
520 header: rtp::header::Header {
521 ssrc: 111111,
522 sequence_number: i,
523 timestamp: i as u32 * 3000,
524 ..Default::default()
525 },
526 payload: vec![0u8; 50].into(),
527 ..Default::default()
528 }),
529 };
530 chain.handle_write(pkt).unwrap();
531 chain.poll_write();
532 }
533
534 for i in 0..7u16 {
536 let pkt = TaggedPacket {
537 now: base_time,
538 transport: Default::default(),
539 message: Packet::Rtp(rtp::Packet {
540 header: rtp::header::Header {
541 ssrc: 222222,
542 sequence_number: i,
543 timestamp: i as u32 * 960,
544 ..Default::default()
545 },
546 payload: vec![0u8; 200].into(),
547 ..Default::default()
548 }),
549 };
550 chain.handle_write(pkt).unwrap();
551 chain.poll_write();
552 }
553
554 let later_time = base_time + Duration::from_secs(2);
556 chain.handle_timeout(later_time).unwrap();
557
558 let mut ssrcs = vec![];
560 let mut packet_counts = vec![];
561 let mut octet_counts = vec![];
562
563 while let Some(tagged) = chain.poll_write() {
564 if let Packet::Rtcp(rtcp_packets) = tagged.message {
565 for rtcp_pkt in rtcp_packets {
566 if let Some(sr) = rtcp_pkt
567 .as_any()
568 .downcast_ref::<rtcp::sender_report::SenderReport>()
569 {
570 ssrcs.push(sr.ssrc);
571 packet_counts.push(sr.packet_count);
572 octet_counts.push(sr.octet_count);
573 }
574 }
575 }
576 }
577
578 assert_eq!(ssrcs.len(), 2);
579 assert!(ssrcs.contains(&111111));
580 assert!(ssrcs.contains(&222222));
581
582 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
584 assert_eq!(packet_counts[idx1], 3);
585 assert_eq!(octet_counts[idx1], 150);
586
587 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
589 assert_eq!(packet_counts[idx2], 7);
590 assert_eq!(octet_counts[idx2], 1400);
591 }
592
593 #[test]
594 fn test_sender_report_unbind_stream() {
595 let mut chain = Registry::new()
597 .with(
598 SenderReportBuilder::default()
599 .with_interval(Duration::from_secs(1))
600 .build(),
601 )
602 .build();
603
604 let info = StreamInfo {
605 ssrc: 123456,
606 clock_rate: 90000,
607 ..Default::default()
608 };
609 chain.bind_local_stream(&info);
610
611 let base_time = Instant::now();
612
613 let pkt = TaggedPacket {
615 now: base_time,
616 transport: Default::default(),
617 message: Packet::Rtp(rtp::Packet {
618 header: rtp::header::Header {
619 ssrc: 123456,
620 sequence_number: 0,
621 timestamp: 0,
622 ..Default::default()
623 },
624 payload: vec![0u8; 100].into(),
625 ..Default::default()
626 }),
627 };
628 chain.handle_write(pkt).unwrap();
629 chain.poll_write();
630
631 chain.unbind_local_stream(&info);
633
634 let later_time = base_time + Duration::from_secs(2);
636 chain.handle_timeout(later_time).unwrap();
637
638 assert!(chain.poll_write().is_none());
640 }
641
642 #[test]
643 fn test_poll_timeout_returns_earliest() {
644 let mut chain = Registry::new()
646 .with(
647 SenderReportBuilder::default()
648 .with_interval(Duration::from_secs(5))
649 .build(),
650 )
651 .build();
652
653 let timeout = chain.poll_timeout();
655 assert!(timeout.is_some());
656 }
657}