1#![doc = include_str!("../README.md")]
2
3use http_body_util::{BodyExt, Empty, combinators::BoxBody};
4use hyper::{
5 Method, Request, Response, StatusCode,
6 body::{Body, Incoming},
7 server,
8 service::{HttpService, service_fn},
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use moka::sync::Cache;
12use std::{borrow::Borrow, error::Error as StdError, future::Future, sync::Arc};
13use tls::{CertifiedKeyDer, generate_cert};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio_rustls::rustls;
16
17pub use futures;
18pub use hyper;
19pub use moka;
20
21#[cfg(feature = "native-tls-client")]
22pub use tokio_native_tls;
23
24#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
25pub mod default_client;
26mod tls;
27
28#[cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
29pub use default_client::DefaultClient;
30
31#[derive(Clone)]
32pub struct MitmProxy<C> {
34 pub root_cert: Option<C>,
38 pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
43}
44
45impl<C> MitmProxy<C> {
46 pub fn new(root_cert: Option<C>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
48 Self {
49 root_cert,
50 cert_cache: cache,
51 }
52 }
53}
54
55impl<C> MitmProxy<C>
56where
57 C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static,
58{
59 pub async fn bind<A: ToSocketAddrs, S>(
62 self,
63 addr: A,
64 service: S,
65 ) -> Result<impl Future<Output = ()>, std::io::Error>
66 where
67 S: HttpService<Incoming> + Clone + Send + 'static,
68 S::Error: Into<Box<dyn StdError + Send + Sync>>,
69 S::ResBody: Send + Sync + 'static,
70 <S::ResBody as Body>::Data: Send,
71 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
72 S::Future: Send,
73 {
74 let listener = TcpListener::bind(addr).await?;
75
76 let proxy = Arc::new(self);
77
78 Ok(async move {
79 loop {
80 let Ok((stream, _)) = listener.accept().await else {
81 continue;
82 };
83
84 let service = service.clone();
85
86 let proxy = proxy.clone();
87 tokio::spawn(async move {
88 if let Err(err) = server::conn::http1::Builder::new()
89 .preserve_header_case(true)
90 .title_case_headers(true)
91 .serve_connection(
92 TokioIo::new(stream),
93 Self::wrap_service(proxy.clone(), service.clone()),
94 )
95 .with_upgrades()
96 .await
97 {
98 tracing::error!("Error in proxy: {}", err);
99 }
100 });
101 }
102 })
103 }
104
105 pub fn wrap_service<S>(
111 proxy: Arc<Self>,
112 service: S,
113 ) -> impl HttpService<
114 Incoming,
115 ResBody = BoxBody<<S::ResBody as Body>::Data, <S::ResBody as Body>::Error>,
116 Future: Send,
117 >
118 where
119 S: HttpService<Incoming> + Clone + Send + 'static,
120 S::Error: Into<Box<dyn StdError + Send + Sync>>,
121 S::ResBody: Send + Sync + 'static,
122 <S::ResBody as Body>::Data: Send,
123 <S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
124 S::Future: Send,
125 {
126 service_fn(move |req| {
127 let proxy = proxy.clone();
128 let mut service = service.clone();
129
130 async move {
131 if req.method() == Method::CONNECT {
132 let Some(connect_authority) = req.uri().authority().cloned() else {
134 tracing::error!(
135 "Bad CONNECT request: {}, Reason: Invalid Authority",
136 req.uri()
137 );
138 return Ok(no_body(StatusCode::BAD_REQUEST)
139 .map(|b| b.boxed().map_err(|never| match never {}).boxed()));
140 };
141
142 tokio::spawn(async move {
143 let Ok(client) = hyper::upgrade::on(req).await else {
144 tracing::error!(
145 "Bad CONNECT request: {}, Reason: Invalid Upgrade",
146 connect_authority
147 );
148 return;
149 };
150 if let Some(server_config) =
151 proxy.server_config(connect_authority.host().to_string(), true)
152 {
153 let server_config = match server_config {
154 Ok(server_config) => server_config,
155 Err(err) => {
156 tracing::error!(
157 "Failed to create server config for {}, {}",
158 connect_authority.host(),
159 err
160 );
161 return;
162 }
163 };
164 let server_config = Arc::new(server_config);
165 let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
166 let client = match tls_acceptor.accept(TokioIo::new(client)).await {
167 Ok(client) => client,
168 Err(err) => {
169 tracing::error!(
170 "Failed to accept TLS connection for {}, {}",
171 connect_authority.host(),
172 err
173 );
174 return;
175 }
176 };
177 let f = move |mut req: Request<_>| {
178 let connect_authority = connect_authority.clone();
179 let mut service = service.clone();
180
181 async move {
182 inject_authority(&mut req, connect_authority.clone());
183 service.call(req).await
184 }
185 };
186 let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") {
187 server::conn::http2::Builder::new(TokioExecutor::new())
188 .serve_connection(TokioIo::new(client), service_fn(f))
189 .await
190 } else {
191 server::conn::http1::Builder::new()
192 .preserve_header_case(true)
193 .title_case_headers(true)
194 .serve_connection(TokioIo::new(client), service_fn(f))
195 .with_upgrades()
196 .await
197 };
198
199 if let Err(_err) = res {
200 }
203 } else {
204 let Ok(mut server) =
205 TcpStream::connect(connect_authority.as_str()).await
206 else {
207 tracing::error!("Failed to connect to {}", connect_authority);
208 return;
209 };
210 let _ = tokio::io::copy_bidirectional(
211 &mut TokioIo::new(client),
212 &mut server,
213 )
214 .await;
215 }
216 });
217
218 Ok(Response::new(
219 http_body_util::Empty::new()
220 .map_err(|never: std::convert::Infallible| match never {})
221 .boxed(),
222 ))
223 } else {
224 service.call(req).await.map(|res| res.map(|b| b.boxed()))
226 }
227 }
228 })
229 }
230
231 fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
232 self.root_cert.as_ref().map(|root_cert| {
233 if let Some(cache) = self.cert_cache.as_ref() {
234 cache.get_with(host.clone(), move || {
235 generate_cert(host, root_cert.borrow())
236 })
237 } else {
238 generate_cert(host, root_cert.borrow())
239 }
240 })
241 }
242
243 fn server_config(
244 &self,
245 host: String,
246 h2: bool,
247 ) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
248 if let Some(cert) = self.get_certified_key(host) {
249 let config = rustls::ServerConfig::builder()
250 .with_no_client_auth()
251 .with_single_cert(
252 vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
253 rustls::pki_types::PrivateKeyDer::Pkcs8(
254 rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
255 ),
256 );
257
258 Some(if h2 {
259 config.map(|mut server_config| {
260 server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
261 server_config
262 })
263 } else {
264 config
265 })
266 } else {
267 None
268 }
269 }
270}
271
272fn no_body<D>(status: StatusCode) -> Response<Empty<D>> {
273 let mut res = Response::new(Empty::new());
274 *res.status_mut() = status;
275 res
276}
277
278fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::http::uri::Authority) {
279 let mut parts = request_middleman.uri().clone().into_parts();
280 parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
281 if parts.authority.is_none() {
282 parts.authority = Some(authority);
283 }
284 *request_middleman.uri_mut() = hyper::http::uri::Uri::from_parts(parts).unwrap();
285}