1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use crate::capture::loop_detection::LoopDetector;
7use crate::interceptor::{
8 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, ResponseAction,
9};
10use crate::proxy::circuit_breaker::CircuitBreaker;
11use crate::proxy::http_utils::{
12 HttpsClient, build_forward_request, create_error_response, create_initial_flow,
13 mock_to_response, parse_request_meta, update_flow_with_response_headers,
14};
15use crate::proxy::tap::TapBody;
16use crate::proxy::tunnel;
17use crate::proxy::websocket::handle_websocket_handshake;
18use crate::tls::CertificateAuthority;
19use http_body_util::{BodyExt, Full};
20use hyper::body::{Body, Bytes, Incoming};
21use hyper::{Method, Request, Response, StatusCode};
22use relay_core_api::flow::{Direction, FlowUpdate, Layer, ResilienceTrace};
23use relay_core_api::policy::ProxyPolicy;
24
25#[allow(clippy::too_many_arguments)]
27pub async fn handle_request(
28 req: Request<Incoming>,
29 client_addr: SocketAddr,
30 on_flow: Sender<FlowUpdate>,
31 ca: Arc<CertificateAuthority>,
32 client: Arc<HttpsClient>,
33 interceptor: Arc<dyn Interceptor>,
34 target_addr: Option<SocketAddr>,
35 policy_rx: watch::Receiver<ProxyPolicy>,
36 loop_detector: Arc<LoopDetector>,
37 circuit_breaker: Arc<CircuitBreaker>,
38) -> Result<Response<HttpBody>, Infallible> {
39 if req.method() == Method::CONNECT {
40 let host = if let Some(authority) = req.uri().authority() {
43 authority.to_string()
44 } else {
45 req.headers()
47 .get("Host")
48 .and_then(|v| v.to_str().ok())
49 .map(|s| s.to_string())
50 .unwrap_or_else(|| "unknown".to_string())
51 };
52
53 if host == "unknown" {
54 return Ok(create_error_response(
55 StatusCode::BAD_REQUEST,
56 "CONNECT must have authority",
57 ));
58 }
59
60 let loop_detector = loop_detector.clone();
61 let policy_rx = policy_rx.clone();
62
63 tokio::task::spawn(async move {
64 match hyper::upgrade::on(req).await {
65 Ok(upgraded) => {
66 if let Err(e) = tunnel::handle_tunnel(
67 upgraded,
68 host,
69 client_addr,
70 ca,
71 on_flow,
72 client,
73 interceptor,
74 policy_rx,
75 target_addr,
76 loop_detector,
77 circuit_breaker,
78 )
79 .await
80 {
81 tracing::error!("Tunnel error: {}", e);
82 }
83 }
84 Err(e) => tracing::error!("Upgrade error: {}", e),
85 }
86 });
87 return Ok(Response::new(
88 Full::new(Bytes::new()).map_err(|e| e.into()).boxed(),
89 ));
90 }
91
92 handle_http_request(
94 req,
95 client_addr,
96 on_flow,
97 client,
98 interceptor,
99 false,
100 policy_rx,
101 target_addr,
102 loop_detector,
103 circuit_breaker,
104 )
105 .await
106}
107
108#[allow(clippy::too_many_arguments)]
109pub(crate) async fn handle_http_request<B>(
110 req: Request<B>,
111 client_addr: SocketAddr,
112 on_flow: Sender<FlowUpdate>,
113 client: Arc<HttpsClient>,
114 interceptor: Arc<dyn Interceptor>,
115 is_mitm: bool,
116 policy_rx: watch::Receiver<ProxyPolicy>,
117 target_addr: Option<SocketAddr>,
118 loop_detector: Arc<LoopDetector>,
119 circuit_breaker: Arc<CircuitBreaker>,
120) -> Result<Response<HttpBody>, Infallible>
121where
122 B: Body + Send + Sync + Unpin + 'static,
123 B::Data: Send + Into<Bytes>,
124 B::Error: Into<BoxError>,
125{
126 let policy = policy_rx.borrow().clone();
127
128 let request_budget_exceeded = if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
132 && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
133 && len > policy.max_body_size
134 {
135 true
136 } else {
137 false
138 };
139
140 let meta = parse_request_meta(&req, is_mitm);
142
143 let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
145
146 if request_budget_exceeded {
148 flow.tags.push("budget-exceeded".to_string());
149 flow.resilience_trace = Some(ResilienceTrace {
150 budget_exceeded: true,
151 ..flow.resilience_trace.clone().unwrap_or_default()
152 });
153 }
154
155 if hyper_tungstenite::is_upgrade_request(&req) {
157 return handle_websocket_handshake(
158 req,
159 client_addr,
160 on_flow,
161 client,
162 interceptor,
163 is_mitm,
164 policy_rx,
165 target_addr,
166 loop_detector,
167 )
168 .await;
169 }
170
171 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
172 tracing::error!("Failed to send flow update: {}", e);
173 }
174
175 match interceptor.on_request_headers(&mut flow).await {
177 InterceptionResult::Continue => {}
178 InterceptionResult::Drop => {
179 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
180 tracing::error!("Failed to send flow update on drop: {}", e);
181 }
182 return Ok(create_error_response(
183 StatusCode::FORBIDDEN,
184 "Request dropped by policy",
185 ));
186 }
187 InterceptionResult::MockResponse(resp) => {
188 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
189 tracing::error!("Failed to send flow update on mock: {}", e);
190 }
191 return Ok(mock_to_response(resp));
192 }
193 InterceptionResult::ModifiedRequest(_) => {}
194 InterceptionResult::ModifiedResponse(res) => {
195 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
196 tracing::error!("Failed to send flow update on modified response: {}", e);
197 }
198 return Ok(mock_to_response(res));
199 }
200 _ => {}
201 }
202
203 let (_, body) = req.into_parts();
205 let body: HttpBody = body
206 .map_frame(|f| f.map_data(|d| d.into()))
207 .map_err(|e| e.into())
208 .boxed();
209
210 let req_headers = if let Layer::Http(http) = &flow.layer {
212 http.request.headers.clone()
213 } else {
214 vec![]
215 };
216
217 let tap_body = TapBody::new(
218 body,
219 flow.id.to_string(),
220 on_flow.clone(),
221 Direction::ClientToServer,
222 policy.max_body_size,
223 req_headers,
224 );
225 crate::metrics::inc_proxy_http_request();
226 let mut current_body = tap_body.boxed();
227
228 match interceptor.on_request(&mut flow, current_body).await {
229 Ok(RequestAction::Continue(new_body)) => {
230 current_body = new_body;
231 }
232 Ok(RequestAction::Drop) => {
233 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
234 tracing::error!("Failed to send flow update on request drop: {}", e);
235 }
236 return Ok(create_error_response(
237 StatusCode::FORBIDDEN,
238 "Request dropped by interceptor",
239 ));
240 }
241 Ok(RequestAction::MockResponse(res)) => {
242 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
243 tracing::error!("Failed to send flow update on request mock: {}", e);
244 }
245 let (parts, body) = res.into_parts();
246 return Ok(Response::from_parts(parts, body));
247 }
248 Err(e) => {
249 tracing::error!("Interceptor error on_request: {}", e);
250 return Ok(create_error_response(
251 StatusCode::INTERNAL_SERVER_ERROR,
252 format!("Interceptor Error: {}", e),
253 ));
254 }
255 }
256
257 if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
259 && let Ok(bps) = bps_str.parse::<u64>()
260 && bps > 0
261 {
262 current_body = crate::proxy::throttle::ThrottleBody::new(current_body, bps).boxed();
263 }
264
265 let forward_req = match build_forward_request(
266 &mut flow,
267 current_body,
268 target_addr,
269 &policy,
270 &loop_detector,
271 ) {
272 Ok(req) => req,
273 Err(res) => return Ok(res),
274 };
275
276 let upstream_host = forward_req
278 .uri()
279 .authority()
280 .map(|a| a.to_string())
281 .unwrap_or_else(|| "unknown".to_string());
282 if !circuit_breaker.allow_request(&upstream_host).await {
283 tracing::warn!(
284 "Circuit breaker open for upstream {}, returning 503",
285 upstream_host
286 );
287 flow.resilience_trace = Some(ResilienceTrace {
289 circuit_open: true,
290 ..flow.resilience_trace.clone().unwrap_or_default()
291 });
292 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
293 tracing::error!("Failed to send flow update on circuit breaker: {}", e);
294 }
295 return Ok(create_error_response(
296 StatusCode::SERVICE_UNAVAILABLE,
297 format!("Circuit breaker open for upstream {}", upstream_host),
298 ));
299 }
300
301 let upstream_start = std::time::Instant::now();
303 let res = match tokio::time::timeout(
304 std::time::Duration::from_millis(policy.request_timeout_ms),
305 client.request(forward_req),
306 )
307 .await
308 {
309 Ok(Ok(res)) => {
310 circuit_breaker.record_success(&upstream_host).await;
311 res
312 }
313 Ok(Err(e)) => {
314 circuit_breaker.record_failure(&upstream_host).await;
315 tracing::error!("Upstream request failed: {}", e);
316 flow.resilience_trace = Some(ResilienceTrace {
318 upstream_errors: vec![format!("Upstream Error: {}", e)],
319 ..flow.resilience_trace.clone().unwrap_or_default()
320 });
321 if let Layer::Http(http) = &mut flow.layer {
322 http.error = Some(format!("Upstream Error: {}", e));
323 }
324 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
325 tracing::error!("Failed to send flow update on upstream error: {}", e);
326 }
327 return Ok(create_error_response(
328 StatusCode::BAD_GATEWAY,
329 format!("Upstream Error: {}", e),
330 ));
331 }
332 Err(_) => {
333 circuit_breaker.record_failure(&upstream_host).await;
334 tracing::error!("Upstream request timed out");
335 flow.resilience_trace = Some(ResilienceTrace {
340 upstream_errors: vec!["Upstream Request Timed Out".to_string()],
341 timeout_type: Some("total".to_string()),
342 ..flow.resilience_trace.clone().unwrap_or_default()
343 });
344 if let Layer::Http(http) = &mut flow.layer {
345 http.error = Some("Upstream Request Timed Out".to_string());
346 }
347 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
348 tracing::error!("Failed to send flow update on upstream timeout: {}", e);
349 }
350 return Ok(create_error_response(
351 StatusCode::GATEWAY_TIMEOUT,
352 "Upstream Request Timed Out",
353 ));
354 }
355 };
356
357 let (mut res_parts, res_body) = res.into_parts();
359
360 apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
362
363 update_flow_with_response_headers(
364 &mut flow,
365 res_parts.status,
366 res_parts.version,
367 &res_parts.headers,
368 );
369
370 let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
371 if let Layer::Http(http) = &mut flow.layer
372 && let Some(response) = &mut http.response
373 {
374 response.timing.time_to_first_byte = Some(ttfbs_ms);
375 }
376
377 match interceptor.on_response_headers(&mut flow).await {
378 InterceptionResult::Continue => {}
379 InterceptionResult::Drop => {
380 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
381 tracing::error!("Failed to send flow update on response drop: {}", e);
382 }
383 return Ok(create_error_response(
384 StatusCode::FORBIDDEN,
385 "Response dropped by policy",
386 ));
387 }
388 InterceptionResult::MockResponse(resp) => {
389 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
390 tracing::error!("Failed to send flow update on response mock: {}", e);
391 }
392 return Ok(mock_to_response(resp));
393 }
394 InterceptionResult::ModifiedResponse(resp) => {
395 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
396 tracing::error!("Failed to send flow update on response modification: {}", e);
397 }
398 return Ok(mock_to_response(resp));
399 }
400 _ => {}
401 }
402
403 let res_body: HttpBody = res_body
405 .map_frame(|f| f.map_data(|d| d))
406 .map_err(|e| e.into())
407 .boxed();
408
409 let res_headers = if let Layer::Http(http) = &flow.layer {
411 http.response
412 .as_ref()
413 .map(|r| r.headers.clone())
414 .unwrap_or_default()
415 } else {
416 vec![]
417 };
418
419 let tap_res_body = TapBody::new(
420 res_body,
421 flow.id.to_string(),
422 on_flow.clone(),
423 Direction::ServerToClient,
424 policy.max_body_size,
425 res_headers,
426 );
427 let mut current_res_body = tap_res_body.boxed();
428
429 match interceptor.on_response(&mut flow, current_res_body).await {
430 Ok(ResponseAction::Continue(new_body)) => {
431 current_res_body = new_body;
432 }
433 Ok(ResponseAction::Drop) => {
434 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
435 tracing::error!("Failed to send flow update on response body drop: {}", e);
436 }
437 return Ok(create_error_response(
438 StatusCode::FORBIDDEN,
439 "Response dropped by interceptor",
440 ));
441 }
442 Ok(ResponseAction::ModifiedResponse(res)) => {
443 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
444 tracing::error!(
445 "Failed to send flow update on response body modification: {}",
446 e
447 );
448 }
449 let (parts, body) = res.into_parts();
450 return Ok(Response::from_parts(parts, body));
451 }
452 Err(e) => {
453 tracing::error!("Interceptor error on_response: {}", e);
454 return Ok(create_error_response(
455 StatusCode::INTERNAL_SERVER_ERROR,
456 format!("Interceptor Error: {}", e),
457 ));
458 }
459 }
460
461 if let Some(bps_str) = flow.meta.get("throttle_bytes_per_sec")
463 && let Ok(bps) = bps_str.parse::<u64>()
464 && bps > 0
465 {
466 current_res_body = crate::proxy::throttle::ThrottleBody::new(current_res_body, bps).boxed();
467 }
468
469 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
470 tracing::error!("Failed to send final flow update: {}", e);
471 }
472
473 if let Layer::Http(http) = &mut flow.layer
475 && let Some(response) = &mut http.response
476 {
477 response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
478 }
479
480 Ok(Response::from_parts(res_parts, current_res_body))
481}
482
483pub(crate) fn apply_quic_downgrade(
484 parts: &mut hyper::http::response::Parts,
485 flow: &mut relay_core_api::flow::Flow,
486 policy: &ProxyPolicy,
487) {
488 use relay_core_api::policy::QuicMode;
489 if policy.quic_mode == QuicMode::Downgrade {
490 if parts.headers.remove("Alt-Svc").is_some() {
491 flow.tags.push("quic-downgraded".to_string());
492 }
493 if policy.quic_downgrade_clear_cache {
494 parts.headers.insert(
495 "Clear-Site-Data",
496 hyper::header::HeaderValue::from_static("\"cache\""),
497 );
498 }
499 }
500}
501
502#[cfg(test)]
503mod http_tests;