http_mitm_proxy/
default_client.rs

1use bytes::Bytes;
2use http_body_util::Empty;
3use hyper::{
4    Request, Response, StatusCode, Uri, Version,
5    body::{Body, Incoming},
6    client, header,
7};
8use hyper_util::rt::{TokioExecutor, TokioIo};
9use std::task::{Context, Poll};
10use tokio::{net::TcpStream, task::JoinHandle};
11
12#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
13compile_error!(
14    "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
15);
16
17#[cfg(all(not(feature = "native-tls-client"), not(feature = "rustls-client")))]
18compile_error!("feature \"native-tls-client\" or feature \"rustls-client\" must be enabled");
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22    #[error("{0} doesn't have an valid host")]
23    InvalidHost(Uri),
24    #[error(transparent)]
25    IoError(#[from] std::io::Error),
26    #[error(transparent)]
27    HyperError(#[from] hyper::Error),
28    #[error("Failed to connect to {0}, {1}")]
29    ConnectError(Uri, hyper::Error),
30
31    #[cfg(feature = "native-tls-client")]
32    #[error("Failed to connect with TLS to {0}, {1}")]
33    TlsConnectError(Uri, native_tls::Error),
34    #[cfg(feature = "native-tls-client")]
35    #[error(transparent)]
36    NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
37
38    #[cfg(feature = "rustls-client")]
39    #[error("Failed to connect with TLS to {0}, {1}")]
40    TlsConnectError(Uri, std::io::Error),
41}
42
43/// Upgraded connections
44pub struct Upgraded {
45    /// A socket to Client
46    pub client: TokioIo<hyper::upgrade::Upgraded>,
47    /// A socket to Server
48    pub server: TokioIo<hyper::upgrade::Upgraded>,
49}
50#[derive(Clone)]
51/// Default HTTP client for this crate
52pub struct DefaultClient {
53    #[cfg(feature = "native-tls-client")]
54    tls_connector_no_alpn: tokio_native_tls::TlsConnector,
55    #[cfg(feature = "native-tls-client")]
56    tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
57
58    #[cfg(feature = "rustls-client")]
59    tls_connector_no_alpn: tokio_rustls::TlsConnector,
60    #[cfg(feature = "rustls-client")]
61    tls_connector_alpn_h2: tokio_rustls::TlsConnector,
62
63    /// If true, send_request will returns an Upgraded struct when the response is an upgrade
64    /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
65    pub with_upgrades: bool,
66}
67impl DefaultClient {
68    #[cfg(feature = "native-tls-client")]
69    pub fn new() -> Self {
70        let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
71        let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
72            .request_alpns(&["h2", "http/1.1"])
73            .build()
74            .unwrap();
75
76        Self {
77            tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
78            tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
79            with_upgrades: false,
80        }
81    }
82
83    #[cfg(feature = "rustls-client")]
84    pub fn new() -> Self {
85        use std::sync::Arc;
86
87        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
88        root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
89
90        let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
91            .with_root_certificates(root_cert_store.clone())
92            .with_no_client_auth();
93        let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
94            .with_root_certificates(root_cert_store.clone())
95            .with_no_client_auth();
96        tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
97
98        Self {
99            tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
100                tls_connector_no_alpn,
101            )),
102            tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
103                tls_connector_alpn_h2,
104            )),
105            with_upgrades: false,
106        }
107    }
108
109    /// Enable HTTP upgrades
110    /// If you don't enable HTTP upgrades, send_request will just copy bidirectional when the response is an upgrade
111    pub fn with_upgrades(mut self) -> Self {
112        self.with_upgrades = true;
113        self
114    }
115
116    #[cfg(feature = "native-tls-client")]
117    fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
118        match http_version {
119            Version::HTTP_2 => &self.tls_connector_alpn_h2,
120            _ => &self.tls_connector_no_alpn,
121        }
122    }
123
124    #[cfg(feature = "rustls-client")]
125    fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
126        match http_version {
127            Version::HTTP_2 => &self.tls_connector_alpn_h2,
128            _ => &self.tls_connector_no_alpn,
129        }
130    }
131
132    /// Send a request and return a response.
133    /// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
134    /// Request should have a full URL including scheme.
135    pub async fn send_request<B>(
136        &self,
137        req: Request<B>,
138    ) -> Result<
139        (
140            Response<Incoming>,
141            Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
142        ),
143        Error,
144    >
145    where
146        B: Body + Unpin + Send + 'static,
147        B::Data: Send,
148        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
149    {
150        let mut send_request = self.connect(req.uri(), req.version()).await?;
151
152        let (req_parts, req_body) = req.into_parts();
153
154        let res = send_request
155            .send_request(Request::from_parts(req_parts.clone(), req_body))
156            .await?;
157
158        if res.status() == StatusCode::SWITCHING_PROTOCOLS {
159            let (res_parts, res_body) = res.into_parts();
160
161            let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
162            let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
163
164            let upgrade = if self.with_upgrades {
165                Some(tokio::task::spawn(async move {
166                    let client = hyper::upgrade::on(client_request).await?;
167                    let server = hyper::upgrade::on(server_response).await?;
168
169                    Ok(Upgraded {
170                        client: TokioIo::new(client),
171                        server: TokioIo::new(server),
172                    })
173                }))
174            } else {
175                tokio::task::spawn(async move {
176                    let client = hyper::upgrade::on(client_request).await?;
177                    let server = hyper::upgrade::on(server_response).await?;
178
179                    let _ = tokio::io::copy_bidirectional(
180                        &mut TokioIo::new(client),
181                        &mut TokioIo::new(server),
182                    )
183                    .await;
184
185                    Ok::<_, hyper::Error>(())
186                });
187                None
188            };
189
190            Ok((Response::from_parts(res_parts, res_body), upgrade))
191        } else {
192            Ok((res, None))
193        }
194    }
195
196    async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
197    where
198        B: Body + Unpin + Send + 'static,
199        B::Data: Send,
200        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
201    {
202        let host = uri.host().ok_or_else(|| Error::InvalidHost(uri.clone()))?;
203        let port =
204            uri.port_u16()
205                .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
206                    443
207                } else {
208                    80
209                });
210
211        let tcp = TcpStream::connect((host, port)).await?;
212        // This is actually needed to some servers
213        let _ = tcp.set_nodelay(true);
214
215        if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
216            #[cfg(feature = "native-tls-client")]
217            let tls = self
218                .tls_connector(http_version)
219                .connect(host, tcp)
220                .await
221                .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
222            #[cfg(feature = "rustls-client")]
223            let tls = self
224                .tls_connector(http_version)
225                .connect(host.to_string().try_into().expect("Invalid host"), tcp)
226                .await
227                .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
228
229            #[cfg(feature = "native-tls-client")]
230            let is_h2 = matches!(
231                tls.get_ref()
232                    .negotiated_alpn()
233                    .map(|a| a.map(|b| b == b"h2")),
234                Ok(Some(true))
235            );
236
237            #[cfg(feature = "rustls-client")]
238            let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
239
240            if is_h2 {
241                let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
242                    .handshake(TokioIo::new(tls))
243                    .await
244                    .map_err(|err| Error::ConnectError(uri.clone(), err))?;
245
246                tokio::spawn(conn);
247
248                Ok(SendRequest::Http2(sender))
249            } else {
250                let (sender, conn) = client::conn::http1::Builder::new()
251                    .preserve_header_case(true)
252                    .title_case_headers(true)
253                    .handshake(TokioIo::new(tls))
254                    .await
255                    .map_err(|err| Error::ConnectError(uri.clone(), err))?;
256
257                tokio::spawn(conn.with_upgrades());
258
259                Ok(SendRequest::Http1(sender))
260            }
261        } else {
262            let (sender, conn) = client::conn::http1::Builder::new()
263                .preserve_header_case(true)
264                .title_case_headers(true)
265                .handshake(TokioIo::new(tcp))
266                .await
267                .map_err(|err| Error::ConnectError(uri.clone(), err))?;
268            tokio::spawn(conn.with_upgrades());
269            Ok(SendRequest::Http1(sender))
270        }
271    }
272}
273
274enum SendRequest<B> {
275    Http1(hyper::client::conn::http1::SendRequest<B>),
276    Http2(hyper::client::conn::http2::SendRequest<B>),
277}
278
279impl<B> SendRequest<B>
280where
281    B: Body + 'static,
282{
283    async fn send_request(
284        &mut self,
285        mut req: Request<B>,
286    ) -> Result<Response<Incoming>, hyper::Error> {
287        match self {
288            SendRequest::Http1(sender) => {
289                if req.version() == hyper::Version::HTTP_2 {
290                    if let Some(authority) = req.uri().authority().cloned() {
291                        req.headers_mut().insert(
292                            header::HOST,
293                            authority.as_str().parse().expect("Invalid authority"),
294                        );
295                    }
296                }
297                remove_authority(&mut req);
298                sender.send_request(req).await
299            }
300            SendRequest::Http2(sender) => {
301                if req.version() != hyper::Version::HTTP_2 {
302                    req.headers_mut().remove(header::HOST);
303                }
304                sender.send_request(req).await
305            }
306        }
307    }
308}
309
310impl<B> SendRequest<B> {
311    #[allow(dead_code)]
312    // TODO: connection pooling
313    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
314        match self {
315            SendRequest::Http1(sender) => sender.poll_ready(cx),
316            SendRequest::Http2(sender) => sender.poll_ready(cx),
317        }
318    }
319}
320
321fn remove_authority<B>(req: &mut Request<B>) {
322    let mut parts = req.uri().clone().into_parts();
323    parts.scheme = None;
324    parts.authority = None;
325    *req.uri_mut() = Uri::from_parts(parts).unwrap();
326}