tinytun/
lib.rs

1use std::{
2    cmp::min,
3    error::Error,
4    pin::Pin,
5    task::{self, Poll},
6};
7
8use bytes::{Buf, Bytes};
9use h2::{server::Connection, Reason, RecvStream, SendStream};
10use hyper::{
11    upgrade::{self, Upgraded},
12    Body, Client, Method, Request, Response, Uri,
13};
14use hyper_rustls::HttpsConnectorBuilder;
15use tokio::io::{self, AsyncRead, AsyncWrite};
16
17pub struct Tunnel {
18    proxy_url: Uri,
19    connection: Connection<Upgraded, Bytes>,
20}
21
22impl Tunnel {
23    pub fn builder() -> TunnelBuilder {
24        TunnelBuilder::default()
25    }
26
27    pub fn proxy_url(&self) -> &Uri {
28        &self.proxy_url
29    }
30
31    pub async fn accept(&mut self) -> Option<TunnelStream> {
32        match self.connection.accept().await {
33            Some(Ok((req, mut respond))) => {
34                let sender = respond.send_response(Response::new(()), false).ok()?;
35                Some(TunnelStream::new(req.into_body(), sender))
36            }
37            _ => None,
38        }
39    }
40}
41
42pub struct TunnelBuilder {
43    server_url: Result<Uri, Box<dyn Error + Send + Sync>>,
44    subdomain: Option<String>,
45    max_concurrent_streams: u32,
46}
47
48impl Default for TunnelBuilder {
49    fn default() -> Self {
50        Self {
51            server_url: Ok(Uri::from_static("https://tinytun.com:5555")),
52            max_concurrent_streams: 100,
53            subdomain: Default::default(),
54        }
55    }
56}
57
58impl TunnelBuilder {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    pub fn server_url<T>(self, server_url: T) -> Self
64    where
65        Uri: TryFrom<T>,
66        <Uri as TryFrom<T>>::Error: Into<Box<dyn Error + Send + Sync>>,
67    {
68        Self {
69            server_url: server_url.try_into().map_err(Into::into),
70            ..self
71        }
72    }
73
74    pub fn subdomain(self, subdomain: impl Into<Option<String>>) -> Self {
75        Self {
76            subdomain: subdomain.into(),
77            ..self
78        }
79    }
80
81    pub fn max_concurrent_streams(self, streams: u32) -> Self {
82        Self {
83            max_concurrent_streams: streams,
84            ..self
85        }
86    }
87
88    pub async fn listen(self) -> Result<Tunnel, Box<dyn Error + Send + Sync>> {
89        let server_url = self.server_url?;
90        let res = Client::builder()
91            .build(
92                HttpsConnectorBuilder::new()
93                    .with_native_roots()
94                    .https_or_http()
95                    .enable_http1()
96                    .build(),
97            )
98            .request({
99                let req = Request::builder().uri(&server_url).method(Method::CONNECT);
100
101                match self.subdomain {
102                    Some(subdomain) if !subdomain.trim().is_empty() => req
103                        .header("x-tinytun-subdomain", subdomain)
104                        .body(Body::empty())?,
105                    _ => req.body(Body::empty())?,
106                }
107            })
108            .await?;
109
110        let domain = res
111            .headers()
112            .get("x-tinytun-domain")
113            .ok_or("Server didn't provide a connection id")?
114            .to_str()?;
115
116        let proxy_url = Uri::builder()
117            .scheme(
118                server_url
119                    .scheme()
120                    .map(|scheme| scheme.to_string())
121                    .unwrap_or("http".to_string())
122                    .as_str(),
123            )
124            .authority(domain)
125            .path_and_query("")
126            .build()?;
127
128        let remote = upgrade::on(res).await?;
129        let connection = h2::server::Builder::new()
130            .max_concurrent_streams(self.max_concurrent_streams)
131            .handshake(remote)
132            .await?;
133
134        Ok(Tunnel {
135            proxy_url,
136            connection,
137        })
138    }
139}
140
141pub struct TunnelStream {
142    receiver: RecvStream,
143    sender: SendStream<Bytes>,
144    buf: Bytes,
145}
146
147impl TunnelStream {
148    pub fn new(receiver: RecvStream, sender: SendStream<Bytes>) -> Self {
149        Self {
150            sender,
151            receiver,
152            buf: Bytes::new(),
153        }
154    }
155}
156
157impl AsyncRead for TunnelStream {
158    fn poll_read(
159        mut self: Pin<&mut Self>,
160        cx: &mut task::Context<'_>,
161        buf: &mut io::ReadBuf<'_>,
162    ) -> Poll<io::Result<()>> {
163        if self.buf.is_empty() {
164            self.buf = loop {
165                match task::ready!(self.receiver.poll_data(cx)) {
166                    Some(Ok(buf)) if buf.is_empty() && !self.receiver.is_end_stream() => continue,
167                    Some(Ok(buf)) => break buf,
168                    Some(Err(err)) => {
169                        return Poll::Ready(match err.reason() {
170                            Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()),
171                            Some(Reason::STREAM_CLOSED) => {
172                                Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
173                            }
174                            _ => Err(h2_error_to_io_error(err)),
175                        })
176                    }
177                    None => return Poll::Ready(Ok(())),
178                }
179            };
180        }
181
182        let len = min(self.buf.len(), buf.remaining());
183        buf.put_slice(&self.buf[..len]);
184        self.buf.advance(len);
185        self.receiver.flow_control().release_capacity(len).ok();
186
187        Poll::Ready(Ok(()))
188    }
189}
190
191impl AsyncWrite for TunnelStream {
192    fn poll_write(
193        mut self: Pin<&mut Self>,
194        cx: &mut task::Context<'_>,
195        buf: &[u8],
196    ) -> Poll<io::Result<usize>> {
197        if buf.is_empty() {
198            return Poll::Ready(Ok(0));
199        }
200
201        self.sender.reserve_capacity(buf.len());
202
203        let written = match task::ready!(self.sender.poll_capacity(cx)) {
204            Some(Ok(capacity)) => self
205                .sender
206                // TODO: try to figure out a way to avoid this copy
207                .send_data(Bytes::copy_from_slice(&buf[..capacity]), false)
208                .ok()
209                .map(|_| capacity),
210            Some(Err(_)) => None,
211            None => Some(0),
212        };
213
214        if let Some(len) = written {
215            return Poll::Ready(Ok(len));
216        }
217
218        match task::ready!(self.sender.poll_reset(cx)) {
219            Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
220                Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
221            }
222            Ok(reason) => Poll::Ready(Err(h2_error_to_io_error(reason.into()))),
223            Err(err) => Poll::Ready(Err(h2_error_to_io_error(err))),
224        }
225    }
226
227    fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
228        Poll::Ready(Ok(()))
229    }
230
231    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
232        if self.sender.send_data(Bytes::new(), true).is_ok() {
233            return Poll::Ready(Ok(()));
234        }
235
236        match task::ready!(self.sender.poll_reset(cx)) {
237            Ok(Reason::NO_ERROR) => Poll::Ready(Ok(())),
238            Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
239                Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
240            }
241            Ok(reason) => Poll::Ready(Err(h2_error_to_io_error(reason.into()))),
242            Err(err) => Poll::Ready(Err(h2_error_to_io_error(err))),
243        }
244    }
245}
246
247fn h2_error_to_io_error(err: h2::Error) -> io::Error {
248    if err.is_io() {
249        err.into_io().unwrap()
250    } else {
251        io::Error::new(io::ErrorKind::Other, err)
252    }
253}