interceptor/twcc/receiver/
mod.rs

1mod receiver_stream;
2#[cfg(test)]
3mod receiver_test;
4
5use std::time::Duration;
6
7use receiver_stream::ReceiverStream;
8use rtp::extension::transport_cc_extension::TransportCcExtension;
9use tokio::sync::{mpsc, Mutex};
10use tokio::time::MissedTickBehavior;
11use util::Unmarshal;
12use waitgroup::WaitGroup;
13
14use crate::twcc::sender::TRANSPORT_CC_URI;
15use crate::twcc::Recorder;
16use crate::*;
17
18/// ReceiverBuilder is a InterceptorBuilder for a SenderInterceptor
19#[derive(Default)]
20pub struct ReceiverBuilder {
21    interval: Option<Duration>,
22}
23
24impl ReceiverBuilder {
25    /// with_interval sets send interval for the interceptor.
26    pub fn with_interval(mut self, interval: Duration) -> ReceiverBuilder {
27        self.interval = Some(interval);
28        self
29    }
30}
31
32impl InterceptorBuilder for ReceiverBuilder {
33    fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
34        let (close_tx, close_rx) = mpsc::channel(1);
35        let (packet_chan_tx, packet_chan_rx) = mpsc::channel(1);
36        Ok(Arc::new(Receiver {
37            internal: Arc::new(ReceiverInternal {
38                interval: if let Some(interval) = &self.interval {
39                    *interval
40                } else {
41                    Duration::from_millis(100)
42                },
43                recorder: Mutex::new(Recorder::default()),
44                packet_chan_rx: Mutex::new(Some(packet_chan_rx)),
45                streams: Mutex::new(HashMap::new()),
46                close_rx: Mutex::new(Some(close_rx)),
47            }),
48            start_time: tokio::time::Instant::now(),
49            packet_chan_tx,
50            wg: Mutex::new(Some(WaitGroup::new())),
51            close_tx: Mutex::new(Some(close_tx)),
52        }))
53    }
54}
55
56struct Packet {
57    hdr: rtp::header::Header,
58    sequence_number: u16,
59    arrival_time: i64,
60    ssrc: u32,
61}
62
63struct ReceiverInternal {
64    interval: Duration,
65    recorder: Mutex<Recorder>,
66    packet_chan_rx: Mutex<Option<mpsc::Receiver<Packet>>>,
67    streams: Mutex<HashMap<u32, Arc<ReceiverStream>>>,
68    close_rx: Mutex<Option<mpsc::Receiver<()>>>,
69}
70
71/// Receiver sends transport-wide congestion control reports as specified in:
72/// <https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01>
73pub struct Receiver {
74    internal: Arc<ReceiverInternal>,
75
76    // we use tokio's Instant because it makes testing easier via `tokio::time::advance`.
77    start_time: tokio::time::Instant,
78    packet_chan_tx: mpsc::Sender<Packet>,
79
80    wg: Mutex<Option<WaitGroup>>,
81    close_tx: Mutex<Option<mpsc::Sender<()>>>,
82}
83
84impl Receiver {
85    /// builder returns a new ReceiverBuilder.
86    pub fn builder() -> ReceiverBuilder {
87        ReceiverBuilder::default()
88    }
89
90    async fn is_closed(&self) -> bool {
91        let close_tx = self.close_tx.lock().await;
92        close_tx.is_none()
93    }
94
95    async fn run(
96        rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
97        internal: Arc<ReceiverInternal>,
98    ) -> Result<()> {
99        let mut close_rx = {
100            let mut close_rx = internal.close_rx.lock().await;
101            if let Some(close_rx) = close_rx.take() {
102                close_rx
103            } else {
104                return Err(Error::ErrInvalidCloseRx);
105            }
106        };
107        let mut packet_chan_rx = {
108            let mut packet_chan_rx = internal.packet_chan_rx.lock().await;
109            if let Some(packet_chan_rx) = packet_chan_rx.take() {
110                packet_chan_rx
111            } else {
112                return Err(Error::ErrInvalidPacketRx);
113            }
114        };
115
116        let a = Attributes::new();
117        let mut ticker = tokio::time::interval(internal.interval);
118        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
119        loop {
120            tokio::select! {
121                _ = close_rx.recv() =>{
122                    return Ok(());
123                }
124                p = packet_chan_rx.recv() => {
125                    if let Some(p) = p {
126                        let mut recorder = internal.recorder.lock().await;
127                        recorder.record(p.ssrc, p.sequence_number, p.arrival_time);
128                    }
129                }
130                _ = ticker.tick() =>{
131                    // build and send twcc
132                    let pkts = {
133                        let mut recorder = internal.recorder.lock().await;
134                        recorder.build_feedback_packet()
135                    };
136
137                    if pkts.is_empty() {
138                        continue;
139                    }
140
141                    if let Err(err) = rtcp_writer.write(&pkts, &a).await{
142                        log::error!("rtcp_writer.write got err: {}", err);
143                    }
144                }
145            }
146        }
147    }
148}
149
150#[async_trait]
151impl Interceptor for Receiver {
152    /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might
153    /// change in the future. The returned method will be called once per packet batch.
154    async fn bind_rtcp_reader(
155        &self,
156        reader: Arc<dyn RTCPReader + Send + Sync>,
157    ) -> Arc<dyn RTCPReader + Send + Sync> {
158        reader
159    }
160
161    /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method
162    /// will be called once per packet batch.
163    async fn bind_rtcp_writer(
164        &self,
165        writer: Arc<dyn RTCPWriter + Send + Sync>,
166    ) -> Arc<dyn RTCPWriter + Send + Sync> {
167        if self.is_closed().await {
168            return writer;
169        }
170
171        {
172            let mut recorder = self.internal.recorder.lock().await;
173            *recorder = Recorder::new(rand::random::<u32>());
174        }
175
176        let mut w = {
177            let wait_group = self.wg.lock().await;
178            wait_group.as_ref().map(|wg| wg.worker())
179        };
180        let writer2 = Arc::clone(&writer);
181        let internal = Arc::clone(&self.internal);
182        tokio::spawn(async move {
183            let _d = w.take();
184            if let Err(err) = Receiver::run(writer2, internal).await {
185                log::warn!("bind_rtcp_writer TWCC Sender::run got error: {}", err);
186            }
187        });
188
189        writer
190    }
191
192    /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
193    /// will be called once per rtp packet.
194    async fn bind_local_stream(
195        &self,
196        _info: &StreamInfo,
197        writer: Arc<dyn RTPWriter + Send + Sync>,
198    ) -> Arc<dyn RTPWriter + Send + Sync> {
199        writer
200    }
201
202    /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
203    async fn unbind_local_stream(&self, _info: &StreamInfo) {}
204
205    /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
206    /// will be called once per rtp packet.
207    async fn bind_remote_stream(
208        &self,
209        info: &StreamInfo,
210        reader: Arc<dyn RTPReader + Send + Sync>,
211    ) -> Arc<dyn RTPReader + Send + Sync> {
212        let mut hdr_ext_id = 0u8;
213        for e in &info.rtp_header_extensions {
214            if e.uri == TRANSPORT_CC_URI {
215                hdr_ext_id = e.id as u8;
216                break;
217            }
218        }
219        if hdr_ext_id == 0 {
220            // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID
221            return reader;
222        }
223
224        let stream = Arc::new(ReceiverStream::new(
225            reader,
226            hdr_ext_id,
227            info.ssrc,
228            self.packet_chan_tx.clone(),
229            self.start_time,
230        ));
231
232        {
233            let mut streams = self.internal.streams.lock().await;
234            streams.insert(info.ssrc, Arc::clone(&stream));
235        }
236
237        stream
238    }
239
240    /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
241    async fn unbind_remote_stream(&self, info: &StreamInfo) {
242        let mut streams = self.internal.streams.lock().await;
243        streams.remove(&info.ssrc);
244    }
245
246    /// close closes the Interceptor, cleaning up any data if necessary.
247    async fn close(&self) -> Result<()> {
248        {
249            let mut close_tx = self.close_tx.lock().await;
250            close_tx.take();
251        }
252
253        {
254            let mut wait_group = self.wg.lock().await;
255            if let Some(wg) = wait_group.take() {
256                wg.wait().await;
257            }
258        }
259
260        Ok(())
261    }
262}