1use std::{fmt::Display, pin::Pin};
2
3use anyhow::bail;
4use bytes::{Bytes, BytesMut};
5use tokio::{
6 io::{AsyncRead, AsyncWrite},
7 net::TcpStream,
8};
9use tokio_stream::StreamExt;
10use tokio_util::codec::{Decoder, Framed};
11use tracing::{debug, trace};
12
13struct HttpCodec;
14
15struct ReqInfoGetter<I> {
16 inbound: I,
17}
18
19impl<I> ReqInfoGetter<I>
20where
21 I: AsyncRead + AsyncWrite + Unpin + Send,
22{
23 async fn get(self) -> anyhow::Result<(RequestInfo, I)> {
24 let mut transport = Framed::new(self.inbound, HttpCodec);
25 let request_info = loop {
26 match transport.next().await {
27 Some(Ok(req)) => {
28 debug!("{}", req);
29 break req;
30 }
31 Some(Err(e)) => {
32 debug!("{:?}", e);
33 bail!(e);
34 }
35 None => {}
36 }
37 };
38
39 let inbound = transport.into_inner();
40 Ok((request_info, inbound))
41 }
42}
43
44pub async fn tunnel<I>(inbound: I) -> anyhow::Result<()>
45where
46 I: AsyncRead + AsyncWrite + Unpin + Send,
47{
48 let (request_info, mut inbound) = ReqInfoGetter { inbound }.get().await?;
49
50 if http::Method::CONNECT != request_info.method {
52 bail!("Only support CONNECT");
53 }
54
55 let mut outbound = TcpStream::connect(&request_info.path).await?;
56 debug!("Established tunnel: {}", request_info.path);
57 tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await?;
58 Ok(())
59}
60
61#[derive(Debug)]
62struct RequestInfo {
63 host: Option<String>,
64 path: String,
65 #[allow(unused)]
66 header: Bytes,
67 method: http::Method,
68}
69
70impl Display for RequestInfo {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(
73 f,
74 "host: {:?}\npath: {}\nmeth: {}\n",
75 self.host, self.path, self.method
76 )
77 }
78}
79
80impl Decoder for HttpCodec {
81 type Item = RequestInfo;
82
83 type Error = anyhow::Error;
84
85 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
86 if buf.is_empty() {
87 bail!("parse called with empty buf");
88 }
89
90 let path;
91 let host;
92 let slice;
93 let method;
94 let mut headers = [httparse::EMPTY_HEADER; 16];
95 let mut req = httparse::Request::new(&mut headers);
96
97 match req.parse(buf) {
98 Ok(httparse::Status::Complete(parsed_len)) => {
99 trace!("Request.parse Complete({})", parsed_len);
100 method = http::Method::from_bytes(req.method.unwrap().as_bytes())?;
101
102 path = match req.path {
103 Some(path) => <&str>::clone(&path).to_owned(),
104 None => String::from(""),
105 };
106 let hosts = req
107 .headers
108 .iter()
109 .filter_map(|s| {
110 if s.name.to_lowercase() == "host" {
111 Some(String::from_utf8_lossy(s.value).to_string())
112 } else {
113 None
114 }
115 })
116 .take(1)
117 .collect::<Vec<_>>();
118
119 if hosts.len() == 1 {
120 host = Some(hosts[0].to_owned());
121 } else {
122 host = None;
123 }
124 slice = buf.split_to(parsed_len);
125 }
126 Ok(httparse::Status::Partial) => return Ok(None),
127 Err(err) => {
128 bail!(err);
129 }
130 };
131 Ok(Some(RequestInfo {
132 host,
133 header: slice.freeze(),
134 method,
135 path,
136 }))
137 }
138}
139
140#[derive(Clone)]
141pub struct TokioExec;
142impl<F> hyper::rt::Executor<F> for TokioExec
143where
144 F: std::future::Future + Send + 'static,
145 F::Output: Send + 'static,
146{
147 fn execute(&self, fut: F) {
148 tokio::spawn(fut);
149 }
150}
151
152pub struct QuicBidiStream {
153 pub send: quinn::SendStream,
154 pub recv: quinn::RecvStream,
155}
156
157impl AsyncWrite for QuicBidiStream {
158 fn poll_write(
159 mut self: Pin<&mut Self>,
160 cx: &mut std::task::Context<'_>,
161 buf: &[u8],
162 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
163 Pin::new(&mut self.send).poll_write(cx, buf)
164 }
165
166 fn poll_flush(
167 mut self: Pin<&mut Self>,
168 cx: &mut std::task::Context<'_>,
169 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
170 Pin::new(&mut self.send).poll_flush(cx)
171 }
172
173 fn poll_shutdown(
174 mut self: Pin<&mut Self>,
175 cx: &mut std::task::Context<'_>,
176 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
177 Pin::new(&mut self.send).poll_shutdown(cx)
178 }
179}
180
181impl AsyncRead for QuicBidiStream {
182 fn poll_read(
183 mut self: Pin<&mut Self>,
184 cx: &mut std::task::Context<'_>,
185 buf: &mut tokio::io::ReadBuf<'_>,
186 ) -> std::task::Poll<std::io::Result<()>> {
187 Pin::new(&mut self.recv).poll_read(cx, buf)
188 }
189}