http_mitm_proxy/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5    Method, Request, Response, StatusCode,
6    body::{Body, Incoming},
7    server,
8    service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone)]
32/// The main struct to run proxy server
33pub struct MitmProxy<I> {
34    /// Root issuer to sign fake certificates. You may need to trust this issuer on client application to use HTTPS.
35    ///
36    /// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic.
37    pub root_issuer: Option<I>,
38    /// Cache to store generated certificates. If None, cache will not be used.
39    /// If root_issuer is None, cache will not be used.
40    ///
41    /// The key of cache is hostname.
42    pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
43}
44
45impl<I> MitmProxy<I> {
46    /// Create a new MitmProxy
47    pub fn new(root_issuer: Option<I>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
48        Self {
49            root_issuer,
50            cert_cache: cache,
51        }
52    }
53}
54
55impl<I> MitmProxy<I>
56where
57    I: Borrow<rcgen::Issuer<'static, rcgen::KeyPair>> + Send + Sync + 'static,
58{
59    /// Bind to a socket address and return a future that runs the proxy server.
60    /// URL for requests that passed to service are full URL including scheme.
61    pub async fn bind<A: ToSocketAddrs, S>(
62        self,
63        addr: A,
64        service: S,
65    ) -> Result<impl Future<Output = ()>, std::io::Error>
66    where
67        S: HttpService<Incoming> + Clone + Send + 'static,
68        S::Error: Into<Box<dyn StdError + Send + Sync>>,
69        S::ResBody: Send + Sync + 'static,
70        <S::ResBody as Body>::Data: Send,
71        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
72        S::Future: Send,
73    {
74        let listener = TcpListener::bind(addr).await?;
75
76        let proxy = Arc::new(self);
77
78        Ok(async move {
79            loop {
80                let (stream, _) = match listener.accept().await {
81                    Ok(conn) => conn,
82                    Err(err) => {
83                        tracing::warn!("Failed to accept connection: {}", err);
84                        continue;
85                    }
86                };
87
88                let service = service.clone();
89
90                let proxy = proxy.clone();
91                tokio::spawn(async move {
92                    if let Err(err) = server::conn::http1::Builder::new()
93                        .preserve_header_case(true)
94                        .title_case_headers(true)
95                        .serve_connection(
96                            TokioIo::new(stream),
97                            Self::wrap_service(proxy.clone(), service.clone()),
98                        )
99                        .with_upgrades()
100                        .await
101                    {
102                        tracing::error!("Error in proxy: {}", err);
103                    }
104                });
105            }
106        })
107    }
108
109    /// Transform a service to a service that can be used in hyper server.
110    /// URL for requests that passed to service are full URL including scheme.
111    /// See `examples/https.rs` for usage.
112    /// If you want to serve simple HTTP proxy server, you can use `bind` method instead.
113    /// `bind` will call this method internally.
114    pub fn wrap_service<S>(
115        proxy: Arc<Self>,
116        service: S,
117    ) -> impl HttpService<
118        Incoming,
119        ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
120        Future: Send,
121    >
122    where
123        S: HttpService<Incoming> + Clone + Send + 'static,
124        S::Error: Into<Box<dyn StdError + Send + Sync>>,
125        S::ResBody: Send + Sync + 'static,
126        <S::ResBody as Body>::Data: Send,
127        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
128        S::Future: Send,
129    {
130        service_fn(move |req| {
131            let proxy = proxy.clone();
132            let mut service = service.clone();
133
134            async move {
135                if req.method() == Method::CONNECT {
136                    // https
137                    let Some(connect_authority) = req.uri().authority().cloned() else {
138                        tracing::error!(
139                            "Bad CONNECT request: {}, Reason: Invalid Authority",
140                            req.uri()
141                        );
142                        return Ok(no_body(StatusCode::BAD_REQUEST)
143                            .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
144                    };
145
146                    tokio::spawn(async move {
147                        let client = match hyper::upgrade::on(req).await {
148                            Ok(client) => client,
149                            Err(err) => {
150                                tracing::error!(
151                                    "Failed to upgrade CONNECT request for {}: {}",
152                                    connect_authority,
153                                    err
154                                );
155                                return;
156                            }
157                        };
158                        if let Some(server_config) =
159                            proxy.server_config(connect_authority.host().to_string(), true)
160                        {
161                            let server_config = match server_config {
162                                Ok(server_config) => server_config,
163                                Err(err) => {
164                                    tracing::error!(
165                                        "Failed to create server config for {}, {}",
166                                        connect_authority.host(),
167                                        err
168                                    );
169                                    return;
170                                }
171                            };
172                            let server_config = Arc::new(server_config);
173                            let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
174                            let client = match tls_acceptor.accept(TokioIo::new(client)).await {
175                                Ok(client) => client,
176                                Err(err) => {
177                                    tracing::error!(
178                                        "Failed to accept TLS connection for {}, {}",
179                                        connect_authority.host(),
180                                        err
181                                    );
182                                    return;
183                                }
184                            };
185                            let f = move |mut req: Request<_>| {
186                                let connect_authority = connect_authority.clone();
187                                let mut service = service.clone();
188
189                                async move {
190                                    inject_authority(&mut req, connect_authority.clone());
191                                    service.call(req).await
192                                }
193                            };
194                            let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
195                                server::conn::http2::Builder::new(TokioExecutor::new())
196                                    .serve_connection(TokioIo::new(client), service_fn(f))
197                                    .await
198                            } else {
199                                server::conn::http1::Builder::new()
200                                    .preserve_header_case(true)
201                                    .title_case_headers(true)
202                                    .serve_connection(TokioIo::new(client), service_fn(f))
203                                    .with_upgrades()
204                                    .await
205                            };
206
207                            if let Err(err) = res {
208                                tracing::debug!("Connection closed: {}", err);
209                            }
210                        } else {
211                            let mut server =
212                                match TcpStream::connect(connect_authority.as_str()).await {
213                                    Ok(server) => server,
214                                    Err(err) => {
215                                        tracing::error!(
216                                            "Failed to connect to {}: {}",
217                                            connect_authority,
218                                            err
219                                        );
220                                        return;
221                                    }
222                                };
223                            let _ = tokio::io::copy_bidirectional(
224                                &mut TokioIo::new(client),
225                                &mut server,
226                            )
227                            .await;
228                        }
229                    });
230
231                    Ok(Response::new(
232                        http_body_util::Empty::new()
233                            .map_err(|never: std::convert::Infallible| match never {})
234                            .boxed(),
235                    ))
236                } else {
237                    // http
238                    service.call(req).await.map(|res| res.map(|b| b.boxed()))
239                }
240            }
241        })
242    }
243
244    fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
245        self.root_issuer.as_ref().and_then(|root_issuer| {
246            if let Some(cache) = self.cert_cache.as_ref() {
247                // Try to get from cache, but handle generation errors gracefully
248                cache
249                    .try_get_with(host.clone(), move || {
250                        generate_cert(host, root_issuer.borrow())
251                    })
252                    .map_err(|err| {
253                        tracing::error!("Failed to generate certificate for host: {}", err);
254                    })
255                    .ok()
256            } else {
257                generate_cert(host, root_issuer.borrow())
258                    .map_err(|err| {
259                        tracing::error!("Failed to generate certificate for host: {}", err);
260                    })
261                    .ok()
262            }
263        })
264    }
265
266    fn server_config(
267        &self,
268        host: String,
269        h2: bool,
270    ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
271        if let Some(cert) = self.get_certified_key(host) {
272            let config = rustls::ServerConfig::builder()
273                .with_no_client_auth()
274                .with_single_cert(
275                    vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
276                    rustls::pki_types::PrivateKeyDer::Pkcs8(
277                        rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
278                    ),
279                );
280
281            Some(if h2 {
282                config.map(|mut server_config| {
283                    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
284                    server_config
285                })
286            } else {
287                config
288            })
289        } else {
290            None
291        }
292    }
293}
294
295fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
296    let mut res = Response::new(Empty::new());
297    *res.status_mut() = status;
298    res
299}
300
301fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
302    let mut parts = request_middleman.uri().clone().into_parts();
303    parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
304    if parts.authority.is_none() {
305        parts.authority = Some(authority.clone());
306    }
307
308    match hyper::http::uri::Uri::from_parts(parts) {
309        Ok(uri) => *request_middleman.uri_mut() = uri,
310        Err(err) => {
311            tracing::error!(
312                "Failed to inject authority '{}' into URI: {}",
313                authority,
314                err
315            );
316            // Keep the original URI if injection fails
317        }
318    }
319}