http_mitm_proxy/
default_client.rs

1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::{Buf, Bytes};
4use http_body_util::{BodyExt, Empty, combinators::BoxBody};
5use hyper::{
6    Request, Response, StatusCode, Uri, Version,
7    body::{Body, Incoming},
8    client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::{
12    collections::HashMap,
13    future::poll_fn,
14    sync::Arc,
15    task::{Context, Poll},
16};
17use tokio::sync::Mutex;
18use tokio::{net::TcpStream, task::JoinHandle};
19
20#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
21compile_error!(
22    "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
23);
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27    #[error("{0} doesn't have an valid host")]
28    InvalidHost(Box<Uri>),
29    #[error(transparent)]
30    IoError(#[from] std::io::Error),
31    #[error(transparent)]
32    HyperError(#[from] hyper::Error),
33    #[error("Failed to connect to {0}, {1}")]
34    ConnectError(Box<Uri>, hyper::Error),
35
36    #[cfg(feature = "native-tls-client")]
37    #[error("Failed to connect with TLS to {0}, {1}")]
38    TlsConnectError(Box<Uri>, native_tls::Error),
39    #[cfg(feature = "native-tls-client")]
40    #[error(transparent)]
41    NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
42
43    #[cfg(feature = "rustls-client")]
44    #[error("Failed to connect with TLS to {0}, {1}")]
45    TlsConnectError(Box<Uri>, std::io::Error),
46
47    #[error("Failed to parse URI: {0}")]
48    UriParsingError(#[from] hyper::http::uri::InvalidUri),
49
50    #[error("Failed to parse URI parts: {0}")]
51    UriPartsError(#[from] hyper::http::uri::InvalidUriParts),
52
53    #[error("TLS connector initialization failed: {0}")]
54    TlsConnectorError(String),
55}
56
57/// Upgraded connections
58pub struct Upgraded {
59    /// A socket to Client
60    pub client: TokioIo<hyper::upgrade::Upgraded>,
61    /// A socket to Server
62    pub server: TokioIo<hyper::upgrade::Upgraded>,
63}
64
65type DynError = Box<dyn std::error::Error + Send + Sync>;
66type PooledBody = BoxBody<Bytes, DynError>;
67type Http1Sender = hyper::client::conn::http1::SendRequest<PooledBody>;
68type Http2Sender = hyper::client::conn::http2::SendRequest<PooledBody>;
69
70#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
71enum ConnectionProtocol {
72    Http1,
73    Http2,
74}
75
76#[derive(Clone, Debug, Eq, PartialEq, Hash)]
77struct ConnectionKey {
78    host: String,
79    port: u16,
80    is_tls: bool,
81    protocol: ConnectionProtocol,
82}
83
84impl ConnectionKey {
85    fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self {
86        Self {
87            host,
88            port,
89            is_tls,
90            protocol,
91        }
92    }
93
94    fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result<Self, Error> {
95        let (host, port, is_tls) = host_port(uri)?;
96        Ok(ConnectionKey::new(host, port, is_tls, protocol))
97    }
98}
99
100#[derive(Clone, Default)]
101struct ConnectionPool {
102    http1: Arc<Mutex<HashMap<ConnectionKey, Vec<Http1Sender>>>>,
103    http2: Arc<Mutex<HashMap<ConnectionKey, Http2Sender>>>,
104}
105
106impl ConnectionPool {
107    async fn take_http1(&self, key: &ConnectionKey) -> Option<Http1Sender> {
108        let mut guard = self.http1.lock().await;
109        let entry = guard.get_mut(key)?;
110        while let Some(mut conn) = entry.pop() {
111            if sender_alive_http1(&mut conn).await {
112                return Some(conn);
113            }
114        }
115        if entry.is_empty() {
116            guard.remove(key);
117        }
118        None
119    }
120
121    async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) {
122        let mut guard = self.http1.lock().await;
123        guard.entry(key).or_default().push(sender);
124    }
125
126    async fn get_http2(&self, key: &ConnectionKey) -> Option<Http2Sender> {
127        let mut guard = self.http2.lock().await;
128        let mut sender = guard.get(key).cloned()?;
129
130        let alive = sender_alive_http2(&mut sender).await;
131
132        if alive {
133            Some(sender)
134        } else {
135            guard.remove(key);
136            None
137        }
138    }
139
140    async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) {
141        let mut guard = self.http2.lock().await;
142        guard.entry(key).or_insert(sender);
143    }
144}
145
146async fn sender_alive_http1(sender: &mut Http1Sender) -> bool {
147    poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
148}
149
150async fn sender_alive_http2(sender: &mut Http2Sender) -> bool {
151    poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
152}
153
154#[derive(Clone)]
155/// Default HTTP client for this crate
156pub struct DefaultClient {
157    #[cfg(feature = "native-tls-client")]
158    tls_connector_no_alpn: tokio_native_tls::TlsConnector,
159    #[cfg(feature = "native-tls-client")]
160    tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
161
162    #[cfg(feature = "rustls-client")]
163    tls_connector_no_alpn: tokio_rustls::TlsConnector,
164    #[cfg(feature = "rustls-client")]
165    tls_connector_alpn_h2: tokio_rustls::TlsConnector,
166
167    /// If true, send_request will returns an Upgraded struct when the response is an upgrade
168    /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
169    pub with_upgrades: bool,
170
171    pool: ConnectionPool,
172}
173impl Default for DefaultClient {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179impl DefaultClient {
180    #[cfg(feature = "native-tls-client")]
181    pub fn new() -> Self {
182        Self::try_new().unwrap_or_else(|err| {
183            panic!("Failed to create DefaultClient: {err}");
184        })
185    }
186
187    #[cfg(feature = "native-tls-client")]
188    pub fn try_new() -> Result<Self, Error> {
189        let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().map_err(|e| {
190            Error::TlsConnectorError(format!("Failed to build no-ALPN connector: {e}"))
191        })?;
192        let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
193            .request_alpns(&["h2", "http/1.1"])
194            .build()
195            .map_err(|e| {
196                Error::TlsConnectorError(format!("Failed to build ALPN-H2 connector: {e}"))
197            })?;
198
199        Ok(Self {
200            tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
201            tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
202            with_upgrades: false,
203            pool: ConnectionPool::default(),
204        })
205    }
206
207    #[cfg(feature = "rustls-client")]
208    pub fn new() -> Self {
209        Self::try_new().unwrap_or_else(|err| {
210            panic!("Failed to create DefaultClient: {}", err);
211        })
212    }
213
214    #[cfg(feature = "rustls-client")]
215    pub fn try_new() -> Result<Self, Error> {
216        use std::sync::Arc;
217
218        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
219        root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
220
221        let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
222            .with_root_certificates(root_cert_store.clone())
223            .with_no_client_auth();
224        let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
225            .with_root_certificates(root_cert_store.clone())
226            .with_no_client_auth();
227        tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
228
229        Ok(Self {
230            tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
231                tls_connector_no_alpn,
232            )),
233            tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
234                tls_connector_alpn_h2,
235            )),
236            with_upgrades: false,
237            pool: ConnectionPool::default(),
238        })
239    }
240
241    /// Enable HTTP upgrades
242    /// If you don't enable HTTP upgrades, send_request will just copy bidirectional when the response is an upgrade
243    pub fn with_upgrades(mut self) -> Self {
244        self.with_upgrades = true;
245        self
246    }
247
248    #[cfg(feature = "native-tls-client")]
249    fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
250        match http_version {
251            Version::HTTP_2 => &self.tls_connector_alpn_h2,
252            _ => &self.tls_connector_no_alpn,
253        }
254    }
255
256    #[cfg(feature = "rustls-client")]
257    fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
258        match http_version {
259            Version::HTTP_2 => &self.tls_connector_alpn_h2,
260            _ => &self.tls_connector_no_alpn,
261        }
262    }
263
264    /// Send a request and return a response.
265    /// If the response is an upgrade (= if status code is 101 Switching Protocols), it will return a response and an Upgrade struct.
266    /// Request should have a full URL including scheme.
267    pub async fn send_request<B>(
268        &self,
269        req: Request<B>,
270    ) -> Result<
271        (
272            Response<Incoming>,
273            Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
274        ),
275        Error,
276    >
277    where
278        B: Body<Data = Bytes> + Send + Sync + 'static,
279        B::Data: Send + Buf,
280        B::Error: Into<DynError>,
281    {
282        let target_uri = req.uri().clone();
283        let mut send_request = if req.version() == Version::HTTP_2 {
284            match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) {
285                Ok(pool_key) => {
286                    if let Some(conn) = self.pool.get_http2(&pool_key).await {
287                        SendRequest::Http2(conn)
288                    } else {
289                        self.connect(req.uri(), req.version(), Some(pool_key))
290                            .await?
291                    }
292                }
293                Err(err) => {
294                    tracing::warn!(
295                        "ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool",
296                        err
297                    );
298                    self.connect(req.uri(), req.version(), None).await?
299                }
300            }
301        } else {
302            match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) {
303                Ok(pool_key) => {
304                    if let Some(conn) = self.pool.take_http1(&pool_key).await {
305                        SendRequest::Http1(conn)
306                    } else {
307                        self.connect(req.uri(), req.version(), Some(pool_key))
308                            .await?
309                    }
310                }
311                Err(err) => {
312                    tracing::warn!(
313                        "ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool",
314                        err
315                    );
316                    self.connect(req.uri(), req.version(), None).await?
317                }
318            }
319        };
320
321        let (req_parts, req_body) = req.into_parts();
322
323        let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body));
324
325        let res = send_request.send_request(boxed_req).await?;
326
327        if res.status() == StatusCode::SWITCHING_PROTOCOLS {
328            let (res_parts, res_body) = res.into_parts();
329
330            let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
331            let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
332
333            let upgrade = if self.with_upgrades {
334                Some(tokio::task::spawn(async move {
335                    let client = hyper::upgrade::on(client_request).await?;
336                    let server = hyper::upgrade::on(server_response).await?;
337
338                    Ok(Upgraded {
339                        client: TokioIo::new(client),
340                        server: TokioIo::new(server),
341                    })
342                }))
343            } else {
344                tokio::task::spawn(async move {
345                    let client = hyper::upgrade::on(client_request).await?;
346                    let server = hyper::upgrade::on(server_response).await?;
347
348                    let _ = tokio::io::copy_bidirectional(
349                        &mut TokioIo::new(client),
350                        &mut TokioIo::new(server),
351                    )
352                    .await;
353
354                    Ok::<_, hyper::Error>(())
355                });
356                None
357            };
358
359            Ok((Response::from_parts(res_parts, res_body), upgrade))
360        } else {
361            match send_request {
362                SendRequest::Http1(sender) => {
363                    if let Ok(pool_key) =
364                        ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)
365                    {
366                        self.pool.put_http1(pool_key, sender).await;
367                    } else {
368                        // If we couldn't build a pool key, skip pooling.
369                    }
370                }
371                SendRequest::Http2(_) => {
372                    // For HTTP/2 the pool retains a shared sender; no action needed.
373                }
374            }
375            Ok((res, None))
376        }
377    }
378
379    async fn connect(
380        &self,
381        uri: &Uri,
382        http_version: Version,
383        key: Option<ConnectionKey>,
384    ) -> Result<SendRequest, Error> {
385        let (host, port, is_tls) = host_port(uri)?;
386
387        let tcp = TcpStream::connect((host.as_str(), port)).await?;
388        // This is actually needed to some servers
389        let _ = tcp.set_nodelay(true);
390
391        if is_tls {
392            #[cfg(feature = "native-tls-client")]
393            let tls = self
394                .tls_connector(http_version)
395                .connect(&host, tcp)
396                .await
397                .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
398            #[cfg(feature = "rustls-client")]
399            let tls = self
400                .tls_connector(http_version)
401                .connect(
402                    host.to_string()
403                        .try_into()
404                        .map_err(|_| Error::InvalidHost(Box::new(uri.clone())))?,
405                    tcp,
406                )
407                .await
408                .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
409
410            #[cfg(feature = "native-tls-client")]
411            let is_h2 = matches!(
412                tls.get_ref()
413                    .negotiated_alpn()
414                    .map(|a| a.map(|b| b == b"h2")),
415                Ok(Some(true))
416            );
417
418            #[cfg(feature = "rustls-client")]
419            let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
420
421            if is_h2 {
422                let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
423                    .handshake(TokioIo::new(tls))
424                    .await
425                    .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
426
427                tokio::spawn(conn);
428
429                if let Some(ref k) = key
430                    && matches!(k.protocol, ConnectionProtocol::Http2)
431                {
432                    self.pool
433                        .insert_http2_if_absent(k.clone(), sender.clone())
434                        .await;
435                }
436
437                Ok(SendRequest::Http2(sender))
438            } else {
439                let (sender, conn) = client::conn::http1::Builder::new()
440                    .preserve_header_case(true)
441                    .title_case_headers(true)
442                    .handshake(TokioIo::new(tls))
443                    .await
444                    .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
445
446                tokio::spawn(conn.with_upgrades());
447
448                Ok(SendRequest::Http1(sender))
449            }
450        } else {
451            let (sender, conn) = client::conn::http1::Builder::new()
452                .preserve_header_case(true)
453                .title_case_headers(true)
454                .handshake(TokioIo::new(tcp))
455                .await
456                .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
457            tokio::spawn(conn.with_upgrades());
458            Ok(SendRequest::Http1(sender))
459        }
460    }
461}
462
463enum SendRequest {
464    Http1(Http1Sender),
465    Http2(Http2Sender),
466}
467
468impl SendRequest {
469    async fn send_request(
470        &mut self,
471        mut req: Request<PooledBody>,
472    ) -> Result<Response<Incoming>, hyper::Error> {
473        match self {
474            SendRequest::Http1(sender) => {
475                if req.version() == hyper::Version::HTTP_2
476                    && let Some(authority) = req.uri().authority().cloned()
477                {
478                    match authority.as_str().parse::<header::HeaderValue>() {
479                        Ok(host_value) => {
480                            req.headers_mut().insert(header::HOST, host_value);
481                        }
482                        Err(err) => {
483                            tracing::warn!(
484                                "Failed to parse authority '{}' as HOST header: {}",
485                                authority,
486                                err
487                            );
488                        }
489                    }
490                }
491                if let Err(err) = remove_authority(&mut req) {
492                    tracing::error!("Failed to remove authority from URI: {}", err);
493                    // Continue with the original request if URI modification fails
494                }
495                sender.send_request(req).await
496            }
497            SendRequest::Http2(sender) => {
498                if req.version() != hyper::Version::HTTP_2 {
499                    req.headers_mut().remove(header::HOST);
500                }
501                sender.send_request(req).await
502            }
503        }
504    }
505}
506
507impl SendRequest {
508    #[allow(dead_code)]
509    // TODO: connection pooling
510    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
511        match self {
512            SendRequest::Http1(sender) => sender.poll_ready(cx),
513            SendRequest::Http2(_sender) => Poll::Ready(Ok(())),
514        }
515    }
516}
517
518fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::InvalidUriParts> {
519    let mut parts = req.uri().clone().into_parts();
520    parts.scheme = None;
521    parts.authority = None;
522    *req.uri_mut() = Uri::from_parts(parts)?;
523    Ok(())
524}
525
526fn to_boxed_body<B>(body: B) -> PooledBody
527where
528    B: Body<Data = Bytes> + Send + Sync + 'static,
529    B::Data: Send + Buf,
530    B::Error: Into<DynError>,
531{
532    body.map_err(|err| err.into()).boxed()
533}
534
535fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> {
536    let host = uri
537        .host()
538        .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?
539        .to_string();
540    let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS);
541    let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 });
542    Ok((host, port, is_tls))
543}
544
545impl DefaultClient {}