Skip to main content

relay_core_lib/proxy/
http.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use hyper::{Request, Response, Method, StatusCode};
7use hyper::body::{Bytes, Incoming, Body};
8use http_body_util::{BodyExt, Full};
9use relay_core_api::flow::{FlowUpdate, Layer, Direction};
10use relay_core_api::policy::ProxyPolicy;
11use crate::interceptor::{Interceptor, InterceptionResult, RequestAction, ResponseAction, HttpBody, BoxError};
12use crate::tls::CertificateAuthority;
13use crate::proxy::http_utils::{
14    create_initial_flow, 
15    mock_to_response, parse_request_meta, create_error_response, HttpsClient,
16    build_forward_request, update_flow_with_response_headers,
17};
18use crate::proxy::tunnel;
19use crate::proxy::websocket::handle_websocket_handshake;
20use crate::capture::loop_detection::LoopDetector;
21use crate::proxy::tap::TapBody;
22
23/// Main entry point for HTTP Proxy handling
24#[allow(clippy::too_many_arguments)]
25pub async fn handle_request(
26    req: Request<Incoming>,
27    client_addr: SocketAddr,
28    on_flow: Sender<FlowUpdate>,
29    ca: Arc<CertificateAuthority>,
30    client: Arc<HttpsClient>,
31    interceptor: Arc<dyn Interceptor>,
32    target_addr: Option<SocketAddr>,
33    policy_rx: watch::Receiver<ProxyPolicy>,
34    loop_detector: Arc<LoopDetector>,
35) -> Result<Response<HttpBody>, Infallible>
36{
37    if req.method() == Method::CONNECT {
38        // Handle CONNECT (HTTPS Tunnel)
39        // Extract host from authority
40        let host = if let Some(authority) = req.uri().authority() {
41            authority.to_string()
42        } else {
43            // Fallback: try to get from Host header
44             req.headers().get("Host")
45                .and_then(|v| v.to_str().ok())
46                .map(|s| s.to_string())
47                .unwrap_or_else(|| "unknown".to_string())
48        };
49
50        if host == "unknown" {
51             return Ok(create_error_response(StatusCode::BAD_REQUEST, "CONNECT must have authority"));
52        }
53
54        let loop_detector = loop_detector.clone();
55        let policy_rx = policy_rx.clone();
56
57        tokio::task::spawn(async move {
58            match hyper::upgrade::on(req).await {
59                Ok(upgraded) => {
60                    if let Err(e) = tunnel::handle_tunnel(
61                        upgraded,
62                        host,
63                        client_addr,
64                        ca,
65                        on_flow,
66                        client,
67                        interceptor,
68                        policy_rx,
69                        target_addr,
70                        loop_detector,
71                    ).await {
72                        tracing::error!("Tunnel error: {}", e);
73                    }
74                },
75                Err(e) => tracing::error!("Upgrade error: {}", e),
76            }
77        });
78        return Ok(Response::new(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()));
79    }
80
81    // Handle Standard HTTP / WebSocket
82    handle_http_request(req, client_addr, on_flow, client, interceptor, false, policy_rx, target_addr, loop_detector).await
83}
84
85#[allow(clippy::too_many_arguments)]
86pub(crate) async fn handle_http_request<B>(
87    req: Request<B>,
88    client_addr: SocketAddr,
89    on_flow: Sender<FlowUpdate>,
90    client: Arc<HttpsClient>,
91    interceptor: Arc<dyn Interceptor>,
92    is_mitm: bool,
93    policy_rx: watch::Receiver<ProxyPolicy>,
94    target_addr: Option<SocketAddr>,
95    loop_detector: Arc<LoopDetector>,
96) -> Result<Response<HttpBody>, Infallible>
97where
98    B: Body + Send + Sync + Unpin + 'static,
99    B::Data: Send + Into<Bytes>,
100    B::Error: Into<BoxError>,
101{
102    let policy = policy_rx.borrow().clone();
103    
104    // Check Content-Length against policy
105    if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH) {
106        if let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>() {
107            if len > policy.max_body_size {
108                return Ok(create_error_response(StatusCode::PAYLOAD_TOO_LARGE, "Request body too large"));
109            }
110        }
111    }
112
113    // Create Flow
114    let meta = parse_request_meta(&req, is_mitm);
115    
116    // Note: We don't read body here for streaming support
117    let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
118    
119    // Check for WebSocket
120    if hyper_tungstenite::is_upgrade_request(&req) {
121        return handle_websocket_handshake(req, client_addr, on_flow, client, interceptor, is_mitm, policy_rx, target_addr, loop_detector).await;
122    }
123
124    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
125        tracing::error!("Failed to send flow update: {}", e);
126    }
127
128    // Phase 1: Request Headers Interception
129    match interceptor.on_request_headers(&mut flow).await {
130        InterceptionResult::Continue => {},
131        InterceptionResult::Drop => {
132             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
133                 tracing::error!("Failed to send flow update on drop: {}", e);
134             }
135             return Ok(create_error_response(StatusCode::FORBIDDEN, "Request dropped by policy"));
136        },
137        InterceptionResult::MockResponse(resp) => {
138             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
139                 tracing::error!("Failed to send flow update on mock: {}", e);
140             }
141             return Ok(mock_to_response(resp));
142        },
143        InterceptionResult::ModifiedRequest(_) => {},
144        InterceptionResult::ModifiedResponse(res) => {
145             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
146                 tracing::error!("Failed to send flow update on modified response: {}", e);
147             }
148             return Ok(mock_to_response(res));
149        },
150        _ => {}
151    }
152
153    // Phase 2: Request Body Streaming & Interception
154    let (parts, body) = req.into_parts();
155    let body: HttpBody = body.map_frame(|f| f.map_data(|d| d.into())).map_err(|e| e.into()).boxed();
156    
157    // Wrap in TapBody for streaming visualization BEFORE interception
158    let req_headers = if let Layer::Http(http) = &flow.layer {
159        http.request.headers.clone()
160    } else {
161        vec![]
162    };
163
164    let tap_body = TapBody::new(
165        body,
166        flow.id.to_string(),
167        on_flow.clone(),
168        Direction::ClientToServer,
169        policy.max_body_size,
170        req_headers,
171    );
172    let mut current_body = tap_body.boxed();
173    
174    match interceptor.on_request(&mut flow, current_body).await {
175        Ok(RequestAction::Continue(new_body)) => {
176            current_body = new_body;
177        },
178        Ok(RequestAction::Drop) => {
179             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
180                 tracing::error!("Failed to send flow update on request drop: {}", e);
181             }
182             return Ok(create_error_response(StatusCode::FORBIDDEN, "Request dropped by interceptor"));
183        },
184        Ok(RequestAction::MockResponse(res)) => {
185             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
186                 tracing::error!("Failed to send flow update on request mock: {}", e);
187             }
188             let (parts, body) = res.into_parts();
189             return Ok(Response::from_parts(parts, body));
190        },
191        Err(e) => {
192             tracing::error!("Interceptor error on_request: {}", e);
193             return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
194        }
195    }
196    
197    let forward_req = match build_forward_request(&mut flow, current_body, &parts, target_addr, &policy, &loop_detector) {
198        Ok(req) => req,
199        Err(res) => return Ok(res),
200    };
201    
202    // Send Request
203    let res = match tokio::time::timeout(std::time::Duration::from_millis(policy.request_timeout_ms), client.request(forward_req)).await {
204        Ok(Ok(res)) => res,
205        Ok(Err(e)) => {
206            tracing::error!("Upstream request failed: {}", e);
207             if let Layer::Http(http) = &mut flow.layer {
208                http.error = Some(format!("Upstream Error: {}", e));
209            }
210            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
211                tracing::error!("Failed to send flow update on upstream error: {}", e);
212            }
213            return Ok(create_error_response(StatusCode::BAD_GATEWAY, format!("Upstream Error: {}", e)));
214        },
215        Err(_) => {
216            tracing::error!("Upstream request timed out");
217             if let Layer::Http(http) = &mut flow.layer {
218                http.error = Some("Upstream Request Timed Out".to_string());
219            }
220            if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
221                tracing::error!("Failed to send flow update on upstream timeout: {}", e);
222            }
223            return Ok(create_error_response(StatusCode::GATEWAY_TIMEOUT, "Upstream Request Timed Out"));
224        }
225    };
226    
227    // Phase 3: Response Headers Interception
228    let (mut res_parts, res_body) = res.into_parts();
229
230    // Apply QUIC Downgrade
231    apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
232    
233    update_flow_with_response_headers(&mut flow, res_parts.status, res_parts.version, &res_parts.headers);
234    
235    match interceptor.on_response_headers(&mut flow).await {
236        InterceptionResult::Continue => {},
237        InterceptionResult::Drop => {
238             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
239                 tracing::error!("Failed to send flow update on response drop: {}", e);
240             }
241             return Ok(create_error_response(StatusCode::FORBIDDEN, "Response dropped by policy"));
242        },
243        InterceptionResult::MockResponse(resp) => {
244             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
245                 tracing::error!("Failed to send flow update on response mock: {}", e);
246             }
247             return Ok(mock_to_response(resp));
248        },
249        InterceptionResult::ModifiedResponse(resp) => {
250             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
251                 tracing::error!("Failed to send flow update on response modification: {}", e);
252             }
253             return Ok(mock_to_response(resp));
254        },
255        _ => {}
256    }
257    
258    // Phase 4: Response Body Streaming & Interception
259    let res_body: HttpBody = res_body.map_frame(|f| f.map_data(|d| d)).map_err(|e| e.into()).boxed();
260    
261    // Wrap in TapBody for streaming visualization BEFORE interception
262    let res_headers = if let Layer::Http(http) = &flow.layer {
263        http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default()
264    } else {
265        vec![]
266    };
267
268    let tap_res_body = TapBody::new(
269        res_body,
270        flow.id.to_string(),
271        on_flow.clone(),
272        Direction::ServerToClient,
273        policy.max_body_size,
274        res_headers,
275    );
276    let mut current_res_body = tap_res_body.boxed();
277
278    match interceptor.on_response(&mut flow, current_res_body).await {
279        Ok(ResponseAction::Continue(new_body)) => {
280            current_res_body = new_body;
281        },
282        Ok(ResponseAction::Drop) => {
283             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
284                 tracing::error!("Failed to send flow update on response body drop: {}", e);
285             }
286             return Ok(create_error_response(StatusCode::FORBIDDEN, "Response dropped by interceptor"));
287        },
288        Ok(ResponseAction::ModifiedResponse(res)) => {
289             if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
290                 tracing::error!("Failed to send flow update on response body modification: {}", e);
291             }
292             let (parts, body) = res.into_parts();
293             return Ok(Response::from_parts(parts, body));
294        },
295        Err(e) => {
296             tracing::error!("Interceptor error on_response: {}", e);
297             return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
298        }
299    }
300
301    if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
302        tracing::error!("Failed to send final flow update: {}", e);
303    }
304    
305    Ok(Response::from_parts(res_parts, current_res_body))
306}
307
308pub(crate) fn apply_quic_downgrade(parts: &mut hyper::http::response::Parts, flow: &mut relay_core_api::flow::Flow, policy: &ProxyPolicy) {
309    use relay_core_api::policy::QuicMode;
310    if policy.quic_mode == QuicMode::Downgrade {
311         if parts.headers.remove("Alt-Svc").is_some() {
312             flow.tags.push("quic-downgraded".to_string());
313         }
314         if policy.quic_downgrade_clear_cache {
315             parts.headers.insert("Clear-Site-Data", hyper::header::HeaderValue::from_static("\"cache\""));
316        }
317    }
318}
319
320#[cfg(test)]
321mod http_tests;