http_mitm_proxy/
default_client.rs

1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::Bytes;
4use http_body_util::Empty;
5use hyper::{
6    Request, Response, StatusCode, Uri, Version,
7    body::{Body, Incoming},
8    client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::task::{Context, Poll};
12use tokio::{net::TcpStream, task::JoinHandle};
13
14#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
15compile_error!(
16    "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
17);
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21    #[error("{0} doesn't have an valid host")]
22    InvalidHost(Uri),
23    #[error(transparent)]
24    IoError(#[from] std::io::Error),
25    #[error(transparent)]
26    HyperError(#[from] hyper::Error),
27    #[error("Failed to connect to {0}, {1}")]
28    ConnectError(Uri, hyper::Error),
29
30    #[cfg(feature = "native-tls-client")]
31    #[error("Failed to connect with TLS to {0}, {1}")]
32    TlsConnectError(Uri, native_tls::Error),
33    #[cfg(feature = "native-tls-client")]
34    #[error(transparent)]
35    NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
36
37    #[cfg(feature = "rustls-client")]
38    #[error("Failed to connect with TLS to {0}, {1}")]
39    TlsConnectError(Uri, std::io::Error),
40}
41
42/// Upgraded connections
43pub struct Upgraded {
44    /// A socket to Client
45    pub client: TokioIo<hyper::upgrade::Upgraded>,
46    /// A socket to Server
47    pub server: TokioIo<hyper::upgrade::Upgraded>,
48}
49#[derive(Clone)]
50/// Default HTTP client for this crate
51pub struct DefaultClient {
52    #[cfg(feature = "native-tls-client")]
53    tls_connector_no_alpn: tokio_native_tls::TlsConnector,
54    #[cfg(feature = "native-tls-client")]
55    tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
56
57    #[cfg(feature = "rustls-client")]
58    tls_connector_no_alpn: tokio_rustls::TlsConnector,
59    #[cfg(feature = "rustls-client")]
60    tls_connector_alpn_h2: tokio_rustls::TlsConnector,
61
62    /// If true, send_request will returns an Upgraded struct when the response is an upgrade
63    /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
64    pub with_upgrades: bool,
65}
66impl DefaultClient {
67    #[cfg(feature = "native-tls-client")]
68    pub fn new() -> Self {
69        let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
70        let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
71            .request_alpns(&["h2", "http/1.1"])
72            .build()
73            .unwrap();
74
75        Self {
76            tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
77            tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
78            with_upgrades: false,
79        }
80    }
81
82    #[cfg(feature = "rustls-client")]
83    pub fn new() -> Self {
84        use std::sync::Arc;
85
86        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
87        root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
88
89        let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
90            .with_root_certificates(root_cert_store.clone())
91            .with_no_client_auth();
92        let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
93            .with_root_certificates(root_cert_store.clone())
94            .with_no_client_auth();
95        tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
96
97        Self {
98            tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
99                tls_connector_no_alpn,
100            )),
101            tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
102                tls_connector_alpn_h2,
103            )),
104            with_upgrades: false,
105        }
106    }
107
108    /// Enable HTTP upgrades
109    /// If you don't enable HTTP upgrades, send_request will just copy bidirectional when the response is an upgrade
110    pub fn with_upgrades(mut self) -> Self {
111        self.with_upgrades = true;
112        self
113    }
114
115    #[cfg(feature = "native-tls-client")]
116    fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
117        match http_version {
118            Version::HTTP_2 => &self.tls_connector_alpn_h2,
119            _ => &self.tls_connector_no_alpn,
120        }
121    }
122
123    #[cfg(feature = "rustls-client")]
124    fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
125        match http_version {
126            Version::HTTP_2 => &self.tls_connector_alpn_h2,
127            _ => &self.tls_connector_no_alpn,
128        }
129    }
130
131    /// Send a request and return a response.
132    /// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
133    /// Request should have a full URL including scheme.
134    pub async fn send_request<B>(
135        &self,
136        req: Request<B>,
137    ) -> Result<
138        (
139            Response<Incoming>,
140            Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
141        ),
142        Error,
143    >
144    where
145        B: Body + Unpin + Send + 'static,
146        B::Data: Send,
147        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
148    {
149        let mut send_request = self.connect(req.uri(), req.version()).await?;
150
151        let (req_parts, req_body) = req.into_parts();
152
153        let res = send_request
154            .send_request(Request::from_parts(req_parts.clone(), req_body))
155            .await?;
156
157        if res.status() == StatusCode::SWITCHING_PROTOCOLS {
158            let (res_parts, res_body) = res.into_parts();
159
160            let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
161            let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
162
163            let upgrade = if self.with_upgrades {
164                Some(tokio::task::spawn(async move {
165                    let client = hyper::upgrade::on(client_request).await?;
166                    let server = hyper::upgrade::on(server_response).await?;
167
168                    Ok(Upgraded {
169                        client: TokioIo::new(client),
170                        server: TokioIo::new(server),
171                    })
172                }))
173            } else {
174                tokio::task::spawn(async move {
175                    let client = hyper::upgrade::on(client_request).await?;
176                    let server = hyper::upgrade::on(server_response).await?;
177
178                    let _ = tokio::io::copy_bidirectional(
179                        &mut TokioIo::new(client),
180                        &mut TokioIo::new(server),
181                    )
182                    .await;
183
184                    Ok::<_, hyper::Error>(())
185                });
186                None
187            };
188
189            Ok((Response::from_parts(res_parts, res_body), upgrade))
190        } else {
191            Ok((res, None))
192        }
193    }
194
195    async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
196    where
197        B: Body + Unpin + Send + 'static,
198        B::Data: Send,
199        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
200    {
201        let host = uri.host().ok_or_else(|| Error::InvalidHost(uri.clone()))?;
202        let port =
203            uri.port_u16()
204                .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
205                    443
206                } else {
207                    80
208                });
209
210        let tcp = TcpStream::connect((host, port)).await?;
211        // This is actually needed to some servers
212        let _ = tcp.set_nodelay(true);
213
214        if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
215            #[cfg(feature = "native-tls-client")]
216            let tls = self
217                .tls_connector(http_version)
218                .connect(host, tcp)
219                .await
220                .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
221            #[cfg(feature = "rustls-client")]
222            let tls = self
223                .tls_connector(http_version)
224                .connect(host.to_string().try_into().expect("Invalid host"), tcp)
225                .await
226                .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
227
228            #[cfg(feature = "native-tls-client")]
229            let is_h2 = matches!(
230                tls.get_ref()
231                    .negotiated_alpn()
232                    .map(|a| a.map(|b| b == b"h2")),
233                Ok(Some(true))
234            );
235
236            #[cfg(feature = "rustls-client")]
237            let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
238
239            if is_h2 {
240                let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
241                    .handshake(TokioIo::new(tls))
242                    .await
243                    .map_err(|err| Error::ConnectError(uri.clone(), err))?;
244
245                tokio::spawn(conn);
246
247                Ok(SendRequest::Http2(sender))
248            } else {
249                let (sender, conn) = client::conn::http1::Builder::new()
250                    .preserve_header_case(true)
251                    .title_case_headers(true)
252                    .handshake(TokioIo::new(tls))
253                    .await
254                    .map_err(|err| Error::ConnectError(uri.clone(), err))?;
255
256                tokio::spawn(conn.with_upgrades());
257
258                Ok(SendRequest::Http1(sender))
259            }
260        } else {
261            let (sender, conn) = client::conn::http1::Builder::new()
262                .preserve_header_case(true)
263                .title_case_headers(true)
264                .handshake(TokioIo::new(tcp))
265                .await
266                .map_err(|err| Error::ConnectError(uri.clone(), err))?;
267            tokio::spawn(conn.with_upgrades());
268            Ok(SendRequest::Http1(sender))
269        }
270    }
271}
272
273enum SendRequest<B> {
274    Http1(hyper::client::conn::http1::SendRequest<B>),
275    Http2(hyper::client::conn::http2::SendRequest<B>),
276}
277
278impl<B> SendRequest<B>
279where
280    B: Body + 'static,
281{
282    async fn send_request(
283        &mut self,
284        mut req: Request<B>,
285    ) -> Result<Response<Incoming>, hyper::Error> {
286        match self {
287            SendRequest::Http1(sender) => {
288                if req.version() == hyper::Version::HTTP_2 {
289                    if let Some(authority) = req.uri().authority().cloned() {
290                        req.headers_mut().insert(
291                            header::HOST,
292                            authority.as_str().parse().expect("Invalid authority"),
293                        );
294                    }
295                }
296                remove_authority(&mut req);
297                sender.send_request(req).await
298            }
299            SendRequest::Http2(sender) => {
300                if req.version() != hyper::Version::HTTP_2 {
301                    req.headers_mut().remove(header::HOST);
302                }
303                sender.send_request(req).await
304            }
305        }
306    }
307}
308
309impl<B> SendRequest<B> {
310    #[allow(dead_code)]
311    // TODO: connection pooling
312    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
313        match self {
314            SendRequest::Http1(sender) => sender.poll_ready(cx),
315            SendRequest::Http2(sender) => sender.poll_ready(cx),
316        }
317    }
318}
319
320fn remove_authority<B>(req: &mut Request<B>) {
321    let mut parts = req.uri().clone().into_parts();
322    parts.scheme = None;
323    parts.authority = None;
324    *req.uri_mut() = Uri::from_parts(parts).unwrap();
325}