http_mitm_proxy/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5    Method, Request, Response, StatusCode,
6    body::{Body, Incoming},
7    server,
8    service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone)]
32/// The main struct to run proxy server
33pub struct MitmProxy<C> {
34    /// Root certificate to sign fake certificates. You may need to trust this certificate on client application to use HTTPS.
35    ///
36    /// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic.
37    pub root_cert: Option<C>,
38    /// Cache to store generated certificates. If None, cache will not be used.
39    /// If root_cert is None, cache will not be used.
40    ///
41    /// The key of cache is hostname.
42    pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
43}
44
45impl<C> MitmProxy<C> {
46    /// Create a new MitmProxy
47    pub fn new(root_cert: Option<C>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
48        Self {
49            root_cert,
50            cert_cache: cache,
51        }
52    }
53}
54
55impl<C> MitmProxy<C>
56where
57    C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static,
58{
59    /// Bind to a socket address and return a future that runs the proxy server.
60    /// URL for requests that passed to service are full URL including scheme.
61    pub async fn bind<A: ToSocketAddrs, S>(
62        self,
63        addr: A,
64        service: S,
65    ) -> Result<impl Future<Output = ()>, std::io::Error>
66    where
67        S: HttpService<Incoming> + Clone + Send + 'static,
68        S::Error: Into<Box<dyn StdError + Send + Sync>>,
69        S::ResBody: Send + Sync + 'static,
70        <S::ResBody as Body>::Data: Send,
71        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
72        S::Future: Send,
73    {
74        let listener = TcpListener::bind(addr).await?;
75
76        let proxy = Arc::new(self);
77
78        Ok(async move {
79            loop {
80                let Ok((stream, _)) = listener.accept().await else {
81                    continue;
82                };
83
84                let service = service.clone();
85
86                let proxy = proxy.clone();
87                tokio::spawn(async move {
88                    if let Err(err) = server::conn::http1::Builder::new()
89                        .preserve_header_case(true)
90                        .title_case_headers(true)
91                        .serve_connection(
92                            TokioIo::new(stream),
93                            Self::wrap_service(proxy.clone(), service.clone()),
94                        )
95                        .with_upgrades()
96                        .await
97                    {
98                        tracing::error!("Error in proxy: {}", err);
99                    }
100                });
101            }
102        })
103    }
104
105    /// Transform a service to a service that can be used in hyper server.
106    /// URL for requests that passed to service are full URL including scheme.
107    /// See `examples/https.rs` for usage.
108    /// If you want to serve simple HTTP proxy server, you can use `bind` method instead.
109    /// `bind` will call this method internally.
110    pub fn wrap_service<S>(
111        proxy: Arc<Self>,
112        service: S,
113    ) -> impl HttpService<
114        Incoming,
115        ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
116        Future: Send,
117    >
118    where
119        S: HttpService<Incoming> + Clone + Send + 'static,
120        S::Error: Into<Box<dyn StdError + Send + Sync>>,
121        S::ResBody: Send + Sync + 'static,
122        <S::ResBody as Body>::Data: Send,
123        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
124        S::Future: Send,
125    {
126        service_fn(move |req| {
127            let proxy = proxy.clone();
128            let mut service = service.clone();
129
130            async move {
131                if req.method() == Method::CONNECT {
132                    // https
133                    let Some(connect_authority) = req.uri().authority().cloned() else {
134                        tracing::error!(
135                            "Bad CONNECT request: {}, Reason: Invalid Authority",
136                            req.uri()
137                        );
138                        return Ok(no_body(StatusCode::BAD_REQUEST)
139                            .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
140                    };
141
142                    tokio::spawn(async move {
143                        let Ok(client) = hyper::upgrade::on(req).await else {
144                            tracing::error!(
145                                "Bad CONNECT request: {}, Reason: Invalid Upgrade",
146                                connect_authority
147                            );
148                            return;
149                        };
150                        if let Some(server_config) =
151                            proxy.server_config(connect_authority.host().to_string(), true)
152                        {
153                            let server_config = match server_config {
154                                Ok(server_config) => server_config,
155                                Err(err) => {
156                                    tracing::error!(
157                                        "Failed to create server config for {}, {}",
158                                        connect_authority.host(),
159                                        err
160                                    );
161                                    return;
162                                }
163                            };
164                            let server_config = Arc::new(server_config);
165                            let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
166                            let client = match tls_acceptor.accept(TokioIo::new(client)).await {
167                                Ok(client) => client,
168                                Err(err) => {
169                                    tracing::error!(
170                                        "Failed to accept TLS connection for {}, {}",
171                                        connect_authority.host(),
172                                        err
173                                    );
174                                    return;
175                                }
176                            };
177                            let f = move |mut req: Request<_>| {
178                                let connect_authority = connect_authority.clone();
179                                let mut service = service.clone();
180
181                                async move {
182                                    inject_authority(&mut req, connect_authority.clone());
183                                    service.call(req).await
184                                }
185                            };
186                            let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
187                                server::conn::http2::Builder::new(TokioExecutor::new())
188                                    .serve_connection(TokioIo::new(client), service_fn(f))
189                                    .await
190                            } else {
191                                server::conn::http1::Builder::new()
192                                    .preserve_header_case(true)
193                                    .title_case_headers(true)
194                                    .serve_connection(TokioIo::new(client), service_fn(f))
195                                    .with_upgrades()
196                                    .await
197                            };
198
199                            if let Err(_err) = res {
200                                // Suppress error because if we serving HTTPS proxy server and forward to HTTPS server, it will always error when closing connection.
201                                // tracing::error!("Error in proxy: {}", err);
202                            }
203                        } else {
204                            let Ok(mut server) =
205                                TcpStream::connect(connect_authority.as_str()).await
206                            else {
207                                tracing::error!("Failed to connect to {}", connect_authority);
208                                return;
209                            };
210                            let _ = tokio::io::copy_bidirectional(
211                                &mut TokioIo::new(client),
212                                &mut server,
213                            )
214                            .await;
215                        }
216                    });
217
218                    Ok(Response::new(
219                        http_body_util::Empty::new()
220                            .map_err(|never: std::convert::Infallible| match never {})
221                            .boxed(),
222                    ))
223                } else {
224                    // http
225                    service.call(req).await.map(|res| res.map(|b| b.boxed()))
226                }
227            }
228        })
229    }
230
231    fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
232        self.root_cert.as_ref().map(|root_cert| {
233            if let Some(cache) = self.cert_cache.as_ref() {
234                cache.get_with(host.clone(), move || {
235                    generate_cert(host, root_cert.borrow())
236                })
237            } else {
238                generate_cert(host, root_cert.borrow())
239            }
240        })
241    }
242
243    fn server_config(
244        &self,
245        host: String,
246        h2: bool,
247    ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
248        if let Some(cert) = self.get_certified_key(host) {
249            let config = rustls::ServerConfig::builder()
250                .with_no_client_auth()
251                .with_single_cert(
252                    vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
253                    rustls::pki_types::PrivateKeyDer::Pkcs8(
254                        rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
255                    ),
256                );
257
258            Some(if h2 {
259                config.map(|mut server_config| {
260                    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
261                    server_config
262                })
263            } else {
264                config
265            })
266        } else {
267            None
268        }
269    }
270}
271
272fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
273    let mut res = Response::new(Empty::new());
274    *res.status_mut() = status;
275    res
276}
277
278fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
279    let mut parts = request_middleman.uri().clone().into_parts();
280    parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
281    if parts.authority.is_none() {
282        parts.authority = Some(authority);
283    }
284    *request_middleman.uri_mut() = hyper::http::uri::Uri::from_parts(parts).unwrap();
285}