hconnect/
lib.rs

1use std::net::SocketAddr;
2
3use bytes::Bytes;
4use http::header::{HOST, PROXY_AUTHORIZATION};
5use http::Request;
6use http::Uri;
7use http_body_util::Empty;
8use hyper::client::conn::http1;
9use hyper::upgrade::Upgraded;
10use hyper_util::rt::TokioIo;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpStream;
13
14pub mod auth;
15
16pub mod io_ext;
17
18#[derive(Debug, thiserror::Error)]
19pub enum Error {
20    #[error("I/o error: {0}")]
21    Io(
22        #[from]
23        #[source]
24        std::io::Error,
25    ),
26    #[error("HTTP error: {0}")]
27    Http(
28        #[from]
29        #[source]
30        http::Error,
31    ),
32    #[error("HTTP wire error: {0}")]
33    Hyper(
34        #[from]
35        #[source]
36        hyper::Error,
37    ),
38    #[error("invalid URI: {0}")]
39    InvalidUri(Uri),
40    #[error("connection error when connecting to {0}: {1}")]
41    Connect(Host, tokio::io::Error),
42    #[error("handshake error with {0} via {1}: {2}")]
43    Handshake(Host, SocketAddr, #[source] hyper::Error),
44    #[error("authentication error: {0}")]
45    Auth(
46        #[from]
47        #[source]
48        crate::auth::Error,
49    ),
50}
51
52#[derive(Debug)]
53pub struct Connection {
54    io: Upgraded,
55}
56
57impl Connection {
58    pub async fn connect(
59        proxy: Uri,
60        target_uri: Uri,
61        authorization: auth::Authenticator,
62    ) -> Result<Connection, Error> {
63        let proxy = Host::from_uri(&proxy)?;
64        let target = Host::from_uri(&target_uri)?;
65
66        let stream = TcpStream::connect(proxy.addr()).await.map_err({
67            let p = proxy.clone();
68            move |e| Error::Connect(p, e)
69        })?;
70
71        let proxy_addr = stream.peer_addr()?;
72
73        let (mut request_sender, connection) =
74            http1::handshake(TokioIo::new(stream)).await.map_err({
75                let t = target.clone();
76                move |e| Error::Handshake(t, proxy_addr, e)
77            })?;
78
79        let request = Request::builder()
80            .method("CONNECT")
81            .uri(&target_uri)
82            .header(HOST, target.addr().0);
83        let request = if let Some(token) = authorization.for_host(&proxy.0)? {
84            request.header(PROXY_AUTHORIZATION, token.as_str())
85        } else {
86            request
87        };
88        let request = request.body(Empty::<Bytes>::new())?;
89
90        let send_request = async move {
91            let response = request_sender.send_request(request).await?;
92            hyper::upgrade::on(response).await
93        };
94
95        let (response, _connection) = tokio::join!(send_request, connection.with_upgrades());
96
97        Ok(Self { io: response? })
98    }
99
100    pub async fn copy_bidirectional<T>(self, mut other: T) -> std::io::Result<(u64, u64)>
101    where
102        T: AsyncRead + AsyncWrite + std::marker::Unpin,
103    {
104        let mut stream = Box::pin(TokioIo::new(self.io));
105        tokio::io::copy_bidirectional(&mut stream, &mut other).await
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct Host(String, u16);
111
112impl std::fmt::Display for Host {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(f, "{}:{}", self.0, self.1)
115    }
116}
117
118impl Host {
119    fn from_uri(uri: &Uri) -> Result<Host, Error> {
120        let authority = uri.authority();
121        let host = authority.map(|a| a.host()).ok_or_else({
122            let u = uri.to_owned();
123            move || Error::InvalidUri(u)
124        })?;
125        let port = authority.and_then(|a| a.port_u16()).ok_or_else({
126            let u = uri.to_owned();
127            move || Error::InvalidUri(u)
128        })?;
129
130        Ok(Host(host.into(), port))
131    }
132
133    fn addr(&self) -> (String, u16) {
134        (self.0.clone(), self.1)
135    }
136}