mod responder_stream;
#[cfg(test)]
mod responder_test;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use responder_stream::ResponderStream;
use rtcp::transport_feedbacks::transport_layer_nack::TransportLayerNack;
use tokio::sync::Mutex;
use crate::error::Result;
use crate::nack::stream_support_nack;
use crate::stream_info::StreamInfo;
use crate::{
Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter,
};
#[derive(Default)]
pub struct ResponderBuilder {
log2_size: Option<u8>,
}
impl ResponderBuilder {
pub fn with_log2_size(mut self, log2_size: u8) -> ResponderBuilder {
self.log2_size = Some(log2_size);
self
}
}
impl InterceptorBuilder for ResponderBuilder {
fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
Ok(Arc::new(Responder {
internal: Arc::new(ResponderInternal {
log2_size: if let Some(log2_size) = self.log2_size {
log2_size
} else {
13 },
streams: Arc::new(Mutex::new(HashMap::new())),
}),
}))
}
}
pub struct ResponderInternal {
log2_size: u8,
streams: Arc<Mutex<HashMap<u32, Arc<ResponderStream>>>>,
}
impl ResponderInternal {
async fn resend_packets(
streams: Arc<Mutex<HashMap<u32, Arc<ResponderStream>>>>,
nack: TransportLayerNack,
) {
let stream = {
let m = streams.lock().await;
if let Some(stream) = m.get(&nack.media_ssrc) {
stream.clone()
} else {
return;
}
};
for n in &nack.nacks {
let stream2 = Arc::clone(&stream);
let f = Box::new(
move |seq: u16| -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> {
let stream3 = Arc::clone(&stream2);
Box::pin(async move {
if let Some(p) = stream3.get(seq).await {
let a = Attributes::new();
if let Err(err) = stream3.next_rtp_writer.write(&p, &a).await {
log::warn!("failed resending nacked packet: {}", err);
}
}
true
})
},
);
for packet_id in n.into_iter() {
if !f(packet_id).await {
return;
}
}
}
}
}
pub struct ResponderRtcpReader {
parent_rtcp_reader: Arc<dyn RTCPReader + Send + Sync>,
internal: Arc<ResponderInternal>,
}
#[async_trait]
impl RTCPReader for ResponderRtcpReader {
async fn read(
&self,
buf: &mut [u8],
a: &Attributes,
) -> Result<(Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>, Attributes)> {
let (pkts, attr) = { self.parent_rtcp_reader.read(buf, a).await? };
for p in &pkts {
if let Some(nack) = p.as_any().downcast_ref::<TransportLayerNack>() {
let nack = nack.clone();
let streams = Arc::clone(&self.internal.streams);
tokio::spawn(async move {
ResponderInternal::resend_packets(streams, nack).await;
});
}
}
Ok((pkts, attr))
}
}
pub struct Responder {
internal: Arc<ResponderInternal>,
}
impl Responder {
pub fn builder() -> ResponderBuilder {
ResponderBuilder::default()
}
}
#[async_trait]
impl Interceptor for Responder {
async fn bind_rtcp_reader(
&self,
reader: Arc<dyn RTCPReader + Send + Sync>,
) -> Arc<dyn RTCPReader + Send + Sync> {
Arc::new(ResponderRtcpReader {
internal: Arc::clone(&self.internal),
parent_rtcp_reader: reader,
}) as Arc<dyn RTCPReader + Send + Sync>
}
async fn bind_rtcp_writer(
&self,
writer: Arc<dyn RTCPWriter + Send + Sync>,
) -> Arc<dyn RTCPWriter + Send + Sync> {
writer
}
async fn bind_local_stream(
&self,
info: &StreamInfo,
writer: Arc<dyn RTPWriter + Send + Sync>,
) -> Arc<dyn RTPWriter + Send + Sync> {
if !stream_support_nack(info) {
return writer;
}
let stream = Arc::new(ResponderStream::new(self.internal.log2_size, writer));
{
let mut streams = self.internal.streams.lock().await;
streams.insert(info.ssrc, Arc::clone(&stream));
}
stream
}
async fn unbind_local_stream(&self, info: &StreamInfo) {
let mut streams = self.internal.streams.lock().await;
streams.remove(&info.ssrc);
}
async fn bind_remote_stream(
&self,
_info: &StreamInfo,
reader: Arc<dyn RTPReader + Send + Sync>,
) -> Arc<dyn RTPReader + Send + Sync> {
reader
}
async fn unbind_remote_stream(&self, _info: &StreamInfo) {}
async fn close(&self) -> Result<()> {
Ok(())
}
}