1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context, Poll},
6};
7
8use futures::{
9 channel::{mpsc, oneshot},
10 ready, FutureExt, Sink, SinkExt, Stream, StreamExt,
11};
12use tokio::task::JoinHandle;
13
14use crate::{CompoundRtcpPacket, PacketMux, RtpPacket};
15
16pub struct RtcpHandler<T> {
23 context: RtcpContext,
24 stream: T,
25 receiver: JoinHandle<()>,
26 sender: Option<oneshot::Sender<()>>,
27}
28
29impl<T> RtcpHandler<T> {
30 pub fn new<U, E>(rtp: T, rtcp: U) -> Self
32 where
33 U: Stream<Item = Result<CompoundRtcpPacket, E>> + Sink<CompoundRtcpPacket> + Send + 'static,
34 {
35 let context = RtcpContext::new();
36
37 let (rtcp_tx, rtcp_rx) = rtcp.split();
38
39 let (close_tx, close_rx) = oneshot::channel();
40
41 let sender = RtcpSender {
42 context: context.clone(),
43 sink: rtcp_tx,
44 close_rx: Some(close_rx),
45 pending: None,
46 };
47
48 let receiver = RtcpReceiver {
49 context: context.clone(),
50 stream: rtcp_rx,
51 };
52
53 tokio::spawn(async move { sender.await.unwrap_or_default() });
54
55 let receiver = tokio::spawn(async move { receiver.await.unwrap_or_default() });
56
57 Self {
58 context,
59 stream: rtp,
60 receiver,
61 sender: Some(close_tx),
62 }
63 }
64}
65
66impl<T> Drop for RtcpHandler<T> {
67 #[inline]
68 fn drop(&mut self) {
69 self.receiver.abort();
71
72 if let Some(close_tx) = self.sender.take() {
74 close_tx.send(()).unwrap_or_default();
75 }
76 }
77}
78
79impl<T, E> Stream for RtcpHandler<T>
80where
81 T: Stream<Item = Result<RtpPacket, E>> + Unpin,
82{
83 type Item = Result<RtpPacket, E>;
84
85 #[inline]
86 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87 if let Poll::Ready(ready) = self.stream.poll_next_unpin(cx) {
88 if let Some(packet) = ready.transpose()? {
89 self.context.process_incoming_rtp_packet(&packet);
90
91 Poll::Ready(Some(Ok(packet)))
92 } else {
93 Poll::Ready(None)
94 }
95 } else {
96 Poll::Pending
97 }
98 }
99}
100
101impl<T, E> Sink<RtpPacket> for RtcpHandler<T>
102where
103 T: Sink<RtpPacket, Error = E> + Unpin,
104{
105 type Error = E;
106
107 #[inline]
108 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 self.stream.poll_ready_unpin(cx)
110 }
111
112 #[inline]
113 fn start_send(mut self: Pin<&mut Self>, packet: RtpPacket) -> Result<(), Self::Error> {
114 self.context.process_outgoing_rtp_packet(&packet);
115 self.stream.start_send_unpin(packet)
116 }
117
118 #[inline]
119 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.stream.poll_flush_unpin(cx)
121 }
122
123 #[inline]
124 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 self.stream.poll_close_unpin(cx)
126 }
127}
128
129type DemuxingRtpStream<E> = mpsc::Receiver<Result<RtpPacket, E>>;
131
132type MuxingRtpSink = PacketMuxer<mpsc::Sender<PacketMux>>;
134
135type RtpComponent<E> = StreamSink<DemuxingRtpStream<E>, MuxingRtpSink>;
137
138pub struct MuxedRtcpHandler<E> {
145 inner: RtcpHandler<RtpComponent<E>>,
146 reader: JoinHandle<()>,
147 writer: JoinHandle<Result<(), E>>,
148 sink_error: bool,
149}
150
151impl<E> MuxedRtcpHandler<E> {
152 pub fn new<T>(stream: T) -> Self
154 where
155 T: Stream<Item = Result<PacketMux, E>> + Sink<PacketMux, Error = E> + Send + 'static,
156 E: Send + 'static,
157 {
158 let (mut muxed_tx, mut muxed_rx) = stream.split();
159
160 let (mut input_rtp_tx, input_rtp_rx) = mpsc::channel(4);
161 let (output_rtp_tx, output_rtp_rx) = mpsc::channel(4);
162 let (mut input_rtcp_tx, input_rtcp_rx) = mpsc::channel(4);
163 let (output_rtcp_tx, output_rtcp_rx) = mpsc::channel(4);
164
165 let output_rtp_tx = PacketMuxer::new(output_rtp_tx);
166 let output_rtcp_tx = PacketMuxer::new(output_rtcp_tx);
167
168 let rtp = StreamSink::new(input_rtp_rx, output_rtp_tx);
169 let rtcp = StreamSink::new(input_rtcp_rx, output_rtcp_tx);
170
171 let reader = tokio::spawn(async move {
172 while let Some(item) = muxed_rx.next().await {
173 match item {
174 Ok(PacketMux::Rtp(packet)) => {
175 input_rtp_tx.send(Ok(packet)).await.unwrap_or_default();
176 }
177 Ok(PacketMux::Rtcp(packet)) => {
178 input_rtcp_tx
179 .send(Ok(packet) as Result<_, E>)
180 .await
181 .unwrap_or_default();
182 }
183 Err(err) => {
184 input_rtp_tx.send(Err(err)).await.unwrap_or_default();
186
187 break;
189 }
190 }
191 }
192 });
193
194 let writer = tokio::spawn(async move {
195 let mut stream = futures::stream::select(output_rtp_rx, output_rtcp_rx);
196
197 while let Some(item) = stream.next().await {
198 muxed_tx.send(item).await?;
199 }
200
201 Ok(()) as Result<(), T::Error>
202 });
203
204 Self {
205 inner: RtcpHandler::new(rtp, rtcp),
206 reader,
207 writer,
208 sink_error: false,
209 }
210 }
211
212 fn poll_writer_result(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
214 match ready!(self.writer.poll_unpin(cx)) {
215 Ok(Ok(_)) => Poll::Ready(Ok(())),
216 Ok(Err(err)) => Poll::Ready(Err(err)),
217 Err(_) => Poll::Ready(Ok(())),
218 }
219 }
220}
221
222impl<E> Drop for MuxedRtcpHandler<E> {
223 #[inline]
224 fn drop(&mut self) {
225 self.reader.abort();
226 }
227}
228
229impl<E> Stream for MuxedRtcpHandler<E> {
230 type Item = Result<RtpPacket, E>;
231
232 #[inline]
233 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
234 self.inner.poll_next_unpin(cx)
235 }
236}
237
238impl<E> Sink<RtpPacket> for MuxedRtcpHandler<E> {
239 type Error = E;
240
241 #[inline]
242 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
243 loop {
244 if self.sink_error {
245 return self.poll_writer_result(cx);
246 }
247
248 let res = ready!(self.inner.poll_ready_unpin(cx));
249
250 if res.is_ok() {
251 return Poll::Ready(Ok(()));
252 } else {
253 self.sink_error = true;
254 }
255 }
256 }
257
258 #[inline]
259 fn start_send(mut self: Pin<&mut Self>, item: RtpPacket) -> Result<(), Self::Error> {
260 let res = self.inner.start_send_unpin(item);
261
262 if res.is_err() {
265 self.sink_error = true;
266 }
267
268 Ok(())
269 }
270
271 #[inline]
272 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273 loop {
274 if self.sink_error {
275 return self.poll_writer_result(cx);
276 }
277
278 let res = ready!(self.inner.poll_flush_unpin(cx));
279
280 if res.is_ok() {
281 return Poll::Ready(Ok(()));
282 } else {
283 self.sink_error = true;
284 }
285 }
286 }
287
288 #[inline]
289 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290 loop {
291 if self.sink_error {
292 return self.poll_writer_result(cx);
293 }
294
295 let res = ready!(self.inner.poll_close_unpin(cx));
296
297 if res.is_ok() {
298 return Poll::Ready(Ok(()));
299 } else {
300 self.sink_error = true;
301 }
302 }
303 }
304}
305
306struct StreamSink<T, U> {
308 stream: T,
309 sink: U,
310}
311
312impl<T, U> StreamSink<T, U> {
313 fn new(stream: T, sink: U) -> Self {
315 Self { stream, sink }
316 }
317}
318
319impl<T, U> Stream for StreamSink<T, U>
320where
321 T: Stream + Unpin,
322 U: Unpin,
323{
324 type Item = T::Item;
325
326 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
327 self.stream.poll_next_unpin(cx)
328 }
329}
330
331impl<T, U, I> Sink<I> for StreamSink<T, U>
332where
333 T: Unpin,
334 U: Sink<I> + Unpin,
335{
336 type Error = U::Error;
337
338 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
339 self.sink.poll_ready_unpin(cx)
340 }
341
342 fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
343 self.sink.start_send_unpin(item)
344 }
345
346 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
347 self.sink.poll_flush_unpin(cx)
348 }
349
350 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
351 self.sink.poll_close_unpin(cx)
352 }
353}
354
355struct PacketMuxer<T> {
357 inner: T,
358}
359
360impl<T> PacketMuxer<T> {
361 fn new(sink: T) -> Self {
363 Self { inner: sink }
364 }
365}
366
367impl<T, I> Sink<I> for PacketMuxer<T>
368where
369 T: Sink<PacketMux> + Unpin,
370 I: Into<PacketMux>,
371{
372 type Error = T::Error;
373
374 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375 self.inner.poll_ready_unpin(cx)
376 }
377
378 fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
379 self.inner.start_send_unpin(item.into())
380 }
381
382 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383 self.inner.poll_flush_unpin(cx)
384 }
385
386 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 self.inner.poll_close_unpin(cx)
388 }
389}
390
391struct RtcpReceiver<T> {
393 context: RtcpContext,
394 stream: T,
395}
396
397impl<T, E> Future for RtcpReceiver<T>
398where
399 T: Stream<Item = Result<CompoundRtcpPacket, E>> + Unpin,
400{
401 type Output = Result<(), E>;
402
403 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
404 while let Poll::Ready(ready) = self.stream.poll_next_unpin(cx) {
405 if let Some(packet) = ready.transpose()? {
406 self.context.process_incoming_rtcp_packet(&packet);
407 } else {
408 return Poll::Ready(Ok(()));
409 }
410 }
411
412 Poll::Pending
413 }
414}
415
416struct RtcpSender<T> {
418 context: RtcpContext,
419 sink: T,
420 close_rx: Option<oneshot::Receiver<()>>,
421 pending: Option<CompoundRtcpPacket>,
422}
423
424impl<T> RtcpSender<T> {
425 fn poll_next_packet(&mut self, cx: &mut Context) -> Poll<Option<CompoundRtcpPacket>> {
427 if let Some(close_rx) = self.close_rx.as_mut() {
428 if close_rx.poll_unpin(cx).is_ready() {
429 self.close_rx = None;
432 }
433 }
434
435 if let Some(packet) = self.pending.take() {
436 Poll::Ready(Some(packet))
437 } else if self.close_rx.is_none() {
438 Poll::Ready(None)
439 } else {
440 Poll::Pending
443 }
444 }
445}
446
447impl<T> Future for RtcpSender<T>
448where
449 T: Sink<CompoundRtcpPacket> + Unpin,
450{
451 type Output = Result<(), T::Error>;
452
453 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
454 while let Poll::Ready(ready) = self.poll_next_packet(cx) {
455 if let Some(packet) = ready {
456 let poll = self.sink.poll_ready_unpin(cx)?;
457
458 if poll.is_ready() {
459 self.context.process_outgoing_rtcp_packet(&packet);
460 self.sink.start_send_unpin(packet)?;
461 } else {
462 self.pending = Some(packet);
464
465 return Poll::Pending;
467 }
468 } else {
469 return self.sink.poll_close_unpin(cx);
470 }
471 }
472
473 let _ = self.sink.poll_flush_unpin(cx);
475
476 Poll::Pending
477 }
478}
479
480#[derive(Clone)]
482struct RtcpContext {
483 inner: Arc<Mutex<InnerRtcpContext>>,
484}
485
486impl RtcpContext {
487 fn new() -> Self {
489 Self {
490 inner: Arc::new(Mutex::new(InnerRtcpContext::new())),
491 }
492 }
493
494 fn process_incoming_rtp_packet(&mut self, packet: &RtpPacket) {
496 self.inner
497 .lock()
498 .unwrap()
499 .process_incoming_rtp_packet(packet);
500 }
501
502 fn process_incoming_rtcp_packet(&mut self, packet: &CompoundRtcpPacket) {
504 self.inner
505 .lock()
506 .unwrap()
507 .process_incoming_rtcp_packet(packet);
508 }
509
510 fn process_outgoing_rtp_packet(&mut self, packet: &RtpPacket) {
512 self.inner
513 .lock()
514 .unwrap()
515 .process_outgoing_rtp_packet(packet);
516 }
517
518 fn process_outgoing_rtcp_packet(&mut self, packet: &CompoundRtcpPacket) {
520 self.inner
521 .lock()
522 .unwrap()
523 .process_outgoing_rtcp_packet(packet);
524 }
525}
526
527struct InnerRtcpContext {}
529
530impl InnerRtcpContext {
531 fn new() -> Self {
533 Self {}
534 }
535
536 fn process_incoming_rtp_packet(&mut self, _: &RtpPacket) {
538 }
540
541 fn process_incoming_rtcp_packet(&mut self, _: &CompoundRtcpPacket) {
543 }
545
546 fn process_outgoing_rtp_packet(&mut self, _: &RtpPacket) {
548 }
550
551 fn process_outgoing_rtcp_packet(&mut self, _: &CompoundRtcpPacket) {
553 }
555}