Skip to main content

relay_core_lib/proxy/
http.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::http_utils::{
11    HttpsClient, build_forward_request, create_error_response, create_initial_flow,
12    mock_to_response, parse_request_meta, update_flow_with_response_headers,
13};
14use crate::proxy::tap::TapBody;
15use crate::proxy::tunnel;
16use crate::proxy::websocket::handle_websocket_handshake;
17use crate::tls::CertificateAuthority;
18use http_body_util::{BodyExt, Full};
19use hyper::body::{Body, Bytes, Incoming};
20use hyper::{Method, Request, Response, StatusCode};
21use relay_core_api::flow::{Direction, FlowUpdate, Layer};
22use relay_core_api::policy::ProxyPolicy;
23
24/// Main entry point for HTTP Proxy handling
25#[allow(clippy::too_many_arguments)]
26pub async fn handle_request(
27    req: Request<Incoming>,
28    client_addr: SocketAddr,
29    on_flow: Sender<FlowUpdate>,
30    ca: Arc<CertificateAuthority>,
31    client: Arc<HttpsClient>,
32    interceptor: Arc<dyn Interceptor>,
33    target_addr: Option<SocketAddr>,
34    policy_rx: watch::Receiver<ProxyPolicy>,
35    loop_detector: Arc<LoopDetector>,
36) -> Result<Response<HttpBody>, Infallible> {
37    if req.method() == Method::CONNECT {
38        // Handle CONNECT (HTTPS Tunnel)
39        // Extract host from authority
40        let host = if let Some(authority) = req.uri().authority() {
41            authority.to_string()
42        } else {
43            // Fallback: try to get from Host header
44            req.headers()
45                .get("Host")
46                .and_then(|v| v.to_str().ok())
47                .map(|s| s.to_string())
48                .unwrap_or_else(|| "unknown".to_string())
49        };
50
51        if host == "unknown" {
52            return Ok(create_error_response(
53                StatusCode::BAD_REQUEST,
54                "CONNECT must have authority",
55            ));
56        }
57
58        let loop_detector = loop_detector.clone();
59        let policy_rx = policy_rx.clone();
60
61        tokio::task::spawn(async move {
62            match hyper::upgrade::on(req).await {
63                Ok(upgraded) => {
64                    if let Err(e) = tunnel::handle_tunnel(
65                        upgraded,
66                        host,
67                        client_addr,
68                        ca,
69                        on_flow,
70                        client,
71                        interceptor,
72                        policy_rx,
73                        target_addr,
74                        loop_detector,
75                    )
76                    .await
77                    {
78                        tracing::error!("Tunnel error: {}", e);
79                    }
80                }
81                Err(e) => tracing::error!("Upgrade error: {}", e),
82            }
83        });
84        return Ok(Response::new(
85            Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
86        ));
87    }
88
89    // Handle Standard HTTP / WebSocket
90    handle_http_request(
91        req,
92        client_addr,
93        on_flow,
94        client,
95        interceptor,
96        false,
97        policy_rx,
98        target_addr,
99        loop_detector,
100    )
101    .await
102}
103
104#[allow(clippy::too_many_arguments)]
105pub(crate) async fn handle_http_request<B>(
106    req: Request<B>,
107    client_addr: SocketAddr,
108    on_flow: Sender<FlowUpdate>,
109    client: Arc<HttpsClient>,
110    interceptor: Arc<dyn Interceptor>,
111    is_mitm: bool,
112    policy_rx: watch::Receiver<ProxyPolicy>,
113    target_addr: Option<SocketAddr>,
114    loop_detector: Arc<LoopDetector>,
115) -> Result<Response<HttpBody>, Infallible>
116where
117    B: Body + Send + Sync + Unpin + 'static,
118    B::Data: Send + Into<Bytes>,
119    B::Error: Into<BoxError>,
120{
121    let policy = policy_rx.borrow().clone();
122
123    // Check Content-Length against policy
124    if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
125        && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
126        && len > policy.max_body_size
127    {
128        return Ok(create_error_response(
129            StatusCode::PAYLOAD_TOO_LARGE,
130            "Request body too large",
131        ));
132    }
133
134    // Create Flow
135    let meta = parse_request_meta(&req, is_mitm);
136
137    // Note: We don't read body here for streaming support
138    let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
139
140    // Check for WebSocket
141    if hyper_tungstenite::is_upgrade_request(&req) {
142        return handle_websocket_handshake(
143            req,
144            client_addr,
145            on_flow,
146            client,
147            interceptor,
148            is_mitm,
149            policy_rx,
150            target_addr,
151            loop_detector,
152        )
153        .await;
154    }
155
156    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
157        tracing::error!("Failed to send flow update: {}", e);
158    }
159
160    // Phase 1: Request Headers Interception
161    match interceptor.on_request_headers(&mut flow).await {
162        InterceptionResult::Continue => {}
163        InterceptionResult::Drop => {
164            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
165                tracing::error!("Failed to send flow update on drop: {}", e);
166            }
167            return Ok(create_error_response(
168                StatusCode::FORBIDDEN,
169                "Request dropped by policy",
170            ));
171        }
172        InterceptionResult::MockResponse(resp) => {
173            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
174                tracing::error!("Failed to send flow update on mock: {}", e);
175            }
176            return Ok(mock_to_response(resp));
177        }
178        InterceptionResult::ModifiedRequest(_) => {}
179        InterceptionResult::ModifiedResponse(res) => {
180            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
181                tracing::error!("Failed to send flow update on modified response: {}", e);
182            }
183            return Ok(mock_to_response(res));
184        }
185        _ => {}
186    }
187
188    // Phase 2: Request Body Streaming & Interception
189    let (_, body) = req.into_parts();
190    let body: HttpBody = body
191        .map_frame(|f| f.map_data(|d| d.into()))
192        .map_err(|e| e.into())
193        .boxed();
194
195    // Wrap in TapBody for streaming visualization BEFORE interception
196    let req_headers = if let Layer::Http(http) = &flow.layer {
197        http.request.headers.clone()
198    } else {
199        vec![]
200    };
201
202    let tap_body = TapBody::new(
203        body,
204        flow.id.to_string(),
205        on_flow.clone(),
206        Direction::ClientToServer,
207        policy.max_body_size,
208        req_headers,
209    );
210    let mut current_body = tap_body.boxed();
211
212    match interceptor.on_request(&mut flow, current_body).await {
213        Ok(RequestAction::Continue(new_body)) => {
214            current_body = new_body;
215        }
216        Ok(RequestAction::Drop) => {
217            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
218                tracing::error!("Failed to send flow update on request drop: {}", e);
219            }
220            return Ok(create_error_response(
221                StatusCode::FORBIDDEN,
222                "Request dropped by interceptor",
223            ));
224        }
225        Ok(RequestAction::MockResponse(res)) => {
226            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
227                tracing::error!("Failed to send flow update on request mock: {}", e);
228            }
229            let (parts, body) = res.into_parts();
230            return Ok(Response::from_parts(parts, body));
231        }
232        Err(e) => {
233            tracing::error!("Interceptor error on_request: {}", e);
234            return Ok(create_error_response(
235                StatusCode::INTERNAL_SERVER_ERROR,
236                format!("Interceptor Error: {}", e),
237            ));
238        }
239    }
240
241    let forward_req = match build_forward_request(
242        &mut flow,
243        current_body,
244        target_addr,
245        &policy,
246        &loop_detector,
247    ) {
248        Ok(req) => req,
249        Err(res) => return Ok(res),
250    };
251
252    // Send Request
253    let upstream_start = std::time::Instant::now();
254    let res = match tokio::time::timeout(
255        std::time::Duration::from_millis(policy.request_timeout_ms),
256        client.request(forward_req),
257    )
258    .await
259    {
260        Ok(Ok(res)) => res,
261        Ok(Err(e)) => {
262            tracing::error!("Upstream request failed: {}", e);
263            if let Layer::Http(http) = &mut flow.layer {
264                http.error = Some(format!("Upstream Error: {}", e));
265            }
266            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
267                tracing::error!("Failed to send flow update on upstream error: {}", e);
268            }
269            return Ok(create_error_response(
270                StatusCode::BAD_GATEWAY,
271                format!("Upstream Error: {}", e),
272            ));
273        }
274        Err(_) => {
275            tracing::error!("Upstream request timed out");
276            if let Layer::Http(http) = &mut flow.layer {
277                http.error = Some("Upstream Request Timed Out".to_string());
278            }
279            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
280                tracing::error!("Failed to send flow update on upstream timeout: {}", e);
281            }
282            return Ok(create_error_response(
283                StatusCode::GATEWAY_TIMEOUT,
284                "Upstream Request Timed Out",
285            ));
286        }
287    };
288
289    // Phase 3: Response Headers Interception
290    let (mut res_parts, res_body) = res.into_parts();
291
292    // Apply QUIC Downgrade
293    apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
294
295    update_flow_with_response_headers(
296        &mut flow,
297        res_parts.status,
298        res_parts.version,
299        &res_parts.headers,
300    );
301
302    let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
303    if let Layer::Http(http) = &mut flow.layer
304        && let Some(response) = &mut http.response
305    {
306        response.timing.time_to_first_byte = Some(ttfbs_ms);
307    }
308
309    match interceptor.on_response_headers(&mut flow).await {
310        InterceptionResult::Continue => {}
311        InterceptionResult::Drop => {
312            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
313                tracing::error!("Failed to send flow update on response drop: {}", e);
314            }
315            return Ok(create_error_response(
316                StatusCode::FORBIDDEN,
317                "Response dropped by policy",
318            ));
319        }
320        InterceptionResult::MockResponse(resp) => {
321            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
322                tracing::error!("Failed to send flow update on response mock: {}", e);
323            }
324            return Ok(mock_to_response(resp));
325        }
326        InterceptionResult::ModifiedResponse(resp) => {
327            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
328                tracing::error!("Failed to send flow update on response modification: {}", e);
329            }
330            return Ok(mock_to_response(resp));
331        }
332        _ => {}
333    }
334
335    // Phase 4: Response Body Streaming & Interception
336    let res_body: HttpBody = res_body
337        .map_frame(|f| f.map_data(|d| d))
338        .map_err(|e| e.into())
339        .boxed();
340
341    // Wrap in TapBody for streaming visualization BEFORE interception
342    let res_headers = if let Layer::Http(http) = &flow.layer {
343        http.response
344            .as_ref()
345            .map(|r| r.headers.clone())
346            .unwrap_or_default()
347    } else {
348        vec![]
349    };
350
351    let tap_res_body = TapBody::new(
352        res_body,
353        flow.id.to_string(),
354        on_flow.clone(),
355        Direction::ServerToClient,
356        policy.max_body_size,
357        res_headers,
358    );
359    let mut current_res_body = tap_res_body.boxed();
360
361    match interceptor.on_response(&mut flow, current_res_body).await {
362        Ok(ResponseAction::Continue(new_body)) => {
363            current_res_body = new_body;
364        }
365        Ok(ResponseAction::Drop) => {
366            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
367                tracing::error!("Failed to send flow update on response body drop: {}", e);
368            }
369            return Ok(create_error_response(
370                StatusCode::FORBIDDEN,
371                "Response dropped by interceptor",
372            ));
373        }
374        Ok(ResponseAction::ModifiedResponse(res)) => {
375            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
376                tracing::error!(
377                    "Failed to send flow update on response body modification: {}",
378                    e
379                );
380            }
381            let (parts, body) = res.into_parts();
382            return Ok(Response::from_parts(parts, body));
383        }
384        Err(e) => {
385            tracing::error!("Interceptor error on_response: {}", e);
386            return Ok(create_error_response(
387                StatusCode::INTERNAL_SERVER_ERROR,
388                format!("Interceptor Error: {}", e),
389            ));
390        }
391    }
392
393    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
394        tracing::error!("Failed to send final flow update: {}", e);
395    }
396
397    // Record time-to-last-byte as total upstream-to-client latency
398    if let Layer::Http(http) = &mut flow.layer
399        && let Some(response) = &mut http.response
400    {
401        response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
402    }
403
404    Ok(Response::from_parts(res_parts, current_res_body))
405}
406
407pub(crate) fn apply_quic_downgrade(
408    parts: &mut hyper::http::response::Parts,
409    flow: &mut relay_core_api::flow::Flow,
410    policy: &ProxyPolicy,
411) {
412    use relay_core_api::policy::QuicMode;
413    if policy.quic_mode == QuicMode::Downgrade {
414        if parts.headers.remove("Alt-Svc").is_some() {
415            flow.tags.push("quic-downgraded".to_string());
416        }
417        if policy.quic_downgrade_clear_cache {
418            parts.headers.insert(
419                "Clear-Site-Data",
420                hyper::header::HeaderValue::from_static("\"cache\""),
421            );
422        }
423    }
424}
425
426#[cfg(test)]
427mod http_tests;