Skip to main content

relay_core_lib/proxy/
http.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12    HttpsClient, build_forward_request, create_error_response, create_initial_flow,
13    mock_to_response, parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::tap::TapBody;
16use crate::proxy::tunnel;
17use crate::proxy::websocket::handle_websocket_handshake;
18use crate::tls::CertificateAuthority;
19use http_body_util::{BodyExt, Full};
20use hyper::body::{Body, Bytes, Incoming};
21use hyper::{Method, Request, Response, StatusCode};
22use relay_core_api::flow::{Direction, FlowUpdate, Layer};
23use relay_core_api::policy::ProxyPolicy;
24
25/// Main entry point for HTTP Proxy handling
26#[allow(clippy::too_many_arguments)]
27pub async fn handle_request(
28    req: Request<Incoming>,
29    client_addr: SocketAddr,
30    on_flow: Sender<FlowUpdate>,
31    ca: Arc<CertificateAuthority>,
32    client: Arc<HttpsClient>,
33    interceptor: Arc<dyn Interceptor>,
34    target_addr: Option<SocketAddr>,
35    policy_rx: watch::Receiver<ProxyPolicy>,
36    loop_detector: Arc<LoopDetector>,
37    circuit_breaker: Arc<CircuitBreaker>,
38) -> Result<Response<HttpBody>, Infallible> {
39    if req.method() == Method::CONNECT {
40        // Handle CONNECT (HTTPS Tunnel)
41        // Extract host from authority
42        let host = if let Some(authority) = req.uri().authority() {
43            authority.to_string()
44        } else {
45            // Fallback: try to get from Host header
46            req.headers()
47                .get("Host")
48                .and_then(|v| v.to_str().ok())
49                .map(|s| s.to_string())
50                .unwrap_or_else(|| "unknown".to_string())
51        };
52
53        if host == "unknown" {
54            return Ok(create_error_response(
55                StatusCode::BAD_REQUEST,
56                "CONNECT must have authority",
57            ));
58        }
59
60        let loop_detector = loop_detector.clone();
61        let policy_rx = policy_rx.clone();
62
63        tokio::task::spawn(async move {
64            match hyper::upgrade::on(req).await {
65                Ok(upgraded) => {
66                    if let Err(e) = tunnel::handle_tunnel(
67                        upgraded,
68                        host,
69                        client_addr,
70                        ca,
71                        on_flow,
72                        client,
73                        interceptor,
74                        policy_rx,
75                        target_addr,
76                        loop_detector,
77                        circuit_breaker,
78                    )
79                    .await
80                    {
81                        tracing::error!("Tunnel error: {}", e);
82                    }
83                }
84                Err(e) => tracing::error!("Upgrade error: {}", e),
85            }
86        });
87        return Ok(Response::new(
88            Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
89        ));
90    }
91
92    // Handle Standard HTTP / WebSocket
93    handle_http_request(
94        req,
95        client_addr,
96        on_flow,
97        client,
98        interceptor,
99        false,
100        policy_rx,
101        target_addr,
102        loop_detector,
103        circuit_breaker,
104    )
105    .await
106}
107
108#[allow(clippy::too_many_arguments)]
109pub(crate) async fn handle_http_request<B>(
110    req: Request<B>,
111    client_addr: SocketAddr,
112    on_flow: Sender<FlowUpdate>,
113    client: Arc<HttpsClient>,
114    interceptor: Arc<dyn Interceptor>,
115    is_mitm: bool,
116    policy_rx: watch::Receiver<ProxyPolicy>,
117    target_addr: Option<SocketAddr>,
118    loop_detector: Arc<LoopDetector>,
119    circuit_breaker: Arc<CircuitBreaker>,
120) -> Result<Response<HttpBody>, Infallible>
121where
122    B: Body + Send + Sync + Unpin + 'static,
123    B::Data: Send + Into<Bytes>,
124    B::Error: Into<BoxError>,
125{
126    let policy = policy_rx.borrow().clone();
127
128    // Check Content-Length against policy
129    if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
130        && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
131        && len > policy.max_body_size
132    {
133        return Ok(create_error_response(
134            StatusCode::PAYLOAD_TOO_LARGE,
135            "Request body too large",
136        ));
137    }
138
139    // Create Flow
140    let meta = parse_request_meta(&req, is_mitm);
141
142    // Note: We don't read body here for streaming support
143    let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
144
145    // Check for WebSocket
146    if hyper_tungstenite::is_upgrade_request(&req) {
147        return handle_websocket_handshake(
148            req,
149            client_addr,
150            on_flow,
151            client,
152            interceptor,
153            is_mitm,
154            policy_rx,
155            target_addr,
156            loop_detector,
157        )
158        .await;
159    }
160
161    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
162        tracing::error!("Failed to send flow update: {}", e);
163    }
164
165    // Phase 1: Request Headers Interception
166    match interceptor.on_request_headers(&mut flow).await {
167        InterceptionResult::Continue => {}
168        InterceptionResult::Drop => {
169            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
170                tracing::error!("Failed to send flow update on drop: {}", e);
171            }
172            return Ok(create_error_response(
173                StatusCode::FORBIDDEN,
174                "Request dropped by policy",
175            ));
176        }
177        InterceptionResult::MockResponse(resp) => {
178            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
179                tracing::error!("Failed to send flow update on mock: {}", e);
180            }
181            return Ok(mock_to_response(resp));
182        }
183        InterceptionResult::ModifiedRequest(_) => {}
184        InterceptionResult::ModifiedResponse(res) => {
185            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
186                tracing::error!("Failed to send flow update on modified response: {}", e);
187            }
188            return Ok(mock_to_response(res));
189        }
190        _ => {}
191    }
192
193    // Phase 2: Request Body Streaming & Interception
194    let (_, body) = req.into_parts();
195    let body: HttpBody = body
196        .map_frame(|f| f.map_data(|d| d.into()))
197        .map_err(|e| e.into())
198        .boxed();
199
200    // Wrap in TapBody for streaming visualization BEFORE interception
201    let req_headers = if let Layer::Http(http) = &flow.layer {
202        http.request.headers.clone()
203    } else {
204        vec![]
205    };
206
207    let tap_body = TapBody::new(
208        body,
209        flow.id.to_string(),
210        on_flow.clone(),
211        Direction::ClientToServer,
212        policy.max_body_size,
213        req_headers,
214    );
215    let mut current_body = tap_body.boxed();
216
217    match interceptor.on_request(&mut flow, current_body).await {
218        Ok(RequestAction::Continue(new_body)) => {
219            current_body = new_body;
220        }
221        Ok(RequestAction::Drop) => {
222            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
223                tracing::error!("Failed to send flow update on request drop: {}", e);
224            }
225            return Ok(create_error_response(
226                StatusCode::FORBIDDEN,
227                "Request dropped by interceptor",
228            ));
229        }
230        Ok(RequestAction::MockResponse(res)) => {
231            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
232                tracing::error!("Failed to send flow update on request mock: {}", e);
233            }
234            let (parts, body) = res.into_parts();
235            return Ok(Response::from_parts(parts, body));
236        }
237        Err(e) => {
238            tracing::error!("Interceptor error on_request: {}", e);
239            return Ok(create_error_response(
240                StatusCode::INTERNAL_SERVER_ERROR,
241                format!("Interceptor Error: {}", e),
242            ));
243        }
244    }
245
246    let forward_req = match build_forward_request(
247        &mut flow,
248        current_body,
249        target_addr,
250        &policy,
251        &loop_detector,
252    ) {
253        Ok(req) => req,
254        Err(res) => return Ok(res),
255    };
256
257    // P3: Circuit breaker check before upstream request
258    let upstream_host = forward_req
259        .uri()
260        .authority()
261        .map(|a| a.to_string())
262        .unwrap_or_else(|| "unknown".to_string());
263    if !circuit_breaker.allow_request(&upstream_host).await {
264        tracing::warn!(
265            "Circuit breaker open for upstream {}, returning 503",
266            upstream_host
267        );
268        return Ok(create_error_response(
269            StatusCode::SERVICE_UNAVAILABLE,
270            format!("Circuit breaker open for upstream {}", upstream_host),
271        ));
272    }
273
274    // Send Request
275    let upstream_start = std::time::Instant::now();
276    let res = match tokio::time::timeout(
277        std::time::Duration::from_millis(policy.request_timeout_ms),
278        client.request(forward_req),
279    )
280    .await
281    {
282        Ok(Ok(res)) => {
283            circuit_breaker.record_success(&upstream_host).await;
284            res
285        }
286        Ok(Err(e)) => {
287            circuit_breaker.record_failure(&upstream_host).await;
288            tracing::error!("Upstream request failed: {}", e);
289            if let Layer::Http(http) = &mut flow.layer {
290                http.error = Some(format!("Upstream Error: {}", e));
291            }
292            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
293                tracing::error!("Failed to send flow update on upstream error: {}", e);
294            }
295            return Ok(create_error_response(
296                StatusCode::BAD_GATEWAY,
297                format!("Upstream Error: {}", e),
298            ));
299        }
300        Err(_) => {
301            circuit_breaker.record_failure(&upstream_host).await;
302            tracing::error!("Upstream request timed out");
303            if let Layer::Http(http) = &mut flow.layer {
304                http.error = Some("Upstream Request Timed Out".to_string());
305            }
306            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
307                tracing::error!("Failed to send flow update on upstream timeout: {}", e);
308            }
309            return Ok(create_error_response(
310                StatusCode::GATEWAY_TIMEOUT,
311                "Upstream Request Timed Out",
312            ));
313        }
314    };
315
316    // Phase 3: Response Headers Interception
317    let (mut res_parts, res_body) = res.into_parts();
318
319    // Apply QUIC Downgrade
320    apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
321
322    update_flow_with_response_headers(
323        &mut flow,
324        res_parts.status,
325        res_parts.version,
326        &res_parts.headers,
327    );
328
329    let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
330    if let Layer::Http(http) = &mut flow.layer
331        && let Some(response) = &mut http.response
332    {
333        response.timing.time_to_first_byte = Some(ttfbs_ms);
334    }
335
336    match interceptor.on_response_headers(&mut flow).await {
337        InterceptionResult::Continue => {}
338        InterceptionResult::Drop => {
339            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
340                tracing::error!("Failed to send flow update on response drop: {}", e);
341            }
342            return Ok(create_error_response(
343                StatusCode::FORBIDDEN,
344                "Response dropped by policy",
345            ));
346        }
347        InterceptionResult::MockResponse(resp) => {
348            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
349                tracing::error!("Failed to send flow update on response mock: {}", e);
350            }
351            return Ok(mock_to_response(resp));
352        }
353        InterceptionResult::ModifiedResponse(resp) => {
354            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
355                tracing::error!("Failed to send flow update on response modification: {}", e);
356            }
357            return Ok(mock_to_response(resp));
358        }
359        _ => {}
360    }
361
362    // Phase 4: Response Body Streaming & Interception
363    let res_body: HttpBody = res_body
364        .map_frame(|f| f.map_data(|d| d))
365        .map_err(|e| e.into())
366        .boxed();
367
368    // Wrap in TapBody for streaming visualization BEFORE interception
369    let res_headers = if let Layer::Http(http) = &flow.layer {
370        http.response
371            .as_ref()
372            .map(|r| r.headers.clone())
373            .unwrap_or_default()
374    } else {
375        vec![]
376    };
377
378    let tap_res_body = TapBody::new(
379        res_body,
380        flow.id.to_string(),
381        on_flow.clone(),
382        Direction::ServerToClient,
383        policy.max_body_size,
384        res_headers,
385    );
386    let mut current_res_body = tap_res_body.boxed();
387
388    match interceptor.on_response(&mut flow, current_res_body).await {
389        Ok(ResponseAction::Continue(new_body)) => {
390            current_res_body = new_body;
391        }
392        Ok(ResponseAction::Drop) => {
393            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
394                tracing::error!("Failed to send flow update on response body drop: {}", e);
395            }
396            return Ok(create_error_response(
397                StatusCode::FORBIDDEN,
398                "Response dropped by interceptor",
399            ));
400        }
401        Ok(ResponseAction::ModifiedResponse(res)) => {
402            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
403                tracing::error!(
404                    "Failed to send flow update on response body modification: {}",
405                    e
406                );
407            }
408            let (parts, body) = res.into_parts();
409            return Ok(Response::from_parts(parts, body));
410        }
411        Err(e) => {
412            tracing::error!("Interceptor error on_response: {}", e);
413            return Ok(create_error_response(
414                StatusCode::INTERNAL_SERVER_ERROR,
415                format!("Interceptor Error: {}", e),
416            ));
417        }
418    }
419
420    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
421        tracing::error!("Failed to send final flow update: {}", e);
422    }
423
424    // Record time-to-last-byte as total upstream-to-client latency
425    if let Layer::Http(http) = &mut flow.layer
426        && let Some(response) = &mut http.response
427    {
428        response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
429    }
430
431    Ok(Response::from_parts(res_parts, current_res_body))
432}
433
434pub(crate) fn apply_quic_downgrade(
435    parts: &mut hyper::http::response::Parts,
436    flow: &mut relay_core_api::flow::Flow,
437    policy: &ProxyPolicy,
438) {
439    use relay_core_api::policy::QuicMode;
440    if policy.quic_mode == QuicMode::Downgrade {
441        if parts.headers.remove("Alt-Svc").is_some() {
442            flow.tags.push("quic-downgraded".to_string());
443        }
444        if policy.quic_downgrade_clear_cache {
445            parts.headers.insert(
446                "Clear-Site-Data",
447                hyper::header::HeaderValue::from_static("\"cache\""),
448            );
449        }
450    }
451}
452
453#[cfg(test)]
454mod http_tests;