iroh_net/relay/server/
streams.rs1use std::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use anyhow::Result;
9use futures_lite::Stream;
10use futures_sink::Sink;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_tungstenite::WebSocketStream;
13use tokio_util::codec::Framed;
14
15use crate::relay::codec::{DerpCodec, Frame};
16
17#[derive(Debug)]
18pub(crate) enum RelayIo {
19 Derp(Framed<MaybeTlsStream, DerpCodec>),
20 Ws(WebSocketStream<MaybeTlsStream>),
21}
22
23fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error {
24 match e {
25 tungstenite::Error::Io(io_err) => io_err,
26 _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
27 }
28}
29
30impl Sink<Frame> for RelayIo {
31 type Error = std::io::Error;
32
33 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34 match *self {
35 Self::Derp(ref mut framed) => Pin::new(framed).poll_ready(cx),
36 Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_to_io_err),
37 }
38 }
39
40 fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
41 match *self {
42 Self::Derp(ref mut framed) => Pin::new(framed).start_send(item),
43 Self::Ws(ref mut ws) => Pin::new(ws)
44 .start_send(tungstenite::Message::Binary(item.encode_for_ws_msg()))
45 .map_err(tung_to_io_err),
46 }
47 }
48
49 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 match *self {
51 Self::Derp(ref mut framed) => Pin::new(framed).poll_flush(cx),
52 Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_to_io_err),
53 }
54 }
55
56 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57 match *self {
58 Self::Derp(ref mut framed) => Pin::new(framed).poll_close(cx),
59 Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_to_io_err),
60 }
61 }
62}
63
64impl Stream for RelayIo {
65 type Item = anyhow::Result<Frame>;
66
67 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
68 match *self {
69 Self::Derp(ref mut framed) => Pin::new(framed).poll_next(cx),
70 Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
71 Poll::Ready(Some(Ok(tungstenite::Message::Binary(vec)))) => {
72 Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
73 }
74 Poll::Ready(Some(Ok(msg))) => {
75 tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
76 Poll::Pending
77 }
78 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
79 Poll::Ready(None) => Poll::Ready(None),
80 Poll::Pending => Poll::Pending,
81 },
82 }
83 }
84}
85
86#[derive(Debug)]
90pub enum MaybeTlsStream {
91 Plain(tokio::net::TcpStream),
93 Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
95 #[cfg(test)]
97 Test(tokio::io::DuplexStream),
98}
99
100impl AsyncRead for MaybeTlsStream {
101 fn poll_read(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &mut tokio::io::ReadBuf<'_>,
105 ) -> Poll<std::io::Result<()>> {
106 match &mut *self {
107 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
108 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
109 #[cfg(test)]
110 MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_read(cx, buf),
111 }
112 }
113}
114
115impl AsyncWrite for MaybeTlsStream {
116 fn poll_flush(
117 mut self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 ) -> Poll<std::result::Result<(), std::io::Error>> {
120 match &mut *self {
121 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
122 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
123 #[cfg(test)]
124 MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_flush(cx),
125 }
126 }
127
128 fn poll_shutdown(
129 mut self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 ) -> Poll<std::result::Result<(), std::io::Error>> {
132 match &mut *self {
133 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
134 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
135 #[cfg(test)]
136 MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_shutdown(cx),
137 }
138 }
139
140 fn poll_write(
141 mut self: Pin<&mut Self>,
142 cx: &mut Context<'_>,
143 buf: &[u8],
144 ) -> Poll<std::result::Result<usize, std::io::Error>> {
145 match &mut *self {
146 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
147 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
148 #[cfg(test)]
149 MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write(cx, buf),
150 }
151 }
152
153 fn poll_write_vectored(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 bufs: &[std::io::IoSlice<'_>],
157 ) -> Poll<std::result::Result<usize, std::io::Error>> {
158 match &mut *self {
159 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
160 MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
161 #[cfg(test)]
162 MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
163 }
164 }
165}