1use std::net::SocketAddr;
2
3use bytes::Bytes;
4use http::header::{HOST, PROXY_AUTHORIZATION};
5use http::Request;
6use http::Uri;
7use http_body_util::Empty;
8use hyper::client::conn::http1;
9use hyper::upgrade::Upgraded;
10use hyper_util::rt::TokioIo;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpStream;
13
14pub mod auth;
15
16pub mod io_ext;
17
18#[derive(Debug, thiserror::Error)]
19pub enum Error {
20 #[error("I/o error: {0}")]
21 Io(
22 #[from]
23 #[source]
24 std::io::Error,
25 ),
26 #[error("HTTP error: {0}")]
27 Http(
28 #[from]
29 #[source]
30 http::Error,
31 ),
32 #[error("HTTP wire error: {0}")]
33 Hyper(
34 #[from]
35 #[source]
36 hyper::Error,
37 ),
38 #[error("invalid URI: {0}")]
39 InvalidUri(Uri),
40 #[error("connection error when connecting to {0}: {1}")]
41 Connect(Host, tokio::io::Error),
42 #[error("handshake error with {0} via {1}: {2}")]
43 Handshake(Host, SocketAddr, #[source] hyper::Error),
44 #[error("authentication error: {0}")]
45 Auth(
46 #[from]
47 #[source]
48 crate::auth::Error,
49 ),
50}
51
52#[derive(Debug)]
53pub struct Connection {
54 io: Upgraded,
55}
56
57impl Connection {
58 pub async fn connect(
59 proxy: Uri,
60 target_uri: Uri,
61 authorization: auth::Authenticator,
62 ) -> Result<Connection, Error> {
63 let proxy = Host::from_uri(&proxy)?;
64 let target = Host::from_uri(&target_uri)?;
65
66 let stream = TcpStream::connect(proxy.addr()).await.map_err({
67 let p = proxy.clone();
68 move |e| Error::Connect(p, e)
69 })?;
70
71 let proxy_addr = stream.peer_addr()?;
72
73 let (mut request_sender, connection) =
74 http1::handshake(TokioIo::new(stream)).await.map_err({
75 let t = target.clone();
76 move |e| Error::Handshake(t, proxy_addr, e)
77 })?;
78
79 let request = Request::builder()
80 .method("CONNECT")
81 .uri(&target_uri)
82 .header(HOST, target.addr().0);
83 let request = if let Some(token) = authorization.for_host(&proxy.0)? {
84 request.header(PROXY_AUTHORIZATION, token.as_str())
85 } else {
86 request
87 };
88 let request = request.body(Empty::<Bytes>::new())?;
89
90 let send_request = async move {
91 let response = request_sender.send_request(request).await?;
92 hyper::upgrade::on(response).await
93 };
94
95 let (response, _connection) = tokio::join!(send_request, connection.with_upgrades());
96
97 Ok(Self { io: response? })
98 }
99
100 pub async fn copy_bidirectional<T>(self, mut other: T) -> std::io::Result<(u64, u64)>
101 where
102 T: AsyncRead + AsyncWrite + std::marker::Unpin,
103 {
104 let mut stream = Box::pin(TokioIo::new(self.io));
105 tokio::io::copy_bidirectional(&mut stream, &mut other).await
106 }
107}
108
109#[derive(Debug, Clone)]
110pub struct Host(String, u16);
111
112impl std::fmt::Display for Host {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "{}:{}", self.0, self.1)
115 }
116}
117
118impl Host {
119 fn from_uri(uri: &Uri) -> Result<Host, Error> {
120 let authority = uri.authority();
121 let host = authority.map(|a| a.host()).ok_or_else({
122 let u = uri.to_owned();
123 move || Error::InvalidUri(u)
124 })?;
125 let port = authority.and_then(|a| a.port_u16()).ok_or_else({
126 let u = uri.to_owned();
127 move || Error::InvalidUri(u)
128 })?;
129
130 Ok(Host(host.into(), port))
131 }
132
133 fn addr(&self) -> (String, u16) {
134 (self.0.clone(), self.1)
135 }
136}