1use std::{
2 cmp::min,
3 error::Error,
4 pin::Pin,
5 task::{self, Poll},
6};
7
8use bytes::{Buf, Bytes};
9use h2::{server::Connection, Reason, RecvStream, SendStream};
10use hyper::{
11 upgrade::{self, Upgraded},
12 Body, Client, Method, Request, Response, Uri,
13};
14use hyper_rustls::HttpsConnectorBuilder;
15use tokio::io::{self, AsyncRead, AsyncWrite};
16
17pub struct Tunnel {
18 proxy_url: Uri,
19 connection: Connection<Upgraded, Bytes>,
20}
21
22impl Tunnel {
23 pub fn builder() -> TunnelBuilder {
24 TunnelBuilder::default()
25 }
26
27 pub fn proxy_url(&self) -> &Uri {
28 &self.proxy_url
29 }
30
31 pub async fn accept(&mut self) -> Option<TunnelStream> {
32 match self.connection.accept().await {
33 Some(Ok((req, mut respond))) => {
34 let sender = respond.send_response(Response::new(()), false).ok()?;
35 Some(TunnelStream::new(req.into_body(), sender))
36 }
37 _ => None,
38 }
39 }
40}
41
42pub struct TunnelBuilder {
43 server_url: Result<Uri, Box<dyn Error + Send + Sync>>,
44 subdomain: Option<String>,
45 max_concurrent_streams: u32,
46}
47
48impl Default for TunnelBuilder {
49 fn default() -> Self {
50 Self {
51 server_url: Ok(Uri::from_static("https://tinytun.com:5555")),
52 max_concurrent_streams: 100,
53 subdomain: Default::default(),
54 }
55 }
56}
57
58impl TunnelBuilder {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn server_url<T>(self, server_url: T) -> Self
64 where
65 Uri: TryFrom<T>,
66 <Uri as TryFrom<T>>::Error: Into<Box<dyn Error + Send + Sync>>,
67 {
68 Self {
69 server_url: server_url.try_into().map_err(Into::into),
70 ..self
71 }
72 }
73
74 pub fn subdomain(self, subdomain: impl Into<Option<String>>) -> Self {
75 Self {
76 subdomain: subdomain.into(),
77 ..self
78 }
79 }
80
81 pub fn max_concurrent_streams(self, streams: u32) -> Self {
82 Self {
83 max_concurrent_streams: streams,
84 ..self
85 }
86 }
87
88 pub async fn listen(self) -> Result<Tunnel, Box<dyn Error + Send + Sync>> {
89 let server_url = self.server_url?;
90 let res = Client::builder()
91 .build(
92 HttpsConnectorBuilder::new()
93 .with_native_roots()
94 .https_or_http()
95 .enable_http1()
96 .build(),
97 )
98 .request({
99 let req = Request::builder().uri(&server_url).method(Method::CONNECT);
100
101 match self.subdomain {
102 Some(subdomain) if !subdomain.trim().is_empty() => req
103 .header("x-tinytun-subdomain", subdomain)
104 .body(Body::empty())?,
105 _ => req.body(Body::empty())?,
106 }
107 })
108 .await?;
109
110 let domain = res
111 .headers()
112 .get("x-tinytun-domain")
113 .ok_or("Server didn't provide a connection id")?
114 .to_str()?;
115
116 let proxy_url = Uri::builder()
117 .scheme(
118 server_url
119 .scheme()
120 .map(|scheme| scheme.to_string())
121 .unwrap_or("http".to_string())
122 .as_str(),
123 )
124 .authority(domain)
125 .path_and_query("")
126 .build()?;
127
128 let remote = upgrade::on(res).await?;
129 let connection = h2::server::Builder::new()
130 .max_concurrent_streams(self.max_concurrent_streams)
131 .handshake(remote)
132 .await?;
133
134 Ok(Tunnel {
135 proxy_url,
136 connection,
137 })
138 }
139}
140
141pub struct TunnelStream {
142 receiver: RecvStream,
143 sender: SendStream<Bytes>,
144 buf: Bytes,
145}
146
147impl TunnelStream {
148 pub fn new(receiver: RecvStream, sender: SendStream<Bytes>) -> Self {
149 Self {
150 sender,
151 receiver,
152 buf: Bytes::new(),
153 }
154 }
155}
156
157impl AsyncRead for TunnelStream {
158 fn poll_read(
159 mut self: Pin<&mut Self>,
160 cx: &mut task::Context<'_>,
161 buf: &mut io::ReadBuf<'_>,
162 ) -> Poll<io::Result<()>> {
163 if self.buf.is_empty() {
164 self.buf = loop {
165 match task::ready!(self.receiver.poll_data(cx)) {
166 Some(Ok(buf)) if buf.is_empty() && !self.receiver.is_end_stream() => continue,
167 Some(Ok(buf)) => break buf,
168 Some(Err(err)) => {
169 return Poll::Ready(match err.reason() {
170 Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()),
171 Some(Reason::STREAM_CLOSED) => {
172 Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
173 }
174 _ => Err(h2_error_to_io_error(err)),
175 })
176 }
177 None => return Poll::Ready(Ok(())),
178 }
179 };
180 }
181
182 let len = min(self.buf.len(), buf.remaining());
183 buf.put_slice(&self.buf[..len]);
184 self.buf.advance(len);
185 self.receiver.flow_control().release_capacity(len).ok();
186
187 Poll::Ready(Ok(()))
188 }
189}
190
191impl AsyncWrite for TunnelStream {
192 fn poll_write(
193 mut self: Pin<&mut Self>,
194 cx: &mut task::Context<'_>,
195 buf: &[u8],
196 ) -> Poll<io::Result<usize>> {
197 if buf.is_empty() {
198 return Poll::Ready(Ok(0));
199 }
200
201 self.sender.reserve_capacity(buf.len());
202
203 let written = match task::ready!(self.sender.poll_capacity(cx)) {
204 Some(Ok(capacity)) => self
205 .sender
206 .send_data(Bytes::copy_from_slice(&buf[..capacity]), false)
208 .ok()
209 .map(|_| capacity),
210 Some(Err(_)) => None,
211 None => Some(0),
212 };
213
214 if let Some(len) = written {
215 return Poll::Ready(Ok(len));
216 }
217
218 match task::ready!(self.sender.poll_reset(cx)) {
219 Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
220 Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
221 }
222 Ok(reason) => Poll::Ready(Err(h2_error_to_io_error(reason.into()))),
223 Err(err) => Poll::Ready(Err(h2_error_to_io_error(err))),
224 }
225 }
226
227 fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
228 Poll::Ready(Ok(()))
229 }
230
231 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
232 if self.sender.send_data(Bytes::new(), true).is_ok() {
233 return Poll::Ready(Ok(()));
234 }
235
236 match task::ready!(self.sender.poll_reset(cx)) {
237 Ok(Reason::NO_ERROR) => Poll::Ready(Ok(())),
238 Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
239 Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
240 }
241 Ok(reason) => Poll::Ready(Err(h2_error_to_io_error(reason.into()))),
242 Err(err) => Poll::Ready(Err(h2_error_to_io_error(err))),
243 }
244 }
245}
246
247fn h2_error_to_io_error(err: h2::Error) -> io::Error {
248 if err.is_io() {
249 err.into_io().unwrap()
250 } else {
251 io::Error::new(io::ErrorKind::Other, err)
252 }
253}