1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, BoxBody, RequestHandler};
4use crate::tls;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Empty, Full};
7use hyper::client::conn::http1 as client_http1;
8use hyper::server::conn::http1 as server_http1;
9use hyper::service::service_fn;
10use hyper::upgrade::Upgraded;
11use hyper::{Method, Request, Response};
12use hyper_util::rt::TokioIo;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::TcpStream;
16use tracing::{debug, error, info, warn};
17
18pub struct ProxyState {
20 pub ca: Arc<CertificateAuthority>,
21 pub mitm: bool,
22 pub handler: Arc<dyn RequestHandler>,
23}
24
25pub async fn handle_connection(
27 stream: TcpStream,
28 addr: SocketAddr,
29 state: Arc<ProxyState>,
30) {
31 debug!("New connection from {addr}");
32
33 let io = TokioIo::new(stream);
34 let state = state.clone();
35
36 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
37 let state = state.clone();
38 async move { handle_request(req, state).await }
39 });
40
41 if let Err(e) = server_http1::Builder::new()
42 .preserve_header_case(true)
43 .title_case_headers(true)
44 .serve_connection(io, service)
45 .with_upgrades()
46 .await
47 {
48 if !e.to_string().contains("early eof")
49 && !e.to_string().contains("connection closed")
50 {
51 error!("Connection error from {addr}: {e}");
52 }
53 }
54}
55
56async fn handle_request(
58 req: Request<hyper::body::Incoming>,
59 state: Arc<ProxyState>,
60) -> Result<Response<BoxBody>, hyper::Error> {
61 if req.method() == Method::CONNECT {
62 handle_connect(req, state).await
63 } else {
64 handle_forward(req, state).await
65 }
66}
67
68async fn handle_forward(
72 req: Request<hyper::body::Incoming>,
73 state: Arc<ProxyState>,
74) -> Result<Response<BoxBody>, hyper::Error> {
75 let uri = req.uri().clone();
76 let host = match uri.host() {
77 Some(h) => h.to_string(),
78 None => {
79 warn!("Request with no host: {uri}");
80 return Ok(bad_request("Missing host in URI"));
81 }
82 };
83 let port = uri.port_u16().unwrap_or(80);
84 let addr = format!("{host}:{port}");
85
86 let (mut parts, body) = req.into_parts();
88 let path = parts
89 .uri
90 .path_and_query()
91 .map(|pq| pq.as_str())
92 .unwrap_or("/");
93 parts.uri = match path.parse() {
94 Ok(uri) => uri,
95 Err(_) => {
96 warn!("Invalid path: {path}");
97 return Ok(bad_request("Invalid request URI"));
98 }
99 };
100 strip_hop_by_hop_headers(&mut parts.headers);
101
102 let mut forwarded_req = Request::from_parts(parts, boxed_body(body));
103
104 state.handler.handle_request(&mut forwarded_req);
106
107 let upstream = match TcpStream::connect(&addr).await {
109 Ok(s) => s,
110 Err(e) => {
111 error!("Failed to connect to {addr}: {e}");
112 return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
113 }
114 };
115
116 let io = TokioIo::new(upstream);
117 let (mut sender, conn) = match client_http1::handshake(io).await {
118 Ok(r) => r,
119 Err(e) => {
120 error!("Handshake with {addr} failed: {e}");
121 return Ok(bad_gateway("Upstream handshake failed"));
122 }
123 };
124
125 tokio::spawn(async move {
126 if let Err(e) = conn.await {
127 error!("Upstream connection error: {e}");
128 }
129 });
130
131 match sender.send_request(forwarded_req).await {
132 Ok(res) => {
133 let (parts, body) = res.into_parts();
134 let mut response = Response::from_parts(parts, boxed_body(body));
135 state.handler.handle_response(&mut response);
136 Ok(response)
137 }
138 Err(e) => {
139 error!("Upstream request failed: {e}");
140 Ok(bad_gateway("Upstream request failed"))
141 }
142 }
143}
144
145async fn handle_connect(
149 req: Request<hyper::body::Incoming>,
150 state: Arc<ProxyState>,
151) -> Result<Response<BoxBody>, hyper::Error> {
152 let target = match req.uri().authority() {
153 Some(auth) => auth.to_string(),
154 None => {
155 warn!("CONNECT without authority");
156 return Ok(bad_request("CONNECT target missing"));
157 }
158 };
159
160 let (host, port) = parse_host_port(&target);
161 let addr = format!("{host}:{port}");
162
163 info!("CONNECT {target}");
164
165 if state.mitm {
166 handle_mitm(req, host, addr, state).await
168 } else {
169 handle_tunnel(req, addr).await
171 }
172}
173
174async fn handle_tunnel(
176 req: Request<hyper::body::Incoming>,
177 addr: String,
178) -> Result<Response<BoxBody>, hyper::Error> {
179 tokio::spawn(async move {
180 match hyper::upgrade::on(req).await {
181 Ok(upgraded) => {
182 if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
183 error!("Tunnel error to {addr}: {e}");
184 }
185 }
186 Err(e) => {
187 error!("Upgrade failed: {e}");
188 }
189 }
190 });
191
192 Ok(Response::new(empty_body()))
194}
195
196async fn tunnel_bidirectional(
198 upgraded: Upgraded,
199 addr: &str,
200) -> crate::error::Result<()> {
201 let mut upstream = TcpStream::connect(addr).await?;
202
203 let mut client = TokioIo::new(upgraded);
204
205 let (client_to_server, server_to_client) =
206 tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
207
208 debug!(
209 "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
210 );
211 Ok(())
212}
213
214async fn handle_mitm(
216 req: Request<hyper::body::Incoming>,
217 host: String,
218 addr: String,
219 state: Arc<ProxyState>,
220) -> Result<Response<BoxBody>, hyper::Error> {
221 let state = state.clone();
222
223 tokio::spawn(async move {
224 match hyper::upgrade::on(req).await {
225 Ok(upgraded) => {
226 if let Err(e) =
227 mitm_intercept(upgraded, &host, &addr, state).await
228 {
229 error!("MITM error for {host}: {e}");
230 }
231 }
232 Err(e) => {
233 error!("MITM upgrade failed: {e}");
234 }
235 }
236 });
237
238 Ok(Response::new(empty_body()))
239}
240
241async fn mitm_intercept(
243 upgraded: Upgraded,
244 host: &str,
245 addr: &str,
246 state: Arc<ProxyState>,
247) -> crate::error::Result<()> {
248 let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
250
251 let client_io = TokioIo::new(upgraded);
253 let client_tls = acceptor
254 .accept(client_io)
255 .await
256 .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
257
258 let client_tls = TokioIo::new(client_tls);
259
260 let host = host.to_string();
262 let addr = addr.to_string();
263
264 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
265 let host = host.clone();
266 let addr = addr.clone();
267 let state = state.clone();
268 async move {
269 mitm_forward_request(req, &host, &addr, state).await
270 }
271 });
272
273 if let Err(e) = server_http1::Builder::new()
274 .preserve_header_case(true)
275 .title_case_headers(true)
276 .serve_connection(client_tls, service)
277 .await
278 {
279 if !e.to_string().contains("early eof")
280 && !e.to_string().contains("connection closed")
281 {
282 debug!("MITM connection closed: {e}");
283 }
284 }
285
286 Ok(())
287}
288
289async fn mitm_forward_request(
291 req: Request<hyper::body::Incoming>,
292 host: &str,
293 addr: &str,
294 state: Arc<ProxyState>,
295) -> Result<Response<BoxBody>, hyper::Error> {
296 let (mut parts, body) = req.into_parts();
297 strip_hop_by_hop_headers(&mut parts.headers);
298
299 let mut forwarded_req = Request::from_parts(parts, boxed_body(body));
300
301 state.handler.handle_request(&mut forwarded_req);
303
304 let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
306 Ok(s) => s,
307 Err(e) => {
308 error!("Failed TLS connect to {addr}: {e}");
309 return Ok(bad_gateway(&format!(
310 "Failed to connect to upstream: {e}"
311 )));
312 }
313 };
314
315 let io = TokioIo::new(upstream_tls);
316 let (mut sender, conn) = match client_http1::handshake(io).await {
317 Ok(r) => r,
318 Err(e) => {
319 error!("Upstream TLS handshake failed: {e}");
320 return Ok(bad_gateway("Upstream TLS handshake failed"));
321 }
322 };
323
324 tokio::spawn(async move {
325 if let Err(e) = conn.await {
326 debug!("Upstream TLS connection closed: {e}");
327 }
328 });
329
330 match sender.send_request(forwarded_req).await {
331 Ok(res) => {
332 let (parts, body) = res.into_parts();
333 let mut response = Response::from_parts(parts, boxed_body(body));
334 state.handler.handle_response(&mut response);
335 Ok(response)
336 }
337 Err(e) => {
338 error!("Upstream TLS request failed: {e}");
339 Ok(bad_gateway("Upstream request failed"))
340 }
341 }
342}
343
344const HOP_BY_HOP_HEADERS: &[&str] = &[
348 "connection",
349 "keep-alive",
350 "proxy-authenticate",
351 "proxy-authorization",
352 "te",
353 "trailers",
354 "transfer-encoding",
355 "upgrade",
356];
357
358pub fn parse_host_port(target: &str) -> (String, u16) {
361 if let Some(bracketed) = target.strip_prefix('[') {
362 if let Some((ip6, rest)) = bracketed.split_once(']') {
364 let port = rest
365 .strip_prefix(':')
366 .and_then(|p| p.parse().ok())
367 .unwrap_or(443);
368 return (ip6.to_string(), port);
369 }
370 }
371 if let Some((host, port_str)) = target.rsplit_once(':') {
373 if let Ok(port) = port_str.parse::<u16>() {
374 return (host.to_string(), port);
375 }
376 }
377 (target.to_string(), 443)
378}
379
380fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
381 if let Some(conn_val) = headers.get("connection").cloned() {
383 if let Ok(val) = conn_val.to_str() {
384 for name in val.split(',') {
385 let name = name.trim();
386 if !name.is_empty() {
387 headers.remove(name);
388 }
389 }
390 }
391 }
392
393 for name in HOP_BY_HOP_HEADERS {
394 headers.remove(*name);
395 }
396}
397
398fn empty_body() -> BoxBody {
399 Empty::<Bytes>::new()
400 .map_err(|never| match never {})
401 .boxed()
402}
403
404fn bad_request(msg: &str) -> Response<BoxBody> {
405 Response::builder()
406 .status(400)
407 .body(full_body(msg))
408 .unwrap()
409}
410
411fn bad_gateway(msg: &str) -> Response<BoxBody> {
412 Response::builder()
413 .status(502)
414 .body(full_body(msg))
415 .unwrap()
416}
417
418fn full_body(msg: &str) -> BoxBody {
419 Full::new(Bytes::from(msg.to_string()))
420 .map_err(|never| match never {})
421 .boxed()
422}