Skip to main content

rustgate/
proxy.rs

1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, full_boxed_body, Buffered, BoxBody, Dropped, RequestHandler};
4use crate::tls;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Empty, Full};
7use hyper::client::conn::http1 as client_http1;
8use hyper::server::conn::http1 as server_http1;
9use hyper::service::service_fn;
10use hyper::upgrade::Upgraded;
11use hyper::{Method, Request, Response};
12use hyper_util::rt::TokioIo;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::TcpStream;
16use tracing::{debug, error, info, warn};
17
18/// Maximum body size for interception (10 MB).
19const MAX_INTERCEPT_BODY: usize = 10 * 1024 * 1024;
20
21/// Check if a body should be intercepted based on Content-Length header.
22/// Returns true ONLY if Content-Length is explicitly present and within the limit.
23/// All other cases (chunked, close-delimited, unknown-length) skip interception
24/// to avoid consuming streaming bodies.
25fn should_intercept_body(headers: &hyper::HeaderMap) -> bool {
26    if let Some(cl) = headers.get(hyper::header::CONTENT_LENGTH) {
27        if let Ok(s) = cl.to_str() {
28            if let Ok(len) = s.parse::<usize>() {
29                return len <= MAX_INTERCEPT_BODY;
30            }
31        }
32    }
33    false
34}
35
36/// Collect a body into Bytes. Returns None on failure or size exceeded.
37async fn try_collect_body<B>(body: B) -> Option<Bytes>
38where
39    B: hyper::body::Body<Data = Bytes, Error = hyper::Error>,
40{
41    use http_body_util::Limited;
42    let limited = Limited::new(body, MAX_INTERCEPT_BODY);
43    BodyExt::collect(limited)
44        .await
45        .ok()
46        .map(|c| c.to_bytes())
47}
48
49/// Shared state passed to each connection handler.
50pub struct ProxyState {
51    pub ca: Arc<CertificateAuthority>,
52    pub mitm: bool,
53    pub intercept: bool,
54    pub handler: Arc<dyn RequestHandler>,
55}
56
57/// Handle a single accepted TCP connection.
58pub async fn handle_connection(
59    stream: TcpStream,
60    addr: SocketAddr,
61    state: Arc<ProxyState>,
62) {
63    debug!("New connection from {addr}");
64
65    let io = TokioIo::new(stream);
66    let state = state.clone();
67
68    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
69        let state = state.clone();
70        async move { handle_request(req, state).await }
71    });
72
73    if let Err(e) = server_http1::Builder::new()
74        .preserve_header_case(true)
75        .title_case_headers(true)
76        .serve_connection(io, service)
77        .with_upgrades()
78        .await
79    {
80        if !e.to_string().contains("early eof")
81            && !e.to_string().contains("connection closed")
82        {
83            error!("Connection error from {addr}: {e}");
84        }
85    }
86}
87
88/// Route a request: CONNECT goes to tunnel/MITM, everything else gets forwarded.
89async fn handle_request(
90    req: Request<hyper::body::Incoming>,
91    state: Arc<ProxyState>,
92) -> Result<Response<BoxBody>, hyper::Error> {
93    if req.method() == Method::CONNECT {
94        handle_connect(req, state).await
95    } else {
96        handle_forward(req, state).await
97    }
98}
99
100// ─── HTTP Forwarding ───────────────────────────────────────────────────────────
101
102/// Forward a plain HTTP request to the upstream server.
103async fn handle_forward(
104    req: Request<hyper::body::Incoming>,
105    state: Arc<ProxyState>,
106) -> Result<Response<BoxBody>, hyper::Error> {
107    let uri = req.uri().clone();
108    let host = match uri.host() {
109        Some(h) => h.to_string(),
110        None => {
111            warn!("Request with no host: {uri}");
112            return Ok(bad_request("Missing host in URI"));
113        }
114    };
115    let port = uri.port_u16().unwrap_or(80);
116    let addr = format!("{host}:{port}");
117
118    // Build the request to forward
119    let (mut parts, body) = req.into_parts();
120    let path = parts
121        .uri
122        .path_and_query()
123        .map(|pq| pq.as_str())
124        .unwrap_or("/");
125    parts.uri = match path.parse() {
126        Ok(uri) => uri,
127        Err(_) => {
128            warn!("Invalid path: {path}");
129            return Ok(bad_request("Invalid request URI"));
130        }
131    };
132
133    // Check intercept eligibility BEFORE stripping hop-by-hop headers
134    // so Transfer-Encoding: chunked is still visible for the decision.
135    let do_intercept = state.intercept && should_intercept_body(&parts.headers);
136
137    strip_hop_by_hop_headers(&mut parts.headers);
138
139    let mut forwarded_req = if do_intercept {
140        match try_collect_body(body).await {
141            Some(bytes) => {
142                let mut req = Request::from_parts(parts, full_boxed_body(bytes));
143                req.extensions_mut().insert(Buffered);
144                req
145            }
146            None => {
147                // Collection failed (read error on a supposedly small body).
148                // Body is consumed; log and return error.
149                error!("Request body collection failed despite acceptable Content-Length");
150                return Ok(bad_gateway("Request body read error"));
151            }
152        }
153    } else {
154        Request::from_parts(parts, boxed_body(body))
155    };
156
157    state.handler.handle_request(&mut forwarded_req);
158
159    if forwarded_req.extensions().get::<Dropped>().is_some() {
160        return Ok(bad_gateway("Request dropped by interceptor"));
161    }
162
163    // Connect to upstream
164    let upstream = match TcpStream::connect(&addr).await {
165        Ok(s) => s,
166        Err(e) => {
167            error!("Failed to connect to {addr}: {e}");
168            return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
169        }
170    };
171
172    let io = TokioIo::new(upstream);
173    let (mut sender, conn) = match client_http1::handshake(io).await {
174        Ok(r) => r,
175        Err(e) => {
176            error!("Handshake with {addr} failed: {e}");
177            return Ok(bad_gateway("Upstream handshake failed"));
178        }
179    };
180
181    tokio::spawn(async move {
182        if let Err(e) = conn.await {
183            error!("Upstream connection error: {e}");
184        }
185    });
186
187    match sender.send_request(forwarded_req).await {
188        Ok(res) => {
189            let (parts, body) = res.into_parts();
190            let mut response = if state.intercept && should_intercept_body(&parts.headers) {
191                match try_collect_body(body).await {
192                    Some(bytes) => {
193                        let mut res = Response::from_parts(parts, full_boxed_body(bytes));
194                        res.extensions_mut().insert(Buffered);
195                        res
196                    }
197                    None => {
198                        error!("Response body collection failed");
199                        return Ok(bad_gateway("Response body collection failed"));
200                    }
201                }
202            } else {
203                Response::from_parts(parts, boxed_body(body))
204            };
205            state.handler.handle_response(&mut response);
206            if response.extensions().get::<Dropped>().is_some() {
207                return Ok(interceptor_dropped_response());
208            }
209            Ok(response)
210        }
211        Err(e) => {
212            error!("Upstream request failed: {e}");
213            Ok(bad_gateway("Upstream request failed"))
214        }
215    }
216}
217
218// ─── CONNECT Handling ──────────────────────────────────────────────────────────
219
220/// Handle a CONNECT request: either tunnel (passthrough) or MITM.
221async fn handle_connect(
222    req: Request<hyper::body::Incoming>,
223    state: Arc<ProxyState>,
224) -> Result<Response<BoxBody>, hyper::Error> {
225    let target = match req.uri().authority() {
226        Some(auth) => auth.to_string(),
227        None => {
228            warn!("CONNECT without authority");
229            return Ok(bad_request("CONNECT target missing"));
230        }
231    };
232
233    let (host, port) = parse_host_port(&target);
234    let addr = format!("{host}:{port}");
235
236    info!("CONNECT {target}");
237
238    if state.mitm {
239        // MITM mode: intercept the TLS connection
240        handle_mitm(req, host, addr, state).await
241    } else {
242        // Passthrough mode: just tunnel bytes
243        handle_tunnel(req, addr).await
244    }
245}
246
247/// Passthrough tunneling: bidirectional copy between client and upstream.
248async fn handle_tunnel(
249    req: Request<hyper::body::Incoming>,
250    addr: String,
251) -> Result<Response<BoxBody>, hyper::Error> {
252    tokio::spawn(async move {
253        match hyper::upgrade::on(req).await {
254            Ok(upgraded) => {
255                if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
256                    error!("Tunnel error to {addr}: {e}");
257                }
258            }
259            Err(e) => {
260                error!("Upgrade failed: {e}");
261            }
262        }
263    });
264
265    // Respond with 200 to tell the client the tunnel is established
266    Ok(Response::new(empty_body()))
267}
268
269/// Copy data bidirectionally between the upgraded client connection and upstream.
270async fn tunnel_bidirectional(
271    upgraded: Upgraded,
272    addr: &str,
273) -> crate::error::Result<()> {
274    let mut upstream = TcpStream::connect(addr).await?;
275
276    let mut client = TokioIo::new(upgraded);
277
278    let (client_to_server, server_to_client) =
279        tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
280
281    debug!(
282        "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
283    );
284    Ok(())
285}
286
287/// MITM mode: terminate TLS with both ends, intercept HTTP traffic.
288async fn handle_mitm(
289    req: Request<hyper::body::Incoming>,
290    host: String,
291    addr: String,
292    state: Arc<ProxyState>,
293) -> Result<Response<BoxBody>, hyper::Error> {
294    let state = state.clone();
295
296    tokio::spawn(async move {
297        match hyper::upgrade::on(req).await {
298            Ok(upgraded) => {
299                if let Err(e) =
300                    mitm_intercept(upgraded, &host, &addr, state).await
301                {
302                    error!("MITM error for {host}: {e}");
303                }
304            }
305            Err(e) => {
306                error!("MITM upgrade failed: {e}");
307            }
308        }
309    });
310
311    Ok(Response::new(empty_body()))
312}
313
314/// Perform MITM interception on an upgraded connection.
315async fn mitm_intercept(
316    upgraded: Upgraded,
317    host: &str,
318    addr: &str,
319    state: Arc<ProxyState>,
320) -> crate::error::Result<()> {
321    // Create a TLS acceptor with a fake cert for this domain
322    let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
323
324    // Accept TLS from the client side
325    let client_io = TokioIo::new(upgraded);
326    let client_tls = acceptor
327        .accept(client_io)
328        .await
329        .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
330
331    let client_tls = TokioIo::new(client_tls);
332
333    // Serve HTTP on the decrypted client stream
334    let host = host.to_string();
335    let addr = addr.to_string();
336
337    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
338        let host = host.clone();
339        let addr = addr.clone();
340        let state = state.clone();
341        async move {
342            mitm_forward_request(req, &host, &addr, state).await
343        }
344    });
345
346    if let Err(e) = server_http1::Builder::new()
347        .preserve_header_case(true)
348        .title_case_headers(true)
349        .serve_connection(client_tls, service)
350        .await
351    {
352        if !e.to_string().contains("early eof")
353            && !e.to_string().contains("connection closed")
354        {
355            debug!("MITM connection closed: {e}");
356        }
357    }
358
359    Ok(())
360}
361
362/// Forward a request from the MITM-decrypted stream to the real upstream over TLS.
363async fn mitm_forward_request(
364    req: Request<hyper::body::Incoming>,
365    host: &str,
366    addr: &str,
367    state: Arc<ProxyState>,
368) -> Result<Response<BoxBody>, hyper::Error> {
369    let (mut parts, body) = req.into_parts();
370
371    let do_intercept = state.intercept && should_intercept_body(&parts.headers);
372    strip_hop_by_hop_headers(&mut parts.headers);
373
374    let mut forwarded_req = if do_intercept {
375        match try_collect_body(body).await {
376            Some(bytes) => {
377                let mut req = Request::from_parts(parts, full_boxed_body(bytes));
378                req.extensions_mut().insert(Buffered);
379                req
380            }
381            None => {
382                error!("MITM request body collection failed");
383                return Ok(bad_gateway("Request body read error"));
384            }
385        }
386    } else {
387        Request::from_parts(parts, boxed_body(body))
388    };
389
390    state.handler.handle_request(&mut forwarded_req);
391
392    if forwarded_req.extensions().get::<Dropped>().is_some() {
393        return Ok(bad_gateway("Request dropped by interceptor"));
394    }
395
396    // Connect to upstream over TLS
397    let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
398        Ok(s) => s,
399        Err(e) => {
400            error!("Failed TLS connect to {addr}: {e}");
401            return Ok(bad_gateway(&format!(
402                "Failed to connect to upstream: {e}"
403            )));
404        }
405    };
406
407    let io = TokioIo::new(upstream_tls);
408    let (mut sender, conn) = match client_http1::handshake(io).await {
409        Ok(r) => r,
410        Err(e) => {
411            error!("Upstream TLS handshake failed: {e}");
412            return Ok(bad_gateway("Upstream TLS handshake failed"));
413        }
414    };
415
416    tokio::spawn(async move {
417        if let Err(e) = conn.await {
418            debug!("Upstream TLS connection closed: {e}");
419        }
420    });
421
422    match sender.send_request(forwarded_req).await {
423        Ok(res) => {
424            let (parts, body) = res.into_parts();
425            let mut response = if state.intercept && should_intercept_body(&parts.headers) {
426                match try_collect_body(body).await {
427                    Some(bytes) => {
428                        let mut res = Response::from_parts(parts, full_boxed_body(bytes));
429                        res.extensions_mut().insert(Buffered);
430                        res
431                    }
432                    None => {
433                        error!("MITM response body collection failed");
434                        return Ok(bad_gateway("Response body collection failed"));
435                    }
436                }
437            } else {
438                Response::from_parts(parts, boxed_body(body))
439            };
440            state.handler.handle_response(&mut response);
441            if response.extensions().get::<Dropped>().is_some() {
442                return Ok(interceptor_dropped_response());
443            }
444            Ok(response)
445        }
446        Err(e) => {
447            error!("Upstream TLS request failed: {e}");
448            Ok(bad_gateway("Upstream request failed"))
449        }
450    }
451}
452
453// ─── Helpers ───────────────────────────────────────────────────────────────────
454
455/// Hop-by-hop headers that should not be forwarded.
456const HOP_BY_HOP_HEADERS: &[&str] = &[
457    "connection",
458    "keep-alive",
459    "proxy-authenticate",
460    "proxy-authorization",
461    "te",
462    "trailers",
463    "transfer-encoding",
464    "upgrade",
465];
466
467/// Parse host and port from a CONNECT target, handling IPv6 bracket notation.
468/// e.g. "example.com:443", "[::1]:443", "example.com"
469pub fn parse_host_port(target: &str) -> (String, u16) {
470    if let Some(bracketed) = target.strip_prefix('[') {
471        // IPv6: [::1]:port
472        if let Some((ip6, rest)) = bracketed.split_once(']') {
473            let port = rest
474                .strip_prefix(':')
475                .and_then(|p| p.parse().ok())
476                .unwrap_or(443);
477            return (ip6.to_string(), port);
478        }
479    }
480    // IPv4 / hostname: host:port
481    if let Some((host, port_str)) = target.rsplit_once(':') {
482        if let Ok(port) = port_str.parse::<u16>() {
483            return (host.to_string(), port);
484        }
485    }
486    (target.to_string(), 443)
487}
488
489fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
490    // Also remove headers listed in the Connection header value
491    if let Some(conn_val) = headers.get("connection").cloned() {
492        if let Ok(val) = conn_val.to_str() {
493            for name in val.split(',') {
494                let name = name.trim();
495                if !name.is_empty() {
496                    headers.remove(name);
497                }
498            }
499        }
500    }
501
502    for name in HOP_BY_HOP_HEADERS {
503        headers.remove(*name);
504    }
505}
506
507fn empty_body() -> BoxBody {
508    Empty::<Bytes>::new()
509        .map_err(|never| match never {})
510        .boxed()
511}
512
513fn bad_request(msg: &str) -> Response<BoxBody> {
514    Response::builder()
515        .status(400)
516        .body(full_body(msg))
517        .unwrap()
518}
519
520fn bad_gateway(msg: &str) -> Response<BoxBody> {
521    Response::builder()
522        .status(502)
523        .body(full_body(msg))
524        .unwrap()
525}
526
527/// Non-retryable response for interceptor-dropped responses.
528/// Uses 444 (No Response, nginx convention) + Connection: close to signal
529/// that the response was locally suppressed and the client should NOT retry.
530/// The upstream request was already executed.
531fn interceptor_dropped_response() -> Response<BoxBody> {
532    Response::builder()
533        .status(444)
534        .header("Connection", "close")
535        .header("X-RustGate-Interceptor", "response-dropped")
536        .body(full_body(
537            "Response dropped by interceptor. The upstream request was already executed. Do not retry.",
538        ))
539        .unwrap()
540}
541
542fn full_body(msg: &str) -> BoxBody {
543    Full::new(Bytes::from(msg.to_string()))
544        .map_err(|never| match never {})
545        .boxed()
546}