octoproxy_lib/
proxy.rs

1use std::{fmt::Display, pin::Pin};
2
3use anyhow::bail;
4use bytes::{Bytes, BytesMut};
5use tokio::{
6    io::{AsyncRead, AsyncWrite},
7    net::TcpStream,
8};
9use tokio_stream::StreamExt;
10use tokio_util::codec::{Decoder, Framed};
11use tracing::{debug, trace};
12
13struct HttpCodec;
14
15struct ReqInfoGetter<I> {
16    inbound: I,
17}
18
19impl<I> ReqInfoGetter<I>
20where
21    I: AsyncRead + AsyncWrite + Unpin + Send,
22{
23    async fn get(self) -> anyhow::Result<(RequestInfo, I)> {
24        let mut transport = Framed::new(self.inbound, HttpCodec);
25        let request_info = loop {
26            match transport.next().await {
27                Some(Ok(req)) => {
28                    debug!("{}", req);
29                    break req;
30                }
31                Some(Err(e)) => {
32                    debug!("{:?}", e);
33                    bail!(e);
34                }
35                None => {}
36            }
37        };
38
39        let inbound = transport.into_inner();
40        Ok((request_info, inbound))
41    }
42}
43
44pub async fn tunnel<I>(inbound: I) -> anyhow::Result<()>
45where
46    I: AsyncRead + AsyncWrite + Unpin + Send,
47{
48    let (request_info, mut inbound) = ReqInfoGetter { inbound }.get().await?;
49
50    // dont have to check again if reusing a connection
51    if http::Method::CONNECT != request_info.method {
52        bail!("Only support CONNECT");
53    }
54
55    let mut outbound = TcpStream::connect(&request_info.path).await?;
56    debug!("Established tunnel: {}", request_info.path);
57    tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await?;
58    Ok(())
59}
60
61#[derive(Debug)]
62struct RequestInfo {
63    host: Option<String>,
64    path: String,
65    #[allow(unused)]
66    header: Bytes,
67    method: http::Method,
68}
69
70impl Display for RequestInfo {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(
73            f,
74            "host: {:?}\npath: {}\nmeth: {}\n",
75            self.host, self.path, self.method
76        )
77    }
78}
79
80impl Decoder for HttpCodec {
81    type Item = RequestInfo;
82
83    type Error = anyhow::Error;
84
85    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
86        if buf.is_empty() {
87            bail!("parse called with empty buf");
88        }
89
90        let path;
91        let host;
92        let slice;
93        let method;
94        let mut headers = [httparse::EMPTY_HEADER; 16];
95        let mut req = httparse::Request::new(&mut headers);
96
97        match req.parse(buf) {
98            Ok(httparse::Status::Complete(parsed_len)) => {
99                trace!("Request.parse Complete({})", parsed_len);
100                method = http::Method::from_bytes(req.method.unwrap().as_bytes())?;
101
102                path = match req.path {
103                    Some(path) => <&str>::clone(&path).to_owned(),
104                    None => String::from(""),
105                };
106                let hosts = req
107                    .headers
108                    .iter()
109                    .filter_map(|s| {
110                        if s.name.to_lowercase() == "host" {
111                            Some(String::from_utf8_lossy(s.value).to_string())
112                        } else {
113                            None
114                        }
115                    })
116                    .take(1)
117                    .collect::<Vec<_>>();
118
119                if hosts.len() == 1 {
120                    host = Some(hosts[0].to_owned());
121                } else {
122                    host = None;
123                }
124                slice = buf.split_to(parsed_len);
125            }
126            Ok(httparse::Status::Partial) => return Ok(None),
127            Err(err) => {
128                bail!(err);
129            }
130        };
131        Ok(Some(RequestInfo {
132            host,
133            header: slice.freeze(),
134            method,
135            path,
136        }))
137    }
138}
139
140#[derive(Clone)]
141pub struct TokioExec;
142impl<F> hyper::rt::Executor<F> for TokioExec
143where
144    F: std::future::Future + Send + 'static,
145    F::Output: Send + 'static,
146{
147    fn execute(&self, fut: F) {
148        tokio::spawn(fut);
149    }
150}
151
152pub struct QuicBidiStream {
153    pub send: quinn::SendStream,
154    pub recv: quinn::RecvStream,
155}
156
157impl AsyncWrite for QuicBidiStream {
158    fn poll_write(
159        mut self: Pin<&mut Self>,
160        cx: &mut std::task::Context<'_>,
161        buf: &[u8],
162    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
163        Pin::new(&mut self.send).poll_write(cx, buf)
164    }
165
166    fn poll_flush(
167        mut self: Pin<&mut Self>,
168        cx: &mut std::task::Context<'_>,
169    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
170        Pin::new(&mut self.send).poll_flush(cx)
171    }
172
173    fn poll_shutdown(
174        mut self: Pin<&mut Self>,
175        cx: &mut std::task::Context<'_>,
176    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
177        Pin::new(&mut self.send).poll_shutdown(cx)
178    }
179}
180
181impl AsyncRead for QuicBidiStream {
182    fn poll_read(
183        mut self: Pin<&mut Self>,
184        cx: &mut std::task::Context<'_>,
185        buf: &mut tokio::io::ReadBuf<'_>,
186    ) -> std::task::Poll<std::io::Result<()>> {
187        Pin::new(&mut self.recv).poll_read(cx, buf)
188    }
189}