http_mitm_proxy/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5    Method, Request, Response, StatusCode,
6    body::{Body, Incoming},
7    server,
8    service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24pub mod default_client;
25mod tls;
26
27pub use default_client::DefaultClient;
28
29#[derive(Clone)]
30/// The main struct to run proxy server
31pub struct MitmProxy<C> {
32    /// Root certificate to sign fake certificates. You may need to trust this certificate on client application to use HTTPS.
33    ///
34    /// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic.
35    pub root_cert: Option<C>,
36    /// Cache to store generated certificates. If None, cache will not be used.
37    /// If root_cert is None, cache will not be used.
38    ///
39    /// The key of cache is hostname.
40    pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
41}
42
43impl<C> MitmProxy<C> {
44    /// Create a new MitmProxy
45    pub fn new(root_cert: Option<C>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
46        Self {
47            root_cert,
48            cert_cache: cache,
49        }
50    }
51}
52
53impl<C> MitmProxy<C>
54where
55    C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static,
56{
57    /// Bind to a socket address and return a future that runs the proxy server.
58    /// URL for requests that passed to service are full URL including scheme.
59    pub async fn bind<A: ToSocketAddrs, S>(
60        self,
61        addr: A,
62        service: S,
63    ) -> Result<impl Future<Output = ()>, std::io::Error>
64    where
65        S: HttpService<Incoming> + Clone + Send + 'static,
66        S::Error: Into<Box<dyn StdError + Send + Sync>>,
67        S::ResBody: Send + Sync + 'static,
68        <S::ResBody as Body>::Data: Send,
69        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
70        S::Future: Send,
71    {
72        let listener = TcpListener::bind(addr).await?;
73
74        let proxy = Arc::new(self);
75
76        Ok(async move {
77            loop {
78                let Ok((stream, _)) = listener.accept().await else {
79                    continue;
80                };
81
82                let service = service.clone();
83
84                let proxy = proxy.clone();
85                tokio::spawn(async move {
86                    if let Err(err) = server::conn::http1::Builder::new()
87                        .preserve_header_case(true)
88                        .title_case_headers(true)
89                        .serve_connection(
90                            TokioIo::new(stream),
91                            Self::wrap_service(proxy.clone(), service.clone()),
92                        )
93                        .with_upgrades()
94                        .await
95                    {
96                        tracing::error!("Error in proxy: {}", err);
97                    }
98                });
99            }
100        })
101    }
102
103    /// Transform a service to a service that can be used in hyper server.
104    /// URL for requests that passed to service are full URL including scheme.
105    /// See `examples/https.rs` for usage.
106    /// If you want to serve simple HTTP proxy server, you can use `bind` method instead.
107    /// `bind` will call this method internally.
108    pub fn wrap_service<S>(
109        proxy: Arc<Self>,
110        service: S,
111    ) -> impl HttpService<
112        Incoming,
113        ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
114        Future: Send,
115    >
116    where
117        S: HttpService<Incoming> + Clone + Send + 'static,
118        S::Error: Into<Box<dyn StdError + Send + Sync>>,
119        S::ResBody: Send + Sync + 'static,
120        <S::ResBody as Body>::Data: Send,
121        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
122        S::Future: Send,
123    {
124        service_fn(move |req| {
125            let proxy = proxy.clone();
126            let mut service = service.clone();
127
128            async move {
129                if req.method() == Method::CONNECT {
130                    // https
131                    let Some(connect_authority) = req.uri().authority().cloned() else {
132                        tracing::error!(
133                            "Bad CONNECT request: {}, Reason: Invalid Authority",
134                            req.uri()
135                        );
136                        return Ok(no_body(StatusCode::BAD_REQUEST)
137                            .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
138                    };
139
140                    tokio::spawn(async move {
141                        let Ok(client) = hyper::upgrade::on(req).await else {
142                            tracing::error!(
143                                "Bad CONNECT request: {}, Reason: Invalid Upgrade",
144                                connect_authority
145                            );
146                            return;
147                        };
148                        if let Some(server_config) =
149                            proxy.server_config(connect_authority.host().to_string(), true)
150                        {
151                            let server_config = match server_config {
152                                Ok(server_config) => server_config,
153                                Err(err) => {
154                                    tracing::error!(
155                                        "Failed to create server config for {}, {}",
156                                        connect_authority.host(),
157                                        err
158                                    );
159                                    return;
160                                }
161                            };
162                            let server_config = Arc::new(server_config);
163                            let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
164                            let client = match tls_acceptor.accept(TokioIo::new(client)).await {
165                                Ok(client) => client,
166                                Err(err) => {
167                                    tracing::error!(
168                                        "Failed to accept TLS connection for {}, {}",
169                                        connect_authority.host(),
170                                        err
171                                    );
172                                    return;
173                                }
174                            };
175                            let f = move |mut req: Request<_>| {
176                                let connect_authority = connect_authority.clone();
177                                let mut service = service.clone();
178
179                                async move {
180                                    inject_authority(&mut req, connect_authority.clone());
181                                    service.call(req).await
182                                }
183                            };
184                            let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
185                                server::conn::http2::Builder::new(TokioExecutor::new())
186                                    .serve_connection(TokioIo::new(client), service_fn(f))
187                                    .await
188                            } else {
189                                server::conn::http1::Builder::new()
190                                    .preserve_header_case(true)
191                                    .title_case_headers(true)
192                                    .serve_connection(TokioIo::new(client), service_fn(f))
193                                    .with_upgrades()
194                                    .await
195                            };
196
197                            if let Err(_err) = res {
198                                // Suppress error because if we serving HTTPS proxy server and forward to HTTPS server, it will always error when closing connection.
199                                // tracing::error!("Error in proxy: {}", err);
200                            }
201                        } else {
202                            let Ok(mut server) =
203                                TcpStream::connect(connect_authority.as_str()).await
204                            else {
205                                tracing::error!("Failed to connect to {}", connect_authority);
206                                return;
207                            };
208                            let _ = tokio::io::copy_bidirectional(
209                                &mut TokioIo::new(client),
210                                &mut server,
211                            )
212                            .await;
213                        }
214                    });
215
216                    Ok(Response::new(
217                        http_body_util::Empty::new()
218                            .map_err(|never: std::convert::Infallible| match never {})
219                            .boxed(),
220                    ))
221                } else {
222                    // http
223                    service.call(req).await.map(|res| res.map(|b| b.boxed()))
224                }
225            }
226        })
227    }
228
229    fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
230        self.root_cert.as_ref().map(|root_cert| {
231            if let Some(cache) = self.cert_cache.as_ref() {
232                cache.get_with(host.clone(), move || {
233                    generate_cert(host, root_cert.borrow())
234                })
235            } else {
236                generate_cert(host, root_cert.borrow())
237            }
238        })
239    }
240
241    fn server_config(
242        &self,
243        host: String,
244        h2: bool,
245    ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
246        if let Some(cert) = self.get_certified_key(host) {
247            let config = rustls::ServerConfig::builder()
248                .with_no_client_auth()
249                .with_single_cert(
250                    vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
251                    rustls::pki_types::PrivateKeyDer::Pkcs8(
252                        rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
253                    ),
254                );
255
256            Some(if h2 {
257                config.map(|mut server_config| {
258                    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
259                    server_config
260                })
261            } else {
262                config
263            })
264        } else {
265            None
266        }
267    }
268}
269
270fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
271    let mut res = Response::new(Empty::new());
272    *res.status_mut() = status;
273    res
274}
275
276fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
277    let mut parts = request_middleman.uri().clone().into_parts();
278    parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
279    if parts.authority.is_none() {
280        parts.authority = Some(authority);
281    }
282    *request_middleman.uri_mut() = hyper::http::uri::Uri::from_parts(parts).unwrap();
283}