1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12 build_forward_request, create_error_response, create_initial_flow, mock_to_response,
13 parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::outbound::OutboundConnector;
16use crate::proxy::tap::TapBody;
17use crate::proxy::tunnel;
18use crate::proxy::websocket::handle_websocket_handshake;
19use crate::tls::CertificateAuthority;
20use http_body_util::{BodyExt, Full};
21use hyper::body::{Body, Bytes, Incoming};
22use hyper::{Method, Request, Response, StatusCode};
23use relay_core_api::flow::{Direction, FlowUpdate, Layer, ResilienceTrace};
24use relay_core_api::policy::ProxyPolicy;
25
26#[allow(clippy::too_many_arguments)]
28pub async fn handle_request(
29 req: Request<Incoming>,
30 client_addr: SocketAddr,
31 on_flow: Sender<FlowUpdate>,
32 ca: Arc<CertificateAuthority>,
33 connector: Arc<dyn OutboundConnector>,
34 interceptor: Arc<dyn Interceptor>,
35 target_addr: Option<SocketAddr>,
36 policy_rx: watch::Receiver<ProxyPolicy>,
37 loop_detector: Arc<LoopDetector>,
38 circuit_breaker: Arc<CircuitBreaker>,
39) -> Result<Response<HttpBody>, Infallible> {
40 if req.method() == Method::CONNECT {
41 let host = if let Some(authority) = req.uri().authority() {
44 authority.to_string()
45 } else {
46 req.headers()
48 .get("Host")
49 .and_then(|v| v.to_str().ok())
50 .map(|s| s.to_string())
51 .unwrap_or_else(|| "unknown".to_string())
52 };
53
54 if host == "unknown" {
55 return Ok(create_error_response(
56 StatusCode::BAD_REQUEST,
57 "CONNECT must have authority",
58 ));
59 }
60
61 let loop_detector = loop_detector.clone();
62 let policy_rx = policy_rx.clone();
63
64 tokio::task::spawn(async move {
65 match hyper::upgrade::on(req).await {
66 Ok(upgraded) => {
67 if let Err(e) = tunnel::handle_tunnel(
68 upgraded,
69 host,
70 client_addr,
71 ca,
72 on_flow,
73 connector,
74 interceptor,
75 policy_rx,
76 target_addr,
77 loop_detector,
78 circuit_breaker,
79 )
80 .await
81 {
82 tracing::error!("Tunnel error: {}", e);
83 }
84 }
85 Err(e) => tracing::error!("Upgrade error: {}", e),
86 }
87 });
88 return Ok(Response::new(
89 Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
90 ));
91 }
92
93 handle_http_request(
95 req,
96 client_addr,
97 on_flow,
98 connector,
99 interceptor,
100 false,
101 policy_rx,
102 target_addr,
103 loop_detector,
104 circuit_breaker,
105 )
106 .await
107}
108
109#[allow(clippy::too_many_arguments)]
110pub(crate) async fn handle_http_request<B>(
111 req: Request<B>,
112 client_addr: SocketAddr,
113 on_flow: Sender<FlowUpdate>,
114 connector: Arc<dyn OutboundConnector>,
115 interceptor: Arc<dyn Interceptor>,
116 is_mitm: bool,
117 policy_rx: watch::Receiver<ProxyPolicy>,
118 target_addr: Option<SocketAddr>,
119 loop_detector: Arc<LoopDetector>,
120 circuit_breaker: Arc<CircuitBreaker>,
121) -> Result<Response<HttpBody>, Infallible>
122where
123 B: Body + Send + Sync + Unpin + 'static,
124 B::Data: Send + Into<Bytes>,
125 B::Error: Into<BoxError>,
126{
127 let policy = policy_rx.borrow().clone();
128
129 let request_budget_exceeded = if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
133 && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
134 && len > policy.max_body_size
135 {
136 true
137 } else {
138 false
139 };
140
141 let meta = parse_request_meta(&req, is_mitm);
143
144 let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
146
147 if request_budget_exceeded {
149 flow.tags.push("budget-exceeded".to_string());
150 flow.resilience_trace = Some(ResilienceTrace {
151 budget_exceeded: true,
152 ..flow.resilience_trace.clone().unwrap_or_default()
153 });
154 }
155
156 if hyper_tungstenite::is_upgrade_request(&req) {
158 return handle_websocket_handshake(
159 req,
160 client_addr,
161 on_flow,
162 connector,
163 interceptor,
164 is_mitm,
165 policy_rx,
166 target_addr,
167 loop_detector,
168 )
169 .await;
170 }
171
172 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
173 tracing::error!("Failed to send flow update: {}", e);
174 }
175
176 match interceptor.on_request_headers(&mut flow).await {
178 InterceptionResult::Continue => {}
179 InterceptionResult::Drop => {
180 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
181 tracing::error!("Failed to send flow update on drop: {}", e);
182 }
183 return Ok(create_error_response(
184 StatusCode::FORBIDDEN,
185 "Request dropped by policy",
186 ));
187 }
188 InterceptionResult::MockResponse(resp) => {
189 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
190 tracing::error!("Failed to send flow update on mock: {}", e);
191 }
192 return Ok(mock_to_response(resp));
193 }
194 InterceptionResult::ModifiedRequest(_) => {}
195 InterceptionResult::ModifiedResponse(res) => {
196 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
197 tracing::error!("Failed to send flow update on modified response: {}", e);
198 }
199 return Ok(mock_to_response(res));
200 }
201 _ => {}
202 }
203
204 let (_, body) = req.into_parts();
206 let body: HttpBody = body
207 .map_frame(|f| f.map_data(|d| d.into()))
208 .map_err(|e| e.into())
209 .boxed();
210
211 let req_headers = if let Layer::Http(http) = &flow.layer {
213 http.request.headers.clone()
214 } else {
215 vec![]
216 };
217
218 let tap_body = TapBody::new(
219 body,
220 flow.id.to_string(),
221 on_flow.clone(),
222 Direction::ClientToServer,
223 policy.max_body_size,
224 req_headers,
225 );
226 crate::metrics::inc_proxy_http_request();
227 let mut current_body = tap_body.boxed();
228
229 match interceptor.on_request(&mut flow, current_body).await {
230 Ok(RequestAction::Continue(new_body)) => {
231 current_body = new_body;
232 }
233 Ok(RequestAction::Drop) => {
234 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
235 tracing::error!("Failed to send flow update on request drop: {}", e);
236 }
237 return Ok(create_error_response(
238 StatusCode::FORBIDDEN,
239 "Request dropped by interceptor",
240 ));
241 }
242 Ok(RequestAction::MockResponse(res)) => {
243 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
244 tracing::error!("Failed to send flow update on request mock: {}", e);
245 }
246 let (parts, body) = res.into_parts();
247 return Ok(Response::from_parts(parts, body));
248 }
249 Err(e) => {
250 tracing::error!("Interceptor error on_request: {}", e);
251 return Ok(create_error_response(
252 StatusCode::INTERNAL_SERVER_ERROR,
253 format!("Interceptor Error: {}", e),
254 ));
255 }
256 }
257
258 if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
260 && let Ok(bps) = bps_str.parse::<u64>()
261 && bps > 0
262 {
263 current_body = crate::proxy::throttle::ThrottleBody::new(current_body, bps).boxed();
264 }
265
266 let forward_req = match build_forward_request(
267 &mut flow,
268 current_body,
269 target_addr,
270 &policy,
271 &loop_detector,
272 ) {
273 Ok(req) => req,
274 Err(res) => return Ok(res),
275 };
276
277 let circuit_breaker_key = connector
281 .upstream_proxy_url()
282 .map(|u| u.to_string())
283 .unwrap_or_else(|| {
284 forward_req
285 .uri()
286 .authority()
287 .map(|a| a.to_string())
288 .unwrap_or_else(|| "unknown".to_string())
289 });
290 if !circuit_breaker.allow_request(&circuit_breaker_key).await {
291 tracing::warn!(
292 "Circuit breaker open for upstream {}, returning 503",
293 circuit_breaker_key
294 );
295 flow.resilience_trace = Some(ResilienceTrace {
297 circuit_open: true,
298 ..flow.resilience_trace.clone().unwrap_or_default()
299 });
300 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
301 tracing::error!("Failed to send flow update on circuit breaker: {}", e);
302 }
303 return Ok(create_error_response(
304 StatusCode::SERVICE_UNAVAILABLE,
305 format!("Circuit breaker open for upstream {}", circuit_breaker_key),
306 ));
307 }
308
309 let upstream_start = std::time::Instant::now();
311 let target_host = forward_req.uri().host().unwrap_or("unknown").to_string();
312 let target_port = forward_req.uri().port_u16().unwrap_or(
313 if forward_req.uri().scheme_str() == Some("https") {
314 443
315 } else {
316 80
317 },
318 );
319 let res = match tokio::time::timeout(
320 std::time::Duration::from_millis(policy.request_timeout_ms),
321 connector.send_request(forward_req, &target_host, target_port, &mut flow),
322 )
323 .await
324 {
325 Ok(Ok(res)) => {
326 circuit_breaker.record_success(&circuit_breaker_key).await;
327 res
328 }
329 Ok(Err(e)) => {
330 circuit_breaker.record_failure(&circuit_breaker_key).await;
331 tracing::error!("Upstream request failed: {}", e);
332 flow.resilience_trace = Some(ResilienceTrace {
334 upstream_errors: vec![format!("Upstream Error: {}", e)],
335 ..flow.resilience_trace.clone().unwrap_or_default()
336 });
337 if let Layer::Http(http) = &mut flow.layer {
338 http.error = Some(format!("Upstream Error: {}", e));
339 }
340 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
341 tracing::error!("Failed to send flow update on upstream error: {}", e);
342 }
343 return Ok(create_error_response(
344 StatusCode::BAD_GATEWAY,
345 format!("Upstream Error: {}", e),
346 ));
347 }
348 Err(_) => {
349 circuit_breaker.record_failure(&circuit_breaker_key).await;
350 tracing::error!("Upstream request timed out");
351 flow.resilience_trace = Some(ResilienceTrace {
356 upstream_errors: vec!["Upstream Request Timed Out".to_string()],
357 timeout_type: Some("total".to_string()),
358 ..flow.resilience_trace.clone().unwrap_or_default()
359 });
360 if let Layer::Http(http) = &mut flow.layer {
361 http.error = Some("Upstream Request Timed Out".to_string());
362 }
363 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
364 tracing::error!("Failed to send flow update on upstream timeout: {}", e);
365 }
366 return Ok(create_error_response(
367 StatusCode::GATEWAY_TIMEOUT,
368 "Upstream Request Timed Out",
369 ));
370 }
371 };
372
373 let (mut res_parts, res_body) = res.into_parts();
375
376 apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
378
379 update_flow_with_response_headers(
380 &mut flow,
381 res_parts.status,
382 res_parts.version,
383 &res_parts.headers,
384 );
385
386 let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
387 if let Layer::Http(http) = &mut flow.layer
388 && let Some(response) = &mut http.response
389 {
390 response.timing.time_to_first_byte = Some(ttfbs_ms);
391 }
392
393 match interceptor.on_response_headers(&mut flow).await {
394 InterceptionResult::Continue => {}
395 InterceptionResult::Drop => {
396 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
397 tracing::error!("Failed to send flow update on response drop: {}", e);
398 }
399 return Ok(create_error_response(
400 StatusCode::FORBIDDEN,
401 "Response dropped by policy",
402 ));
403 }
404 InterceptionResult::MockResponse(resp) => {
405 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
406 tracing::error!("Failed to send flow update on response mock: {}", e);
407 }
408 return Ok(mock_to_response(resp));
409 }
410 InterceptionResult::ModifiedResponse(resp) => {
411 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
412 tracing::error!("Failed to send flow update on response modification: {}", e);
413 }
414 return Ok(mock_to_response(resp));
415 }
416 _ => {}
417 }
418
419 let res_body: HttpBody = res_body
421 .map_frame(|f| f.map_data(|d| d))
422 .map_err(|e| e.into())
423 .boxed();
424
425 let res_headers = if let Layer::Http(http) = &flow.layer {
427 http.response
428 .as_ref()
429 .map(|r| r.headers.clone())
430 .unwrap_or_default()
431 } else {
432 vec![]
433 };
434
435 let tap_res_body = TapBody::new(
436 res_body,
437 flow.id.to_string(),
438 on_flow.clone(),
439 Direction::ServerToClient,
440 policy.max_body_size,
441 res_headers,
442 );
443 let mut current_res_body = tap_res_body.boxed();
444
445 match interceptor.on_response(&mut flow, current_res_body).await {
446 Ok(ResponseAction::Continue(new_body)) => {
447 current_res_body = new_body;
448 }
449 Ok(ResponseAction::Drop) => {
450 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
451 tracing::error!("Failed to send flow update on response body drop: {}", e);
452 }
453 return Ok(create_error_response(
454 StatusCode::FORBIDDEN,
455 "Response dropped by interceptor",
456 ));
457 }
458 Ok(ResponseAction::ModifiedResponse(res)) => {
459 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
460 tracing::error!(
461 "Failed to send flow update on response body modification: {}",
462 e
463 );
464 }
465 let (parts, body) = res.into_parts();
466 return Ok(Response::from_parts(parts, body));
467 }
468 Err(e) => {
469 tracing::error!("Interceptor error on_response: {}", e);
470 return Ok(create_error_response(
471 StatusCode::INTERNAL_SERVER_ERROR,
472 format!("Interceptor Error: {}", e),
473 ));
474 }
475 }
476
477 if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
479 && let Ok(bps) = bps_str.parse::<u64>()
480 && bps > 0
481 {
482 current_res_body = crate::proxy::throttle::ThrottleBody::new(current_res_body, bps).boxed();
483 }
484
485 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
486 tracing::error!("Failed to send final flow update: {}", e);
487 }
488
489 if let Layer::Http(http) = &mut flow.layer
491 && let Some(response) = &mut http.response
492 {
493 response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
494 }
495
496 Ok(Response::from_parts(res_parts, current_res_body))
497}
498
499pub(crate) fn apply_quic_downgrade(
500 parts: &mut hyper::http::response::Parts,
501 flow: &mut relay_core_api::flow::Flow,
502 policy: &ProxyPolicy,
503) {
504 use relay_core_api::policy::QuicMode;
505 if policy.quic_mode == QuicMode::Downgrade {
506 if parts.headers.remove("Alt-Svc").is_some() {
507 flow.tags.push("quic-downgraded".to_string());
508 }
509 if policy.quic_downgrade_clear_cache {
510 parts.headers.insert(
511 "Clear-Site-Data",
512 hyper::header::HeaderValue::from_static("\"cache\""),
513 );
514 }
515 }
516}
517
518#[cfg(test)]
519mod http_tests;