1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12 HttpsClient, build_forward_request, create_error_response, create_initial_flow,
13 mock_to_response, parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::tap::TapBody;
16use crate::proxy::tunnel;
17use crate::proxy::websocket::handle_websocket_handshake;
18use crate::tls::CertificateAuthority;
19use http_body_util::{BodyExt, Full};
20use hyper::body::{Body, Bytes, Incoming};
21use hyper::{Method, Request, Response, StatusCode};
22use relay_core_api::flow::{Direction, FlowUpdate, Layer};
23use relay_core_api::policy::ProxyPolicy;
24
25#[allow(clippy::too_many_arguments)]
27pub async fn handle_request(
28 req: Request<Incoming>,
29 client_addr: SocketAddr,
30 on_flow: Sender<FlowUpdate>,
31 ca: Arc<CertificateAuthority>,
32 client: Arc<HttpsClient>,
33 interceptor: Arc<dyn Interceptor>,
34 target_addr: Option<SocketAddr>,
35 policy_rx: watch::Receiver<ProxyPolicy>,
36 loop_detector: Arc<LoopDetector>,
37 circuit_breaker: Arc<CircuitBreaker>,
38) -> Result<Response<HttpBody>, Infallible> {
39 if req.method() == Method::CONNECT {
40 let host = if let Some(authority) = req.uri().authority() {
43 authority.to_string()
44 } else {
45 req.headers()
47 .get("Host")
48 .and_then(|v| v.to_str().ok())
49 .map(|s| s.to_string())
50 .unwrap_or_else(|| "unknown".to_string())
51 };
52
53 if host == "unknown" {
54 return Ok(create_error_response(
55 StatusCode::BAD_REQUEST,
56 "CONNECT must have authority",
57 ));
58 }
59
60 let loop_detector = loop_detector.clone();
61 let policy_rx = policy_rx.clone();
62
63 tokio::task::spawn(async move {
64 match hyper::upgrade::on(req).await {
65 Ok(upgraded) => {
66 if let Err(e) = tunnel::handle_tunnel(
67 upgraded,
68 host,
69 client_addr,
70 ca,
71 on_flow,
72 client,
73 interceptor,
74 policy_rx,
75 target_addr,
76 loop_detector,
77 circuit_breaker,
78 )
79 .await
80 {
81 tracing::error!("Tunnel error: {}", e);
82 }
83 }
84 Err(e) => tracing::error!("Upgrade error: {}", e),
85 }
86 });
87 return Ok(Response::new(
88 Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
89 ));
90 }
91
92 handle_http_request(
94 req,
95 client_addr,
96 on_flow,
97 client,
98 interceptor,
99 false,
100 policy_rx,
101 target_addr,
102 loop_detector,
103 circuit_breaker,
104 )
105 .await
106}
107
108#[allow(clippy::too_many_arguments)]
109pub(crate) async fn handle_http_request<B>(
110 req: Request<B>,
111 client_addr: SocketAddr,
112 on_flow: Sender<FlowUpdate>,
113 client: Arc<HttpsClient>,
114 interceptor: Arc<dyn Interceptor>,
115 is_mitm: bool,
116 policy_rx: watch::Receiver<ProxyPolicy>,
117 target_addr: Option<SocketAddr>,
118 loop_detector: Arc<LoopDetector>,
119 circuit_breaker: Arc<CircuitBreaker>,
120) -> Result<Response<HttpBody>, Infallible>
121where
122 B: Body + Send + Sync + Unpin + 'static,
123 B::Data: Send + Into<Bytes>,
124 B::Error: Into<BoxError>,
125{
126 let policy = policy_rx.borrow().clone();
127
128 if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
130 && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
131 && len > policy.max_body_size
132 {
133 return Ok(create_error_response(
134 StatusCode::PAYLOAD_TOO_LARGE,
135 "Request body too large",
136 ));
137 }
138
139 let meta = parse_request_meta(&req, is_mitm);
141
142 let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
144
145 if hyper_tungstenite::is_upgrade_request(&req) {
147 return handle_websocket_handshake(
148 req,
149 client_addr,
150 on_flow,
151 client,
152 interceptor,
153 is_mitm,
154 policy_rx,
155 target_addr,
156 loop_detector,
157 )
158 .await;
159 }
160
161 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
162 tracing::error!("Failed to send flow update: {}", e);
163 }
164
165 match interceptor.on_request_headers(&mut flow).await {
167 InterceptionResult::Continue => {}
168 InterceptionResult::Drop => {
169 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
170 tracing::error!("Failed to send flow update on drop: {}", e);
171 }
172 return Ok(create_error_response(
173 StatusCode::FORBIDDEN,
174 "Request dropped by policy",
175 ));
176 }
177 InterceptionResult::MockResponse(resp) => {
178 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
179 tracing::error!("Failed to send flow update on mock: {}", e);
180 }
181 return Ok(mock_to_response(resp));
182 }
183 InterceptionResult::ModifiedRequest(_) => {}
184 InterceptionResult::ModifiedResponse(res) => {
185 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
186 tracing::error!("Failed to send flow update on modified response: {}", e);
187 }
188 return Ok(mock_to_response(res));
189 }
190 _ => {}
191 }
192
193 let (_, body) = req.into_parts();
195 let body: HttpBody = body
196 .map_frame(|f| f.map_data(|d| d.into()))
197 .map_err(|e| e.into())
198 .boxed();
199
200 let req_headers = if let Layer::Http(http) = &flow.layer {
202 http.request.headers.clone()
203 } else {
204 vec![]
205 };
206
207 let tap_body = TapBody::new(
208 body,
209 flow.id.to_string(),
210 on_flow.clone(),
211 Direction::ClientToServer,
212 policy.max_body_size,
213 req_headers,
214 );
215 let mut current_body = tap_body.boxed();
216
217 match interceptor.on_request(&mut flow, current_body).await {
218 Ok(RequestAction::Continue(new_body)) => {
219 current_body = new_body;
220 }
221 Ok(RequestAction::Drop) => {
222 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
223 tracing::error!("Failed to send flow update on request drop: {}", e);
224 }
225 return Ok(create_error_response(
226 StatusCode::FORBIDDEN,
227 "Request dropped by interceptor",
228 ));
229 }
230 Ok(RequestAction::MockResponse(res)) => {
231 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
232 tracing::error!("Failed to send flow update on request mock: {}", e);
233 }
234 let (parts, body) = res.into_parts();
235 return Ok(Response::from_parts(parts, body));
236 }
237 Err(e) => {
238 tracing::error!("Interceptor error on_request: {}", e);
239 return Ok(create_error_response(
240 StatusCode::INTERNAL_SERVER_ERROR,
241 format!("Interceptor Error: {}", e),
242 ));
243 }
244 }
245
246 let forward_req = match build_forward_request(
247 &mut flow,
248 current_body,
249 target_addr,
250 &policy,
251 &loop_detector,
252 ) {
253 Ok(req) => req,
254 Err(res) => return Ok(res),
255 };
256
257 let upstream_host = forward_req
259 .uri()
260 .authority()
261 .map(|a| a.to_string())
262 .unwrap_or_else(|| "unknown".to_string());
263 if !circuit_breaker.allow_request(&upstream_host).await {
264 tracing::warn!(
265 "Circuit breaker open for upstream {}, returning 503",
266 upstream_host
267 );
268 return Ok(create_error_response(
269 StatusCode::SERVICE_UNAVAILABLE,
270 format!("Circuit breaker open for upstream {}", upstream_host),
271 ));
272 }
273
274 let upstream_start = std::time::Instant::now();
276 let res = match tokio::time::timeout(
277 std::time::Duration::from_millis(policy.request_timeout_ms),
278 client.request(forward_req),
279 )
280 .await
281 {
282 Ok(Ok(res)) => {
283 circuit_breaker.record_success(&upstream_host).await;
284 res
285 }
286 Ok(Err(e)) => {
287 circuit_breaker.record_failure(&upstream_host).await;
288 tracing::error!("Upstream request failed: {}", e);
289 if let Layer::Http(http) = &mut flow.layer {
290 http.error = Some(format!("Upstream Error: {}", e));
291 }
292 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
293 tracing::error!("Failed to send flow update on upstream error: {}", e);
294 }
295 return Ok(create_error_response(
296 StatusCode::BAD_GATEWAY,
297 format!("Upstream Error: {}", e),
298 ));
299 }
300 Err(_) => {
301 circuit_breaker.record_failure(&upstream_host).await;
302 tracing::error!("Upstream request timed out");
303 if let Layer::Http(http) = &mut flow.layer {
304 http.error = Some("Upstream Request Timed Out".to_string());
305 }
306 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
307 tracing::error!("Failed to send flow update on upstream timeout: {}", e);
308 }
309 return Ok(create_error_response(
310 StatusCode::GATEWAY_TIMEOUT,
311 "Upstream Request Timed Out",
312 ));
313 }
314 };
315
316 let (mut res_parts, res_body) = res.into_parts();
318
319 apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
321
322 update_flow_with_response_headers(
323 &mut flow,
324 res_parts.status,
325 res_parts.version,
326 &res_parts.headers,
327 );
328
329 let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
330 if let Layer::Http(http) = &mut flow.layer
331 && let Some(response) = &mut http.response
332 {
333 response.timing.time_to_first_byte = Some(ttfbs_ms);
334 }
335
336 match interceptor.on_response_headers(&mut flow).await {
337 InterceptionResult::Continue => {}
338 InterceptionResult::Drop => {
339 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
340 tracing::error!("Failed to send flow update on response drop: {}", e);
341 }
342 return Ok(create_error_response(
343 StatusCode::FORBIDDEN,
344 "Response dropped by policy",
345 ));
346 }
347 InterceptionResult::MockResponse(resp) => {
348 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
349 tracing::error!("Failed to send flow update on response mock: {}", e);
350 }
351 return Ok(mock_to_response(resp));
352 }
353 InterceptionResult::ModifiedResponse(resp) => {
354 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
355 tracing::error!("Failed to send flow update on response modification: {}", e);
356 }
357 return Ok(mock_to_response(resp));
358 }
359 _ => {}
360 }
361
362 let res_body: HttpBody = res_body
364 .map_frame(|f| f.map_data(|d| d))
365 .map_err(|e| e.into())
366 .boxed();
367
368 let res_headers = if let Layer::Http(http) = &flow.layer {
370 http.response
371 .as_ref()
372 .map(|r| r.headers.clone())
373 .unwrap_or_default()
374 } else {
375 vec![]
376 };
377
378 let tap_res_body = TapBody::new(
379 res_body,
380 flow.id.to_string(),
381 on_flow.clone(),
382 Direction::ServerToClient,
383 policy.max_body_size,
384 res_headers,
385 );
386 let mut current_res_body = tap_res_body.boxed();
387
388 match interceptor.on_response(&mut flow, current_res_body).await {
389 Ok(ResponseAction::Continue(new_body)) => {
390 current_res_body = new_body;
391 }
392 Ok(ResponseAction::Drop) => {
393 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
394 tracing::error!("Failed to send flow update on response body drop: {}", e);
395 }
396 return Ok(create_error_response(
397 StatusCode::FORBIDDEN,
398 "Response dropped by interceptor",
399 ));
400 }
401 Ok(ResponseAction::ModifiedResponse(res)) => {
402 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
403 tracing::error!(
404 "Failed to send flow update on response body modification: {}",
405 e
406 );
407 }
408 let (parts, body) = res.into_parts();
409 return Ok(Response::from_parts(parts, body));
410 }
411 Err(e) => {
412 tracing::error!("Interceptor error on_response: {}", e);
413 return Ok(create_error_response(
414 StatusCode::INTERNAL_SERVER_ERROR,
415 format!("Interceptor Error: {}", e),
416 ));
417 }
418 }
419
420 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
421 tracing::error!("Failed to send final flow update: {}", e);
422 }
423
424 if let Layer::Http(http) = &mut flow.layer
426 && let Some(response) = &mut http.response
427 {
428 response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
429 }
430
431 Ok(Response::from_parts(res_parts, current_res_body))
432}
433
434pub(crate) fn apply_quic_downgrade(
435 parts: &mut hyper::http::response::Parts,
436 flow: &mut relay_core_api::flow::Flow,
437 policy: &ProxyPolicy,
438) {
439 use relay_core_api::policy::QuicMode;
440 if policy.quic_mode == QuicMode::Downgrade {
441 if parts.headers.remove("Alt-Svc").is_some() {
442 flow.tags.push("quic-downgraded".to_string());
443 }
444 if policy.quic_downgrade_clear_cache {
445 parts.headers.insert(
446 "Clear-Site-Data",
447 hyper::header::HeaderValue::from_static("\"cache\""),
448 );
449 }
450 }
451}
452
453#[cfg(test)]
454mod http_tests;