1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, full_boxed_body, Buffered, BoxBody, Dropped, RequestHandler};
4use crate::logging::{LogId, UpstreamTarget};
5use crate::tls;
6use bytes::Bytes;
7use http_body_util::{BodyExt, Empty, Full};
8use hyper::client::conn::http1 as client_http1;
9use hyper::server::conn::http1 as server_http1;
10use hyper::service::service_fn;
11use hyper::upgrade::Upgraded;
12use hyper::{Method, Request, Response};
13use hyper_util::rt::TokioIo;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::TcpStream;
17use tracing::{debug, error, info, warn};
18
19const MAX_INTERCEPT_BODY: usize = 10 * 1024 * 1024;
21
22fn should_intercept_body(headers: &hyper::HeaderMap) -> bool {
27 if let Some(cl) = headers.get(hyper::header::CONTENT_LENGTH) {
28 if let Ok(s) = cl.to_str() {
29 if let Ok(len) = s.parse::<usize>() {
30 return len <= MAX_INTERCEPT_BODY;
31 }
32 }
33 }
34 false
35}
36
37async fn try_collect_body<B>(body: B) -> Option<Bytes>
39where
40 B: hyper::body::Body<Data = Bytes, Error = hyper::Error>,
41{
42 use http_body_util::Limited;
43 let limited = Limited::new(body, MAX_INTERCEPT_BODY);
44 BodyExt::collect(limited)
45 .await
46 .ok()
47 .map(|c| c.to_bytes())
48}
49
50pub struct ProxyState {
52 pub ca: Arc<CertificateAuthority>,
53 pub mitm: bool,
54 pub intercept: bool,
55 pub log_traffic: bool,
56 pub handler: Arc<dyn RequestHandler>,
57}
58
59fn flush_log_on_error(
61 handler: &Arc<dyn RequestHandler>,
62 log_id: Option<LogId>,
63 status: u16,
64) {
65 if let Some(id) = log_id {
66 let mut res = Response::builder()
67 .status(status)
68 .body(full_boxed_body(Bytes::new()))
69 .unwrap();
70 res.extensions_mut().insert(id);
71 handler.handle_response(&mut res);
72 }
73}
74
75pub async fn handle_connection(
77 stream: TcpStream,
78 addr: SocketAddr,
79 state: Arc<ProxyState>,
80) {
81 debug!("New connection from {addr}");
82
83 let io = TokioIo::new(stream);
84 let state = state.clone();
85
86 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
87 let state = state.clone();
88 async move { handle_request(req, state).await }
89 });
90
91 if let Err(e) = server_http1::Builder::new()
92 .preserve_header_case(true)
93 .title_case_headers(true)
94 .serve_connection(io, service)
95 .with_upgrades()
96 .await
97 {
98 if !e.to_string().contains("early eof")
99 && !e.to_string().contains("connection closed")
100 {
101 error!("Connection error from {addr}: {e}");
102 }
103 }
104}
105
106async fn handle_request(
108 req: Request<hyper::body::Incoming>,
109 state: Arc<ProxyState>,
110) -> Result<Response<BoxBody>, hyper::Error> {
111 if req.method() == Method::CONNECT {
112 handle_connect(req, state).await
113 } else {
114 handle_forward(req, state).await
115 }
116}
117
118async fn handle_forward(
122 req: Request<hyper::body::Incoming>,
123 state: Arc<ProxyState>,
124) -> Result<Response<BoxBody>, hyper::Error> {
125 let uri = req.uri().clone();
126 let host = match uri.host() {
127 Some(h) => h.to_string(),
128 None => {
129 warn!("Request with no host: {uri}");
130 return Ok(bad_request("Missing host in URI"));
131 }
132 };
133 let port = uri.port_u16().unwrap_or(80);
134 let addr = format!("{host}:{port}");
135
136 let (mut parts, body) = req.into_parts();
138 let path = parts
139 .uri
140 .path_and_query()
141 .map(|pq| pq.as_str())
142 .unwrap_or("/");
143 parts.uri = match path.parse() {
144 Ok(uri) => uri,
145 Err(_) => {
146 warn!("Invalid path: {path}");
147 return Ok(bad_request("Invalid request URI"));
148 }
149 };
150
151 parts.extensions.insert(UpstreamTarget {
153 scheme: "http".into(),
154 host: host.to_string(),
155 port,
156 });
157
158 let do_buffer = (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers);
160
161 strip_hop_by_hop_headers(&mut parts.headers);
162
163 let mut forwarded_req = if do_buffer {
164 match try_collect_body(body).await {
165 Some(bytes) => {
166 let mut req = Request::from_parts(parts, full_boxed_body(bytes));
167 req.extensions_mut().insert(Buffered);
168 req
169 }
170 None => {
171 error!("Request body collection failed");
172 return Ok(bad_gateway("Request body read error"));
173 }
174 }
175 } else {
176 Request::from_parts(parts, boxed_body(body))
177 };
178
179 state.handler.handle_request(&mut forwarded_req);
180 let log_id = forwarded_req.extensions().get::<LogId>().cloned();
181
182 if forwarded_req.extensions().get::<Dropped>().is_some() {
183 return Ok(bad_gateway("Request dropped by interceptor"));
184 }
185
186 let upstream = match TcpStream::connect(&addr).await {
188 Ok(s) => s,
189 Err(e) => {
190 error!("Failed to connect to {addr}: {e}");
191 flush_log_on_error(&state.handler, log_id, 502);
192 return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
193 }
194 };
195
196 let io = TokioIo::new(upstream);
197 let (mut sender, conn) = match client_http1::handshake(io).await {
198 Ok(r) => r,
199 Err(e) => {
200 error!("Handshake with {addr} failed: {e}");
201 flush_log_on_error(&state.handler, log_id, 502);
202 return Ok(bad_gateway("Upstream handshake failed"));
203 }
204 };
205
206 tokio::spawn(async move {
207 if let Err(e) = conn.await {
208 error!("Upstream connection error: {e}");
209 }
210 });
211
212 match sender.send_request(forwarded_req).await {
213 Ok(res) => {
214 let (parts, body) = res.into_parts();
215 let mut response = if (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers) {
216 match try_collect_body(body).await {
217 Some(bytes) => {
218 let mut res = Response::from_parts(parts, full_boxed_body(bytes));
219 res.extensions_mut().insert(Buffered);
220 res
221 }
222 None => {
223 error!("Response body collection failed");
224 flush_log_on_error(&state.handler, log_id, 502);
225 return Ok(bad_gateway("Response body read error"));
226 }
227 }
228 } else {
229 Response::from_parts(parts, boxed_body(body))
230 };
231 if let Some(id) = log_id.clone() { response.extensions_mut().insert(id); }
232 state.handler.handle_response(&mut response);
233 if response.extensions().get::<Dropped>().is_some() {
234 return Ok(interceptor_dropped_response());
235 }
236 Ok(response)
237 }
238 Err(e) => {
239 error!("Upstream request failed: {e}");
240 flush_log_on_error(&state.handler, log_id, 502);
241 Ok(bad_gateway("Upstream request failed"))
242 }
243 }
244}
245
246async fn handle_connect(
250 req: Request<hyper::body::Incoming>,
251 state: Arc<ProxyState>,
252) -> Result<Response<BoxBody>, hyper::Error> {
253 let target = match req.uri().authority() {
254 Some(auth) => auth.to_string(),
255 None => {
256 warn!("CONNECT without authority");
257 return Ok(bad_request("CONNECT target missing"));
258 }
259 };
260
261 let (host, port) = parse_host_port(&target);
262 let addr = format!("{host}:{port}");
263
264 info!("CONNECT {target}");
265
266 if state.mitm {
267 handle_mitm(req, host, addr, state).await
269 } else {
270 handle_tunnel(req, addr).await
272 }
273}
274
275async fn handle_tunnel(
277 req: Request<hyper::body::Incoming>,
278 addr: String,
279) -> Result<Response<BoxBody>, hyper::Error> {
280 tokio::spawn(async move {
281 match hyper::upgrade::on(req).await {
282 Ok(upgraded) => {
283 if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
284 error!("Tunnel error to {addr}: {e}");
285 }
286 }
287 Err(e) => {
288 error!("Upgrade failed: {e}");
289 }
290 }
291 });
292
293 Ok(Response::new(empty_body()))
295}
296
297async fn tunnel_bidirectional(
299 upgraded: Upgraded,
300 addr: &str,
301) -> crate::error::Result<()> {
302 let mut upstream = TcpStream::connect(addr).await?;
303
304 let mut client = TokioIo::new(upgraded);
305
306 let (client_to_server, server_to_client) =
307 tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
308
309 debug!(
310 "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
311 );
312 Ok(())
313}
314
315async fn handle_mitm(
317 req: Request<hyper::body::Incoming>,
318 host: String,
319 addr: String,
320 state: Arc<ProxyState>,
321) -> Result<Response<BoxBody>, hyper::Error> {
322 let state = state.clone();
323
324 tokio::spawn(async move {
325 match hyper::upgrade::on(req).await {
326 Ok(upgraded) => {
327 if let Err(e) =
328 mitm_intercept(upgraded, &host, &addr, state).await
329 {
330 error!("MITM error for {host}: {e}");
331 }
332 }
333 Err(e) => {
334 error!("MITM upgrade failed: {e}");
335 }
336 }
337 });
338
339 Ok(Response::new(empty_body()))
340}
341
342async fn mitm_intercept(
344 upgraded: Upgraded,
345 host: &str,
346 addr: &str,
347 state: Arc<ProxyState>,
348) -> crate::error::Result<()> {
349 let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
351
352 let client_io = TokioIo::new(upgraded);
354 let client_tls = acceptor
355 .accept(client_io)
356 .await
357 .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
358
359 let client_tls = TokioIo::new(client_tls);
360
361 let host = host.to_string();
363 let addr = addr.to_string();
364
365 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
366 let host = host.clone();
367 let addr = addr.clone();
368 let state = state.clone();
369 async move {
370 mitm_forward_request(req, &host, &addr, state).await
371 }
372 });
373
374 if let Err(e) = server_http1::Builder::new()
375 .preserve_header_case(true)
376 .title_case_headers(true)
377 .serve_connection(client_tls, service)
378 .await
379 {
380 if !e.to_string().contains("early eof")
381 && !e.to_string().contains("connection closed")
382 {
383 debug!("MITM connection closed: {e}");
384 }
385 }
386
387 Ok(())
388}
389
390async fn mitm_forward_request(
392 req: Request<hyper::body::Incoming>,
393 host: &str,
394 addr: &str,
395 state: Arc<ProxyState>,
396) -> Result<Response<BoxBody>, hyper::Error> {
397 let (mut parts, body) = req.into_parts();
398
399 parts.extensions.insert(UpstreamTarget {
400 scheme: "https".into(),
401 host: host.to_string(),
402 port: addr.rsplit_once(':').and_then(|(_, p)| p.parse().ok()).unwrap_or(443),
403 });
404
405 let do_buffer = (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers);
406 strip_hop_by_hop_headers(&mut parts.headers);
407
408 let mut forwarded_req = if do_buffer {
409 match try_collect_body(body).await {
410 Some(bytes) => {
411 let mut req = Request::from_parts(parts, full_boxed_body(bytes));
412 req.extensions_mut().insert(Buffered);
413 req
414 }
415 None => {
416 error!("MITM request body collection failed");
417 return Ok(bad_gateway("Request body read error"));
418 }
419 }
420 } else {
421 Request::from_parts(parts, boxed_body(body))
422 };
423
424 state.handler.handle_request(&mut forwarded_req);
425 let log_id = forwarded_req.extensions().get::<LogId>().cloned();
426
427 if forwarded_req.extensions().get::<Dropped>().is_some() {
428 return Ok(bad_gateway("Request dropped by interceptor"));
429 }
430
431 let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
433 Ok(s) => s,
434 Err(e) => {
435 error!("Failed TLS connect to {addr}: {e}");
436 flush_log_on_error(&state.handler, log_id.clone(), 502);
437 return Ok(bad_gateway(&format!(
438 "Failed to connect to upstream: {e}"
439 )));
440 }
441 };
442
443 let io = TokioIo::new(upstream_tls);
444 let (mut sender, conn) = match client_http1::handshake(io).await {
445 Ok(r) => r,
446 Err(e) => {
447 error!("Upstream TLS handshake failed: {e}");
448 flush_log_on_error(&state.handler, log_id.clone(), 502);
449 return Ok(bad_gateway("Upstream TLS handshake failed"));
450 }
451 };
452
453 tokio::spawn(async move {
454 if let Err(e) = conn.await {
455 debug!("Upstream TLS connection closed: {e}");
456 }
457 });
458
459 match sender.send_request(forwarded_req).await {
460 Ok(res) => {
461 let (parts, body) = res.into_parts();
462 let mut response = if (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers) {
463 match try_collect_body(body).await {
464 Some(bytes) => {
465 let mut res = Response::from_parts(parts, full_boxed_body(bytes));
466 res.extensions_mut().insert(Buffered);
467 res
468 }
469 None => {
470 error!("MITM response body collection failed");
471 flush_log_on_error(&state.handler, log_id, 502);
472 return Ok(bad_gateway("Response body read error"));
473 }
474 }
475 } else {
476 Response::from_parts(parts, boxed_body(body))
477 };
478 if let Some(id) = log_id.clone() { response.extensions_mut().insert(id); }
479 state.handler.handle_response(&mut response);
480 if response.extensions().get::<Dropped>().is_some() {
481 return Ok(interceptor_dropped_response());
482 }
483 Ok(response)
484 }
485 Err(e) => {
486 error!("Upstream TLS request failed: {e}");
487 flush_log_on_error(&state.handler, log_id, 502);
488 Ok(bad_gateway("Upstream request failed"))
489 }
490 }
491}
492
493const HOP_BY_HOP_HEADERS: &[&str] = &[
497 "connection",
498 "keep-alive",
499 "proxy-authenticate",
500 "proxy-authorization",
501 "te",
502 "trailers",
503 "transfer-encoding",
504 "upgrade",
505];
506
507pub fn parse_host_port(target: &str) -> (String, u16) {
510 if let Some(bracketed) = target.strip_prefix('[') {
511 if let Some((ip6, rest)) = bracketed.split_once(']') {
513 let port = rest
514 .strip_prefix(':')
515 .and_then(|p| p.parse().ok())
516 .unwrap_or(443);
517 return (ip6.to_string(), port);
518 }
519 }
520 if let Some((host, port_str)) = target.rsplit_once(':') {
522 if let Ok(port) = port_str.parse::<u16>() {
523 return (host.to_string(), port);
524 }
525 }
526 (target.to_string(), 443)
527}
528
529fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
530 if let Some(conn_val) = headers.get("connection").cloned() {
532 if let Ok(val) = conn_val.to_str() {
533 for name in val.split(',') {
534 let name = name.trim();
535 if !name.is_empty() {
536 headers.remove(name);
537 }
538 }
539 }
540 }
541
542 for name in HOP_BY_HOP_HEADERS {
543 headers.remove(*name);
544 }
545}
546
547fn empty_body() -> BoxBody {
548 Empty::<Bytes>::new()
549 .map_err(|never| match never {})
550 .boxed()
551}
552
553fn bad_request(msg: &str) -> Response<BoxBody> {
554 Response::builder()
555 .status(400)
556 .body(full_body(msg))
557 .unwrap()
558}
559
560fn bad_gateway(msg: &str) -> Response<BoxBody> {
561 Response::builder()
562 .status(502)
563 .body(full_body(msg))
564 .unwrap()
565}
566
567fn interceptor_dropped_response() -> Response<BoxBody> {
572 Response::builder()
573 .status(444)
574 .header("Connection", "close")
575 .header("X-RustGate-Interceptor", "response-dropped")
576 .body(full_body(
577 "Response dropped by interceptor. The upstream request was already executed. Do not retry.",
578 ))
579 .unwrap()
580}
581
582fn full_body(msg: &str) -> BoxBody {
583 Full::new(Bytes::from(msg.to_string()))
584 .map_err(|never| match never {})
585 .boxed()
586}