1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
use std::sync::Arc;

use async_trait::async_trait;
use tokio::sync::{mpsc, Mutex};
use util::Marshal;

use crate::error::{Error, Result};
use crate::stream_info::StreamInfo;
use crate::{Attributes, Interceptor, RTCPReader, RTCPWriter, RTPReader, RTPWriter};

type RTCPPackets = Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>;

/// MockStream is a helper struct for testing interceptors.
pub struct MockStream {
    interceptor: Arc<dyn Interceptor + Send + Sync>,

    rtcp_writer: Mutex<Option<Arc<dyn RTCPWriter + Send + Sync>>>,
    rtp_writer: Mutex<Option<Arc<dyn RTPWriter + Send + Sync>>>,

    rtcp_out_modified_tx: mpsc::Sender<RTCPPackets>,
    rtp_out_modified_tx: mpsc::Sender<rtp::packet::Packet>,
    rtcp_in_rx: Mutex<mpsc::Receiver<RTCPPackets>>,
    rtp_in_rx: Mutex<mpsc::Receiver<rtp::packet::Packet>>,

    rtcp_out_modified_rx: Mutex<mpsc::Receiver<RTCPPackets>>,
    rtp_out_modified_rx: Mutex<mpsc::Receiver<rtp::packet::Packet>>,
    rtcp_in_tx: Mutex<Option<mpsc::Sender<RTCPPackets>>>,
    rtp_in_tx: Mutex<Option<mpsc::Sender<rtp::packet::Packet>>>,

    rtcp_in_modified_rx: Mutex<mpsc::Receiver<Result<RTCPPackets>>>,
    rtp_in_modified_rx: Mutex<mpsc::Receiver<Result<rtp::packet::Packet>>>,
}

impl MockStream {
    /// new creates a new MockStream
    pub async fn new(
        info: &StreamInfo,
        interceptor: Arc<dyn Interceptor + Send + Sync>,
    ) -> Arc<Self> {
        let (rtcp_in_tx, rtcp_in_rx) = mpsc::channel(1000);
        let (rtp_in_tx, rtp_in_rx) = mpsc::channel(1000);
        let (rtcp_out_modified_tx, rtcp_out_modified_rx) = mpsc::channel(1000);
        let (rtp_out_modified_tx, rtp_out_modified_rx) = mpsc::channel(1000);
        let (rtcp_in_modified_tx, rtcp_in_modified_rx) = mpsc::channel(1000);
        let (rtp_in_modified_tx, rtp_in_modified_rx) = mpsc::channel(1000);

        let stream = Arc::new(MockStream {
            interceptor: Arc::clone(&interceptor),

            rtcp_writer: Mutex::new(None),
            rtp_writer: Mutex::new(None),

            rtcp_in_tx: Mutex::new(Some(rtcp_in_tx)),
            rtp_in_tx: Mutex::new(Some(rtp_in_tx)),
            rtcp_in_rx: Mutex::new(rtcp_in_rx),
            rtp_in_rx: Mutex::new(rtp_in_rx),

            rtcp_out_modified_tx,
            rtp_out_modified_tx,
            rtcp_out_modified_rx: Mutex::new(rtcp_out_modified_rx),
            rtp_out_modified_rx: Mutex::new(rtp_out_modified_rx),

            rtcp_in_modified_rx: Mutex::new(rtcp_in_modified_rx),
            rtp_in_modified_rx: Mutex::new(rtp_in_modified_rx),
        });

        let rtcp_writer = interceptor
            .bind_rtcp_writer(Arc::clone(&stream) as Arc<dyn RTCPWriter + Send + Sync>)
            .await;
        {
            let mut rw = stream.rtcp_writer.lock().await;
            *rw = Some(rtcp_writer);
        }
        let rtp_writer = interceptor
            .bind_local_stream(
                info,
                Arc::clone(&stream) as Arc<dyn RTPWriter + Send + Sync>,
            )
            .await;
        {
            let mut rw = stream.rtp_writer.lock().await;
            *rw = Some(rtp_writer);
        }

        let rtcp_reader = interceptor
            .bind_rtcp_reader(Arc::clone(&stream) as Arc<dyn RTCPReader + Send + Sync>)
            .await;
        tokio::spawn(async move {
            let mut buf = vec![0u8; 1500];
            let a = Attributes::new();
            loop {
                let pkts = match rtcp_reader.read(&mut buf, &a).await {
                    Ok((n, _)) => n,
                    Err(err) => {
                        let _ = rtcp_in_modified_tx.send(Err(err)).await;
                        break;
                    }
                };

                let _ = rtcp_in_modified_tx.send(Ok(pkts)).await;
            }
        });

        let rtp_reader = interceptor
            .bind_remote_stream(
                info,
                Arc::clone(&stream) as Arc<dyn RTPReader + Send + Sync>,
            )
            .await;
        tokio::spawn(async move {
            let mut buf = vec![0u8; 1500];
            let a = Attributes::new();
            loop {
                let pkt = match rtp_reader.read(&mut buf, &a).await {
                    Ok((pkt, _)) => pkt,
                    Err(err) => {
                        let _ = rtp_in_modified_tx.send(Err(err)).await;
                        break;
                    }
                };

                let _ = rtp_in_modified_tx.send(Ok(pkt)).await;
            }
        });

        stream
    }

    /// write_rtcp writes a batch of rtcp packet to the stream, using the interceptor
    pub async fn write_rtcp(
        &self,
        pkt: &[Box<dyn rtcp::packet::Packet + Send + Sync>],
    ) -> Result<usize> {
        let a = Attributes::new();
        let rtcp_writer = self.rtcp_writer.lock().await;
        if let Some(writer) = &*rtcp_writer {
            writer.write(pkt, &a).await
        } else {
            Err(Error::Other("invalid rtcp_writer".to_owned()))
        }
    }

    /// write_rtp writes an rtp packet to the stream, using the interceptor
    pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
        let a = Attributes::new();
        let rtp_writer = self.rtp_writer.lock().await;
        if let Some(writer) = &*rtp_writer {
            writer.write(pkt, &a).await
        } else {
            Err(Error::Other("invalid rtp_writer".to_owned()))
        }
    }

    /// receive_rtcp schedules a new rtcp batch, so it can be read be the stream
    pub async fn receive_rtcp(&self, pkts: Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>) {
        let rtcp_in_tx = self.rtcp_in_tx.lock().await;
        if let Some(tx) = &*rtcp_in_tx {
            let _ = tx.send(pkts).await;
        }
    }

    /// receive_rtp schedules a rtp packet, so it can be read be the stream
    pub async fn receive_rtp(&self, pkt: rtp::packet::Packet) {
        let rtp_in_tx = self.rtp_in_tx.lock().await;
        if let Some(tx) = &*rtp_in_tx {
            let _ = tx.send(pkt).await;
        }
    }

    /// written_rtcp returns a channel containing the rtcp batches written, modified by the interceptor
    pub async fn written_rtcp(&self) -> Option<Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>> {
        let mut rtcp_out_modified_rx = self.rtcp_out_modified_rx.lock().await;
        rtcp_out_modified_rx.recv().await
    }

    /// Returns the last rtcp packet bacth that was written, modified by the interceptor.
    ///
    /// NB: This method discards all other previously recoreded packet batches.
    pub async fn last_written_rtcp(
        &self,
    ) -> Option<Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>> {
        let mut last = None;
        let mut rtcp_out_modified_rx = self.rtcp_out_modified_rx.lock().await;

        while let Ok(v) = rtcp_out_modified_rx.try_recv() {
            last = Some(v);
        }

        last
    }

    /// written_rtp returns a channel containing rtp packets written, modified by the interceptor
    pub async fn written_rtp(&self) -> Option<rtp::packet::Packet> {
        let mut rtp_out_modified_rx = self.rtp_out_modified_rx.lock().await;
        rtp_out_modified_rx.recv().await
    }

    /// read_rtcp returns a channel containing the rtcp batched read, modified by the interceptor
    pub async fn read_rtcp(
        &self,
    ) -> Option<Result<Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>>> {
        let mut rtcp_in_modified_rx = self.rtcp_in_modified_rx.lock().await;
        rtcp_in_modified_rx.recv().await
    }

    /// read_rtp returns a channel containing the rtp packets read, modified by the interceptor
    pub async fn read_rtp(&self) -> Option<Result<rtp::packet::Packet>> {
        let mut rtp_in_modified_rx = self.rtp_in_modified_rx.lock().await;
        rtp_in_modified_rx.recv().await
    }

    /// close closes the stream and the underlying interceptor
    pub async fn close(&self) -> Result<()> {
        {
            let mut rtcp_in_tx = self.rtcp_in_tx.lock().await;
            rtcp_in_tx.take();
        }
        {
            let mut rtp_in_tx = self.rtp_in_tx.lock().await;
            rtp_in_tx.take();
        }
        self.interceptor.close().await
    }
}

#[async_trait]
impl RTCPWriter for MockStream {
    async fn write(
        &self,
        pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>],
        _attributes: &Attributes,
    ) -> Result<usize> {
        let _ = self.rtcp_out_modified_tx.send(pkts.to_vec()).await;

        Ok(0)
    }
}

#[async_trait]
impl RTCPReader for MockStream {
    async fn read(
        &self,
        buf: &mut [u8],
        a: &Attributes,
    ) -> Result<(Vec<Box<dyn rtcp::packet::Packet + Send + Sync>>, Attributes)> {
        let pkts = {
            let mut rtcp_in = self.rtcp_in_rx.lock().await;
            rtcp_in.recv().await.ok_or(Error::ErrIoEOF)?
        };

        let marshaled = rtcp::packet::marshal(&pkts)?;
        let n = marshaled.len();
        if n > buf.len() {
            return Err(Error::ErrShortBuffer);
        }

        buf[..n].copy_from_slice(&marshaled);
        Ok((pkts, a.clone()))
    }
}

#[async_trait]
impl RTPWriter for MockStream {
    async fn write(&self, pkt: &rtp::packet::Packet, _a: &Attributes) -> Result<usize> {
        let _ = self.rtp_out_modified_tx.send(pkt.clone()).await;
        Ok(0)
    }
}

#[async_trait]
impl RTPReader for MockStream {
    async fn read(
        &self,
        buf: &mut [u8],
        a: &Attributes,
    ) -> Result<(rtp::packet::Packet, Attributes)> {
        let pkt = {
            let mut rtp_in = self.rtp_in_rx.lock().await;
            rtp_in.recv().await.ok_or(Error::ErrIoEOF)?
        };

        let marshaled = pkt.marshal()?;
        let n = marshaled.len();
        if n > buf.len() {
            return Err(Error::ErrShortBuffer);
        }

        buf[..n].copy_from_slice(&marshaled);
        Ok((pkt, a.clone()))
    }
}

#[cfg(test)]
mod test {
    use tokio::time::Duration;

    use super::*;
    use crate::noop::NoOp;
    use crate::test::timeout_or_fail;

    #[tokio::test]
    async fn test_mock_stream() -> Result<()> {
        use rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication;

        let s = MockStream::new(&StreamInfo::default(), Arc::new(NoOp)).await;

        s.write_rtcp(&[Box::<PictureLossIndication>::default()])
            .await?;
        timeout_or_fail(Duration::from_millis(10), s.written_rtcp()).await;
        let result = tokio::time::timeout(Duration::from_millis(10), s.written_rtcp()).await;
        assert!(
            result.is_err(),
            "single rtcp packet written, but multiple found"
        );

        s.write_rtp(&rtp::packet::Packet::default()).await?;
        timeout_or_fail(Duration::from_millis(10), s.written_rtp()).await;
        let result = tokio::time::timeout(Duration::from_millis(10), s.written_rtp()).await;
        assert!(
            result.is_err(),
            "single rtp packet written, but multiple found"
        );

        s.receive_rtcp(vec![Box::<PictureLossIndication>::default()])
            .await;
        assert!(
            timeout_or_fail(Duration::from_millis(10), s.read_rtcp())
                .await
                .is_some(),
            "read rtcp returned error",
        );
        let result = tokio::time::timeout(Duration::from_millis(10), s.read_rtcp()).await;
        assert!(
            result.is_err(),
            "single rtcp packet written, but multiple found"
        );

        s.receive_rtp(rtp::packet::Packet::default()).await;
        assert!(
            timeout_or_fail(Duration::from_millis(10), s.read_rtp())
                .await
                .is_some(),
            "read rtp returned error",
        );
        let result = tokio::time::timeout(Duration::from_millis(10), s.read_rtp()).await;
        assert!(
            result.is_err(),
            "single rtp packet written, but multiple found"
        );

        s.close().await?;

        Ok(())
    }
}