1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5 Method, Request, Response, StatusCode,
6 body::{Body, Incoming},
7 server,
8 service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24pub mod default_client;
25mod tls;
26
27pub use default_client::DefaultClient;
28
29#[derive(Clone)]
30pub struct MitmProxy<C> {
32 pub root_cert: Option<C>,
36 pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
41}
42
43impl<C> MitmProxy<C> {
44 pub fn new(root_cert: Option<C>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
46 Self {
47 root_cert,
48 cert_cache: cache,
49 }
50 }
51}
52
53impl<C> MitmProxy<C>
54where
55 C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static,
56{
57 pub async fn bind<A: ToSocketAddrs, S>(
60 self,
61 addr: A,
62 service: S,
63 ) -> Result<impl Future<Output = ()>, std::io::Error>
64 where
65 S: HttpService<Incoming> + Clone + Send + 'static,
66 S::Error: Into<Box<dyn StdError + Send + Sync>>,
67 S::ResBody: Send + Sync + 'static,
68 <S::ResBody as Body>::Data: Send,
69 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
70 S::Future: Send,
71 {
72 let listener = TcpListener::bind(addr).await?;
73
74 let proxy = Arc::new(self);
75
76 Ok(async move {
77 loop {
78 let Ok((stream, _)) = listener.accept().await else {
79 continue;
80 };
81
82 let service = service.clone();
83
84 let proxy = proxy.clone();
85 tokio::spawn(async move {
86 if let Err(err) = server::conn::http1::Builder::new()
87 .preserve_header_case(true)
88 .title_case_headers(true)
89 .serve_connection(
90 TokioIo::new(stream),
91 Self::wrap_service(proxy.clone(), service.clone()),
92 )
93 .with_upgrades()
94 .await
95 {
96 tracing::error!("Error in proxy: {}", err);
97 }
98 });
99 }
100 })
101 }
102
103 pub fn wrap_service<S>(
109 proxy: Arc<Self>,
110 service: S,
111 ) -> impl HttpService<
112 Incoming,
113 ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
114 Future: Send,
115 >
116 where
117 S: HttpService<Incoming> + Clone + Send + 'static,
118 S::Error: Into<Box<dyn StdError + Send + Sync>>,
119 S::ResBody: Send + Sync + 'static,
120 <S::ResBody as Body>::Data: Send,
121 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
122 S::Future: Send,
123 {
124 service_fn(move |req| {
125 let proxy = proxy.clone();
126 let mut service = service.clone();
127
128 async move {
129 if req.method() == Method::CONNECT {
130 let Some(connect_authority) = req.uri().authority().cloned() else {
132 tracing::error!(
133 "Bad CONNECT request: {}, Reason: Invalid Authority",
134 req.uri()
135 );
136 return Ok(no_body(StatusCode::BAD_REQUEST)
137 .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
138 };
139
140 tokio::spawn(async move {
141 let Ok(client) = hyper::upgrade::on(req).await else {
142 tracing::error!(
143 "Bad CONNECT request: {}, Reason: Invalid Upgrade",
144 connect_authority
145 );
146 return;
147 };
148 if let Some(server_config) =
149 proxy.server_config(connect_authority.host().to_string(), true)
150 {
151 let server_config = match server_config {
152 Ok(server_config) => server_config,
153 Err(err) => {
154 tracing::error!(
155 "Failed to create server config for {}, {}",
156 connect_authority.host(),
157 err
158 );
159 return;
160 }
161 };
162 let server_config = Arc::new(server_config);
163 let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
164 let client = match tls_acceptor.accept(TokioIo::new(client)).await {
165 Ok(client) => client,
166 Err(err) => {
167 tracing::error!(
168 "Failed to accept TLS connection for {}, {}",
169 connect_authority.host(),
170 err
171 );
172 return;
173 }
174 };
175 let f = move |mut req: Request<_>| {
176 let connect_authority = connect_authority.clone();
177 let mut service = service.clone();
178
179 async move {
180 inject_authority(&mut req, connect_authority.clone());
181 service.call(req).await
182 }
183 };
184 let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
185 server::conn::http2::Builder::new(TokioExecutor::new())
186 .serve_connection(TokioIo::new(client), service_fn(f))
187 .await
188 } else {
189 server::conn::http1::Builder::new()
190 .preserve_header_case(true)
191 .title_case_headers(true)
192 .serve_connection(TokioIo::new(client), service_fn(f))
193 .with_upgrades()
194 .await
195 };
196
197 if let Err(_err) = res {
198 }
201 } else {
202 let Ok(mut server) =
203 TcpStream::connect(connect_authority.as_str()).await
204 else {
205 tracing::error!("Failed to connect to {}", connect_authority);
206 return;
207 };
208 let _ = tokio::io::copy_bidirectional(
209 &mut TokioIo::new(client),
210 &mut server,
211 )
212 .await;
213 }
214 });
215
216 Ok(Response::new(
217 http_body_util::Empty::new()
218 .map_err(|never: std::convert::Infallible| match never {})
219 .boxed(),
220 ))
221 } else {
222 service.call(req).await.map(|res| res.map(|b| b.boxed()))
224 }
225 }
226 })
227 }
228
229 fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
230 self.root_cert.as_ref().map(|root_cert| {
231 if let Some(cache) = self.cert_cache.as_ref() {
232 cache.get_with(host.clone(), move || {
233 generate_cert(host, root_cert.borrow())
234 })
235 } else {
236 generate_cert(host, root_cert.borrow())
237 }
238 })
239 }
240
241 fn server_config(
242 &self,
243 host: String,
244 h2: bool,
245 ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
246 if let Some(cert) = self.get_certified_key(host) {
247 let config = rustls::ServerConfig::builder()
248 .with_no_client_auth()
249 .with_single_cert(
250 vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
251 rustls::pki_types::PrivateKeyDer::Pkcs8(
252 rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
253 ),
254 );
255
256 Some(if h2 {
257 config.map(|mut server_config| {
258 server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
259 server_config
260 })
261 } else {
262 config
263 })
264 } else {
265 None
266 }
267 }
268}
269
270fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
271 let mut res = Response::new(Empty::new());
272 *res.status_mut() = status;
273 res
274}
275
276fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
277 let mut parts = request_middleman.uri().clone().into_parts();
278 parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
279 if parts.authority.is_none() {
280 parts.authority = Some(authority);
281 }
282 *request_middleman.uri_mut() = hyper::http::uri::Uri::from_parts(parts).unwrap();
283}