1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5 Method, Request, Response, StatusCode,
6 body::{Body, Incoming},
7 server,
8 service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone)]
32pub struct MitmProxy<I> {
34 pub root_issuer: Option<I>,
38 pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
43}
44
45impl<I> MitmProxy<I> {
46 pub fn new(root_issuer: Option<I>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
48 Self {
49 root_issuer,
50 cert_cache: cache,
51 }
52 }
53}
54
55impl<I> MitmProxy<I>
56where
57 I: Borrow<rcgen::Issuer<'static, rcgen::KeyPair>> + Send + Sync + 'static,
58{
59 pub async fn bind<A: ToSocketAddrs, S>(
62 self,
63 addr: A,
64 service: S,
65 ) -> Result<impl Future<Output = ()>, std::io::Error>
66 where
67 S: HttpService<Incoming> + Clone + Send + 'static,
68 S::Error: Into<Box<dyn StdError + Send + Sync>>,
69 S::ResBody: Send + Sync + 'static,
70 <S::ResBody as Body>::Data: Send,
71 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
72 S::Future: Send,
73 {
74 let listener = TcpListener::bind(addr).await?;
75
76 let proxy = Arc::new(self);
77
78 Ok(async move {
79 loop {
80 let (stream, _) = match listener.accept().await {
81 Ok(conn) => conn,
82 Err(err) => {
83 tracing::warn!("Failed to accept connection: {}", err);
84 continue;
85 }
86 };
87
88 let service = service.clone();
89
90 let proxy = proxy.clone();
91 tokio::spawn(async move {
92 if let Err(err) = server::conn::http1::Builder::new()
93 .preserve_header_case(true)
94 .title_case_headers(true)
95 .serve_connection(
96 TokioIo::new(stream),
97 Self::wrap_service(proxy.clone(), service.clone()),
98 )
99 .with_upgrades()
100 .await
101 {
102 tracing::error!("Error in proxy: {}", err);
103 }
104 });
105 }
106 })
107 }
108
109 pub fn wrap_service<S>(
115 proxy: Arc<Self>,
116 service: S,
117 ) -> impl HttpService<
118 Incoming,
119 ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
120 Future: Send,
121 >
122 where
123 S: HttpService<Incoming> + Clone + Send + 'static,
124 S::Error: Into<Box<dyn StdError + Send + Sync>>,
125 S::ResBody: Send + Sync + 'static,
126 <S::ResBody as Body>::Data: Send,
127 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
128 S::Future: Send,
129 {
130 service_fn(move |req| {
131 let proxy = proxy.clone();
132 let mut service = service.clone();
133
134 async move {
135 if req.method() == Method::CONNECT {
136 let Some(connect_authority) = req.uri().authority().cloned() else {
138 tracing::error!(
139 "Bad CONNECT request: {}, Reason: Invalid Authority",
140 req.uri()
141 );
142 return Ok(no_body(StatusCode::BAD_REQUEST)
143 .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
144 };
145
146 tokio::spawn(async move {
147 let client = match hyper::upgrade::on(req).await {
148 Ok(client) => client,
149 Err(err) => {
150 tracing::error!(
151 "Failed to upgrade CONNECT request for {}: {}",
152 connect_authority,
153 err
154 );
155 return;
156 }
157 };
158 if let Some(server_config) =
159 proxy.server_config(connect_authority.host().to_string(), true)
160 {
161 let server_config = match server_config {
162 Ok(server_config) => server_config,
163 Err(err) => {
164 tracing::error!(
165 "Failed to create server config for {}, {}",
166 connect_authority.host(),
167 err
168 );
169 return;
170 }
171 };
172 let server_config = Arc::new(server_config);
173 let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
174 let client = match tls_acceptor.accept(TokioIo::new(client)).await {
175 Ok(client) => client,
176 Err(err) => {
177 tracing::error!(
178 "Failed to accept TLS connection for {}, {}",
179 connect_authority.host(),
180 err
181 );
182 return;
183 }
184 };
185 let f = move |mut req: Request<_>| {
186 let connect_authority = connect_authority.clone();
187 let mut service = service.clone();
188
189 async move {
190 inject_authority(&mut req, connect_authority.clone());
191 service.call(req).await
192 }
193 };
194 let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
195 server::conn::http2::Builder::new(TokioExecutor::new())
196 .serve_connection(TokioIo::new(client), service_fn(f))
197 .await
198 } else {
199 server::conn::http1::Builder::new()
200 .preserve_header_case(true)
201 .title_case_headers(true)
202 .serve_connection(TokioIo::new(client), service_fn(f))
203 .with_upgrades()
204 .await
205 };
206
207 if let Err(err) = res {
208 tracing::debug!("Connection closed: {}", err);
209 }
210 } else {
211 let mut server =
212 match TcpStream::connect(connect_authority.as_str()).await {
213 Ok(server) => server,
214 Err(err) => {
215 tracing::error!(
216 "Failed to connect to {}: {}",
217 connect_authority,
218 err
219 );
220 return;
221 }
222 };
223 let _ = tokio::io::copy_bidirectional(
224 &mut TokioIo::new(client),
225 &mut server,
226 )
227 .await;
228 }
229 });
230
231 Ok(Response::new(
232 http_body_util::Empty::new()
233 .map_err(|never: std::convert::Infallible| match never {})
234 .boxed(),
235 ))
236 } else {
237 service.call(req).await.map(|res| res.map(|b| b.boxed()))
239 }
240 }
241 })
242 }
243
244 fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
245 self.root_issuer.as_ref().and_then(|root_issuer| {
246 if let Some(cache) = self.cert_cache.as_ref() {
247 cache
249 .try_get_with(host.clone(), move || {
250 generate_cert(host, root_issuer.borrow())
251 })
252 .map_err(|err| {
253 tracing::error!("Failed to generate certificate for host: {}", err);
254 })
255 .ok()
256 } else {
257 generate_cert(host, root_issuer.borrow())
258 .map_err(|err| {
259 tracing::error!("Failed to generate certificate for host: {}", err);
260 })
261 .ok()
262 }
263 })
264 }
265
266 fn server_config(
267 &self,
268 host: String,
269 h2: bool,
270 ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
271 if let Some(cert) = self.get_certified_key(host) {
272 let config = rustls::ServerConfig::builder()
273 .with_no_client_auth()
274 .with_single_cert(
275 vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
276 rustls::pki_types::PrivateKeyDer::Pkcs8(
277 rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
278 ),
279 );
280
281 Some(if h2 {
282 config.map(|mut server_config| {
283 server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
284 server_config
285 })
286 } else {
287 config
288 })
289 } else {
290 None
291 }
292 }
293}
294
295fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
296 let mut res = Response::new(Empty::new());
297 *res.status_mut() = status;
298 res
299}
300
301fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
302 let mut parts = request_middleman.uri().clone().into_parts();
303 parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
304 if parts.authority.is_none() {
305 parts.authority = Some(authority.clone());
306 }
307
308 match hyper::http::uri::Uri::from_parts(parts) {
309 Ok(uri) => *request_middleman.uri_mut() = uri,
310 Err(err) => {
311 tracing::error!(
312 "Failed to inject authority '{}' into URI: {}",
313 authority,
314 err
315 );
316 }
318 }
319}