use super::recorder::Recorder;
use super::stream_supports_twcc;
use crate::stream_info::StreamInfo;
use crate::{Interceptor, Packet, TaggedPacket, interceptor};
use shared::TransportContext;
use shared::error::Error;
use shared::marshal::Unmarshal;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::time::{Duration, Instant};
const DEFAULT_INTERVAL: Duration = Duration::from_millis(100);
pub struct TwccReceiverBuilder<P> {
interval: Duration,
_phantom: PhantomData<P>,
}
impl<P> Default for TwccReceiverBuilder<P> {
fn default() -> Self {
Self {
interval: DEFAULT_INTERVAL,
_phantom: PhantomData,
}
}
}
impl<P> TwccReceiverBuilder<P> {
pub fn new() -> Self {
Self::default()
}
pub fn with_interval(mut self, interval: Duration) -> Self {
self.interval = interval;
self
}
pub fn build(self) -> impl FnOnce(P) -> TwccReceiverInterceptor<P> {
move |inner| TwccReceiverInterceptor::new(inner, self.interval)
}
}
struct RemoteStream {
hdr_ext_id: u8,
}
#[derive(Interceptor)]
pub struct TwccReceiverInterceptor<P> {
#[next]
inner: P,
interval: Duration,
start_time: Option<Instant>,
recorder: Option<Recorder>,
streams: HashMap<u32, RemoteStream>,
write_queue: VecDeque<TaggedPacket>,
next_timeout: Option<Instant>,
}
impl<P> TwccReceiverInterceptor<P> {
fn new(inner: P, interval: Duration) -> Self {
Self {
inner,
interval,
start_time: None,
recorder: None,
streams: HashMap::new(),
write_queue: VecDeque::new(),
next_timeout: None,
}
}
fn generate_feedback(&mut self, now: Instant) {
let Some(recorder) = self.recorder.as_mut() else {
return;
};
let packets = recorder.build_feedback_packet();
for pkt in packets {
self.write_queue.push_back(TaggedPacket {
now,
transport: TransportContext::default(),
message: Packet::Rtcp(vec![pkt]),
});
}
}
}
#[interceptor]
impl<P: Interceptor> TwccReceiverInterceptor<P> {
#[overrides]
fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
if let Packet::Rtp(ref rtp_packet) = msg.message
&& let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
{
if self.recorder.is_none() {
self.recorder = Some(Recorder::new(rand::random()));
self.start_time = Some(msg.now);
self.next_timeout = Some(msg.now + self.interval);
}
if let Some(ext_data) = rtp_packet.header.get_extension(stream.hdr_ext_id)
&& let Ok(tcc) =
rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
&mut ext_data.as_ref(),
)
{
let arrival_time = self
.start_time
.map(|start| msg.now.duration_since(start).as_micros() as i64)
.unwrap_or(0);
if let Some(recorder) = self.recorder.as_mut() {
recorder.record(rtp_packet.header.ssrc, tcc.transport_sequence, arrival_time);
}
}
}
self.inner.handle_read(msg)
}
#[overrides]
fn poll_write(&mut self) -> Option<Self::Wout> {
if let Some(pkt) = self.write_queue.pop_front() {
return Some(pkt);
}
self.inner.poll_write()
}
#[overrides]
fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
if let Some(timeout) = self.next_timeout
&& now >= timeout
{
self.generate_feedback(now);
self.next_timeout = Some(now + self.interval);
}
self.inner.handle_timeout(now)
}
#[overrides]
fn poll_timeout(&mut self) -> Option<Self::Time> {
let inner_timeout = self.inner.poll_timeout();
match (self.next_timeout, inner_timeout) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}
#[overrides]
fn bind_remote_stream(&mut self, info: &StreamInfo) {
if let Some(hdr_ext_id) = stream_supports_twcc(info) {
if hdr_ext_id != 0 {
self.streams.insert(info.ssrc, RemoteStream { hdr_ext_id });
}
}
self.inner.bind_remote_stream(info);
}
#[overrides]
fn unbind_remote_stream(&mut self, info: &StreamInfo) {
self.streams.remove(&info.ssrc);
self.inner.unbind_remote_stream(info);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Registry;
use crate::stream_info::RTPHeaderExtension;
use sansio::Protocol;
use shared::marshal::Marshal;
fn make_rtp_packet_with_twcc(
ssrc: u32,
seq: u16,
twcc_seq: u16,
hdr_ext_id: u8,
) -> rtp::Packet {
let mut pkt = rtp::Packet {
header: rtp::header::Header {
ssrc,
sequence_number: seq,
..Default::default()
},
payload: vec![].into(),
};
let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
transport_sequence: twcc_seq,
};
if let Ok(ext_data) = tcc_ext.marshal() {
let _ = pkt.header.set_extension(hdr_ext_id, ext_data.freeze());
}
pkt
}
#[test]
fn test_twcc_receiver_builder_defaults() {
let chain = Registry::new()
.with(TwccReceiverBuilder::default().build())
.build();
assert_eq!(chain.interval, DEFAULT_INTERVAL);
assert!(chain.recorder.is_none());
}
#[test]
fn test_twcc_receiver_builder_custom_interval() {
let chain = Registry::new()
.with(
TwccReceiverBuilder::new()
.with_interval(Duration::from_millis(50))
.build(),
)
.build();
assert_eq!(chain.interval, Duration::from_millis(50));
}
#[test]
fn test_twcc_receiver_records_packets() {
let mut chain = Registry::new()
.with(TwccReceiverBuilder::new().build())
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_remote_stream(&info);
let now = Instant::now();
let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
let pkt = TaggedPacket {
now,
transport: Default::default(),
message: Packet::Rtp(rtp),
};
chain.handle_read(pkt).unwrap();
assert!(chain.recorder.is_some());
assert!(chain.next_timeout.is_some());
}
#[test]
fn test_twcc_receiver_generates_feedback_on_timeout() {
let mut chain = Registry::new()
.with(
TwccReceiverBuilder::new()
.with_interval(Duration::from_millis(100))
.build(),
)
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_remote_stream(&info);
let start = Instant::now();
for i in 0..5u16 {
let rtp = make_rtp_packet_with_twcc(12345, i, i, 5);
let pkt = TaggedPacket {
now: start + Duration::from_millis(i as u64 * 10),
transport: Default::default(),
message: Packet::Rtp(rtp),
};
chain.handle_read(pkt).unwrap();
}
let timeout_time = start + Duration::from_millis(150);
chain.handle_timeout(timeout_time).unwrap();
let feedback = chain.poll_write();
assert!(feedback.is_some());
if let Some(tagged) = feedback {
if let Packet::Rtcp(rtcp_packets) = tagged.message {
assert!(!rtcp_packets.is_empty());
} else {
panic!("Expected RTCP packet");
}
}
}
#[test]
fn test_twcc_receiver_no_feedback_without_binding() {
let mut chain = Registry::new()
.with(TwccReceiverBuilder::new().build())
.build();
let now = Instant::now();
let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
let pkt = TaggedPacket {
now,
transport: Default::default(),
message: Packet::Rtp(rtp),
};
chain.handle_read(pkt).unwrap();
assert!(chain.recorder.is_none());
}
#[test]
fn test_twcc_receiver_unbind_removes_stream() {
let mut chain = Registry::new()
.with(TwccReceiverBuilder::new().build())
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_remote_stream(&info);
assert!(chain.streams.contains_key(&12345));
chain.unbind_remote_stream(&info);
assert!(!chain.streams.contains_key(&12345));
}
#[test]
fn test_twcc_receiver_poll_timeout() {
let mut chain = Registry::new()
.with(TwccReceiverBuilder::new().build())
.build();
assert!(chain.poll_timeout().is_none());
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_remote_stream(&info);
let now = Instant::now();
let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
let pkt = TaggedPacket {
now,
transport: Default::default(),
message: Packet::Rtp(rtp),
};
chain.handle_read(pkt).unwrap();
let timeout = chain.poll_timeout();
assert!(timeout.is_some());
}
}