Skip to main content

relay_core_lib/proxy/
http.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12    HttpsClient, build_forward_request, create_error_response, create_initial_flow,
13    mock_to_response, parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::tap::TapBody;
16use crate::proxy::tunnel;
17use crate::proxy::websocket::handle_websocket_handshake;
18use crate::tls::CertificateAuthority;
19use http_body_util::{BodyExt, Full};
20use hyper::body::{Body, Bytes, Incoming};
21use hyper::{Method, Request, Response, StatusCode};
22use relay_core_api::flow::{Direction, FlowUpdate, Layer, ResilienceTrace};
23use relay_core_api::policy::ProxyPolicy;
24
25/// Main entry point for HTTP Proxy handling
26#[allow(clippy::too_many_arguments)]
27pub async fn handle_request(
28    req: Request<Incoming>,
29    client_addr: SocketAddr,
30    on_flow: Sender<FlowUpdate>,
31    ca: Arc<CertificateAuthority>,
32    client: Arc<HttpsClient>,
33    interceptor: Arc<dyn Interceptor>,
34    target_addr: Option<SocketAddr>,
35    policy_rx: watch::Receiver<ProxyPolicy>,
36    loop_detector: Arc<LoopDetector>,
37    circuit_breaker: Arc<CircuitBreaker>,
38) -> Result<Response<HttpBody>, Infallible> {
39    if req.method() == Method::CONNECT {
40        // Handle CONNECT (HTTPS Tunnel)
41        // Extract host from authority
42        let host = if let Some(authority) = req.uri().authority() {
43            authority.to_string()
44        } else {
45            // Fallback: try to get from Host header
46            req.headers()
47                .get("Host")
48                .and_then(|v| v.to_str().ok())
49                .map(|s| s.to_string())
50                .unwrap_or_else(|| "unknown".to_string())
51        };
52
53        if host == "unknown" {
54            return Ok(create_error_response(
55                StatusCode::BAD_REQUEST,
56                "CONNECT must have authority",
57            ));
58        }
59
60        let loop_detector = loop_detector.clone();
61        let policy_rx = policy_rx.clone();
62
63        tokio::task::spawn(async move {
64            match hyper::upgrade::on(req).await {
65                Ok(upgraded) => {
66                    if let Err(e) = tunnel::handle_tunnel(
67                        upgraded,
68                        host,
69                        client_addr,
70                        ca,
71                        on_flow,
72                        client,
73                        interceptor,
74                        policy_rx,
75                        target_addr,
76                        loop_detector,
77                        circuit_breaker,
78                    )
79                    .await
80                    {
81                        tracing::error!("Tunnel error: {}", e);
82                    }
83                }
84                Err(e) => tracing::error!("Upgrade error: {}", e),
85            }
86        });
87        return Ok(Response::new(
88            Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
89        ));
90    }
91
92    // Handle Standard HTTP / WebSocket
93    handle_http_request(
94        req,
95        client_addr,
96        on_flow,
97        client,
98        interceptor,
99        false,
100        policy_rx,
101        target_addr,
102        loop_detector,
103        circuit_breaker,
104    )
105    .await
106}
107
108#[allow(clippy::too_many_arguments)]
109pub(crate) async fn handle_http_request<B>(
110    req: Request<B>,
111    client_addr: SocketAddr,
112    on_flow: Sender<FlowUpdate>,
113    client: Arc<HttpsClient>,
114    interceptor: Arc<dyn Interceptor>,
115    is_mitm: bool,
116    policy_rx: watch::Receiver<ProxyPolicy>,
117    target_addr: Option<SocketAddr>,
118    loop_detector: Arc<LoopDetector>,
119    circuit_breaker: Arc<CircuitBreaker>,
120) -> Result<Response<HttpBody>, Infallible>
121where
122    B: Body + Send + Sync + Unpin + 'static,
123    B::Data: Send + Into<Bytes>,
124    B::Error: Into<BoxError>,
125{
126    let policy = policy_rx.borrow().clone();
127
128    // P1a: Track oversized requests for streaming-first pipeline.
129    // Instead of hard-failing with PAYLOAD_TOO_LARGE, we allow the request
130    // through and mark budget_exceeded so rules that need full body are skipped.
131    let request_budget_exceeded = if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
132        && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
133        && len > policy.max_body_size
134    {
135        true
136    } else {
137        false
138    };
139
140    // Create Flow
141    let meta = parse_request_meta(&req, is_mitm);
142
143    // Note: We don't read body here for streaming support
144    let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
145
146    // P1a: Mark budget exceeded for oversized requests
147    if request_budget_exceeded {
148        flow.tags.push("budget-exceeded".to_string());
149        flow.resilience_trace = Some(ResilienceTrace {
150            budget_exceeded: true,
151            ..flow.resilience_trace.clone().unwrap_or_default()
152        });
153    }
154
155    // Check for WebSocket
156    if hyper_tungstenite::is_upgrade_request(&req) {
157        return handle_websocket_handshake(
158            req,
159            client_addr,
160            on_flow,
161            client,
162            interceptor,
163            is_mitm,
164            policy_rx,
165            target_addr,
166            loop_detector,
167        )
168        .await;
169    }
170
171    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
172        tracing::error!("Failed to send flow update: {}", e);
173    }
174
175    // Phase 1: Request Headers Interception
176    match interceptor.on_request_headers(&mut flow).await {
177        InterceptionResult::Continue => {}
178        InterceptionResult::Drop => {
179            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
180                tracing::error!("Failed to send flow update on drop: {}", e);
181            }
182            return Ok(create_error_response(
183                StatusCode::FORBIDDEN,
184                "Request dropped by policy",
185            ));
186        }
187        InterceptionResult::MockResponse(resp) => {
188            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
189                tracing::error!("Failed to send flow update on mock: {}", e);
190            }
191            return Ok(mock_to_response(resp));
192        }
193        InterceptionResult::ModifiedRequest(_) => {}
194        InterceptionResult::ModifiedResponse(res) => {
195            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
196                tracing::error!("Failed to send flow update on modified response: {}", e);
197            }
198            return Ok(mock_to_response(res));
199        }
200        _ => {}
201    }
202
203    // Phase 2: Request Body Streaming & Interception
204    let (_, body) = req.into_parts();
205    let body: HttpBody = body
206        .map_frame(|f| f.map_data(|d| d.into()))
207        .map_err(|e| e.into())
208        .boxed();
209
210    // Wrap in TapBody for streaming visualization BEFORE interception
211    let req_headers = if let Layer::Http(http) = &flow.layer {
212        http.request.headers.clone()
213    } else {
214        vec![]
215    };
216
217    let tap_body = TapBody::new(
218        body,
219        flow.id.to_string(),
220        on_flow.clone(),
221        Direction::ClientToServer,
222        policy.max_body_size,
223        req_headers,
224    );
225    crate::metrics::inc_proxy_http_request();
226    let mut current_body = tap_body.boxed();
227
228    match interceptor.on_request(&mut flow, current_body).await {
229        Ok(RequestAction::Continue(new_body)) => {
230            current_body = new_body;
231        }
232        Ok(RequestAction::Drop) => {
233            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
234                tracing::error!("Failed to send flow update on request drop: {}", e);
235            }
236            return Ok(create_error_response(
237                StatusCode::FORBIDDEN,
238                "Request dropped by interceptor",
239            ));
240        }
241        Ok(RequestAction::MockResponse(res)) => {
242            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
243                tracing::error!("Failed to send flow update on request mock: {}", e);
244            }
245            let (parts, body) = res.into_parts();
246            return Ok(Response::from_parts(parts, body));
247        }
248        Err(e) => {
249            tracing::error!("Interceptor error on_request: {}", e);
250            return Ok(create_error_response(
251                StatusCode::INTERNAL_SERVER_ERROR,
252                format!("Interceptor Error: {}", e),
253            ));
254        }
255    }
256
257    // RE2: Apply ThrottleBody if a Throttle rule set the rate in flow.meta
258    if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
259        && let Ok(bps) = bps_str.parse::<u64>()
260        && bps > 0
261    {
262        current_body = crate::proxy::throttle::ThrottleBody::new(current_body, bps).boxed();
263    }
264
265    let forward_req = match build_forward_request(
266        &mut flow,
267        current_body,
268        target_addr,
269        &policy,
270        &loop_detector,
271    ) {
272        Ok(req) => req,
273        Err(res) => return Ok(res),
274    };
275
276    // P3: Circuit breaker check before upstream request
277    let upstream_host = forward_req
278        .uri()
279        .authority()
280        .map(|a| a.to_string())
281        .unwrap_or_else(|| "unknown".to_string());
282    if !circuit_breaker.allow_request(&upstream_host).await {
283        tracing::warn!(
284            "Circuit breaker open for upstream {}, returning 503",
285            upstream_host
286        );
287        // P4: Record circuit breaker open in resilience trace
288        flow.resilience_trace = Some(ResilienceTrace {
289            circuit_open: true,
290            ..flow.resilience_trace.clone().unwrap_or_default()
291        });
292        if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
293            tracing::error!("Failed to send flow update on circuit breaker: {}", e);
294        }
295        return Ok(create_error_response(
296            StatusCode::SERVICE_UNAVAILABLE,
297            format!("Circuit breaker open for upstream {}", upstream_host),
298        ));
299    }
300
301    // Send Request
302    let upstream_start = std::time::Instant::now();
303    let res = match tokio::time::timeout(
304        std::time::Duration::from_millis(policy.request_timeout_ms),
305        client.request(forward_req),
306    )
307    .await
308    {
309        Ok(Ok(res)) => {
310            circuit_breaker.record_success(&upstream_host).await;
311            res
312        }
313        Ok(Err(e)) => {
314            circuit_breaker.record_failure(&upstream_host).await;
315            tracing::error!("Upstream request failed: {}", e);
316            // P4: Record upstream error in resilience trace
317            flow.resilience_trace = Some(ResilienceTrace {
318                upstream_errors: vec![format!("Upstream Error: {}", e)],
319                ..flow.resilience_trace.clone().unwrap_or_default()
320            });
321            if let Layer::Http(http) = &mut flow.layer {
322                http.error = Some(format!("Upstream Error: {}", e));
323            }
324            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
325                tracing::error!("Failed to send flow update on upstream error: {}", e);
326            }
327            return Ok(create_error_response(
328                StatusCode::BAD_GATEWAY,
329                format!("Upstream Error: {}", e),
330            ));
331        }
332        Err(_) => {
333            circuit_breaker.record_failure(&upstream_host).await;
334            tracing::error!("Upstream request timed out");
335            // Record timeout in resilience trace
336            // We use a single tokio::time::timeout wrapping the entire upstream
337            // request, so we cannot reliably distinguish connect vs read.
338            // Mark as "total" rather than guessing.
339            flow.resilience_trace = Some(ResilienceTrace {
340                upstream_errors: vec!["Upstream Request Timed Out".to_string()],
341                timeout_type: Some("total".to_string()),
342                ..flow.resilience_trace.clone().unwrap_or_default()
343            });
344            if let Layer::Http(http) = &mut flow.layer {
345                http.error = Some("Upstream Request Timed Out".to_string());
346            }
347            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
348                tracing::error!("Failed to send flow update on upstream timeout: {}", e);
349            }
350            return Ok(create_error_response(
351                StatusCode::GATEWAY_TIMEOUT,
352                "Upstream Request Timed Out",
353            ));
354        }
355    };
356
357    // Phase 3: Response Headers Interception
358    let (mut res_parts, res_body) = res.into_parts();
359
360    // Apply QUIC Downgrade
361    apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
362
363    update_flow_with_response_headers(
364        &mut flow,
365        res_parts.status,
366        res_parts.version,
367        &res_parts.headers,
368    );
369
370    let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
371    if let Layer::Http(http) = &mut flow.layer
372        && let Some(response) = &mut http.response
373    {
374        response.timing.time_to_first_byte = Some(ttfbs_ms);
375    }
376
377    match interceptor.on_response_headers(&mut flow).await {
378        InterceptionResult::Continue => {}
379        InterceptionResult::Drop => {
380            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
381                tracing::error!("Failed to send flow update on response drop: {}", e);
382            }
383            return Ok(create_error_response(
384                StatusCode::FORBIDDEN,
385                "Response dropped by policy",
386            ));
387        }
388        InterceptionResult::MockResponse(resp) => {
389            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
390                tracing::error!("Failed to send flow update on response mock: {}", e);
391            }
392            return Ok(mock_to_response(resp));
393        }
394        InterceptionResult::ModifiedResponse(resp) => {
395            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
396                tracing::error!("Failed to send flow update on response modification: {}", e);
397            }
398            return Ok(mock_to_response(resp));
399        }
400        _ => {}
401    }
402
403    // Phase 4: Response Body Streaming & Interception
404    let res_body: HttpBody = res_body
405        .map_frame(|f| f.map_data(|d| d))
406        .map_err(|e| e.into())
407        .boxed();
408
409    // Wrap in TapBody for streaming visualization BEFORE interception
410    let res_headers = if let Layer::Http(http) = &flow.layer {
411        http.response
412            .as_ref()
413            .map(|r| r.headers.clone())
414            .unwrap_or_default()
415    } else {
416        vec![]
417    };
418
419    let tap_res_body = TapBody::new(
420        res_body,
421        flow.id.to_string(),
422        on_flow.clone(),
423        Direction::ServerToClient,
424        policy.max_body_size,
425        res_headers,
426    );
427    let mut current_res_body = tap_res_body.boxed();
428
429    match interceptor.on_response(&mut flow, current_res_body).await {
430        Ok(ResponseAction::Continue(new_body)) => {
431            current_res_body = new_body;
432        }
433        Ok(ResponseAction::Drop) => {
434            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
435                tracing::error!("Failed to send flow update on response body drop: {}", e);
436            }
437            return Ok(create_error_response(
438                StatusCode::FORBIDDEN,
439                "Response dropped by interceptor",
440            ));
441        }
442        Ok(ResponseAction::ModifiedResponse(res)) => {
443            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
444                tracing::error!(
445                    "Failed to send flow update on response body modification: {}",
446                    e
447                );
448            }
449            let (parts, body) = res.into_parts();
450            return Ok(Response::from_parts(parts, body));
451        }
452        Err(e) => {
453            tracing::error!("Interceptor error on_response: {}", e);
454            return Ok(create_error_response(
455                StatusCode::INTERNAL_SERVER_ERROR,
456                format!("Interceptor Error: {}", e),
457            ));
458        }
459    }
460
461    // RE2: Apply ThrottleBody to response if a Throttle rule set the rate in flow.meta
462    if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
463        && let Ok(bps) = bps_str.parse::<u64>()
464        && bps > 0
465    {
466        current_res_body = crate::proxy::throttle::ThrottleBody::new(current_res_body, bps).boxed();
467    }
468
469    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
470        tracing::error!("Failed to send final flow update: {}", e);
471    }
472
473    // Record time-to-last-byte as total upstream-to-client latency
474    if let Layer::Http(http) = &mut flow.layer
475        && let Some(response) = &mut http.response
476    {
477        response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
478    }
479
480    Ok(Response::from_parts(res_parts, current_res_body))
481}
482
483pub(crate) fn apply_quic_downgrade(
484    parts: &mut hyper::http::response::Parts,
485    flow: &mut relay_core_api::flow::Flow,
486    policy: &ProxyPolicy,
487) {
488    use relay_core_api::policy::QuicMode;
489    if policy.quic_mode == QuicMode::Downgrade {
490        if parts.headers.remove("Alt-Svc").is_some() {
491            flow.tags.push("quic-downgraded".to_string());
492        }
493        if policy.quic_downgrade_clear_cache {
494            parts.headers.insert(
495                "Clear-Site-Data",
496                hyper::header::HeaderValue::from_static("\"cache\""),
497            );
498        }
499    }
500}
501
502#[cfg(test)]
503mod http_tests;