http_mitm_proxy/
default_client.rs

1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::Bytes;
4use http_body_util::Empty;
5use hyper::{
6    Request, Response, StatusCode, Uri, Version,
7    body::{Body, Incoming},
8    client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::task::{Context, Poll};
12use tokio::{net::TcpStream, task::JoinHandle};
13
14#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
15compile_error!(
16    "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
17);
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21    #[error("{0} doesn't have an valid host")]
22    InvalidHost(Box<Uri>),
23    #[error(transparent)]
24    IoError(#[from] std::io::Error),
25    #[error(transparent)]
26    HyperError(#[from] hyper::Error),
27    #[error("Failed to connect to {0}, {1}")]
28    ConnectError(Box<Uri>, hyper::Error),
29
30    #[cfg(feature = "native-tls-client")]
31    #[error("Failed to connect with TLS to {0}, {1}")]
32    TlsConnectError(Box<Uri>, native_tls::Error),
33    #[cfg(feature = "native-tls-client")]
34    #[error(transparent)]
35    NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
36
37    #[cfg(feature = "rustls-client")]
38    #[error("Failed to connect with TLS to {0}, {1}")]
39    TlsConnectError(Box<Uri>, std::io::Error),
40
41    #[error("Failed to parse URI: {0}")]
42    UriParsingError(#[from] hyper::http::uri::InvalidUri),
43
44    #[error("Failed to parse URI parts: {0}")]
45    UriPartsError(#[from] hyper::http::uri::InvalidUriParts),
46
47    #[error("TLS connector initialization failed: {0}")]
48    TlsConnectorError(String),
49}
50
51/// Upgraded connections
52pub struct Upgraded {
53    /// A socket to Client
54    pub client: TokioIo<hyper::upgrade::Upgraded>,
55    /// A socket to Server
56    pub server: TokioIo<hyper::upgrade::Upgraded>,
57}
58#[derive(Clone)]
59/// Default HTTP client for this crate
60pub struct DefaultClient {
61    #[cfg(feature = "native-tls-client")]
62    tls_connector_no_alpn: tokio_native_tls::TlsConnector,
63    #[cfg(feature = "native-tls-client")]
64    tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
65
66    #[cfg(feature = "rustls-client")]
67    tls_connector_no_alpn: tokio_rustls::TlsConnector,
68    #[cfg(feature = "rustls-client")]
69    tls_connector_alpn_h2: tokio_rustls::TlsConnector,
70
71    /// If true, send_request will returns an Upgraded struct when the response is an upgrade
72    /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
73    pub with_upgrades: bool,
74}
75impl Default for DefaultClient {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl DefaultClient {
82    #[cfg(feature = "native-tls-client")]
83    pub fn new() -> Self {
84        Self::try_new().unwrap_or_else(|err| {
85            panic!("Failed to create DefaultClient: {err}");
86        })
87    }
88
89    #[cfg(feature = "native-tls-client")]
90    pub fn try_new() -> Result<Self, Error> {
91        let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().map_err(|e| {
92            Error::TlsConnectorError(format!("Failed to build no-ALPN connector: {e}"))
93        })?;
94        let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
95            .request_alpns(&["h2", "http/1.1"])
96            .build()
97            .map_err(|e| {
98                Error::TlsConnectorError(format!("Failed to build ALPN-H2 connector: {e}"))
99            })?;
100
101        Ok(Self {
102            tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
103            tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
104            with_upgrades: false,
105        })
106    }
107
108    #[cfg(feature = "rustls-client")]
109    pub fn new() -> Self {
110        Self::try_new().unwrap_or_else(|err| {
111            panic!("Failed to create DefaultClient: {}", err);
112        })
113    }
114
115    #[cfg(feature = "rustls-client")]
116    pub fn try_new() -> Result<Self, Error> {
117        use std::sync::Arc;
118
119        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
120        root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
121
122        let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
123            .with_root_certificates(root_cert_store.clone())
124            .with_no_client_auth();
125        let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
126            .with_root_certificates(root_cert_store.clone())
127            .with_no_client_auth();
128        tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
129
130        Ok(Self {
131            tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
132                tls_connector_no_alpn,
133            )),
134            tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
135                tls_connector_alpn_h2,
136            )),
137            with_upgrades: false,
138        })
139    }
140
141    /// Enable HTTP upgrades
142    /// If you don't enable HTTP upgrades, send_request will just copy bidirectional when the response is an upgrade
143    pub fn with_upgrades(mut self) -> Self {
144        self.with_upgrades = true;
145        self
146    }
147
148    #[cfg(feature = "native-tls-client")]
149    fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
150        match http_version {
151            Version::HTTP_2 => &self.tls_connector_alpn_h2,
152            _ => &self.tls_connector_no_alpn,
153        }
154    }
155
156    #[cfg(feature = "rustls-client")]
157    fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
158        match http_version {
159            Version::HTTP_2 => &self.tls_connector_alpn_h2,
160            _ => &self.tls_connector_no_alpn,
161        }
162    }
163
164    /// Send a request and return a response.
165    /// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
166    /// Request should have a full URL including scheme.
167    pub async fn send_request<B>(
168        &self,
169        req: Request<B>,
170    ) -> Result<
171        (
172            Response<Incoming>,
173            Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
174        ),
175        Error,
176    >
177    where
178        B: Body + Unpin + Send + 'static,
179        B::Data: Send,
180        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181    {
182        let mut send_request = self.connect(req.uri(), req.version()).await?;
183
184        let (req_parts, req_body) = req.into_parts();
185
186        let res = send_request
187            .send_request(Request::from_parts(req_parts.clone(), req_body))
188            .await?;
189
190        if res.status() == StatusCode::SWITCHING_PROTOCOLS {
191            let (res_parts, res_body) = res.into_parts();
192
193            let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
194            let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
195
196            let upgrade = if self.with_upgrades {
197                Some(tokio::task::spawn(async move {
198                    let client = hyper::upgrade::on(client_request).await?;
199                    let server = hyper::upgrade::on(server_response).await?;
200
201                    Ok(Upgraded {
202                        client: TokioIo::new(client),
203                        server: TokioIo::new(server),
204                    })
205                }))
206            } else {
207                tokio::task::spawn(async move {
208                    let client = hyper::upgrade::on(client_request).await?;
209                    let server = hyper::upgrade::on(server_response).await?;
210
211                    let _ = tokio::io::copy_bidirectional(
212                        &mut TokioIo::new(client),
213                        &mut TokioIo::new(server),
214                    )
215                    .await;
216
217                    Ok::<_, hyper::Error>(())
218                });
219                None
220            };
221
222            Ok((Response::from_parts(res_parts, res_body), upgrade))
223        } else {
224            Ok((res, None))
225        }
226    }
227
228    async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
229    where
230        B: Body + Unpin + Send + 'static,
231        B::Data: Send,
232        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
233    {
234        let host = uri
235            .host()
236            .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?;
237        let port =
238            uri.port_u16()
239                .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
240                    443
241                } else {
242                    80
243                });
244
245        let tcp = TcpStream::connect((host, port)).await?;
246        // This is actually needed to some servers
247        let _ = tcp.set_nodelay(true);
248
249        if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
250            #[cfg(feature = "native-tls-client")]
251            let tls = self
252                .tls_connector(http_version)
253                .connect(host, tcp)
254                .await
255                .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
256            #[cfg(feature = "rustls-client")]
257            let tls = self
258                .tls_connector(http_version)
259                .connect(
260                    host.to_string()
261                        .try_into()
262                        .map_err(|_| Error::InvalidHost(Box::new(uri.clone())))?,
263                    tcp,
264                )
265                .await
266                .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
267
268            #[cfg(feature = "native-tls-client")]
269            let is_h2 = matches!(
270                tls.get_ref()
271                    .negotiated_alpn()
272                    .map(|a| a.map(|b| b == b"h2")),
273                Ok(Some(true))
274            );
275
276            #[cfg(feature = "rustls-client")]
277            let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
278
279            if is_h2 {
280                let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
281                    .handshake(TokioIo::new(tls))
282                    .await
283                    .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
284
285                tokio::spawn(conn);
286
287                Ok(SendRequest::Http2(sender))
288            } else {
289                let (sender, conn) = client::conn::http1::Builder::new()
290                    .preserve_header_case(true)
291                    .title_case_headers(true)
292                    .handshake(TokioIo::new(tls))
293                    .await
294                    .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
295
296                tokio::spawn(conn.with_upgrades());
297
298                Ok(SendRequest::Http1(sender))
299            }
300        } else {
301            let (sender, conn) = client::conn::http1::Builder::new()
302                .preserve_header_case(true)
303                .title_case_headers(true)
304                .handshake(TokioIo::new(tcp))
305                .await
306                .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
307            tokio::spawn(conn.with_upgrades());
308            Ok(SendRequest::Http1(sender))
309        }
310    }
311}
312
313enum SendRequest<B> {
314    Http1(hyper::client::conn::http1::SendRequest<B>),
315    Http2(hyper::client::conn::http2::SendRequest<B>),
316}
317
318impl<B> SendRequest<B>
319where
320    B: Body + 'static,
321{
322    async fn send_request(
323        &mut self,
324        mut req: Request<B>,
325    ) -> Result<Response<Incoming>, hyper::Error> {
326        match self {
327            SendRequest::Http1(sender) => {
328                if req.version() == hyper::Version::HTTP_2 {
329                    if let Some(authority) = req.uri().authority().cloned() {
330                        match authority.as_str().parse::<header::HeaderValue>() {
331                            Ok(host_value) => {
332                                req.headers_mut().insert(header::HOST, host_value);
333                            }
334                            Err(err) => {
335                                tracing::warn!(
336                                    "Failed to parse authority '{}' as HOST header: {}",
337                                    authority,
338                                    err
339                                );
340                            }
341                        }
342                    }
343                }
344                if let Err(err) = remove_authority(&mut req) {
345                    tracing::error!("Failed to remove authority from URI: {}", err);
346                    // Continue with the original request if URI modification fails
347                }
348                sender.send_request(req).await
349            }
350            SendRequest::Http2(sender) => {
351                if req.version() != hyper::Version::HTTP_2 {
352                    req.headers_mut().remove(header::HOST);
353                }
354                sender.send_request(req).await
355            }
356        }
357    }
358}
359
360impl<B> SendRequest<B> {
361    #[allow(dead_code)]
362    // TODO: connection pooling
363    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
364        match self {
365            SendRequest::Http1(sender) => sender.poll_ready(cx),
366            SendRequest::Http2(sender) => sender.poll_ready(cx),
367        }
368    }
369}
370
371fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::InvalidUriParts> {
372    let mut parts = req.uri().clone().into_parts();
373    parts.scheme = None;
374    parts.authority = None;
375    *req.uri_mut() = Uri::from_parts(parts)?;
376    Ok(())
377}