use super::sender_stream::SenderStream;
use crate::stream_info::StreamInfo;
use crate::{Interceptor, Packet, TaggedPacket, interceptor};
use rtcp::header::PacketType;
use shared::TransportContext;
use shared::error::Error;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::time::{Duration, Instant};
pub struct SenderReportBuilder<P> {
interval: Duration,
use_latest_packet: bool,
_phantom: PhantomData<P>,
}
impl<P> Default for SenderReportBuilder<P> {
fn default() -> Self {
Self {
interval: Duration::from_secs(1),
use_latest_packet: false,
_phantom: PhantomData,
}
}
}
impl<P> SenderReportBuilder<P> {
pub fn new() -> Self {
Self::default()
}
pub fn with_interval(mut self, interval: Duration) -> Self {
self.interval = interval;
self
}
pub fn with_use_latest_packet(mut self) -> Self {
self.use_latest_packet = true;
self
}
pub fn build(self) -> impl FnOnce(P) -> SenderReportInterceptor<P> {
move |inner| SenderReportInterceptor::new(inner, self.interval, self.use_latest_packet)
}
}
#[derive(Interceptor)]
pub struct SenderReportInterceptor<P> {
#[next]
inner: P,
interval: Duration,
eto: Instant,
use_latest_packet: bool,
streams: HashMap<u32, SenderStream>,
read_queue: VecDeque<TaggedPacket>,
write_queue: VecDeque<TaggedPacket>,
}
impl<P> SenderReportInterceptor<P> {
fn new(inner: P, interval: Duration, use_latest_packet: bool) -> Self {
Self {
inner,
interval,
eto: Instant::now(),
use_latest_packet,
streams: HashMap::new(),
read_queue: VecDeque::new(),
write_queue: VecDeque::new(),
}
}
fn should_filter(packet_type: PacketType) -> bool {
packet_type == PacketType::ReceiverReport
|| (packet_type == PacketType::TransportSpecificFeedback)
}
fn inner(&self) -> &P {
&self.inner
}
fn inner_mut(&mut self) -> &mut P {
&mut self.inner
}
}
#[interceptor]
impl<P: Interceptor> SenderReportInterceptor<P> {
#[overrides]
fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
if let Packet::Rtp(rtp_packet) = &msg.message
&& let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
{
stream.process_rtp(msg.now, rtp_packet);
}
self.inner.handle_write(msg)
}
#[overrides]
fn poll_write(&mut self) -> Option<Self::Wout> {
if let Some(pkt) = self.write_queue.pop_front() {
return Some(pkt);
}
self.inner.poll_write()
}
#[overrides]
fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
if self.eto <= now {
self.eto = now + self.interval;
for stream in self.streams.values_mut() {
let rr = stream.generate_report(now);
self.write_queue.push_back(TaggedPacket {
now,
transport: TransportContext::default(),
message: Packet::Rtcp(vec![Box::new(rr)]),
});
}
}
self.inner.handle_timeout(now)
}
#[overrides]
fn poll_timeout(&mut self) -> Option<Self::Time> {
if let Some(eto) = self.inner.poll_timeout()
&& eto < self.eto
{
Some(eto)
} else {
Some(self.eto)
}
}
#[overrides]
fn bind_local_stream(&mut self, info: &StreamInfo) {
let stream = SenderStream::new(info.ssrc, info.clock_rate, self.use_latest_packet);
self.streams.insert(info.ssrc, stream);
self.inner.bind_local_stream(info);
}
#[overrides]
fn unbind_local_stream(&mut self, info: &StreamInfo) {
self.streams.remove(&info.ssrc);
self.inner.unbind_local_stream(info);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{NoopInterceptor, Registry};
use sansio::Protocol;
fn dummy_rtp_packet() -> TaggedPacket {
TaggedPacket {
now: Instant::now(),
transport: Default::default(),
message: crate::Packet::Rtp(rtp::Packet::default()),
}
}
#[test]
fn test_sender_report_builder_default() {
let chain = Registry::new()
.with(SenderReportBuilder::default().build())
.build();
assert_eq!(chain.interval, Duration::from_secs(1));
}
#[test]
fn test_sender_report_builder_with_custom_interval() {
let chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_millis(500))
.build(),
)
.build();
assert_eq!(chain.interval, Duration::from_millis(500));
}
#[test]
fn test_sender_report_chain_handle_read_write() {
let mut chain = Registry::new()
.with(SenderReportBuilder::default().build())
.build();
let pkt = dummy_rtp_packet();
chain.handle_read(pkt).unwrap();
assert!(chain.poll_read().is_some());
let pkt2 = dummy_rtp_packet();
let pkt2_message = pkt2.message.clone();
chain.handle_write(pkt2).unwrap();
assert_eq!(chain.poll_write().unwrap().message, pkt2_message);
}
#[test]
fn test_should_filter() {
assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
PacketType::ReceiverReport
));
assert!(SenderReportInterceptor::<NoopInterceptor>::should_filter(
PacketType::TransportSpecificFeedback
));
assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
PacketType::SenderReport
));
assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
PacketType::SourceDescription
));
assert!(!SenderReportInterceptor::<NoopInterceptor>::should_filter(
PacketType::Goodbye
));
}
#[test]
fn test_inner_access() {
let mut chain = Registry::new()
.with(SenderReportBuilder::default().build())
.build();
let _ = chain.inner();
let pkt = dummy_rtp_packet();
let pkt_message = pkt.message.clone();
chain.inner_mut().handle_write(pkt).unwrap();
assert_eq!(chain.inner_mut().poll_write().unwrap().message, pkt_message);
}
#[test]
fn test_use_latest_packet_option() {
let chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_use_latest_packet()
.build(),
)
.build();
assert!(chain.use_latest_packet);
let chain_default = Registry::new()
.with(SenderReportBuilder::default().build())
.build();
assert!(!chain_default.use_latest_packet);
}
#[test]
fn test_use_latest_packet_combined_options() {
let chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_millis(250))
.with_use_latest_packet()
.build(),
)
.build();
assert_eq!(chain.interval, Duration::from_millis(250));
assert!(chain.use_latest_packet);
}
#[test]
fn test_sender_report_generation_on_timeout() {
let mut chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_secs(1))
.build(),
)
.build();
let info = StreamInfo {
ssrc: 123456,
clock_rate: 90000,
..Default::default()
};
chain.bind_local_stream(&info);
let base_time = Instant::now();
for i in 0..5u16 {
let pkt = TaggedPacket {
now: base_time,
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc: 123456,
sequence_number: i,
timestamp: i as u32 * 3000,
..Default::default()
},
payload: vec![0u8; 100].into(),
..Default::default()
}),
};
chain.handle_write(pkt).unwrap();
chain.poll_write();
}
chain.handle_timeout(base_time).unwrap();
while chain.poll_write().is_some() {}
let later_time = base_time + Duration::from_secs(2);
chain.handle_timeout(later_time).unwrap();
let report = chain.poll_write();
assert!(report.is_some());
if let Some(tagged) = report {
if let Packet::Rtcp(rtcp_packets) = tagged.message {
assert_eq!(rtcp_packets.len(), 1);
let sr = rtcp_packets[0]
.as_any()
.downcast_ref::<rtcp::sender_report::SenderReport>()
.expect("Expected SenderReport");
assert_eq!(sr.ssrc, 123456);
assert_eq!(sr.packet_count, 5);
assert_eq!(sr.octet_count, 500);
} else {
panic!("Expected RTCP packet");
}
}
}
#[test]
fn test_sender_report_multiple_streams() {
let mut chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_secs(1))
.build(),
)
.build();
let info1 = StreamInfo {
ssrc: 111111,
clock_rate: 90000,
..Default::default()
};
let info2 = StreamInfo {
ssrc: 222222,
clock_rate: 48000,
..Default::default()
};
chain.bind_local_stream(&info1);
chain.bind_local_stream(&info2);
let base_time = Instant::now();
for i in 0..3u16 {
let pkt = TaggedPacket {
now: base_time,
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc: 111111,
sequence_number: i,
timestamp: i as u32 * 3000,
..Default::default()
},
payload: vec![0u8; 50].into(),
..Default::default()
}),
};
chain.handle_write(pkt).unwrap();
chain.poll_write();
}
for i in 0..7u16 {
let pkt = TaggedPacket {
now: base_time,
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc: 222222,
sequence_number: i,
timestamp: i as u32 * 960,
..Default::default()
},
payload: vec![0u8; 200].into(),
..Default::default()
}),
};
chain.handle_write(pkt).unwrap();
chain.poll_write();
}
let later_time = base_time + Duration::from_secs(2);
chain.handle_timeout(later_time).unwrap();
let mut ssrcs = vec![];
let mut packet_counts = vec![];
let mut octet_counts = vec![];
while let Some(tagged) = chain.poll_write() {
if let Packet::Rtcp(rtcp_packets) = tagged.message {
for rtcp_pkt in rtcp_packets {
if let Some(sr) = rtcp_pkt
.as_any()
.downcast_ref::<rtcp::sender_report::SenderReport>()
{
ssrcs.push(sr.ssrc);
packet_counts.push(sr.packet_count);
octet_counts.push(sr.octet_count);
}
}
}
}
assert_eq!(ssrcs.len(), 2);
assert!(ssrcs.contains(&111111));
assert!(ssrcs.contains(&222222));
let idx1 = ssrcs.iter().position(|&s| s == 111111).unwrap();
assert_eq!(packet_counts[idx1], 3);
assert_eq!(octet_counts[idx1], 150);
let idx2 = ssrcs.iter().position(|&s| s == 222222).unwrap();
assert_eq!(packet_counts[idx2], 7);
assert_eq!(octet_counts[idx2], 1400);
}
#[test]
fn test_sender_report_unbind_stream() {
let mut chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_secs(1))
.build(),
)
.build();
let info = StreamInfo {
ssrc: 123456,
clock_rate: 90000,
..Default::default()
};
chain.bind_local_stream(&info);
let base_time = Instant::now();
let pkt = TaggedPacket {
now: base_time,
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc: 123456,
sequence_number: 0,
timestamp: 0,
..Default::default()
},
payload: vec![0u8; 100].into(),
..Default::default()
}),
};
chain.handle_write(pkt).unwrap();
chain.poll_write();
chain.unbind_local_stream(&info);
let later_time = base_time + Duration::from_secs(2);
chain.handle_timeout(later_time).unwrap();
assert!(chain.poll_write().is_none());
}
#[test]
fn test_poll_timeout_returns_earliest() {
let mut chain = Registry::new()
.with(
SenderReportBuilder::default()
.with_interval(Duration::from_secs(5))
.build(),
)
.build();
let timeout = chain.poll_timeout();
assert!(timeout.is_some());
}
}