1use std::{
2 future::Future,
3 io::{self, Read, Write},
4 pin::Pin,
5 ptr,
6 task::{Context, Poll},
7};
8
9use bytes::{Bytes, BytesMut};
10use futures::{ready, Sink, SinkExt, Stream, StreamExt};
11use msf_rtp::{
12 rtcp::{CompoundRtcpPacket, RtcpContextHandle, RtcpPacketType},
13 transceiver::{RtpTransceiver, RtpTransceiverOptions, SSRCMode},
14 utils::PacketMux,
15 OrderedRtpPacket, RtpPacket,
16};
17use openssl::ssl::{HandshakeError, Ssl, SslStream};
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19
20use crate::{
21 session::{DecodingError, SrtpSession},
22 Error, InternalError,
23};
24
25pub struct Connector {
27 inner: Ssl,
28}
29
30impl Connector {
31 pub fn new(ssl: Ssl) -> Self {
33 Self { inner: ssl }
34 }
35
36 pub async fn connect_srtp<S>(
38 self,
39 mut stream: S,
40 options: RtpTransceiverOptions,
41 ) -> Result<SrtpStream<S>, Error>
42 where
43 S: Stream<Item = io::Result<Bytes>>,
44 S: Sink<Bytes, Error = io::Error>,
45 S: Unpin,
46 {
47 let session = self.connect(&mut stream, options).await?;
48
49 Ok(SrtpStream::new(session, stream))
50 }
51
52 pub async fn connect_srtcp<S>(self, mut stream: S) -> Result<SrtcpStream<S>, Error>
54 where
55 S: Stream<Item = io::Result<Bytes>>,
56 S: Sink<Bytes, Error = io::Error>,
57 S: Unpin,
58 {
59 let options = RtpTransceiverOptions::new()
64 .with_input_ssrc_mode(SSRCMode::Ignore)
65 .with_max_input_ssrcs(Some(1))
66 .with_reordering_buffer_depth(1);
67
68 let session = self.connect(&mut stream, options).await?;
69
70 Ok(SrtcpStream::new(session, stream))
71 }
72
73 pub async fn connect_muxed<S>(
75 self,
76 mut stream: S,
77 options: RtpTransceiverOptions,
78 ) -> Result<MuxedSrtpStream<S>, Error>
79 where
80 S: Stream<Item = io::Result<Bytes>>,
81 S: Sink<Bytes, Error = io::Error>,
82 S: Unpin,
83 {
84 let session = self.connect(&mut stream, options).await?;
85
86 Ok(MuxedSrtpStream::new(session, stream))
87 }
88
89 pub async fn accept_srtp<S>(
91 self,
92 mut stream: S,
93 options: RtpTransceiverOptions,
94 ) -> Result<SrtpStream<S>, Error>
95 where
96 S: Stream<Item = io::Result<Bytes>>,
97 S: Sink<Bytes, Error = io::Error>,
98 S: Unpin,
99 {
100 let session = self.accept(&mut stream, options).await?;
101
102 Ok(SrtpStream::new(session, stream))
103 }
104
105 pub async fn accept_srtcp<S>(self, mut stream: S) -> Result<SrtcpStream<S>, Error>
107 where
108 S: Stream<Item = io::Result<Bytes>>,
109 S: Sink<Bytes, Error = io::Error>,
110 S: Unpin,
111 {
112 let options = RtpTransceiverOptions::new()
117 .with_input_ssrc_mode(SSRCMode::Ignore)
118 .with_max_input_ssrcs(Some(1))
119 .with_reordering_buffer_depth(1);
120
121 let session = self.accept(&mut stream, options).await?;
122
123 Ok(SrtcpStream::new(session, stream))
124 }
125
126 pub async fn accept_muxed<S>(
128 self,
129 mut stream: S,
130 options: RtpTransceiverOptions,
131 ) -> Result<MuxedSrtpStream<S>, Error>
132 where
133 S: Stream<Item = io::Result<Bytes>>,
134 S: Sink<Bytes, Error = io::Error>,
135 S: Unpin,
136 {
137 let session = self.accept(&mut stream, options).await?;
138
139 Ok(MuxedSrtpStream::new(session, stream))
140 }
141
142 async fn connect<S>(
144 self,
145 stream: &mut S,
146 options: RtpTransceiverOptions,
147 ) -> Result<SrtpSession, Error>
148 where
149 S: Stream<Item = io::Result<Bytes>>,
150 S: Sink<Bytes, Error = io::Error>,
151 S: Unpin,
152 {
153 let mut ssl_stream = InnerSslStream::new(stream);
154
155 let connect = futures::future::lazy(move |cx| {
156 ssl_stream.set_async_context(Some(cx));
157
158 let mut res = HandshakeState::from(self.inner.connect(ssl_stream));
159
160 res.set_async_context(None);
161 res
162 });
163
164 let handshake = Handshake::from(connect.await);
165
166 let ssl_stream = handshake.await?;
167
168 let ssl = ssl_stream.ssl();
169
170 SrtpSession::client(ssl, options)
171 }
172
173 async fn accept<S>(
175 self,
176 stream: &mut S,
177 options: RtpTransceiverOptions,
178 ) -> Result<SrtpSession, Error>
179 where
180 S: Stream<Item = io::Result<Bytes>>,
181 S: Sink<Bytes, Error = io::Error>,
182 S: Unpin,
183 {
184 let mut ssl_stream = InnerSslStream::new(stream);
185
186 let accept = futures::future::lazy(move |cx| {
187 ssl_stream.set_async_context(Some(cx));
188
189 let mut res = HandshakeState::from(self.inner.accept(ssl_stream));
190
191 res.set_async_context(None);
192 res
193 });
194
195 let handshake = Handshake::from(accept.await);
196
197 let ssl_stream = handshake.await?;
198
199 let ssl = ssl_stream.ssl();
200
201 SrtpSession::server(ssl, options)
202 }
203}
204
205pin_project_lite::pin_project! {
206 pub struct SrtpStream<S> {
208 #[pin]
209 inner: MuxedSrtpStream<S>,
210 }
211}
212
213impl<S> SrtpStream<S> {
214 fn new(session: SrtpSession, stream: S) -> Self {
216 Self {
217 inner: MuxedSrtpStream::new(session, stream),
218 }
219 }
220}
221
222impl<S> Stream for SrtpStream<S>
223where
224 S: Stream<Item = io::Result<Bytes>>,
225{
226 type Item = Result<OrderedRtpPacket, Error>;
227
228 #[inline]
229 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
230 let mut this = self.project();
231
232 loop {
233 let inner = this.inner.as_mut();
234
235 let inner = inner.project();
236
237 if inner.session.end_of_stream() {
238 *inner.eof = true;
239 }
240
241 let inner = this.inner.as_mut();
242
243 let ready = ready!(inner.poll_next(cx));
244
245 match ready.transpose()? {
246 Some(PacketMux::Rtp(packet)) => return Poll::Ready(Some(Ok(packet))),
247 Some(PacketMux::Rtcp(_)) => (),
248 None => return Poll::Ready(None),
249 }
250 }
251 }
252}
253
254impl<S> Sink<RtpPacket> for SrtpStream<S>
255where
256 S: Sink<Bytes, Error = io::Error>,
257{
258 type Error = Error;
259
260 #[inline]
261 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 let this = self.project();
263
264 this.inner.poll_ready(cx)
265 }
266
267 #[inline]
268 fn start_send(self: Pin<&mut Self>, packet: RtpPacket) -> Result<(), Self::Error> {
269 let this = self.project();
270
271 this.inner.start_send(packet.into())
272 }
273
274 #[inline]
275 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 let this = self.project();
277
278 this.inner.poll_flush(cx)
279 }
280
281 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
282 let mut this = self.project();
283
284 let inner = this.inner.as_mut();
285
286 ready!(inner.poll_close(cx))?;
287
288 let inner = this.inner.project();
289
290 inner.session.close();
291
292 Poll::Ready(Ok(()))
293 }
294}
295
296impl<S> RtpTransceiver for SrtpStream<S> {
297 #[inline]
298 fn rtcp_context(&self) -> RtcpContextHandle {
299 self.inner.session.rtcp_context()
300 }
301}
302
303pin_project_lite::pin_project! {
304 pub struct SrtcpStream<S> {
306 #[pin]
307 inner: MuxedSrtpStream<S>,
308 }
309}
310
311impl<S> SrtcpStream<S> {
312 fn new(session: SrtpSession, stream: S) -> Self {
314 Self {
315 inner: MuxedSrtpStream::new(session, stream),
316 }
317 }
318}
319
320impl<S> Stream for SrtcpStream<S>
321where
322 S: Stream<Item = io::Result<Bytes>>,
323{
324 type Item = Result<CompoundRtcpPacket, Error>;
325
326 #[inline]
327 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328 let mut this = self.project();
329
330 loop {
331 let inner = this.inner.as_mut();
332
333 let ready = ready!(inner.poll_next(cx));
334
335 match ready.transpose()? {
336 Some(PacketMux::Rtp(_)) => (),
337 Some(PacketMux::Rtcp(packet)) => return Poll::Ready(Some(Ok(packet))),
338 None => return Poll::Ready(None),
339 }
340 }
341 }
342}
343
344impl<S> Sink<CompoundRtcpPacket> for SrtcpStream<S>
345where
346 S: Sink<Bytes, Error = io::Error>,
347{
348 type Error = Error;
349
350 #[inline]
351 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
352 let this = self.project();
353
354 this.inner.poll_ready(cx)
355 }
356
357 #[inline]
358 fn start_send(self: Pin<&mut Self>, packet: CompoundRtcpPacket) -> Result<(), Self::Error> {
359 let this = self.project();
360
361 this.inner.start_send(packet.into())
362 }
363
364 #[inline]
365 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
366 let this = self.project();
367
368 this.inner.poll_flush(cx)
369 }
370
371 #[inline]
372 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373 let this = self.project();
374
375 this.inner.poll_close(cx)
376 }
377}
378
379pin_project_lite::pin_project! {
380 pub struct MuxedSrtpStream<S> {
393 #[pin]
394 inner: S,
395 session: SrtpSession,
396 error: Option<Error>,
397 eof: bool,
398 }
399}
400
401impl<S> MuxedSrtpStream<S> {
402 fn new(session: SrtpSession, stream: S) -> Self {
404 Self {
405 inner: stream,
406 session,
407 error: None,
408 eof: false,
409 }
410 }
411}
412
413impl<S> Stream for MuxedSrtpStream<S>
414where
415 S: Stream<Item = io::Result<Bytes>>,
416{
417 type Item = Result<PacketMux<OrderedRtpPacket>, Error>;
418
419 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
420 let mut this = self.project();
421
422 loop {
423 if let Some(packet) = this.session.poll_next_ordered_packet() {
424 return Poll::Ready(Some(Ok(packet)));
425 }
426
427 let inner = this.inner.as_mut();
428
429 if !*this.eof {
430 let res = match ready!(inner.poll_next(cx)) {
431 Some(Ok(frame)) => match this.session.decode(frame) {
432 Err(DecodingError::Other(err)) => Some(Err(err)),
433 _ => Some(Ok(())),
434 },
435 Some(Err(err)) => Some(Err(err.into())),
436 None => None,
437 };
438
439 if !matches!(res, Some(Ok(()))) {
440 if let Some(Err(err)) = res {
441 *this.error = Some(err);
442 }
443
444 *this.eof = true;
445 }
446 } else if let Some(packet) = this.session.take_next_ordered_packet() {
447 return Poll::Ready(Some(Ok(packet)));
448 } else if let Some(err) = this.error.take() {
449 return Poll::Ready(Some(Err(err)));
450 } else {
451 return Poll::Ready(None);
452 }
453 }
454 }
455}
456
457impl<S> Sink<PacketMux> for MuxedSrtpStream<S>
458where
459 S: Sink<Bytes, Error = io::Error>,
460{
461 type Error = Error;
462
463 #[inline]
464 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
465 let this = self.project();
466
467 ready!(this.inner.poll_ready(cx))?;
468
469 Poll::Ready(Ok(()))
470 }
471
472 fn start_send(self: Pin<&mut Self>, packet: PacketMux) -> Result<(), Self::Error> {
473 let this = self.project();
474
475 let frame = match packet {
476 PacketMux::Rtp(packet) => this.session.encode_rtp_packet(packet)?,
477 PacketMux::Rtcp(packet) => {
478 if let Some(first) = packet.first() {
479 match first.packet_type() {
480 RtcpPacketType::SR | RtcpPacketType::RR => {
481 this.session.encode_rtcp_packet(packet)?
482 }
483 _ => return Err(Error::from(InternalError::InvalidPacketType)),
484 }
485 } else {
486 return Ok(());
487 }
488 }
489 };
490
491 this.inner.start_send(frame)?;
492
493 Ok(())
494 }
495
496 #[inline]
497 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
498 let this = self.project();
499
500 ready!(this.inner.poll_flush(cx))?;
501
502 Poll::Ready(Ok(()))
503 }
504
505 #[inline]
506 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
507 let this = self.project();
508
509 ready!(this.inner.poll_close(cx))?;
510
511 Poll::Ready(Ok(()))
512 }
513}
514
515impl<S> RtpTransceiver for MuxedSrtpStream<S> {
516 #[inline]
517 fn rtcp_context(&self) -> RtcpContextHandle {
518 self.session.rtcp_context()
519 }
520}
521
522struct Handshake<'a, S> {
524 inner: Option<HandshakeState<'a, S>>,
525}
526
527impl<'a, S> From<HandshakeState<'a, S>> for Handshake<'a, S> {
528 fn from(state: HandshakeState<'a, S>) -> Self {
529 Self { inner: Some(state) }
530 }
531}
532
533impl<'a, S> Future for Handshake<'a, S>
534where
535 S: Stream<Item = io::Result<Bytes>> + Sink<Bytes, Error = io::Error> + Unpin,
536{
537 type Output = Result<SslStream<InnerSslStream<'a, S>>, InternalError>;
538
539 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
540 let mut state = self
541 .inner
542 .take()
543 .expect("the future has been already resolved");
544
545 state.set_async_context(Some(cx));
546
547 match state.inner {
548 Ok(stream) => Poll::Ready(Ok(stream)),
549 Err(HandshakeError::SetupFailure(err)) => Poll::Ready(Err(err.into())),
550 Err(HandshakeError::Failure(m)) => {
551 Poll::Ready(Err(InternalError::from(m.into_error())))
552 }
553 Err(HandshakeError::WouldBlock(m)) => match m.handshake() {
554 Ok(stream) => Poll::Ready(Ok(stream)),
555 Err(HandshakeError::SetupFailure(err)) => Poll::Ready(Err(err.into())),
556 Err(HandshakeError::Failure(m)) => {
557 Poll::Ready(Err(InternalError::from(m.into_error())))
558 }
559 Err(HandshakeError::WouldBlock(m)) => {
560 let mut state = HandshakeState::from(HandshakeError::WouldBlock(m));
561
562 state.set_async_context(None);
563
564 self.inner = Some(state);
565
566 Poll::Pending
567 }
568 },
569 }
570 }
571}
572
573type HandshakeResult<'a, S> =
575 Result<SslStream<InnerSslStream<'a, S>>, HandshakeError<InnerSslStream<'a, S>>>;
576
577struct HandshakeState<'a, S> {
579 inner: HandshakeResult<'a, S>,
580}
581
582impl<S> HandshakeState<'_, S> {
583 fn set_async_context(&mut self, cx: Option<&mut Context<'_>>) {
585 let ssl_stream = match &mut self.inner {
586 Ok(ssl_stream) => Some(ssl_stream.get_mut()),
587 Err(HandshakeError::Failure(m)) => Some(m.get_mut()),
588 Err(HandshakeError::WouldBlock(m)) => Some(m.get_mut()),
589 _ => None,
590 };
591
592 if let Some(s) = ssl_stream {
593 s.set_async_context(cx);
594 }
595 }
596}
597
598impl<'a, S> From<HandshakeResult<'a, S>> for HandshakeState<'a, S> {
599 fn from(res: HandshakeResult<'a, S>) -> Self {
600 Self { inner: res }
601 }
602}
603
604impl<'a, S> From<HandshakeError<InnerSslStream<'a, S>>> for HandshakeState<'a, S> {
605 fn from(err: HandshakeError<InnerSslStream<'a, S>>) -> Self {
606 Self::from(Err(err))
607 }
608}
609
610struct InnerSslStream<'a, S> {
612 inner: RWStreamRef<'a, S>,
613 context: *mut (),
614}
615
616impl<'a, S> InnerSslStream<'a, S> {
617 fn new(stream: &'a mut S) -> Self {
619 Self {
620 inner: RWStreamRef::new(stream),
621 context: ptr::null_mut(),
622 }
623 }
624}
625
626impl<S> InnerSslStream<'_, S> {
627 fn set_async_context(&mut self, cx: Option<&mut Context<'_>>) {
629 if let Some(cx) = cx {
630 self.context = cx as *mut _ as *mut ();
631 } else {
632 self.context = ptr::null_mut();
633 }
634 }
635}
636
637unsafe impl<S> Send for InnerSslStream<'_, S> where S: Send {}
638unsafe impl<S> Sync for InnerSslStream<'_, S> where S: Sync {}
639
640impl<S> Read for InnerSslStream<'_, S>
641where
642 S: Stream<Item = io::Result<Bytes>> + Unpin,
643{
644 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
645 debug_assert!(!self.context.is_null());
646
647 let cx = unsafe { &mut *(self.context as *mut Context<'_>) };
648
649 let pinned = Pin::new(&mut self.inner);
650
651 let mut buf = ReadBuf::new(buf);
652
653 let data = match pinned.poll_read(cx, &mut buf) {
654 Poll::Ready(Ok(())) => buf.filled(),
655 Poll::Ready(Err(err)) => return Err(err),
656 Poll::Pending => return Err(io::Error::from(io::ErrorKind::WouldBlock)),
657 };
658
659 Ok(data.len())
660 }
661}
662
663impl<S> Write for InnerSslStream<'_, S>
664where
665 S: Sink<Bytes, Error = io::Error> + Unpin,
666{
667 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
668 debug_assert!(!self.context.is_null());
669
670 let cx = unsafe { &mut *(self.context as *mut Context<'_>) };
671
672 let pinned = Pin::new(&mut self.inner);
673
674 if let Poll::Ready(res) = pinned.poll_write(cx, buf) {
675 res
676 } else {
677 Err(io::Error::from(io::ErrorKind::WouldBlock))
678 }
679 }
680
681 fn flush(&mut self) -> io::Result<()> {
682 debug_assert!(!self.context.is_null());
683
684 let cx = unsafe { &mut *(self.context as *mut Context<'_>) };
685
686 let pinned = Pin::new(&mut self.inner);
687
688 if let Poll::Ready(res) = AsyncWrite::poll_flush(pinned, cx) {
689 res
690 } else {
691 Err(io::Error::from(io::ErrorKind::WouldBlock))
692 }
693 }
694}
695
696struct RWStreamRef<'a, S> {
698 stream: &'a mut S,
699 input: Bytes,
700 output: BytesMut,
701}
702
703impl<'a, S> RWStreamRef<'a, S> {
704 fn new(stream: &'a mut S) -> Self {
706 Self {
707 stream,
708 input: Bytes::new(),
709 output: BytesMut::new(),
710 }
711 }
712}
713
714impl<S> AsyncRead for RWStreamRef<'_, S>
715where
716 S: Stream<Item = io::Result<Bytes>> + Unpin,
717{
718 fn poll_read(
719 mut self: Pin<&mut Self>,
720 cx: &mut Context<'_>,
721 buf: &mut ReadBuf<'_>,
722 ) -> Poll<io::Result<()>> {
723 loop {
724 if !self.input.is_empty() {
725 let remaining = buf.remaining();
726 let take = remaining.min(self.input.len());
727 let data = self.input.split_to(take);
728
729 buf.put_slice(&data);
730
731 return Poll::Ready(Ok(()));
732 } else if let Poll::Ready(ready) = self.stream.poll_next_unpin(cx) {
733 if let Some(chunk) = ready.transpose()? {
734 self.input = chunk;
735 } else {
736 return Poll::Ready(Ok(()));
737 }
738 } else {
739 return Poll::Pending;
740 }
741 }
742 }
743}
744
745impl<S> AsyncWrite for RWStreamRef<'_, S>
746where
747 S: Sink<Bytes, Error = io::Error> + Unpin,
748{
749 fn poll_write(
750 mut self: Pin<&mut Self>,
751 cx: &mut Context,
752 buf: &[u8],
753 ) -> Poll<io::Result<usize>> {
754 let this = &mut *self;
755
756 ready!(this.stream.poll_ready_unpin(cx))?;
757
758 this.output.extend_from_slice(buf);
759
760 let data = this.output.split_to(this.output.len());
761
762 this.stream.start_send_unpin(data.freeze())?;
763
764 Poll::Ready(Ok(buf.len()))
765 }
766
767 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
768 self.stream.poll_flush_unpin(cx)
769 }
770
771 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
772 self.stream.poll_close_unpin(cx)
773 }
774}