use super::send_buffer::SendBuffer;
use super::stream_supports_nack;
use crate::stream_info::StreamInfo;
use crate::{Interceptor, Packet, TaggedPacket, interceptor};
use shared::TransportContext;
use shared::error::Error;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::time::Instant;
pub struct NackResponderBuilder<P> {
size: u16,
_phantom: PhantomData<P>,
}
impl<P> Default for NackResponderBuilder<P> {
fn default() -> Self {
Self {
size: 1024,
_phantom: PhantomData,
}
}
}
impl<P> NackResponderBuilder<P> {
pub fn new() -> Self {
Self::default()
}
pub fn with_size(mut self, size: u16) -> Self {
self.size = size;
self
}
pub fn build(self) -> impl FnOnce(P) -> NackResponderInterceptor<P> {
move |inner| NackResponderInterceptor::new(inner, self.size)
}
}
struct LocalStream {
send_buffer: SendBuffer,
ssrc_rtx: Option<u32>,
payload_type_rtx: Option<u8>,
rtx_sequence_number: u16,
}
#[derive(Interceptor)]
pub struct NackResponderInterceptor<P> {
#[next]
inner: P,
size: u16,
streams: HashMap<u32, LocalStream>,
write_queue: VecDeque<TaggedPacket>,
}
impl<P> NackResponderInterceptor<P> {
fn new(inner: P, size: u16) -> Self {
Self {
inner,
size,
streams: HashMap::new(),
write_queue: VecDeque::new(),
}
}
fn handle_nack(
&mut self,
now: Instant,
nack: &rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack,
) {
let mut seqs_to_retransmit = Vec::new();
for nack_pair in &nack.nacks {
seqs_to_retransmit.push(nack_pair.packet_id);
for i in 0..16 {
if nack_pair.lost_packets & (1 << i) != 0 {
let seq = nack_pair.packet_id.wrapping_add(i + 1);
seqs_to_retransmit.push(seq);
}
}
}
let Some(stream) = self.streams.get_mut(&nack.media_ssrc) else {
return;
};
for seq in seqs_to_retransmit {
let Some(original_packet) = stream.send_buffer.get(seq) else {
continue;
};
let packet = if let (Some(ssrc_rtx), Some(pt_rtx)) =
(stream.ssrc_rtx, stream.payload_type_rtx)
{
let original_seq = original_packet.header.sequence_number;
let mut rtx_payload = Vec::with_capacity(2 + original_packet.payload.len());
rtx_payload.extend_from_slice(&original_seq.to_be_bytes());
rtx_payload.extend_from_slice(&original_packet.payload);
let rtx_seq = stream.rtx_sequence_number;
stream.rtx_sequence_number = stream.rtx_sequence_number.wrapping_add(1);
rtp::Packet {
header: rtp::header::Header {
ssrc: ssrc_rtx,
payload_type: pt_rtx,
sequence_number: rtx_seq,
timestamp: original_packet.header.timestamp,
marker: original_packet.header.marker,
..Default::default()
},
payload: rtx_payload.into(),
}
} else {
original_packet.clone()
};
self.write_queue.push_back(TaggedPacket {
now,
transport: TransportContext::default(),
message: Packet::Rtp(packet),
});
}
}
}
#[interceptor]
impl<P: Interceptor> NackResponderInterceptor<P> {
#[overrides]
fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
if let Packet::Rtcp(ref rtcp_packets) = msg.message {
for rtcp_packet in rtcp_packets {
if let Some(nack) = rtcp_packet
.as_any()
.downcast_ref::<rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack>()
{
self.handle_nack(msg.now, nack);
}
}
}
self.inner.handle_read(msg)
}
#[overrides]
fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
if let Packet::Rtp(ref rtp_packet) = msg.message
&& let Some(stream) = self.streams.get_mut(&rtp_packet.header.ssrc)
{
stream.send_buffer.add(rtp_packet.clone());
}
self.inner.handle_write(msg)
}
#[overrides]
fn poll_write(&mut self) -> Option<Self::Wout> {
if let Some(pkt) = self.write_queue.pop_front() {
return Some(pkt);
}
self.inner.poll_write()
}
#[overrides]
fn bind_local_stream(&mut self, info: &StreamInfo) {
if stream_supports_nack(info)
&& let Some(send_buffer) = SendBuffer::new(self.size)
{
self.streams.insert(
info.ssrc,
LocalStream {
send_buffer,
ssrc_rtx: info.ssrc_rtx,
payload_type_rtx: info.payload_type_rtx,
rtx_sequence_number: 0,
},
);
}
self.inner.bind_local_stream(info);
}
#[overrides]
fn unbind_local_stream(&mut self, info: &StreamInfo) {
self.streams.remove(&info.ssrc);
self.inner.unbind_local_stream(info);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Registry;
use crate::stream_info::RTCPFeedback;
use sansio::Protocol;
fn make_rtp_packet(ssrc: u32, seq: u16, payload: &[u8]) -> TaggedPacket {
TaggedPacket {
now: Instant::now(),
transport: Default::default(),
message: Packet::Rtp(rtp::Packet {
header: rtp::header::Header {
ssrc,
sequence_number: seq,
..Default::default()
},
payload: payload.to_vec().into(),
}),
}
}
fn make_nack_packet(sender_ssrc: u32, media_ssrc: u32, nacks: Vec<(u16, u16)>) -> TaggedPacket {
let nack_pairs: Vec<rtcp::transport_feedbacks::transport_layer_nack::NackPair> = nacks
.into_iter()
.map(|(packet_id, lost_packets)| {
rtcp::transport_feedbacks::transport_layer_nack::NackPair {
packet_id,
lost_packets,
}
})
.collect();
TaggedPacket {
now: Instant::now(),
transport: Default::default(),
message: Packet::Rtcp(vec![Box::new(
rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack {
sender_ssrc,
media_ssrc,
nacks: nack_pairs,
},
)]),
}
}
#[test]
fn test_nack_responder_builder_defaults() {
let chain = Registry::new()
.with(NackResponderBuilder::default().build())
.build();
assert_eq!(chain.size, 1024);
}
#[test]
fn test_nack_responder_builder_custom() {
let chain = Registry::new()
.with(NackResponderBuilder::new().with_size(2048).build())
.build();
assert_eq!(chain.size, 2048);
}
#[test]
fn test_nack_responder_retransmits_packet() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let info = StreamInfo {
ssrc: 12345,
clock_rate: 90000,
rtcp_feedback: vec![RTCPFeedback {
typ: "nack".to_string(),
parameter: "".to_string(),
}],
..Default::default()
};
chain.bind_local_stream(&info);
let now = Instant::now();
for seq in [10u16, 11, 12, 14, 15] {
let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
pkt.now = now;
chain.handle_write(pkt).unwrap();
chain.poll_write(); }
let mut nack = make_nack_packet(999, 12345, vec![(11, 0b1011)]);
nack.now = now;
chain.handle_read(nack).unwrap();
let mut retransmitted = Vec::new();
while let Some(pkt) = chain.poll_write() {
if let Packet::Rtp(rtp) = pkt.message {
retransmitted.push(rtp.header.sequence_number);
}
}
assert!(retransmitted.contains(&11));
assert!(retransmitted.contains(&12));
assert!(!retransmitted.contains(&13)); assert!(retransmitted.contains(&15));
}
#[test]
fn test_nack_responder_no_retransmit_without_binding() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let now = Instant::now();
for seq in [10u16, 11, 12] {
let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
pkt.now = now;
chain.handle_write(pkt).unwrap();
chain.poll_write();
}
let mut nack = make_nack_packet(999, 12345, vec![(11, 0)]);
nack.now = now;
chain.handle_read(nack).unwrap();
assert!(chain.poll_write().is_none());
}
#[test]
fn test_nack_responder_no_retransmit_expired_packet() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let info = StreamInfo {
ssrc: 12345,
clock_rate: 90000,
rtcp_feedback: vec![RTCPFeedback {
typ: "nack".to_string(),
parameter: "".to_string(),
}],
..Default::default()
};
chain.bind_local_stream(&info);
let now = Instant::now();
for seq in 0..16u16 {
let mut pkt = make_rtp_packet(12345, seq, &[seq as u8]);
pkt.now = now;
chain.handle_write(pkt).unwrap();
chain.poll_write();
}
let mut nack = make_nack_packet(999, 12345, vec![(0, 0)]);
nack.now = now;
chain.handle_read(nack).unwrap();
assert!(chain.poll_write().is_none());
let mut nack = make_nack_packet(999, 12345, vec![(10, 0)]);
nack.now = now;
chain.handle_read(nack).unwrap();
let pkt = chain.poll_write();
assert!(pkt.is_some());
if let Some(tagged) = pkt
&& let Packet::Rtp(rtp) = tagged.message
{
assert_eq!(rtp.header.sequence_number, 10);
}
}
#[test]
fn test_nack_responder_unbind_removes_stream() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let info = StreamInfo {
ssrc: 12345,
clock_rate: 90000,
rtcp_feedback: vec![RTCPFeedback {
typ: "nack".to_string(),
parameter: "".to_string(),
}],
..Default::default()
};
chain.bind_local_stream(&info);
assert!(chain.streams.contains_key(&12345));
chain.unbind_local_stream(&info);
assert!(!chain.streams.contains_key(&12345));
}
#[test]
fn test_nack_responder_no_nack_support() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let info = StreamInfo {
ssrc: 12345,
clock_rate: 90000,
rtcp_feedback: vec![], ..Default::default()
};
chain.bind_local_stream(&info);
assert!(!chain.streams.contains_key(&12345));
}
#[test]
fn test_nack_responder_passthrough() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let now = Instant::now();
let mut pkt = make_rtp_packet(12345, 1, &[1]);
pkt.now = now;
chain.handle_write(pkt).unwrap();
let out = chain.poll_write();
assert!(out.is_some());
let mut nack = make_nack_packet(999, 12345, vec![(1, 0)]);
nack.now = now;
chain.handle_read(nack).unwrap();
let out = chain.poll_read();
assert!(out.is_none());
}
#[test]
fn test_nack_responder_rfc4588_rtx() {
let mut chain = Registry::new()
.with(NackResponderBuilder::new().with_size(8).build())
.build();
let info = StreamInfo {
ssrc: 1,
ssrc_rtx: Some(2), payload_type: 96,
payload_type_rtx: Some(97), clock_rate: 90000,
rtcp_feedback: vec![RTCPFeedback {
typ: "nack".to_string(),
parameter: "".to_string(),
}],
..Default::default()
};
chain.bind_local_stream(&info);
let now = Instant::now();
for seq in [10u16, 11, 12, 14, 15] {
let mut pkt = make_rtp_packet(1, seq, &[seq as u8]);
pkt.now = now;
chain.handle_write(pkt).unwrap();
chain.poll_write(); }
let mut nack = make_nack_packet(999, 1, vec![(11, 0b1011)]);
nack.now = now;
chain.handle_read(nack).unwrap();
let mut rtx_seq = 0u16;
for expected_original_seq in [11u16, 12, 15] {
let pkt = chain.poll_write();
assert!(
pkt.is_some(),
"Expected RTX packet for seq {}",
expected_original_seq
);
if let Some(tagged) = pkt {
if let Packet::Rtp(rtp) = tagged.message {
assert_eq!(rtp.header.ssrc, 2, "RTX packet should use RTX SSRC");
assert_eq!(
rtp.header.payload_type, 97,
"RTX packet should use RTX payload type"
);
assert_eq!(
rtp.header.sequence_number, rtx_seq,
"RTX seq should be {}",
rtx_seq
);
rtx_seq += 1;
assert!(
rtp.payload.len() >= 2,
"RTX payload should have at least 2 bytes"
);
let original_seq_from_payload =
u16::from_be_bytes([rtp.payload[0], rtp.payload[1]]);
assert_eq!(
original_seq_from_payload, expected_original_seq,
"RTX payload should contain original seq"
);
assert_eq!(
rtp.payload[2..],
[expected_original_seq as u8],
"Original payload should follow seq number"
);
} else {
panic!("Expected RTP packet");
}
}
}
assert!(chain.poll_write().is_none());
}
}