1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5 Method, Request, Response, StatusCode,
6 body::{Body, Incoming},
7 server,
8 service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone, Copy, Debug)]
32pub struct RemoteAddr(pub std::net::SocketAddr);
33
34#[derive(Clone)]
35pub struct MitmProxy<I> {
37 pub root_issuer: Option<I>,
41 pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
46}
47
48impl<I> MitmProxy<I> {
49 pub fn new(root_issuer: Option<I>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
51 Self {
52 root_issuer,
53 cert_cache: cache,
54 }
55 }
56}
57
58impl<I> MitmProxy<I>
59where
60 I: Borrow<rcgen::Issuer<'static, rcgen::KeyPair>> + Send + Sync + 'static,
61{
62 pub async fn bind<A: ToSocketAddrs, S>(
66 self,
67 addr: A,
68 service: S,
69 ) -> Result<impl Future<Output = ()>, std::io::Error>
70 where
71 S: HttpService<Incoming> + Clone + Send + 'static,
72 S::Error: Into<Box<dyn StdError + Send + Sync>>,
73 S::ResBody: Send + Sync + 'static,
74 <S::ResBody as Body>::Data: Send,
75 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
76 S::Future: Send,
77 {
78 let listener = TcpListener::bind(addr).await?;
79
80 let proxy = Arc::new(self);
81
82 Ok(async move {
83 loop {
84 let (stream, remote_addr) = match listener.accept().await {
85 Ok(conn) => conn,
86 Err(err) => {
87 tracing::warn!("Failed to accept connection: {}", err);
88 continue;
89 }
90 };
91
92 let service = service.clone();
93
94 let proxy = proxy.clone();
95 tokio::spawn(async move {
96 if let Err(err) = server::conn::http1::Builder::new()
97 .preserve_header_case(true)
98 .title_case_headers(true)
99 .serve_connection(
100 TokioIo::new(stream),
101 service_fn(move |mut req| {
102 req.extensions_mut().insert(RemoteAddr(remote_addr));
103 Self::wrap_service(proxy.clone(), service.clone()).call(req)
104 }),
105 )
106 .with_upgrades()
107 .await
108 {
109 tracing::error!("Error in proxy: {}", err);
110 }
111 });
112 }
113 })
114 }
115
116 pub fn wrap_service<S>(
122 proxy: Arc<Self>,
123 service: S,
124 ) -> impl HttpService<
125 Incoming,
126 ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
127 Future: Send,
128 >
129 where
130 S: HttpService<Incoming> + Clone + Send + 'static,
131 S::Error: Into<Box<dyn StdError + Send + Sync>>,
132 S::ResBody: Send + Sync + 'static,
133 <S::ResBody as Body>::Data: Send,
134 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
135 S::Future: Send,
136 {
137 service_fn(move |mut req| {
138 let proxy = proxy.clone();
139 let mut service = service.clone();
140
141 async move {
142 if req.method() == Method::CONNECT {
143 let Some(connect_authority) = req.uri().authority().cloned() else {
145 tracing::error!(
146 "Bad CONNECT request: {}, Reason: Invalid Authority",
147 req.uri()
148 );
149 return Ok(no_body(StatusCode::BAD_REQUEST)
150 .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
151 };
152
153 tokio::spawn(async move {
154 let remote_addr: Option<RemoteAddr> = req.extensions_mut().remove();
155 let client = match hyper::upgrade::on(req).await {
156 Ok(client) => client,
157 Err(err) => {
158 tracing::error!(
159 "Failed to upgrade CONNECT request for {}: {}",
160 connect_authority,
161 err
162 );
163 return;
164 }
165 };
166 if let Some(server_config) =
167 proxy.server_config(connect_authority.host().to_string(), true)
168 {
169 let server_config = match server_config {
170 Ok(server_config) => server_config,
171 Err(err) => {
172 tracing::error!(
173 "Failed to create server config for {}, {}",
174 connect_authority.host(),
175 err
176 );
177 return;
178 }
179 };
180 let server_config = Arc::new(server_config);
181 let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
182 let client = match tls_acceptor.accept(TokioIo::new(client)).await {
183 Ok(client) => client,
184 Err(err) => {
185 tracing::error!(
186 "Failed to accept TLS connection for {}, {}",
187 connect_authority.host(),
188 err
189 );
190 return;
191 }
192 };
193 let f = move |mut req: Request<_>| {
194 let connect_authority = connect_authority.clone();
195 let mut service = service.clone();
196
197 async move {
198 if let Some(remote_addr) = remote_addr {
199 req.extensions_mut().insert(remote_addr);
200 }
201 inject_authority(&mut req, connect_authority.clone());
202 service.call(req).await
203 }
204 };
205 let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
206 server::conn::http2::Builder::new(TokioExecutor::new())
207 .serve_connection(TokioIo::new(client), service_fn(f))
208 .await
209 } else {
210 server::conn::http1::Builder::new()
211 .preserve_header_case(true)
212 .title_case_headers(true)
213 .serve_connection(TokioIo::new(client), service_fn(f))
214 .with_upgrades()
215 .await
216 };
217
218 if let Err(err) = res {
219 tracing::debug!("Connection closed: {}", err);
220 }
221 } else {
222 let mut server =
223 match TcpStream::connect(connect_authority.as_str()).await {
224 Ok(server) => server,
225 Err(err) => {
226 tracing::error!(
227 "Failed to connect to {}: {}",
228 connect_authority,
229 err
230 );
231 return;
232 }
233 };
234 let _ = tokio::io::copy_bidirectional(
235 &mut TokioIo::new(client),
236 &mut server,
237 )
238 .await;
239 }
240 });
241
242 Ok(Response::new(
243 http_body_util::Empty::new()
244 .map_err(|never: std::convert::Infallible| match never {})
245 .boxed(),
246 ))
247 } else {
248 service.call(req).await.map(|res| res.map(|b| b.boxed()))
250 }
251 }
252 })
253 }
254
255 fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
256 self.root_issuer.as_ref().and_then(|root_issuer| {
257 if let Some(cache) = self.cert_cache.as_ref() {
258 cache
260 .try_get_with(host.clone(), move || {
261 generate_cert(host, root_issuer.borrow())
262 })
263 .map_err(|err| {
264 tracing::error!("Failed to generate certificate for host: {}", err);
265 })
266 .ok()
267 } else {
268 generate_cert(host, root_issuer.borrow())
269 .map_err(|err| {
270 tracing::error!("Failed to generate certificate for host: {}", err);
271 })
272 .ok()
273 }
274 })
275 }
276
277 fn server_config(
278 &self,
279 host: String,
280 h2: bool,
281 ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
282 if let Some(cert) = self.get_certified_key(host) {
283 let config = rustls::ServerConfig::builder()
284 .with_no_client_auth()
285 .with_single_cert(
286 vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
287 rustls::pki_types::PrivateKeyDer::Pkcs8(
288 rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
289 ),
290 );
291
292 Some(if h2 {
293 config.map(|mut server_config| {
294 server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
295 server_config
296 })
297 } else {
298 config
299 })
300 } else {
301 None
302 }
303 }
304}
305
306fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
307 let mut res = Response::new(Empty::new());
308 *res.status_mut() = status;
309 res
310}
311
312fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
313 let mut parts = request_middleman.uri().clone().into_parts();
314 parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
315 if parts.authority.is_none() {
316 parts.authority = Some(authority.clone());
317 }
318
319 match hyper::http::uri::Uri::from_parts(parts) {
320 Ok(uri) => *request_middleman.uri_mut() = uri,
321 Err(err) => {
322 tracing::error!(
323 "Failed to inject authority '{}' into URI: {}",
324 authority,
325 err
326 );
327 }
329 }
330}