mod generator_stream;
#[cfg(test)]
mod generator_test;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use generator_stream::GeneratorStream;
use rtcp::transport_feedbacks::transport_layer_nack::{
nack_pairs_from_sequence_numbers, TransportLayerNack,
};
use tokio::sync::{mpsc, Mutex};
use waitgroup::WaitGroup;
use crate::error::{Error, Result};
use crate::nack::stream_support_nack;
use crate::stream_info::StreamInfo;
use crate::{
Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter,
};
#[derive(Default)]
pub struct GeneratorBuilder {
log2_size_minus_6: Option<u8>,
skip_last_n: Option<u16>,
interval: Option<Duration>,
}
impl GeneratorBuilder {
pub fn with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder {
self.log2_size_minus_6 = Some(log2_size_minus_6);
self
}
pub fn with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder {
self.skip_last_n = Some(skip_last_n);
self
}
pub fn with_interval(mut self, interval: Duration) -> GeneratorBuilder {
self.interval = Some(interval);
self
}
}
impl InterceptorBuilder for GeneratorBuilder {
fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
let (close_tx, close_rx) = mpsc::channel(1);
Ok(Arc::new(Generator {
internal: Arc::new(GeneratorInternal {
log2_size_minus_6: self.log2_size_minus_6.unwrap_or(13 - 6), skip_last_n: self.skip_last_n.unwrap_or_default(),
interval: if let Some(interval) = self.interval {
interval
} else {
Duration::from_millis(100)
},
streams: Mutex::new(HashMap::new()),
close_rx: Mutex::new(Some(close_rx)),
}),
wg: Mutex::new(Some(WaitGroup::new())),
close_tx: Mutex::new(Some(close_tx)),
}))
}
}
struct GeneratorInternal {
log2_size_minus_6: u8,
skip_last_n: u16,
interval: Duration,
streams: Mutex<HashMap<u32, Arc<GeneratorStream>>>,
close_rx: Mutex<Option<mpsc::Receiver<()>>>,
}
pub struct Generator {
internal: Arc<GeneratorInternal>,
pub(crate) wg: Mutex<Option<WaitGroup>>,
pub(crate) close_tx: Mutex<Option<mpsc::Sender<()>>>,
}
impl Generator {
pub fn builder() -> GeneratorBuilder {
GeneratorBuilder::default()
}
async fn is_closed(&self) -> bool {
let close_tx = self.close_tx.lock().await;
close_tx.is_none()
}
async fn run(
rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
internal: Arc<GeneratorInternal>,
) -> Result<()> {
let mut ticker = tokio::time::interval(internal.interval);
let mut close_rx = internal
.close_rx
.lock()
.await
.take()
.ok_or(Error::ErrInvalidCloseRx)?;
let sender_ssrc = rand::random::<u32>();
loop {
tokio::select! {
_ = ticker.tick() =>{
let nacks = {
let mut nacks = vec![];
let streams = internal.streams.lock().await;
for (ssrc, stream) in streams.iter() {
let missing = stream.missing_seq_numbers(internal.skip_last_n);
if missing.is_empty(){
continue;
}
nacks.push(TransportLayerNack{
sender_ssrc,
media_ssrc: *ssrc,
nacks: nack_pairs_from_sequence_numbers(&missing),
});
}
nacks
};
let a = Attributes::new();
for nack in nacks{
if let Err(err) = rtcp_writer.write(&[Box::new(nack)], &a).await{
log::warn!("failed sending nack: {err}");
}
}
}
_ = close_rx.recv() =>{
return Ok(());
}
}
}
}
}
#[async_trait]
impl Interceptor for Generator {
async fn bind_rtcp_reader(
&self,
reader: Arc<dyn RTCPReader + Send + Sync>,
) -> Arc<dyn RTCPReader + Send + Sync> {
reader
}
async fn bind_rtcp_writer(
&self,
writer: Arc<dyn RTCPWriter + Send + Sync>,
) -> Arc<dyn RTCPWriter + Send + Sync> {
if self.is_closed().await {
return writer;
}
let mut w = {
let wait_group = self.wg.lock().await;
wait_group.as_ref().map(|wg| wg.worker())
};
let writer2 = Arc::clone(&writer);
let internal = Arc::clone(&self.internal);
tokio::spawn(async move {
let _d = w.take();
if let Err(err) = Generator::run(writer2, internal).await {
log::warn!("bind_rtcp_writer NACK Generator::run got error: {err}");
}
});
writer
}
async fn bind_local_stream(
&self,
_info: &StreamInfo,
writer: Arc<dyn RTPWriter + Send + Sync>,
) -> Arc<dyn RTPWriter + Send + Sync> {
writer
}
async fn unbind_local_stream(&self, _info: &StreamInfo) {}
async fn bind_remote_stream(
&self,
info: &StreamInfo,
reader: Arc<dyn RTPReader + Send + Sync>,
) -> Arc<dyn RTPReader + Send + Sync> {
if !stream_support_nack(info) {
return reader;
}
let stream = Arc::new(GeneratorStream::new(
self.internal.log2_size_minus_6,
reader,
));
{
let mut streams = self.internal.streams.lock().await;
streams.insert(info.ssrc, Arc::clone(&stream));
}
stream
}
async fn unbind_remote_stream(&self, info: &StreamInfo) {
let mut receive_logs = self.internal.streams.lock().await;
receive_logs.remove(&info.ssrc);
}
async fn close(&self) -> Result<()> {
{
let mut close_tx = self.close_tx.lock().await;
close_tx.take();
}
{
let mut wait_group = self.wg.lock().await;
if let Some(wg) = wait_group.take() {
wg.wait().await;
}
}
Ok(())
}
}