Skip to main content

relay_core_lib/proxy/
http.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12    build_forward_request, create_error_response, create_initial_flow, mock_to_response,
13    parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::outbound::OutboundConnector;
16use crate::proxy::tap::TapBody;
17use crate::proxy::tunnel;
18use crate::proxy::websocket::handle_websocket_handshake;
19use crate::tls::CertificateAuthority;
20use http_body_util::{BodyExt, Full};
21use hyper::body::{Body, Bytes, Incoming};
22use hyper::{Method, Request, Response, StatusCode};
23use relay_core_api::flow::{Direction, FlowUpdate, Layer, ResilienceTrace};
24use relay_core_api::policy::ProxyPolicy;
25
26/// Main entry point for HTTP Proxy handling
27#[allow(clippy::too_many_arguments)]
28pub async fn handle_request(
29    req: Request<Incoming>,
30    client_addr: SocketAddr,
31    on_flow: Sender<FlowUpdate>,
32    ca: Arc<CertificateAuthority>,
33    connector: Arc<dyn OutboundConnector>,
34    interceptor: Arc<dyn Interceptor>,
35    target_addr: Option<SocketAddr>,
36    policy_rx: watch::Receiver<ProxyPolicy>,
37    loop_detector: Arc<LoopDetector>,
38    circuit_breaker: Arc<CircuitBreaker>,
39) -> Result<Response<HttpBody>, Infallible> {
40    if req.method() == Method::CONNECT {
41        // Handle CONNECT (HTTPS Tunnel)
42        // Extract host from authority
43        let host = if let Some(authority) = req.uri().authority() {
44            authority.to_string()
45        } else {
46            // Fallback: try to get from Host header
47            req.headers()
48                .get("Host")
49                .and_then(|v| v.to_str().ok())
50                .map(|s| s.to_string())
51                .unwrap_or_else(|| "unknown".to_string())
52        };
53
54        if host == "unknown" {
55            return Ok(create_error_response(
56                StatusCode::BAD_REQUEST,
57                "CONNECT must have authority",
58            ));
59        }
60
61        let loop_detector = loop_detector.clone();
62        let policy_rx = policy_rx.clone();
63
64        tokio::task::spawn(async move {
65            match hyper::upgrade::on(req).await {
66                Ok(upgraded) => {
67                    if let Err(e) = tunnel::handle_tunnel(
68                        upgraded,
69                        host,
70                        client_addr,
71                        ca,
72                        on_flow,
73                        connector,
74                        interceptor,
75                        policy_rx,
76                        target_addr,
77                        loop_detector,
78                        circuit_breaker,
79                    )
80                    .await
81                    {
82                        tracing::error!("Tunnel error: {}", e);
83                    }
84                }
85                Err(e) => tracing::error!("Upgrade error: {}", e),
86            }
87        });
88        return Ok(Response::new(
89            Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
90        ));
91    }
92
93    // Handle Standard HTTP / WebSocket
94    handle_http_request(
95        req,
96        client_addr,
97        on_flow,
98        connector,
99        interceptor,
100        false,
101        policy_rx,
102        target_addr,
103        loop_detector,
104        circuit_breaker,
105    )
106    .await
107}
108
109#[allow(clippy::too_many_arguments)]
110pub(crate) async fn handle_http_request<B>(
111    req: Request<B>,
112    client_addr: SocketAddr,
113    on_flow: Sender<FlowUpdate>,
114    connector: Arc<dyn OutboundConnector>,
115    interceptor: Arc<dyn Interceptor>,
116    is_mitm: bool,
117    policy_rx: watch::Receiver<ProxyPolicy>,
118    target_addr: Option<SocketAddr>,
119    loop_detector: Arc<LoopDetector>,
120    circuit_breaker: Arc<CircuitBreaker>,
121) -> Result<Response<HttpBody>, Infallible>
122where
123    B: Body + Send + Sync + Unpin + 'static,
124    B::Data: Send + Into<Bytes>,
125    B::Error: Into<BoxError>,
126{
127    let policy = policy_rx.borrow().clone();
128
129    // P1a: Track oversized requests for streaming-first pipeline.
130    // Instead of hard-failing with PAYLOAD_TOO_LARGE, we allow the request
131    // through and mark budget_exceeded so rules that need full body are skipped.
132    let request_budget_exceeded = if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
133        && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
134        && len > policy.max_body_size
135    {
136        true
137    } else {
138        false
139    };
140
141    // Create Flow
142    let meta = parse_request_meta(&req, is_mitm);
143
144    // Note: We don't read body here for streaming support
145    let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
146
147    // P1a: Mark budget exceeded for oversized requests
148    if request_budget_exceeded {
149        flow.tags.push("budget-exceeded".to_string());
150        flow.resilience_trace = Some(ResilienceTrace {
151            budget_exceeded: true,
152            ..flow.resilience_trace.clone().unwrap_or_default()
153        });
154    }
155
156    // Check for WebSocket
157    if hyper_tungstenite::is_upgrade_request(&req) {
158        return handle_websocket_handshake(
159            req,
160            client_addr,
161            on_flow,
162            connector,
163            interceptor,
164            is_mitm,
165            policy_rx,
166            target_addr,
167            loop_detector,
168        )
169        .await;
170    }
171
172    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
173        tracing::error!("Failed to send flow update: {}", e);
174    }
175
176    // Phase 1: Request Headers Interception
177    match interceptor.on_request_headers(&mut flow).await {
178        InterceptionResult::Continue => {}
179        InterceptionResult::Drop => {
180            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
181                tracing::error!("Failed to send flow update on drop: {}", e);
182            }
183            return Ok(create_error_response(
184                StatusCode::FORBIDDEN,
185                "Request dropped by policy",
186            ));
187        }
188        InterceptionResult::MockResponse(resp) => {
189            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
190                tracing::error!("Failed to send flow update on mock: {}", e);
191            }
192            return Ok(mock_to_response(resp));
193        }
194        InterceptionResult::ModifiedRequest(_) => {}
195        InterceptionResult::ModifiedResponse(res) => {
196            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
197                tracing::error!("Failed to send flow update on modified response: {}", e);
198            }
199            return Ok(mock_to_response(res));
200        }
201        _ => {}
202    }
203
204    // Phase 2: Request Body Streaming & Interception
205    let (_, body) = req.into_parts();
206    let body: HttpBody = body
207        .map_frame(|f| f.map_data(|d| d.into()))
208        .map_err(|e| e.into())
209        .boxed();
210
211    // Wrap in TapBody for streaming visualization BEFORE interception
212    let req_headers = if let Layer::Http(http) = &flow.layer {
213        http.request.headers.clone()
214    } else {
215        vec![]
216    };
217
218    let tap_body = TapBody::new(
219        body,
220        flow.id.to_string(),
221        on_flow.clone(),
222        Direction::ClientToServer,
223        policy.max_body_size,
224        req_headers,
225    );
226    crate::metrics::inc_proxy_http_request();
227    let mut current_body = tap_body.boxed();
228
229    match interceptor.on_request(&mut flow, current_body).await {
230        Ok(RequestAction::Continue(new_body)) => {
231            current_body = new_body;
232        }
233        Ok(RequestAction::Drop) => {
234            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
235                tracing::error!("Failed to send flow update on request drop: {}", e);
236            }
237            return Ok(create_error_response(
238                StatusCode::FORBIDDEN,
239                "Request dropped by interceptor",
240            ));
241        }
242        Ok(RequestAction::MockResponse(res)) => {
243            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
244                tracing::error!("Failed to send flow update on request mock: {}", e);
245            }
246            let (parts, body) = res.into_parts();
247            return Ok(Response::from_parts(parts, body));
248        }
249        Err(e) => {
250            tracing::error!("Interceptor error on_request: {}", e);
251            return Ok(create_error_response(
252                StatusCode::INTERNAL_SERVER_ERROR,
253                format!("Interceptor Error: {}", e),
254            ));
255        }
256    }
257
258    // RE2: Apply ThrottleBody if a Throttle rule set the rate in flow.meta
259    if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
260        && let Ok(bps) = bps_str.parse::<u64>()
261        && bps > 0
262    {
263        current_body = crate::proxy::throttle::ThrottleBody::new(current_body, bps).boxed();
264    }
265
266    let forward_req = match build_forward_request(
267        &mut flow,
268        current_body,
269        target_addr,
270        &policy,
271        &loop_detector,
272    ) {
273        Ok(req) => req,
274        Err(res) => return Ok(res),
275    };
276
277    // P3: Circuit breaker check before upstream request.
278    // When going through an upstream proxy, key on the proxy address so
279    // that a failing upstream proxy isolates correctly from target hosts.
280    let circuit_breaker_key = connector
281        .upstream_proxy_url()
282        .map(|u| u.to_string())
283        .unwrap_or_else(|| {
284            forward_req
285                .uri()
286                .authority()
287                .map(|a| a.to_string())
288                .unwrap_or_else(|| "unknown".to_string())
289        });
290    if !circuit_breaker.allow_request(&circuit_breaker_key).await {
291        tracing::warn!(
292            "Circuit breaker open for upstream {}, returning 503",
293            circuit_breaker_key
294        );
295        // P4: Record circuit breaker open in resilience trace
296        flow.resilience_trace = Some(ResilienceTrace {
297            circuit_open: true,
298            ..flow.resilience_trace.clone().unwrap_or_default()
299        });
300        if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
301            tracing::error!("Failed to send flow update on circuit breaker: {}", e);
302        }
303        return Ok(create_error_response(
304            StatusCode::SERVICE_UNAVAILABLE,
305            format!("Circuit breaker open for upstream {}", circuit_breaker_key),
306        ));
307    }
308
309    // Send Request
310    let upstream_start = std::time::Instant::now();
311    let target_host = forward_req.uri().host().unwrap_or("unknown").to_string();
312    let target_port = forward_req.uri().port_u16().unwrap_or(
313        if forward_req.uri().scheme_str() == Some("https") {
314            443
315        } else {
316            80
317        },
318    );
319    let res = match tokio::time::timeout(
320        std::time::Duration::from_millis(policy.request_timeout_ms),
321        connector.send_request(forward_req, &target_host, target_port, &mut flow),
322    )
323    .await
324    {
325        Ok(Ok(res)) => {
326            circuit_breaker.record_success(&circuit_breaker_key).await;
327            res
328        }
329        Ok(Err(e)) => {
330            circuit_breaker.record_failure(&circuit_breaker_key).await;
331            tracing::error!("Upstream request failed: {}", e);
332            // P4: Record upstream error in resilience trace
333            flow.resilience_trace = Some(ResilienceTrace {
334                upstream_errors: vec![format!("Upstream Error: {}", e)],
335                ..flow.resilience_trace.clone().unwrap_or_default()
336            });
337            if let Layer::Http(http) = &mut flow.layer {
338                http.error = Some(format!("Upstream Error: {}", e));
339            }
340            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
341                tracing::error!("Failed to send flow update on upstream error: {}", e);
342            }
343            return Ok(create_error_response(
344                StatusCode::BAD_GATEWAY,
345                format!("Upstream Error: {}", e),
346            ));
347        }
348        Err(_) => {
349            circuit_breaker.record_failure(&circuit_breaker_key).await;
350            tracing::error!("Upstream request timed out");
351            // Record timeout in resilience trace
352            // We use a single tokio::time::timeout wrapping the entire upstream
353            // request, so we cannot reliably distinguish connect vs read.
354            // Mark as "total" rather than guessing.
355            flow.resilience_trace = Some(ResilienceTrace {
356                upstream_errors: vec!["Upstream Request Timed Out".to_string()],
357                timeout_type: Some("total".to_string()),
358                ..flow.resilience_trace.clone().unwrap_or_default()
359            });
360            if let Layer::Http(http) = &mut flow.layer {
361                http.error = Some("Upstream Request Timed Out".to_string());
362            }
363            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
364                tracing::error!("Failed to send flow update on upstream timeout: {}", e);
365            }
366            return Ok(create_error_response(
367                StatusCode::GATEWAY_TIMEOUT,
368                "Upstream Request Timed Out",
369            ));
370        }
371    };
372
373    // Phase 3: Response Headers Interception
374    let (mut res_parts, res_body) = res.into_parts();
375
376    // Apply QUIC Downgrade
377    apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
378
379    update_flow_with_response_headers(
380        &mut flow,
381        res_parts.status,
382        res_parts.version,
383        &res_parts.headers,
384    );
385
386    let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
387    if let Layer::Http(http) = &mut flow.layer
388        && let Some(response) = &mut http.response
389    {
390        response.timing.time_to_first_byte = Some(ttfbs_ms);
391    }
392
393    match interceptor.on_response_headers(&mut flow).await {
394        InterceptionResult::Continue => {}
395        InterceptionResult::Drop => {
396            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
397                tracing::error!("Failed to send flow update on response drop: {}", e);
398            }
399            return Ok(create_error_response(
400                StatusCode::FORBIDDEN,
401                "Response dropped by policy",
402            ));
403        }
404        InterceptionResult::MockResponse(resp) => {
405            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
406                tracing::error!("Failed to send flow update on response mock: {}", e);
407            }
408            return Ok(mock_to_response(resp));
409        }
410        InterceptionResult::ModifiedResponse(resp) => {
411            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
412                tracing::error!("Failed to send flow update on response modification: {}", e);
413            }
414            return Ok(mock_to_response(resp));
415        }
416        _ => {}
417    }
418
419    // Phase 4: Response Body Streaming & Interception
420    let res_body: HttpBody = res_body
421        .map_frame(|f| f.map_data(|d| d))
422        .map_err(|e| e.into())
423        .boxed();
424
425    // Wrap in TapBody for streaming visualization BEFORE interception
426    let res_headers = if let Layer::Http(http) = &flow.layer {
427        http.response
428            .as_ref()
429            .map(|r| r.headers.clone())
430            .unwrap_or_default()
431    } else {
432        vec![]
433    };
434
435    let tap_res_body = TapBody::new(
436        res_body,
437        flow.id.to_string(),
438        on_flow.clone(),
439        Direction::ServerToClient,
440        policy.max_body_size,
441        res_headers,
442    );
443    let mut current_res_body = tap_res_body.boxed();
444
445    match interceptor.on_response(&mut flow, current_res_body).await {
446        Ok(ResponseAction::Continue(new_body)) => {
447            current_res_body = new_body;
448        }
449        Ok(ResponseAction::Drop) => {
450            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
451                tracing::error!("Failed to send flow update on response body drop: {}", e);
452            }
453            return Ok(create_error_response(
454                StatusCode::FORBIDDEN,
455                "Response dropped by interceptor",
456            ));
457        }
458        Ok(ResponseAction::ModifiedResponse(res)) => {
459            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
460                tracing::error!(
461                    "Failed to send flow update on response body modification: {}",
462                    e
463                );
464            }
465            let (parts, body) = res.into_parts();
466            return Ok(Response::from_parts(parts, body));
467        }
468        Err(e) => {
469            tracing::error!("Interceptor error on_response: {}", e);
470            return Ok(create_error_response(
471                StatusCode::INTERNAL_SERVER_ERROR,
472                format!("Interceptor Error: {}", e),
473            ));
474        }
475    }
476
477    // RE2: Apply ThrottleBody to response if a Throttle rule set the rate in flow.meta
478    if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
479        && let Ok(bps) = bps_str.parse::<u64>()
480        && bps > 0
481    {
482        current_res_body = crate::proxy::throttle::ThrottleBody::new(current_res_body, bps).boxed();
483    }
484
485    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
486        tracing::error!("Failed to send final flow update: {}", e);
487    }
488
489    // Record time-to-last-byte as total upstream-to-client latency
490    if let Layer::Http(http) = &mut flow.layer
491        && let Some(response) = &mut http.response
492    {
493        response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
494    }
495
496    Ok(Response::from_parts(res_parts, current_res_body))
497}
498
499pub(crate) fn apply_quic_downgrade(
500    parts: &mut hyper::http::response::Parts,
501    flow: &mut relay_core_api::flow::Flow,
502    policy: &ProxyPolicy,
503) {
504    use relay_core_api::policy::QuicMode;
505    if policy.quic_mode == QuicMode::Downgrade {
506        if parts.headers.remove("Alt-Svc").is_some() {
507            flow.tags.push("quic-downgraded".to_string());
508        }
509        if policy.quic_downgrade_clear_cache {
510            parts.headers.insert(
511                "Clear-Site-Data",
512                hyper::header::HeaderValue::from_static("\"cache\""),
513            );
514        }
515    }
516}
517
518#[cfg(test)]
519mod http_tests;