Skip to main content

rustgate/
proxy.rs

1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, BoxBody, RequestHandler};
4use crate::tls;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Empty, Full};
7use hyper::client::conn::http1 as client_http1;
8use hyper::server::conn::http1 as server_http1;
9use hyper::service::service_fn;
10use hyper::upgrade::Upgraded;
11use hyper::{Method, Request, Response};
12use hyper_util::rt::TokioIo;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::TcpStream;
16use tracing::{debug, error, info, warn};
17
18/// Shared state passed to each connection handler.
19pub struct ProxyState {
20    pub ca: Arc<CertificateAuthority>,
21    pub mitm: bool,
22    pub handler: Arc<dyn RequestHandler>,
23}
24
25/// Handle a single accepted TCP connection.
26pub async fn handle_connection(
27    stream: TcpStream,
28    addr: SocketAddr,
29    state: Arc<ProxyState>,
30) {
31    debug!("New connection from {addr}");
32
33    let io = TokioIo::new(stream);
34    let state = state.clone();
35
36    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
37        let state = state.clone();
38        async move { handle_request(req, state).await }
39    });
40
41    if let Err(e) = server_http1::Builder::new()
42        .preserve_header_case(true)
43        .title_case_headers(true)
44        .serve_connection(io, service)
45        .with_upgrades()
46        .await
47    {
48        if !e.to_string().contains("early eof")
49            && !e.to_string().contains("connection closed")
50        {
51            error!("Connection error from {addr}: {e}");
52        }
53    }
54}
55
56/// Route a request: CONNECT goes to tunnel/MITM, everything else gets forwarded.
57async fn handle_request(
58    req: Request<hyper::body::Incoming>,
59    state: Arc<ProxyState>,
60) -> Result<Response<BoxBody>, hyper::Error> {
61    if req.method() == Method::CONNECT {
62        handle_connect(req, state).await
63    } else {
64        handle_forward(req, state).await
65    }
66}
67
68// ─── HTTP Forwarding ───────────────────────────────────────────────────────────
69
70/// Forward a plain HTTP request to the upstream server.
71async fn handle_forward(
72    req: Request<hyper::body::Incoming>,
73    state: Arc<ProxyState>,
74) -> Result<Response<BoxBody>, hyper::Error> {
75    let uri = req.uri().clone();
76    let host = match uri.host() {
77        Some(h) => h.to_string(),
78        None => {
79            warn!("Request with no host: {uri}");
80            return Ok(bad_request("Missing host in URI"));
81        }
82    };
83    let port = uri.port_u16().unwrap_or(80);
84    let addr = format!("{host}:{port}");
85
86    // Build the request to forward (path-only URI, strip hop-by-hop headers)
87    let (mut parts, body) = req.into_parts();
88    let path = parts
89        .uri
90        .path_and_query()
91        .map(|pq| pq.as_str())
92        .unwrap_or("/");
93    parts.uri = match path.parse() {
94        Ok(uri) => uri,
95        Err(_) => {
96            warn!("Invalid path: {path}");
97            return Ok(bad_request("Invalid request URI"));
98        }
99    };
100    strip_hop_by_hop_headers(&mut parts.headers);
101
102    let mut forwarded_req = Request::from_parts(parts, boxed_body(body));
103
104    // Let the handler inspect/modify the request
105    state.handler.handle_request(&mut forwarded_req);
106
107    // Connect to upstream
108    let upstream = match TcpStream::connect(&addr).await {
109        Ok(s) => s,
110        Err(e) => {
111            error!("Failed to connect to {addr}: {e}");
112            return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
113        }
114    };
115
116    let io = TokioIo::new(upstream);
117    let (mut sender, conn) = match client_http1::handshake(io).await {
118        Ok(r) => r,
119        Err(e) => {
120            error!("Handshake with {addr} failed: {e}");
121            return Ok(bad_gateway("Upstream handshake failed"));
122        }
123    };
124
125    tokio::spawn(async move {
126        if let Err(e) = conn.await {
127            error!("Upstream connection error: {e}");
128        }
129    });
130
131    match sender.send_request(forwarded_req).await {
132        Ok(res) => {
133            let (parts, body) = res.into_parts();
134            let mut response = Response::from_parts(parts, boxed_body(body));
135            state.handler.handle_response(&mut response);
136            Ok(response)
137        }
138        Err(e) => {
139            error!("Upstream request failed: {e}");
140            Ok(bad_gateway("Upstream request failed"))
141        }
142    }
143}
144
145// ─── CONNECT Handling ──────────────────────────────────────────────────────────
146
147/// Handle a CONNECT request: either tunnel (passthrough) or MITM.
148async fn handle_connect(
149    req: Request<hyper::body::Incoming>,
150    state: Arc<ProxyState>,
151) -> Result<Response<BoxBody>, hyper::Error> {
152    let target = match req.uri().authority() {
153        Some(auth) => auth.to_string(),
154        None => {
155            warn!("CONNECT without authority");
156            return Ok(bad_request("CONNECT target missing"));
157        }
158    };
159
160    let (host, port) = parse_host_port(&target);
161    let addr = format!("{host}:{port}");
162
163    info!("CONNECT {target}");
164
165    if state.mitm {
166        // MITM mode: intercept the TLS connection
167        handle_mitm(req, host, addr, state).await
168    } else {
169        // Passthrough mode: just tunnel bytes
170        handle_tunnel(req, addr).await
171    }
172}
173
174/// Passthrough tunneling: bidirectional copy between client and upstream.
175async fn handle_tunnel(
176    req: Request<hyper::body::Incoming>,
177    addr: String,
178) -> Result<Response<BoxBody>, hyper::Error> {
179    tokio::spawn(async move {
180        match hyper::upgrade::on(req).await {
181            Ok(upgraded) => {
182                if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
183                    error!("Tunnel error to {addr}: {e}");
184                }
185            }
186            Err(e) => {
187                error!("Upgrade failed: {e}");
188            }
189        }
190    });
191
192    // Respond with 200 to tell the client the tunnel is established
193    Ok(Response::new(empty_body()))
194}
195
196/// Copy data bidirectionally between the upgraded client connection and upstream.
197async fn tunnel_bidirectional(
198    upgraded: Upgraded,
199    addr: &str,
200) -> crate::error::Result<()> {
201    let mut upstream = TcpStream::connect(addr).await?;
202
203    let mut client = TokioIo::new(upgraded);
204
205    let (client_to_server, server_to_client) =
206        tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
207
208    debug!(
209        "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
210    );
211    Ok(())
212}
213
214/// MITM mode: terminate TLS with both ends, intercept HTTP traffic.
215async fn handle_mitm(
216    req: Request<hyper::body::Incoming>,
217    host: String,
218    addr: String,
219    state: Arc<ProxyState>,
220) -> Result<Response<BoxBody>, hyper::Error> {
221    let state = state.clone();
222
223    tokio::spawn(async move {
224        match hyper::upgrade::on(req).await {
225            Ok(upgraded) => {
226                if let Err(e) =
227                    mitm_intercept(upgraded, &host, &addr, state).await
228                {
229                    error!("MITM error for {host}: {e}");
230                }
231            }
232            Err(e) => {
233                error!("MITM upgrade failed: {e}");
234            }
235        }
236    });
237
238    Ok(Response::new(empty_body()))
239}
240
241/// Perform MITM interception on an upgraded connection.
242async fn mitm_intercept(
243    upgraded: Upgraded,
244    host: &str,
245    addr: &str,
246    state: Arc<ProxyState>,
247) -> crate::error::Result<()> {
248    // Create a TLS acceptor with a fake cert for this domain
249    let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
250
251    // Accept TLS from the client side
252    let client_io = TokioIo::new(upgraded);
253    let client_tls = acceptor
254        .accept(client_io)
255        .await
256        .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
257
258    let client_tls = TokioIo::new(client_tls);
259
260    // Serve HTTP on the decrypted client stream
261    let host = host.to_string();
262    let addr = addr.to_string();
263
264    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
265        let host = host.clone();
266        let addr = addr.clone();
267        let state = state.clone();
268        async move {
269            mitm_forward_request(req, &host, &addr, state).await
270        }
271    });
272
273    if let Err(e) = server_http1::Builder::new()
274        .preserve_header_case(true)
275        .title_case_headers(true)
276        .serve_connection(client_tls, service)
277        .await
278    {
279        if !e.to_string().contains("early eof")
280            && !e.to_string().contains("connection closed")
281        {
282            debug!("MITM connection closed: {e}");
283        }
284    }
285
286    Ok(())
287}
288
289/// Forward a request from the MITM-decrypted stream to the real upstream over TLS.
290async fn mitm_forward_request(
291    req: Request<hyper::body::Incoming>,
292    host: &str,
293    addr: &str,
294    state: Arc<ProxyState>,
295) -> Result<Response<BoxBody>, hyper::Error> {
296    let (mut parts, body) = req.into_parts();
297    strip_hop_by_hop_headers(&mut parts.headers);
298
299    let mut forwarded_req = Request::from_parts(parts, boxed_body(body));
300
301    // Let the handler inspect/modify
302    state.handler.handle_request(&mut forwarded_req);
303
304    // Connect to upstream over TLS
305    let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
306        Ok(s) => s,
307        Err(e) => {
308            error!("Failed TLS connect to {addr}: {e}");
309            return Ok(bad_gateway(&format!(
310                "Failed to connect to upstream: {e}"
311            )));
312        }
313    };
314
315    let io = TokioIo::new(upstream_tls);
316    let (mut sender, conn) = match client_http1::handshake(io).await {
317        Ok(r) => r,
318        Err(e) => {
319            error!("Upstream TLS handshake failed: {e}");
320            return Ok(bad_gateway("Upstream TLS handshake failed"));
321        }
322    };
323
324    tokio::spawn(async move {
325        if let Err(e) = conn.await {
326            debug!("Upstream TLS connection closed: {e}");
327        }
328    });
329
330    match sender.send_request(forwarded_req).await {
331        Ok(res) => {
332            let (parts, body) = res.into_parts();
333            let mut response = Response::from_parts(parts, boxed_body(body));
334            state.handler.handle_response(&mut response);
335            Ok(response)
336        }
337        Err(e) => {
338            error!("Upstream TLS request failed: {e}");
339            Ok(bad_gateway("Upstream request failed"))
340        }
341    }
342}
343
344// ─── Helpers ───────────────────────────────────────────────────────────────────
345
346/// Hop-by-hop headers that should not be forwarded.
347const HOP_BY_HOP_HEADERS: &[&str] = &[
348    "connection",
349    "keep-alive",
350    "proxy-authenticate",
351    "proxy-authorization",
352    "te",
353    "trailers",
354    "transfer-encoding",
355    "upgrade",
356];
357
358/// Parse host and port from a CONNECT target, handling IPv6 bracket notation.
359/// e.g. "example.com:443", "[::1]:443", "example.com"
360pub fn parse_host_port(target: &str) -> (String, u16) {
361    if let Some(bracketed) = target.strip_prefix('[') {
362        // IPv6: [::1]:port
363        if let Some((ip6, rest)) = bracketed.split_once(']') {
364            let port = rest
365                .strip_prefix(':')
366                .and_then(|p| p.parse().ok())
367                .unwrap_or(443);
368            return (ip6.to_string(), port);
369        }
370    }
371    // IPv4 / hostname: host:port
372    if let Some((host, port_str)) = target.rsplit_once(':') {
373        if let Ok(port) = port_str.parse::<u16>() {
374            return (host.to_string(), port);
375        }
376    }
377    (target.to_string(), 443)
378}
379
380fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
381    // Also remove headers listed in the Connection header value
382    if let Some(conn_val) = headers.get("connection").cloned() {
383        if let Ok(val) = conn_val.to_str() {
384            for name in val.split(',') {
385                let name = name.trim();
386                if !name.is_empty() {
387                    headers.remove(name);
388                }
389            }
390        }
391    }
392
393    for name in HOP_BY_HOP_HEADERS {
394        headers.remove(*name);
395    }
396}
397
398fn empty_body() -> BoxBody {
399    Empty::<Bytes>::new()
400        .map_err(|never| match never {})
401        .boxed()
402}
403
404fn bad_request(msg: &str) -> Response<BoxBody> {
405    Response::builder()
406        .status(400)
407        .body(full_body(msg))
408        .unwrap()
409}
410
411fn bad_gateway(msg: &str) -> Response<BoxBody> {
412    Response::builder()
413        .status(502)
414        .body(full_body(msg))
415        .unwrap()
416}
417
418fn full_body(msg: &str) -> BoxBody {
419    Full::new(Bytes::from(msg.to_string()))
420        .map_err(|never| match never {})
421        .boxed()
422}