iroh_net/relay/server/
streams.rs

1//! Streams used in the server-side implementation of iroh relays.
2
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use anyhow::Result;
9use futures_lite::Stream;
10use futures_sink::Sink;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_tungstenite::WebSocketStream;
13use tokio_util::codec::Framed;
14
15use crate::relay::codec::{DerpCodec, Frame};
16
17#[derive(Debug)]
18pub(crate) enum RelayIo {
19    Derp(Framed<MaybeTlsStream, DerpCodec>),
20    Ws(WebSocketStream<MaybeTlsStream>),
21}
22
23fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error {
24    match e {
25        tungstenite::Error::Io(io_err) => io_err,
26        _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
27    }
28}
29
30impl Sink<Frame> for RelayIo {
31    type Error = std::io::Error;
32
33    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34        match *self {
35            Self::Derp(ref mut framed) => Pin::new(framed).poll_ready(cx),
36            Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_to_io_err),
37        }
38    }
39
40    fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
41        match *self {
42            Self::Derp(ref mut framed) => Pin::new(framed).start_send(item),
43            Self::Ws(ref mut ws) => Pin::new(ws)
44                .start_send(tungstenite::Message::Binary(item.encode_for_ws_msg()))
45                .map_err(tung_to_io_err),
46        }
47    }
48
49    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        match *self {
51            Self::Derp(ref mut framed) => Pin::new(framed).poll_flush(cx),
52            Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_to_io_err),
53        }
54    }
55
56    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57        match *self {
58            Self::Derp(ref mut framed) => Pin::new(framed).poll_close(cx),
59            Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_to_io_err),
60        }
61    }
62}
63
64impl Stream for RelayIo {
65    type Item = anyhow::Result<Frame>;
66
67    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
68        match *self {
69            Self::Derp(ref mut framed) => Pin::new(framed).poll_next(cx),
70            Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
71                Poll::Ready(Some(Ok(tungstenite::Message::Binary(vec)))) => {
72                    Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
73                }
74                Poll::Ready(Some(Ok(msg))) => {
75                    tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
76                    Poll::Pending
77                }
78                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
79                Poll::Ready(None) => Poll::Ready(None),
80                Poll::Pending => Poll::Pending,
81            },
82        }
83    }
84}
85
86/// The main underlying IO stream type used for the relay server.
87///
88/// Allows choosing whether or not the underlying [`tokio::net::TcpStream`] is served over Tls
89#[derive(Debug)]
90pub enum MaybeTlsStream {
91    /// A plain non-Tls [`tokio::net::TcpStream`]
92    Plain(tokio::net::TcpStream),
93    /// A Tls wrapped [`tokio::net::TcpStream`]
94    Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
95    /// An in-memory bidirectional pipe.
96    #[cfg(test)]
97    Test(tokio::io::DuplexStream),
98}
99
100impl AsyncRead for MaybeTlsStream {
101    fn poll_read(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &mut tokio::io::ReadBuf<'_>,
105    ) -> Poll<std::io::Result<()>> {
106        match &mut *self {
107            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
108            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
109            #[cfg(test)]
110            MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_read(cx, buf),
111        }
112    }
113}
114
115impl AsyncWrite for MaybeTlsStream {
116    fn poll_flush(
117        mut self: Pin<&mut Self>,
118        cx: &mut Context<'_>,
119    ) -> Poll<std::result::Result<(), std::io::Error>> {
120        match &mut *self {
121            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
122            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
123            #[cfg(test)]
124            MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_flush(cx),
125        }
126    }
127
128    fn poll_shutdown(
129        mut self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131    ) -> Poll<std::result::Result<(), std::io::Error>> {
132        match &mut *self {
133            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
134            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
135            #[cfg(test)]
136            MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_shutdown(cx),
137        }
138    }
139
140    fn poll_write(
141        mut self: Pin<&mut Self>,
142        cx: &mut Context<'_>,
143        buf: &[u8],
144    ) -> Poll<std::result::Result<usize, std::io::Error>> {
145        match &mut *self {
146            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
147            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
148            #[cfg(test)]
149            MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write(cx, buf),
150        }
151    }
152
153    fn poll_write_vectored(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156        bufs: &[std::io::IoSlice<'_>],
157    ) -> Poll<std::result::Result<usize, std::io::Error>> {
158        match &mut *self {
159            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
160            MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
161            #[cfg(test)]
162            MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
163        }
164    }
165}