1use super::sender_stream::SenderStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use rtcp::header::PacketType;
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13pub struct SenderReportBuilder<P> {
37 interval: Duration,
39 use_latest_packet: bool,
41 _phantom: PhantomData<P>,
42}
43
44impl<P> Default for SenderReportBuilder<P> {
45 fn default() -> Self {
46 Self {
47 interval: Duration::from_secs(1),
48 use_latest_packet: false,
49 _phantom: PhantomData,
50 }
51 }
52}
53
54impl<P> SenderReportBuilder<P> {
55 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn with_interval(mut self, interval: Duration) -> Self {
74 self.interval = interval;
75 self
76 }
77
78 pub fn with_use_latest_packet(mut self) -> Self {
99 self.use_latest_packet = true;
100 self
101 }
102
103 pub fn build(self) -> impl FnOnce(P) -> SenderReportInterceptor<P> {
116 move |inner| SenderReportInterceptor::new(inner, self.interval, self.use_latest_packet)
117 }
118}
119
120pub struct SenderReportInterceptor<P> {
140 inner: P,
141
142 interval: Duration,
143 eto: Instant,
144
145 use_latest_packet: bool,
147
148 streams: HashMap<u32, SenderStream>,
149
150 read_queue: VecDeque<TaggedPacket>,
151 write_queue: VecDeque<TaggedPacket>,
152}
153
154impl<P> SenderReportInterceptor<P> {
155 fn new(inner: P, interval: Duration, use_latest_packet: bool) -> Self {
157 Self {
158 inner,
159
160 interval,
161 eto: Instant::now(),
162
163 use_latest_packet,
164
165 streams: HashMap::new(),
166
167 read_queue: VecDeque::new(),
168 write_queue: VecDeque::new(),
169 }
170 }
171
172 fn should_filter(packet_type: PacketType) -> bool {
178 packet_type == PacketType::ReceiverReport
179 || (packet_type == PacketType::TransportSpecificFeedback)
180 }
181
182 fn inner(&self) -> &P {
184 &self.inner
185 }
186
187 fn inner_mut(&mut self) -> &mut P {
189 &mut self.inner
190 }
191}
192
193impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
194 for SenderReportInterceptor<P>
195{
196 type Rout = TaggedPacket;
197 type Wout = TaggedPacket;
198 type Eout = ();
199 type Error = Error;
200 type Time = Instant;
201
202 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
203 self.inner.handle_read(msg)
204 }
205
206 fn poll_read(&mut self) -> Option<Self::Rout> {
207 self.inner.poll_read()
208 }
209
210 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
211 if let Packet::Rtp(rtp_packet) = &msg.message
212 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
213 {
214 stream.process_rtp(msg.now, rtp_packet);
215 }
216
217 self.inner.handle_write(msg)
218 }
219
220 fn poll_write(&mut self) -> Option<Self::Wout> {
221 if let Some(pkt) = self.write_queue.pop_front() {
223 return Some(pkt);
224 }
225 self.inner.poll_write()
226 }
227
228 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
229 if self.eto <= now {
230 self.eto = now + self.interval;
231
232 for stream in self.streams.values_mut() {
233 let rr = stream.generate_report(now);
234 self.write_queue.push_back(TaggedPacket {
235 now,
236 transport: TransportContext::default(),
237 message: Packet::Rtcp(vec![Box::new(rr)]),
238 });
239 }
240 }
241
242 self.inner.handle_timeout(now)
243 }
244
245 fn poll_timeout(&mut self) -> Option<Self::Time> {
246 if let Some(eto) = self.inner.poll_timeout()
247 && eto < self.eto
248 {
249 Some(eto)
250 } else {
251 Some(self.eto)
252 }
253 }
254}
255
256impl<P: Interceptor> Interceptor for SenderReportInterceptor<P> {
257 fn bind_local_stream(&mut self, info: &StreamInfo) {
258 let stream = SenderStream::new(info.ssrc, info.clock_rate, self.use_latest_packet);
259 self.streams.insert(info.ssrc, stream);
260
261 self.inner.bind_local_stream(info);
262 }
263 fn unbind_local_stream(&mut self, info: &StreamInfo) {
264 self.streams.remove(&info.ssrc);
265
266 self.inner.unbind_local_stream(info);
267 }
268 fn bind_remote_stream(&mut self, info: &StreamInfo) {
269 self.inner.bind_remote_stream(info);
270 }
271 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
272 self.inner.unbind_remote_stream(info);
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::{NoopInterceptor, Registry};
280 use sansio::Protocol;
281
282 fn dummy_rtp_packet() -> TaggedPacket {
283 TaggedPacket {
284 now: Instant::now(),
285 transport: Default::default(),
286 message: crate::Packet::Rtp(rtp::Packet::default()),
287 }
288 }
289
290 #[test]
291 fn test_sender_report_builder_default() {
292 let chain = Registry::new()
294 .with(SenderReportBuilder::default().build())
295 .build();
296
297 assert_eq!(chain.interval, Duration::from_secs(1));
298 }
299
300 #[test]
301 fn test_sender_report_builder_with_custom_interval() {
302 let chain = Registry::new()
304 .with(
305 SenderReportBuilder::default()
306 .with_interval(Duration::from_millis(500))
307 .build(),
308 )
309 .build();
310
311 assert_eq!(chain.interval, Duration::from_millis(500));
312 }
313
314 #[test]
315 fn test_sender_report_chain_handle_read_write() {
316 let mut chain = Registry::new()
318 .with(SenderReportBuilder::default().build())
319 .build();
320
321 let pkt = dummy_rtp_packet();
323 let pkt_message = pkt.message.clone();
324 chain.handle_read(pkt).unwrap();
325 assert_eq!(chain.poll_read().unwrap().message, pkt_message);
326
327 let pkt2 = dummy_rtp_packet();
329 let pkt2_message = pkt2.message.clone();
330 chain.handle_write(pkt2).unwrap();
331 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
332 }
333
334 #[test]
335 fn test_should_filter() {
336 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
338 PacketType::ReceiverReport
339 ));
340
341 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
343 PacketType::TransportSpecificFeedback
344 ));
345
346 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
348 PacketType::SenderReport
349 ));
350
351 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
353 PacketType::SourceDescription
354 ));
355
356 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
358 PacketType::Goodbye
359 ));
360 }
361
362 #[test]
363 fn test_inner_access() {
364 let mut chain = Registry::new()
365 .with(SenderReportBuilder::default().build())
366 .build();
367
368 let _ = chain.inner();
370
371 let pkt = dummy_rtp_packet();
373 let pkt_message = pkt.message.clone();
374 chain.inner_mut().handle_write(pkt).unwrap();
375 assert_eq!(chain.inner_mut().poll_write().unwrap().message, pkt_message);
376 }
377
378 #[test]
379 fn test_use_latest_packet_option() {
380 let chain = Registry::new()
382 .with(
383 SenderReportBuilder::default()
384 .with_use_latest_packet()
385 .build(),
386 )
387 .build();
388
389 assert!(chain.use_latest_packet);
390
391 let chain_default = Registry::new()
393 .with(SenderReportBuilder::default().build())
394 .build();
395
396 assert!(!chain_default.use_latest_packet);
397 }
398
399 #[test]
400 fn test_use_latest_packet_combined_options() {
401 let chain = Registry::new()
403 .with(
404 SenderReportBuilder::default()
405 .with_interval(Duration::from_millis(250))
406 .with_use_latest_packet()
407 .build(),
408 )
409 .build();
410
411 assert_eq!(chain.interval, Duration::from_millis(250));
412 assert!(chain.use_latest_packet);
413 }
414
415 #[test]
416 fn test_sender_report_generation_on_timeout() {
417 let mut chain = Registry::new()
420 .with(
421 SenderReportBuilder::default()
422 .with_interval(Duration::from_secs(1))
423 .build(),
424 )
425 .build();
426
427 let info = StreamInfo {
429 ssrc: 123456,
430 clock_rate: 90000,
431 ..Default::default()
432 };
433 chain.bind_local_stream(&info);
434
435 let base_time = Instant::now();
436
437 for i in 0..5u16 {
439 let pkt = TaggedPacket {
440 now: base_time,
441 transport: Default::default(),
442 message: Packet::Rtp(rtp::Packet {
443 header: rtp::header::Header {
444 ssrc: 123456,
445 sequence_number: i,
446 timestamp: i as u32 * 3000,
447 ..Default::default()
448 },
449 payload: vec![0u8; 100].into(),
450 ..Default::default()
451 }),
452 };
453 chain.handle_write(pkt).unwrap();
454 chain.poll_write();
456 }
457
458 chain.handle_timeout(base_time).unwrap();
460
461 while chain.poll_write().is_some() {}
463
464 let later_time = base_time + Duration::from_secs(2);
466 chain.handle_timeout(later_time).unwrap();
467
468 let report = chain.poll_write();
470 assert!(report.is_some());
471
472 if let Some(tagged) = report {
473 if let Packet::Rtcp(rtcp_packets) = tagged.message {
474 assert_eq!(rtcp_packets.len(), 1);
475 let sr = rtcp_packets[0]
476 .as_any()
477 .downcast_ref::<rtcp::sender_report::SenderReport>()
478 .expect("Expected SenderReport");
479 assert_eq!(sr.ssrc, 123456);
480 assert_eq!(sr.packet_count, 5);
481 assert_eq!(sr.octet_count, 500);
482 } else {
483 panic!("Expected RTCP packet");
484 }
485 }
486 }
487
488 #[test]
489 fn test_sender_report_multiple_streams() {
490 let mut chain = Registry::new()
492 .with(
493 SenderReportBuilder::default()
494 .with_interval(Duration::from_secs(1))
495 .build(),
496 )
497 .build();
498
499 let info1 = StreamInfo {
501 ssrc: 111111,
502 clock_rate: 90000,
503 ..Default::default()
504 };
505 let info2 = StreamInfo {
506 ssrc: 222222,
507 clock_rate: 48000,
508 ..Default::default()
509 };
510 chain.bind_local_stream(&info1);
511 chain.bind_local_stream(&info2);
512
513 let base_time = Instant::now();
514
515 for i in 0..3u16 {
517 let pkt = TaggedPacket {
518 now: base_time,
519 transport: Default::default(),
520 message: Packet::Rtp(rtp::Packet {
521 header: rtp::header::Header {
522 ssrc: 111111,
523 sequence_number: i,
524 timestamp: i as u32 * 3000,
525 ..Default::default()
526 },
527 payload: vec![0u8; 50].into(),
528 ..Default::default()
529 }),
530 };
531 chain.handle_write(pkt).unwrap();
532 chain.poll_write();
533 }
534
535 for i in 0..7u16 {
537 let pkt = TaggedPacket {
538 now: base_time,
539 transport: Default::default(),
540 message: Packet::Rtp(rtp::Packet {
541 header: rtp::header::Header {
542 ssrc: 222222,
543 sequence_number: i,
544 timestamp: i as u32 * 960,
545 ..Default::default()
546 },
547 payload: vec![0u8; 200].into(),
548 ..Default::default()
549 }),
550 };
551 chain.handle_write(pkt).unwrap();
552 chain.poll_write();
553 }
554
555 let later_time = base_time + Duration::from_secs(2);
557 chain.handle_timeout(later_time).unwrap();
558
559 let mut ssrcs = vec![];
561 let mut packet_counts = vec![];
562 let mut octet_counts = vec![];
563
564 while let Some(tagged) = chain.poll_write() {
565 if let Packet::Rtcp(rtcp_packets) = tagged.message {
566 for rtcp_pkt in rtcp_packets {
567 if let Some(sr) = rtcp_pkt
568 .as_any()
569 .downcast_ref::<rtcp::sender_report::SenderReport>()
570 {
571 ssrcs.push(sr.ssrc);
572 packet_counts.push(sr.packet_count);
573 octet_counts.push(sr.octet_count);
574 }
575 }
576 }
577 }
578
579 assert_eq!(ssrcs.len(), 2);
580 assert!(ssrcs.contains(&111111));
581 assert!(ssrcs.contains(&222222));
582
583 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
585 assert_eq!(packet_counts[idx1], 3);
586 assert_eq!(octet_counts[idx1], 150);
587
588 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
590 assert_eq!(packet_counts[idx2], 7);
591 assert_eq!(octet_counts[idx2], 1400);
592 }
593
594 #[test]
595 fn test_sender_report_unbind_stream() {
596 let mut chain = Registry::new()
598 .with(
599 SenderReportBuilder::default()
600 .with_interval(Duration::from_secs(1))
601 .build(),
602 )
603 .build();
604
605 let info = StreamInfo {
606 ssrc: 123456,
607 clock_rate: 90000,
608 ..Default::default()
609 };
610 chain.bind_local_stream(&info);
611
612 let base_time = Instant::now();
613
614 let pkt = TaggedPacket {
616 now: base_time,
617 transport: Default::default(),
618 message: Packet::Rtp(rtp::Packet {
619 header: rtp::header::Header {
620 ssrc: 123456,
621 sequence_number: 0,
622 timestamp: 0,
623 ..Default::default()
624 },
625 payload: vec![0u8; 100].into(),
626 ..Default::default()
627 }),
628 };
629 chain.handle_write(pkt).unwrap();
630 chain.poll_write();
631
632 chain.unbind_local_stream(&info);
634
635 let later_time = base_time + Duration::from_secs(2);
637 chain.handle_timeout(later_time).unwrap();
638
639 assert!(chain.poll_write().is_none());
641 }
642
643 #[test]
644 fn test_poll_timeout_returns_earliest() {
645 let mut chain = Registry::new()
647 .with(
648 SenderReportBuilder::default()
649 .with_interval(Duration::from_secs(5))
650 .build(),
651 )
652 .build();
653
654 let timeout = chain.poll_timeout();
656 assert!(timeout.is_some());
657 }
658}