interceptor/nack/generator/
mod.rs1mod generator_stream;
2#[cfg(test)]
3mod generator_test;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use async_trait::async_trait;
10use generator_stream::GeneratorStream;
11use rtcp::transport_feedbacks::transport_layer_nack::{
12 nack_pairs_from_sequence_numbers, TransportLayerNack,
13};
14use tokio::sync::{mpsc, Mutex};
15use waitgroup::WaitGroup;
16
17use crate::error::{Error, Result};
18use crate::nack::stream_support_nack;
19use crate::stream_info::StreamInfo;
20use crate::{
21 Attributes, Interceptor, InterceptorBuilder, RTCPReader, RTCPWriter, RTPReader, RTPWriter,
22};
23
24#[derive(Default)]
26pub struct GeneratorBuilder {
27 log2_size_minus_6: Option<u8>,
28 skip_last_n: Option<u16>,
29 interval: Option<Duration>,
30}
31
32impl GeneratorBuilder {
33 pub fn with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder {
36 self.log2_size_minus_6 = Some(log2_size_minus_6);
37 self
38 }
39
40 pub fn with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder {
43 self.skip_last_n = Some(skip_last_n);
44 self
45 }
46
47 pub fn with_interval(mut self, interval: Duration) -> GeneratorBuilder {
49 self.interval = Some(interval);
50 self
51 }
52}
53
54impl InterceptorBuilder for GeneratorBuilder {
55 fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
56 let (close_tx, close_rx) = mpsc::channel(1);
57 Ok(Arc::new(Generator {
58 internal: Arc::new(GeneratorInternal {
59 log2_size_minus_6: self.log2_size_minus_6.unwrap_or(13 - 6), skip_last_n: self.skip_last_n.unwrap_or_default(),
61 interval: if let Some(interval) = self.interval {
62 interval
63 } else {
64 Duration::from_millis(100)
65 },
66
67 streams: Mutex::new(HashMap::new()),
68 close_rx: Mutex::new(Some(close_rx)),
69 }),
70
71 wg: Mutex::new(Some(WaitGroup::new())),
72 close_tx: Mutex::new(Some(close_tx)),
73 }))
74 }
75}
76
77struct GeneratorInternal {
78 log2_size_minus_6: u8,
79 skip_last_n: u16,
80 interval: Duration,
81
82 streams: Mutex<HashMap<u32, Arc<GeneratorStream>>>,
83 close_rx: Mutex<Option<mpsc::Receiver<()>>>,
84}
85
86pub struct Generator {
88 internal: Arc<GeneratorInternal>,
89
90 pub(crate) wg: Mutex<Option<WaitGroup>>,
91 pub(crate) close_tx: Mutex<Option<mpsc::Sender<()>>>,
92}
93
94impl Generator {
95 pub fn builder() -> GeneratorBuilder {
97 GeneratorBuilder::default()
98 }
99
100 async fn is_closed(&self) -> bool {
101 let close_tx = self.close_tx.lock().await;
102 close_tx.is_none()
103 }
104
105 async fn run(
106 rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
107 internal: Arc<GeneratorInternal>,
108 ) -> Result<()> {
109 let mut ticker = tokio::time::interval(internal.interval);
110 let mut close_rx = internal
111 .close_rx
112 .lock()
113 .await
114 .take()
115 .ok_or(Error::ErrInvalidCloseRx)?;
116
117 let sender_ssrc = rand::random::<u32>();
118 loop {
119 tokio::select! {
120 _ = ticker.tick() =>{
121 let nacks = {
122 let mut nacks = vec![];
123 let streams = internal.streams.lock().await;
124 for (ssrc, stream) in streams.iter() {
125 let missing = stream.missing_seq_numbers(internal.skip_last_n);
126 if missing.is_empty(){
127 continue;
128 }
129
130 nacks.push(TransportLayerNack{
131 sender_ssrc,
132 media_ssrc: *ssrc,
133 nacks: nack_pairs_from_sequence_numbers(&missing),
134 });
135 }
136 nacks
137 };
138
139 let a = Attributes::new();
140 for nack in nacks{
141 if let Err(err) = rtcp_writer.write(&[Box::new(nack)], &a).await{
142 log::warn!("failed sending nack: {err}");
143 }
144 }
145 }
146 _ = close_rx.recv() =>{
147 return Ok(());
148 }
149 }
150 }
151 }
152}
153
154#[async_trait]
155impl Interceptor for Generator {
156 async fn bind_rtcp_reader(
159 &self,
160 reader: Arc<dyn RTCPReader + Send + Sync>,
161 ) -> Arc<dyn RTCPReader + Send + Sync> {
162 reader
163 }
164
165 async fn bind_rtcp_writer(
168 &self,
169 writer: Arc<dyn RTCPWriter + Send + Sync>,
170 ) -> Arc<dyn RTCPWriter + Send + Sync> {
171 if self.is_closed().await {
172 return writer;
173 }
174
175 let mut w = {
176 let wait_group = self.wg.lock().await;
177 wait_group.as_ref().map(|wg| wg.worker())
178 };
179 let writer2 = Arc::clone(&writer);
180 let internal = Arc::clone(&self.internal);
181 tokio::spawn(async move {
182 let _d = w.take();
183 if let Err(err) = Generator::run(writer2, internal).await {
184 log::warn!("bind_rtcp_writer NACK Generator::run got error: {err}");
185 }
186 });
187
188 writer
189 }
190
191 async fn bind_local_stream(
194 &self,
195 _info: &StreamInfo,
196 writer: Arc<dyn RTPWriter + Send + Sync>,
197 ) -> Arc<dyn RTPWriter + Send + Sync> {
198 writer
199 }
200
201 async fn unbind_local_stream(&self, _info: &StreamInfo) {}
203
204 async fn bind_remote_stream(
207 &self,
208 info: &StreamInfo,
209 reader: Arc<dyn RTPReader + Send + Sync>,
210 ) -> Arc<dyn RTPReader + Send + Sync> {
211 if !stream_support_nack(info) {
212 return reader;
213 }
214
215 let stream = Arc::new(GeneratorStream::new(
216 self.internal.log2_size_minus_6,
217 reader,
218 ));
219 {
220 let mut streams = self.internal.streams.lock().await;
221 streams.insert(info.ssrc, Arc::clone(&stream));
222 }
223
224 stream
225 }
226
227 async fn unbind_remote_stream(&self, info: &StreamInfo) {
229 let mut receive_logs = self.internal.streams.lock().await;
230 receive_logs.remove(&info.ssrc);
231 }
232
233 async fn close(&self) -> Result<()> {
235 {
236 let mut close_tx = self.close_tx.lock().await;
237 close_tx.take();
238 }
239
240 {
241 let mut wait_group = self.wg.lock().await;
242 if let Some(wg) = wait_group.take() {
243 wg.wait().await;
244 }
245 }
246
247 Ok(())
248 }
249}