interceptor/nack/generator/
mod.rs

1mod generator_stream;
2#[cfg(test)]
3mod generator_test;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use async_trait::async_trait;
10use generator_stream::GeneratorStream;
11use rtcp::transport_feedbacks::transport_layer_nack::{
12    nack_pairs_from_sequence_numbers, TransportLayerNack,
13};
14use tokio::sync::{mpsc, Mutex};
15use waitgroup::WaitGroup;
16
17use crate::error::{Error, Result};
18use crate::nack::stream_support_nack;
19use crate::stream_info::StreamInfo;
20use crate::{
21    Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter,
22};
23
24/// GeneratorBuilder can be used to configure Generator Interceptor
25#[derive(Default)]
26pub struct GeneratorBuilder {
27    log2_size_minus_6: Option<u8>,
28    skip_last_n: Option<u16>,
29    interval: Option<Duration>,
30}
31
32impl GeneratorBuilder {
33    /// with_size sets the size of the interceptor.
34    /// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768
35    pub fn with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder {
36        self.log2_size_minus_6 = Some(log2_size_minus_6);
37        self
38    }
39
40    /// with_skip_last_n sets the number of packets (n-1 packets before the last received packets) to ignore when generating
41    /// nack requests.
42    pub fn with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder {
43        self.skip_last_n = Some(skip_last_n);
44        self
45    }
46
47    /// with_interval sets the nack send interval for the interceptor
48    pub fn with_interval(mut self, interval: Duration) -> GeneratorBuilder {
49        self.interval = Some(interval);
50        self
51    }
52}
53
54impl InterceptorBuilder for GeneratorBuilder {
55    fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
56        let (close_tx, close_rx) = mpsc::channel(1);
57        Ok(Arc::new(Generator {
58            internal: Arc::new(GeneratorInternal {
59                log2_size_minus_6: self.log2_size_minus_6.unwrap_or(13 - 6), // 8192 = 1 << 13
60                skip_last_n: self.skip_last_n.unwrap_or_default(),
61                interval: if let Some(interval) = self.interval {
62                    interval
63                } else {
64                    Duration::from_millis(100)
65                },
66
67                streams: Mutex::new(HashMap::new()),
68                close_rx: Mutex::new(Some(close_rx)),
69            }),
70
71            wg: Mutex::new(Some(WaitGroup::new())),
72            close_tx: Mutex::new(Some(close_tx)),
73        }))
74    }
75}
76
77struct GeneratorInternal {
78    log2_size_minus_6: u8,
79    skip_last_n: u16,
80    interval: Duration,
81
82    streams: Mutex<HashMap<u32, Arc<GeneratorStream>>>,
83    close_rx: Mutex<Option<mpsc::Receiver<()>>>,
84}
85
86/// Generator interceptor generates nack feedback messages.
87pub struct Generator {
88    internal: Arc<GeneratorInternal>,
89
90    pub(crate) wg: Mutex<Option<WaitGroup>>,
91    pub(crate) close_tx: Mutex<Option<mpsc::Sender<()>>>,
92}
93
94impl Generator {
95    /// builder returns a new GeneratorBuilder.
96    pub fn builder() -> GeneratorBuilder {
97        GeneratorBuilder::default()
98    }
99
100    async fn is_closed(&self) -> bool {
101        let close_tx = self.close_tx.lock().await;
102        close_tx.is_none()
103    }
104
105    async fn run(
106        rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
107        internal: Arc<GeneratorInternal>,
108    ) -> Result<()> {
109        let mut ticker = tokio::time::interval(internal.interval);
110        let mut close_rx = internal
111            .close_rx
112            .lock()
113            .await
114            .take()
115            .ok_or(Error::ErrInvalidCloseRx)?;
116
117        let sender_ssrc = rand::random::<u32>();
118        loop {
119            tokio::select! {
120                _ = ticker.tick() =>{
121                    let nacks = {
122                        let mut nacks = vec![];
123                        let streams = internal.streams.lock().await;
124                        for (ssrc, stream) in streams.iter() {
125                            let missing = stream.missing_seq_numbers(internal.skip_last_n);
126                            if missing.is_empty(){
127                                continue;
128                            }
129
130                            nacks.push(TransportLayerNack{
131                                sender_ssrc,
132                                media_ssrc: *ssrc,
133                                nacks:  nack_pairs_from_sequence_numbers(&missing),
134                            });
135                        }
136                        nacks
137                    };
138
139                    let a = Attributes::new();
140                    for nack in nacks{
141                        if let Err(err) = rtcp_writer.write(&[Box::new(nack)], &a).await{
142                            log::warn!("failed sending nack: {err}");
143                        }
144                    }
145                }
146                _ = close_rx.recv() =>{
147                    return Ok(());
148                }
149            }
150        }
151    }
152}
153
154#[async_trait]
155impl Interceptor for Generator {
156    /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might
157    /// change in the future. The returned method will be called once per packet batch.
158    async fn bind_rtcp_reader(
159        &self,
160        reader: Arc<dyn RTCPReader + Send + Sync>,
161    ) -> Arc<dyn RTCPReader + Send + Sync> {
162        reader
163    }
164
165    /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method
166    /// will be called once per packet batch.
167    async fn bind_rtcp_writer(
168        &self,
169        writer: Arc<dyn RTCPWriter + Send + Sync>,
170    ) -> Arc<dyn RTCPWriter + Send + Sync> {
171        if self.is_closed().await {
172            return writer;
173        }
174
175        let mut w = {
176            let wait_group = self.wg.lock().await;
177            wait_group.as_ref().map(|wg| wg.worker())
178        };
179        let writer2 = Arc::clone(&writer);
180        let internal = Arc::clone(&self.internal);
181        tokio::spawn(async move {
182            let _d = w.take();
183            if let Err(err) = Generator::run(writer2, internal).await {
184                log::warn!("bind_rtcp_writer NACK Generator::run got error: {err}");
185            }
186        });
187
188        writer
189    }
190
191    /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
192    /// will be called once per rtp packet.
193    async fn bind_local_stream(
194        &self,
195        _info: &StreamInfo,
196        writer: Arc<dyn RTPWriter + Send + Sync>,
197    ) -> Arc<dyn RTPWriter + Send + Sync> {
198        writer
199    }
200
201    /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
202    async fn unbind_local_stream(&self, _info: &StreamInfo) {}
203
204    /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
205    /// will be called once per rtp packet.
206    async fn bind_remote_stream(
207        &self,
208        info: &StreamInfo,
209        reader: Arc<dyn RTPReader + Send + Sync>,
210    ) -> Arc<dyn RTPReader + Send + Sync> {
211        if !stream_support_nack(info) {
212            return reader;
213        }
214
215        let stream = Arc::new(GeneratorStream::new(
216            self.internal.log2_size_minus_6,
217            reader,
218        ));
219        {
220            let mut streams = self.internal.streams.lock().await;
221            streams.insert(info.ssrc, Arc::clone(&stream));
222        }
223
224        stream
225    }
226
227    /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
228    async fn unbind_remote_stream(&self, info: &StreamInfo) {
229        let mut receive_logs = self.internal.streams.lock().await;
230        receive_logs.remove(&info.ssrc);
231    }
232
233    /// close closes the Interceptor, cleaning up any data if necessary.
234    async fn close(&self) -> Result<()> {
235        {
236            let mut close_tx = self.close_tx.lock().await;
237            close_tx.take();
238        }
239
240        {
241            let mut wait_group = self.wg.lock().await;
242            if let Some(wg) = wait_group.take() {
243                wg.wait().await;
244            }
245        }
246
247        Ok(())
248    }
249}