use super::stream_supports_twcc;
use crate::stream_info::StreamInfo;
use crate::{Interceptor, Packet, TaggedPacket, interceptor};
use shared::error::Error;
use shared::marshal::Marshal;
use std::collections::HashMap;
use std::marker::PhantomData;
pub struct TwccSenderBuilder<P> {
_phantom: PhantomData<P>,
}
impl<P> Default for TwccSenderBuilder<P> {
fn default() -> Self {
Self {
_phantom: PhantomData,
}
}
}
impl<P> TwccSenderBuilder<P> {
pub fn new() -> Self {
Self::default()
}
pub fn build(self) -> impl FnOnce(P) -> TwccSenderInterceptor<P> {
move |inner| TwccSenderInterceptor::new(inner)
}
}
struct LocalStream {
hdr_ext_id: u8,
}
#[derive(Interceptor)]
pub struct TwccSenderInterceptor<P> {
#[next]
inner: P,
next_sequence_number: u16,
streams: HashMap<u32, LocalStream>,
}
impl<P> TwccSenderInterceptor<P> {
fn new(inner: P) -> Self {
Self {
inner,
next_sequence_number: 0,
streams: HashMap::new(),
}
}
}
#[interceptor]
impl<P: Interceptor> TwccSenderInterceptor<P> {
#[overrides]
fn handle_write(&mut self, mut msg: TaggedPacket) -> Result<(), Self::Error> {
if let Packet::Rtp(ref mut rtp_packet) = msg.message
&& let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
{
let seq = self.next_sequence_number;
self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
transport_sequence: seq,
};
if let Ok(ext_data) = tcc_ext.marshal() {
let _ = rtp_packet
.header
.set_extension(stream.hdr_ext_id, ext_data.freeze());
}
}
self.inner.handle_write(msg)
}
#[overrides]
fn bind_local_stream(&mut self, info: &StreamInfo) {
if let Some(hdr_ext_id) = stream_supports_twcc(info) {
if hdr_ext_id != 0 {
self.streams.insert(info.ssrc, LocalStream { hdr_ext_id });
}
}
self.inner.bind_local_stream(info);
}
#[overrides]
fn unbind_local_stream(&mut self, info: &StreamInfo) {
self.streams.remove(&info.ssrc);
self.inner.unbind_local_stream(info);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Registry;
use crate::stream_info::RTPHeaderExtension;
use sansio::Protocol;
use shared::marshal::Unmarshal;
use std::time::Instant;
fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
TaggedPacket {
now: Instant::now(),
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc,
sequence_number: seq,
..Default::default()
},
payload: vec![].into(),
}),
}
}
#[test]
fn test_twcc_sender_builder_defaults() {
let chain = Registry::new()
.with(TwccSenderBuilder::default().build())
.build();
assert!(chain.streams.is_empty());
}
#[test]
fn test_twcc_sender_adds_extension() {
let mut chain = Registry::new()
.with(TwccSenderBuilder::new().build())
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_local_stream(&info);
let pkt1 = make_rtp_packet(12345, 1);
chain.handle_write(pkt1).unwrap();
let out1 = chain.poll_write().unwrap();
let pkt2 = make_rtp_packet(12345, 2);
chain.handle_write(pkt2).unwrap();
let out2 = chain.poll_write().unwrap();
if let Packet::Rtp(rtp1) = out1.message {
let ext = rtp1.header.get_extension(5);
assert!(ext.is_some());
let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
&mut ext.unwrap().as_ref(),
)
.unwrap();
assert_eq!(tcc.transport_sequence, 0);
} else {
panic!("Expected RTP packet");
}
if let Packet::Rtp(rtp2) = out2.message {
let ext = rtp2.header.get_extension(5);
assert!(ext.is_some());
let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
&mut ext.unwrap().as_ref(),
)
.unwrap();
assert_eq!(tcc.transport_sequence, 1);
} else {
panic!("Expected RTP packet");
}
}
#[test]
fn test_twcc_sender_no_extension_without_binding() {
let mut chain = Registry::new()
.with(TwccSenderBuilder::new().build())
.build();
let pkt = make_rtp_packet(12345, 1);
chain.handle_write(pkt).unwrap();
let out = chain.poll_write().unwrap();
if let Packet::Rtp(rtp) = out.message {
assert!(rtp.header.get_extension(5).is_none());
} else {
panic!("Expected RTP packet");
}
}
#[test]
fn test_twcc_sender_unbind_removes_stream() {
let mut chain = Registry::new()
.with(TwccSenderBuilder::new().build())
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_local_stream(&info);
assert!(chain.streams.contains_key(&12345));
chain.unbind_local_stream(&info);
assert!(!chain.streams.contains_key(&12345));
}
#[test]
fn test_twcc_sender_sequence_wraparound() {
let mut chain = Registry::new()
.with(TwccSenderBuilder::new().build())
.build();
let info = StreamInfo {
ssrc: 12345,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_local_stream(&info);
chain.next_sequence_number = 65534;
for expected_seq in [65534u16, 65535, 0, 1] {
let pkt = make_rtp_packet(12345, 1);
chain.handle_write(pkt).unwrap();
let out = chain.poll_write().unwrap();
if let Packet::Rtp(rtp) = out.message {
let ext = rtp.header.get_extension(5).unwrap();
let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
&mut ext.as_ref(),
)
.unwrap();
assert_eq!(tcc.transport_sequence, expected_seq);
}
}
}
#[test]
fn test_twcc_sender_multiple_streams_share_counter() {
let mut chain = Registry::new()
.with(TwccSenderBuilder::new().build())
.build();
let info1 = StreamInfo {
ssrc: 1111,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
let info2 = StreamInfo {
ssrc: 2222,
rtp_header_extensions: vec![RTPHeaderExtension {
uri: super::super::TRANSPORT_CC_URI.to_string(),
id: 5,
}],
..Default::default()
};
chain.bind_local_stream(&info1);
chain.bind_local_stream(&info2);
for (i, ssrc) in [1111u32, 2222, 1111, 2222].iter().enumerate() {
let pkt = make_rtp_packet(*ssrc, 1);
chain.handle_write(pkt).unwrap();
let out = chain.poll_write().unwrap();
if let Packet::Rtp(rtp) = out.message {
let ext = rtp.header.get_extension(5).unwrap();
let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
&mut ext.as_ref(),
)
.unwrap();
assert_eq!(tcc.transport_sequence, i as u16);
}
}
}
}