1use super::sender_stream::SenderStream;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket, interceptor};
6use rtcp::header::PacketType;
7use shared::TransportContext;
8use shared::error::Error;
9use std::collections::{HashMap, VecDeque};
10use std::marker::PhantomData;
11use std::time::{Duration, Instant};
12
13pub struct SenderReportBuilder<P> {
37 interval: Duration,
39 use_latest_packet: bool,
41 _phantom: PhantomData<P>,
42}
43
44impl<P> Default for SenderReportBuilder<P> {
45 fn default() -> Self {
46 Self {
47 interval: Duration::from_secs(1),
48 use_latest_packet: false,
49 _phantom: PhantomData,
50 }
51 }
52}
53
54impl<P> SenderReportBuilder<P> {
55 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn with_interval(mut self, interval: Duration) -> Self {
74 self.interval = interval;
75 self
76 }
77
78 pub fn with_use_latest_packet(mut self) -> Self {
99 self.use_latest_packet = true;
100 self
101 }
102
103 pub fn build(self) -> impl FnOnce(P) -> SenderReportInterceptor<P> {
116 move |inner| SenderReportInterceptor::new(inner, self.interval, self.use_latest_packet)
117 }
118}
119
120#[derive(Interceptor)]
140pub struct SenderReportInterceptor<P> {
141 #[next]
142 inner: P,
143
144 interval: Duration,
145 eto: Instant,
146
147 use_latest_packet: bool,
149
150 streams: HashMap<u32, SenderStream>,
151
152 read_queue: VecDeque<TaggedPacket>,
153 write_queue: VecDeque<TaggedPacket>,
154}
155
156impl<P> SenderReportInterceptor<P> {
157 fn new(inner: P, interval: Duration, use_latest_packet: bool) -> Self {
159 Self {
160 inner,
161
162 interval,
163 eto: Instant::now(),
164
165 use_latest_packet,
166
167 streams: HashMap::new(),
168
169 read_queue: VecDeque::new(),
170 write_queue: VecDeque::new(),
171 }
172 }
173
174 fn should_filter(packet_type: PacketType) -> bool {
180 packet_type == PacketType::ReceiverReport
181 || (packet_type == PacketType::TransportSpecificFeedback)
182 }
183
184 fn inner(&self) -> &P {
186 &self.inner
187 }
188
189 fn inner_mut(&mut self) -> &mut P {
191 &mut self.inner
192 }
193}
194
195#[interceptor]
196impl<P: Interceptor> SenderReportInterceptor<P> {
197 #[overrides]
198 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
199 if let Packet::Rtp(rtp_packet) = &msg.message
200 && let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
201 {
202 stream.process_rtp(msg.now, rtp_packet);
203 }
204
205 self.inner.handle_write(msg)
206 }
207
208 #[overrides]
209 fn poll_write(&mut self) -> Option<Self::Wout> {
210 if let Some(pkt) = self.write_queue.pop_front() {
212 return Some(pkt);
213 }
214 self.inner.poll_write()
215 }
216
217 #[overrides]
218 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
219 if self.eto <= now {
220 self.eto = now + self.interval;
221
222 for stream in self.streams.values_mut() {
223 let rr = stream.generate_report(now);
224 self.write_queue.push_back(TaggedPacket {
225 now,
226 transport: TransportContext::default(),
227 message: Packet::Rtcp(vec![Box::new(rr)]),
228 });
229 }
230 }
231
232 self.inner.handle_timeout(now)
233 }
234
235 #[overrides]
236 fn poll_timeout(&mut self) -> Option<Self::Time> {
237 if let Some(eto) = self.inner.poll_timeout()
238 && eto < self.eto
239 {
240 Some(eto)
241 } else {
242 Some(self.eto)
243 }
244 }
245
246 #[overrides]
247 fn bind_local_stream(&mut self, info: &StreamInfo) {
248 let stream = SenderStream::new(info.ssrc, info.clock_rate, self.use_latest_packet);
249 self.streams.insert(info.ssrc, stream);
250
251 self.inner.bind_local_stream(info);
252 }
253
254 #[overrides]
255 fn unbind_local_stream(&mut self, info: &StreamInfo) {
256 self.streams.remove(&info.ssrc);
257
258 self.inner.unbind_local_stream(info);
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::{NoopInterceptor, Registry};
266 use sansio::Protocol;
267
268 fn dummy_rtp_packet() -> TaggedPacket {
269 TaggedPacket {
270 now: Instant::now(),
271 transport: Default::default(),
272 message: crate::Packet::Rtp(rtp::Packet::default()),
273 }
274 }
275
276 #[test]
277 fn test_sender_report_builder_default() {
278 let chain = Registry::new()
280 .with(SenderReportBuilder::default().build())
281 .build();
282
283 assert_eq!(chain.interval, Duration::from_secs(1));
284 }
285
286 #[test]
287 fn test_sender_report_builder_with_custom_interval() {
288 let chain = Registry::new()
290 .with(
291 SenderReportBuilder::default()
292 .with_interval(Duration::from_millis(500))
293 .build(),
294 )
295 .build();
296
297 assert_eq!(chain.interval, Duration::from_millis(500));
298 }
299
300 #[test]
301 fn test_sender_report_chain_handle_read_write() {
302 let mut chain = Registry::new()
304 .with(SenderReportBuilder::default().build())
305 .build();
306
307 let pkt = dummy_rtp_packet();
309 chain.handle_read(pkt).unwrap();
310 assert!(chain.poll_read().is_none());
311
312 let pkt2 = dummy_rtp_packet();
314 let pkt2_message = pkt2.message.clone();
315 chain.handle_write(pkt2).unwrap();
316 assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
317 }
318
319 #[test]
320 fn test_should_filter() {
321 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
323 PacketType::ReceiverReport
324 ));
325
326 assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
328 PacketType::TransportSpecificFeedback
329 ));
330
331 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
333 PacketType::SenderReport
334 ));
335
336 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
338 PacketType::SourceDescription
339 ));
340
341 assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
343 PacketType::Goodbye
344 ));
345 }
346
347 #[test]
348 fn test_inner_access() {
349 let mut chain = Registry::new()
350 .with(SenderReportBuilder::default().build())
351 .build();
352
353 let _ = chain.inner();
355
356 let pkt = dummy_rtp_packet();
358 let pkt_message = pkt.message.clone();
359 chain.inner_mut().handle_write(pkt).unwrap();
360 assert_eq!(chain.inner_mut().poll_write().unwrap().message, pkt_message);
361 }
362
363 #[test]
364 fn test_use_latest_packet_option() {
365 let chain = Registry::new()
367 .with(
368 SenderReportBuilder::default()
369 .with_use_latest_packet()
370 .build(),
371 )
372 .build();
373
374 assert!(chain.use_latest_packet);
375
376 let chain_default = Registry::new()
378 .with(SenderReportBuilder::default().build())
379 .build();
380
381 assert!(!chain_default.use_latest_packet);
382 }
383
384 #[test]
385 fn test_use_latest_packet_combined_options() {
386 let chain = Registry::new()
388 .with(
389 SenderReportBuilder::default()
390 .with_interval(Duration::from_millis(250))
391 .with_use_latest_packet()
392 .build(),
393 )
394 .build();
395
396 assert_eq!(chain.interval, Duration::from_millis(250));
397 assert!(chain.use_latest_packet);
398 }
399
400 #[test]
401 fn test_sender_report_generation_on_timeout() {
402 let mut chain = Registry::new()
405 .with(
406 SenderReportBuilder::default()
407 .with_interval(Duration::from_secs(1))
408 .build(),
409 )
410 .build();
411
412 let info = StreamInfo {
414 ssrc: 123456,
415 clock_rate: 90000,
416 ..Default::default()
417 };
418 chain.bind_local_stream(&info);
419
420 let base_time = Instant::now();
421
422 for i in 0..5u16 {
424 let pkt = TaggedPacket {
425 now: base_time,
426 transport: Default::default(),
427 message: Packet::Rtp(rtp::Packet {
428 header: rtp::header::Header {
429 ssrc: 123456,
430 sequence_number: i,
431 timestamp: i as u32 * 3000,
432 ..Default::default()
433 },
434 payload: vec![0u8; 100].into(),
435 ..Default::default()
436 }),
437 };
438 chain.handle_write(pkt).unwrap();
439 chain.poll_write();
441 }
442
443 chain.handle_timeout(base_time).unwrap();
445
446 while chain.poll_write().is_some() {}
448
449 let later_time = base_time + Duration::from_secs(2);
451 chain.handle_timeout(later_time).unwrap();
452
453 let report = chain.poll_write();
455 assert!(report.is_some());
456
457 if let Some(tagged) = report {
458 if let Packet::Rtcp(rtcp_packets) = tagged.message {
459 assert_eq!(rtcp_packets.len(), 1);
460 let sr = rtcp_packets[0]
461 .as_any()
462 .downcast_ref::<rtcp::sender_report::SenderReport>()
463 .expect("Expected SenderReport");
464 assert_eq!(sr.ssrc, 123456);
465 assert_eq!(sr.packet_count, 5);
466 assert_eq!(sr.octet_count, 500);
467 } else {
468 panic!("Expected RTCP packet");
469 }
470 }
471 }
472
473 #[test]
474 fn test_sender_report_multiple_streams() {
475 let mut chain = Registry::new()
477 .with(
478 SenderReportBuilder::default()
479 .with_interval(Duration::from_secs(1))
480 .build(),
481 )
482 .build();
483
484 let info1 = StreamInfo {
486 ssrc: 111111,
487 clock_rate: 90000,
488 ..Default::default()
489 };
490 let info2 = StreamInfo {
491 ssrc: 222222,
492 clock_rate: 48000,
493 ..Default::default()
494 };
495 chain.bind_local_stream(&info1);
496 chain.bind_local_stream(&info2);
497
498 let base_time = Instant::now();
499
500 for i in 0..3u16 {
502 let pkt = TaggedPacket {
503 now: base_time,
504 transport: Default::default(),
505 message: Packet::Rtp(rtp::Packet {
506 header: rtp::header::Header {
507 ssrc: 111111,
508 sequence_number: i,
509 timestamp: i as u32 * 3000,
510 ..Default::default()
511 },
512 payload: vec![0u8; 50].into(),
513 ..Default::default()
514 }),
515 };
516 chain.handle_write(pkt).unwrap();
517 chain.poll_write();
518 }
519
520 for i in 0..7u16 {
522 let pkt = TaggedPacket {
523 now: base_time,
524 transport: Default::default(),
525 message: Packet::Rtp(rtp::Packet {
526 header: rtp::header::Header {
527 ssrc: 222222,
528 sequence_number: i,
529 timestamp: i as u32 * 960,
530 ..Default::default()
531 },
532 payload: vec![0u8; 200].into(),
533 ..Default::default()
534 }),
535 };
536 chain.handle_write(pkt).unwrap();
537 chain.poll_write();
538 }
539
540 let later_time = base_time + Duration::from_secs(2);
542 chain.handle_timeout(later_time).unwrap();
543
544 let mut ssrcs = vec![];
546 let mut packet_counts = vec![];
547 let mut octet_counts = vec![];
548
549 while let Some(tagged) = chain.poll_write() {
550 if let Packet::Rtcp(rtcp_packets) = tagged.message {
551 for rtcp_pkt in rtcp_packets {
552 if let Some(sr) = rtcp_pkt
553 .as_any()
554 .downcast_ref::<rtcp::sender_report::SenderReport>()
555 {
556 ssrcs.push(sr.ssrc);
557 packet_counts.push(sr.packet_count);
558 octet_counts.push(sr.octet_count);
559 }
560 }
561 }
562 }
563
564 assert_eq!(ssrcs.len(), 2);
565 assert!(ssrcs.contains(&111111));
566 assert!(ssrcs.contains(&222222));
567
568 let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
570 assert_eq!(packet_counts[idx1], 3);
571 assert_eq!(octet_counts[idx1], 150);
572
573 let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
575 assert_eq!(packet_counts[idx2], 7);
576 assert_eq!(octet_counts[idx2], 1400);
577 }
578
579 #[test]
580 fn test_sender_report_unbind_stream() {
581 let mut chain = Registry::new()
583 .with(
584 SenderReportBuilder::default()
585 .with_interval(Duration::from_secs(1))
586 .build(),
587 )
588 .build();
589
590 let info = StreamInfo {
591 ssrc: 123456,
592 clock_rate: 90000,
593 ..Default::default()
594 };
595 chain.bind_local_stream(&info);
596
597 let base_time = Instant::now();
598
599 let pkt = TaggedPacket {
601 now: base_time,
602 transport: Default::default(),
603 message: Packet::Rtp(rtp::Packet {
604 header: rtp::header::Header {
605 ssrc: 123456,
606 sequence_number: 0,
607 timestamp: 0,
608 ..Default::default()
609 },
610 payload: vec![0u8; 100].into(),
611 ..Default::default()
612 }),
613 };
614 chain.handle_write(pkt).unwrap();
615 chain.poll_write();
616
617 chain.unbind_local_stream(&info);
619
620 let later_time = base_time + Duration::from_secs(2);
622 chain.handle_timeout(later_time).unwrap();
623
624 assert!(chain.poll_write().is_none());
626 }
627
628 #[test]
629 fn test_poll_timeout_returns_earliest() {
630 let mut chain = Registry::new()
632 .with(
633 SenderReportBuilder::default()
634 .with_interval(Duration::from_secs(5))
635 .build(),
636 )
637 .build();
638
639 let timeout = chain.poll_timeout();
641 assert!(timeout.is_some());
642 }
643}