http_mitm_proxy/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5    Method, Request, Response, StatusCode,
6    body::{Body, Incoming},
7    server,
8    service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone, Copy, Debug)]
32pub struct RemoteAddr(pub std::net::SocketAddr);
33
34#[derive(Clone)]
35/// The main struct to run proxy server
36pub struct MitmProxy<I> {
37    /// Root issuer to sign fake certificates. You may need to trust this issuer on client application to use HTTPS.
38    ///
39    /// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic.
40    pub root_issuer: Option<I>,
41    /// Cache to store generated certificates. If None, cache will not be used.
42    /// If root_issuer is None, cache will not be used.
43    ///
44    /// The key of cache is hostname.
45    pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
46}
47
48impl<I> MitmProxy<I> {
49    /// Create a new MitmProxy
50    pub fn new(root_issuer: Option<I>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
51        Self {
52            root_issuer,
53            cert_cache: cache,
54        }
55    }
56}
57
58impl<I> MitmProxy<I>
59where
60    I: Borrow<rcgen::Issuer<'static, rcgen::KeyPair>> + Send + Sync + 'static,
61{
62    /// Bind to a socket address and return a future that runs the proxy server.
63    /// URL for requests that passed to service are full URL including scheme.
64    /// remote address of client is stored in request extensions as `RemoteAddr`.
65    pub async fn bind<A: ToSocketAddrs, S>(
66        self,
67        addr: A,
68        service: S,
69    ) -> Result<impl Future<Output = ()>, std::io::Error>
70    where
71        S: HttpService<Incoming> + Clone + Send + 'static,
72        S::Error: Into<Box<dyn StdError + Send + Sync>>,
73        S::ResBody: Send + Sync + 'static,
74        <S::ResBody as Body>::Data: Send,
75        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
76        S::Future: Send,
77    {
78        let listener = TcpListener::bind(addr).await?;
79
80        let proxy = Arc::new(self);
81
82        Ok(async move {
83            loop {
84                let (stream, remote_addr) = match listener.accept().await {
85                    Ok(conn) => conn,
86                    Err(err) => {
87                        tracing::warn!("Failed to accept connection: {}", err);
88                        continue;
89                    }
90                };
91
92                let service = service.clone();
93
94                let proxy = proxy.clone();
95                tokio::spawn(async move {
96                    if let Err(err) = server::conn::http1::Builder::new()
97                        .preserve_header_case(true)
98                        .title_case_headers(true)
99                        .serve_connection(
100                            TokioIo::new(stream),
101                            service_fn(move |mut req| {
102                                req.extensions_mut().insert(RemoteAddr(remote_addr));
103                                Self::wrap_service(proxy.clone(), service.clone()).call(req)
104                            }),
105                        )
106                        .with_upgrades()
107                        .await
108                    {
109                        tracing::error!("Error in proxy: {}", err);
110                    }
111                });
112            }
113        })
114    }
115
116    /// Transform a service to a service that can be used in hyper server.
117    /// URL for requests that passed to service are full URL including scheme.
118    /// See `examples/https.rs` for usage.
119    /// If you want to serve simple HTTP proxy server, you can use `bind` method instead.
120    /// `bind` will call this method internally.
121    pub fn wrap_service<S>(
122        proxy: Arc<Self>,
123        service: S,
124    ) -> impl HttpService<
125        Incoming,
126        ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
127        Future: Send,
128    >
129    where
130        S: HttpService<Incoming> + Clone + Send + 'static,
131        S::Error: Into<Box<dyn StdError + Send + Sync>>,
132        S::ResBody: Send + Sync + 'static,
133        <S::ResBody as Body>::Data: Send,
134        <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
135        S::Future: Send,
136    {
137        service_fn(move |mut req| {
138            let proxy = proxy.clone();
139            let mut service = service.clone();
140
141            async move {
142                if req.method() == Method::CONNECT {
143                    // https
144                    let Some(connect_authority) = req.uri().authority().cloned() else {
145                        tracing::error!(
146                            "Bad CONNECT request: {}, Reason: Invalid Authority",
147                            req.uri()
148                        );
149                        return Ok(no_body(StatusCode::BAD_REQUEST)
150                            .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
151                    };
152
153                    tokio::spawn(async move {
154                        let remote_addr: Option<RemoteAddr> = req.extensions_mut().remove();
155                        let client = match hyper::upgrade::on(req).await {
156                            Ok(client) => client,
157                            Err(err) => {
158                                tracing::error!(
159                                    "Failed to upgrade CONNECT request for {}: {}",
160                                    connect_authority,
161                                    err
162                                );
163                                return;
164                            }
165                        };
166                        if let Some(server_config) =
167                            proxy.server_config(connect_authority.host().to_string(), true)
168                        {
169                            let server_config = match server_config {
170                                Ok(server_config) => server_config,
171                                Err(err) => {
172                                    tracing::error!(
173                                        "Failed to create server config for {}, {}",
174                                        connect_authority.host(),
175                                        err
176                                    );
177                                    return;
178                                }
179                            };
180                            let server_config = Arc::new(server_config);
181                            let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
182                            let client = match tls_acceptor.accept(TokioIo::new(client)).await {
183                                Ok(client) => client,
184                                Err(err) => {
185                                    tracing::error!(
186                                        "Failed to accept TLS connection for {}, {}",
187                                        connect_authority.host(),
188                                        err
189                                    );
190                                    return;
191                                }
192                            };
193                            let f = move |mut req: Request<_>| {
194                                let connect_authority = connect_authority.clone();
195                                let mut service = service.clone();
196
197                                async move {
198                                    if let Some(remote_addr) = remote_addr {
199                                        req.extensions_mut().insert(remote_addr);
200                                    }
201                                    inject_authority(&mut req, connect_authority.clone());
202                                    service.call(req).await
203                                }
204                            };
205                            let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
206                                server::conn::http2::Builder::new(TokioExecutor::new())
207                                    .serve_connection(TokioIo::new(client), service_fn(f))
208                                    .await
209                            } else {
210                                server::conn::http1::Builder::new()
211                                    .preserve_header_case(true)
212                                    .title_case_headers(true)
213                                    .serve_connection(TokioIo::new(client), service_fn(f))
214                                    .with_upgrades()
215                                    .await
216                            };
217
218                            if let Err(err) = res {
219                                tracing::debug!("Connection closed: {}", err);
220                            }
221                        } else {
222                            let mut server =
223                                match TcpStream::connect(connect_authority.as_str()).await {
224                                    Ok(server) => server,
225                                    Err(err) => {
226                                        tracing::error!(
227                                            "Failed to connect to {}: {}",
228                                            connect_authority,
229                                            err
230                                        );
231                                        return;
232                                    }
233                                };
234                            let _ = tokio::io::copy_bidirectional(
235                                &mut TokioIo::new(client),
236                                &mut server,
237                            )
238                            .await;
239                        }
240                    });
241
242                    Ok(Response::new(
243                        http_body_util::Empty::new()
244                            .map_err(|never: std::convert::Infallible| match never {})
245                            .boxed(),
246                    ))
247                } else {
248                    // http
249                    service.call(req).await.map(|res| res.map(|b| b.boxed()))
250                }
251            }
252        })
253    }
254
255    fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
256        self.root_issuer.as_ref().and_then(|root_issuer| {
257            if let Some(cache) = self.cert_cache.as_ref() {
258                // Try to get from cache, but handle generation errors gracefully
259                cache
260                    .try_get_with(host.clone(), move || {
261                        generate_cert(host, root_issuer.borrow())
262                    })
263                    .map_err(|err| {
264                        tracing::error!("Failed to generate certificate for host: {}", err);
265                    })
266                    .ok()
267            } else {
268                generate_cert(host, root_issuer.borrow())
269                    .map_err(|err| {
270                        tracing::error!("Failed to generate certificate for host: {}", err);
271                    })
272                    .ok()
273            }
274        })
275    }
276
277    fn server_config(
278        &self,
279        host: String,
280        h2: bool,
281    ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
282        if let Some(cert) = self.get_certified_key(host) {
283            let config = rustls::ServerConfig::builder()
284                .with_no_client_auth()
285                .with_single_cert(
286                    vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
287                    rustls::pki_types::PrivateKeyDer::Pkcs8(
288                        rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
289                    ),
290                );
291
292            Some(if h2 {
293                config.map(|mut server_config| {
294                    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
295                    server_config
296                })
297            } else {
298                config
299            })
300        } else {
301            None
302        }
303    }
304}
305
306fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
307    let mut res = Response::new(Empty::new());
308    *res.status_mut() = status;
309    res
310}
311
312fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
313    let mut parts = request_middleman.uri().clone().into_parts();
314    parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
315    if parts.authority.is_none() {
316        parts.authority = Some(authority.clone());
317    }
318
319    match hyper::http::uri::Uri::from_parts(parts) {
320        Ok(uri) => *request_middleman.uri_mut() = uri,
321        Err(err) => {
322            tracing::error!(
323                "Failed to inject authority '{}' into URI: {}",
324                authority,
325                err
326            );
327            // Keep the original URI if injection fails
328        }
329    }
330}