1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::{mpsc::Sender, watch};
5
6use hyper::{Request, Response, Method, StatusCode};
7use hyper::body::{Bytes, Incoming, Body};
8use http_body_util::{BodyExt, Full};
9use relay_core_api::flow::{FlowUpdate, Layer, Direction};
10use relay_core_api::policy::ProxyPolicy;
11use crate::interceptor::{Interceptor, InterceptionResult, RequestAction, ResponseAction, HttpBody, BoxError};
12use crate::tls::CertificateAuthority;
13use crate::proxy::http_utils::{
14 create_initial_flow,
15 mock_to_response, parse_request_meta, create_error_response, HttpsClient,
16 build_forward_request, update_flow_with_response_headers,
17};
18use crate::proxy::tunnel;
19use crate::proxy::websocket::handle_websocket_handshake;
20use crate::capture::loop_detection::LoopDetector;
21use crate::proxy::tap::TapBody;
22
23#[allow(clippy::too_many_arguments)]
25pub async fn handle_request(
26 req: Request<Incoming>,
27 client_addr: SocketAddr,
28 on_flow: Sender<FlowUpdate>,
29 ca: Arc<CertificateAuthority>,
30 client: Arc<HttpsClient>,
31 interceptor: Arc<dyn Interceptor>,
32 target_addr: Option<SocketAddr>,
33 policy_rx: watch::Receiver<ProxyPolicy>,
34 loop_detector: Arc<LoopDetector>,
35) -> Result<Response<HttpBody>, Infallible>
36{
37 if req.method() == Method::CONNECT {
38 let host = if let Some(authority) = req.uri().authority() {
41 authority.to_string()
42 } else {
43 req.headers().get("Host")
45 .and_then(|v| v.to_str().ok())
46 .map(|s| s.to_string())
47 .unwrap_or_else(|| "unknown".to_string())
48 };
49
50 if host == "unknown" {
51 return Ok(create_error_response(StatusCode::BAD_REQUEST, "CONNECT must have authority"));
52 }
53
54 let loop_detector = loop_detector.clone();
55 let policy_rx = policy_rx.clone();
56
57 tokio::task::spawn(async move {
58 match hyper::upgrade::on(req).await {
59 Ok(upgraded) => {
60 if let Err(e) = tunnel::handle_tunnel(
61 upgraded,
62 host,
63 client_addr,
64 ca,
65 on_flow,
66 client,
67 interceptor,
68 policy_rx,
69 target_addr,
70 loop_detector,
71 ).await {
72 tracing::error!("Tunnel error: {}", e);
73 }
74 },
75 Err(e) => tracing::error!("Upgrade error: {}", e),
76 }
77 });
78 return Ok(Response::new(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()));
79 }
80
81 handle_http_request(req, client_addr, on_flow, client, interceptor, false, policy_rx, target_addr, loop_detector).await
83}
84
85#[allow(clippy::too_many_arguments)]
86pub(crate) async fn handle_http_request<B>(
87 req: Request<B>,
88 client_addr: SocketAddr,
89 on_flow: Sender<FlowUpdate>,
90 client: Arc<HttpsClient>,
91 interceptor: Arc<dyn Interceptor>,
92 is_mitm: bool,
93 policy_rx: watch::Receiver<ProxyPolicy>,
94 target_addr: Option<SocketAddr>,
95 loop_detector: Arc<LoopDetector>,
96) -> Result<Response<HttpBody>, Infallible>
97where
98 B: Body + Send + Sync + Unpin + 'static,
99 B::Data: Send + Into<Bytes>,
100 B::Error: Into<BoxError>,
101{
102 let policy = policy_rx.borrow().clone();
103
104 if let Some(cl) = req.headers().get(hyper::header::CONTENT_LENGTH)
106 && let Ok(len) = cl.to_str().unwrap_or_default().parse::<usize>()
107 && len > policy.max_body_size {
108 return Ok(create_error_response(StatusCode::PAYLOAD_TOO_LARGE, "Request body too large"));
109 }
110
111 let meta = parse_request_meta(&req, is_mitm);
113
114 let mut flow = create_initial_flow(meta, None, client_addr, is_mitm, false);
116
117 if hyper_tungstenite::is_upgrade_request(&req) {
119 return handle_websocket_handshake(req, client_addr, on_flow, client, interceptor, is_mitm, policy_rx, target_addr, loop_detector).await;
120 }
121
122 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
123 tracing::error!("Failed to send flow update: {}", e);
124 }
125
126 match interceptor.on_request_headers(&mut flow).await {
128 InterceptionResult::Continue => {},
129 InterceptionResult::Drop => {
130 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
131 tracing::error!("Failed to send flow update on drop: {}", e);
132 }
133 return Ok(create_error_response(StatusCode::FORBIDDEN, "Request dropped by policy"));
134 },
135 InterceptionResult::MockResponse(resp) => {
136 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
137 tracing::error!("Failed to send flow update on mock: {}", e);
138 }
139 return Ok(mock_to_response(resp));
140 },
141 InterceptionResult::ModifiedRequest(_) => {},
142 InterceptionResult::ModifiedResponse(res) => {
143 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
144 tracing::error!("Failed to send flow update on modified response: {}", e);
145 }
146 return Ok(mock_to_response(res));
147 },
148 _ => {}
149 }
150
151 let (_, body) = req.into_parts();
153 let body: HttpBody = body.map_frame(|f| f.map_data(|d| d.into())).map_err(|e| e.into()).boxed();
154
155 let req_headers = if let Layer::Http(http) = &flow.layer {
157 http.request.headers.clone()
158 } else {
159 vec![]
160 };
161
162 let tap_body = TapBody::new(
163 body,
164 flow.id.to_string(),
165 on_flow.clone(),
166 Direction::ClientToServer,
167 policy.max_body_size,
168 req_headers,
169 );
170 let mut current_body = tap_body.boxed();
171
172 match interceptor.on_request(&mut flow, current_body).await {
173 Ok(RequestAction::Continue(new_body)) => {
174 current_body = new_body;
175 },
176 Ok(RequestAction::Drop) => {
177 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
178 tracing::error!("Failed to send flow update on request drop: {}", e);
179 }
180 return Ok(create_error_response(StatusCode::FORBIDDEN, "Request dropped by interceptor"));
181 },
182 Ok(RequestAction::MockResponse(res)) => {
183 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
184 tracing::error!("Failed to send flow update on request mock: {}", e);
185 }
186 let (parts, body) = res.into_parts();
187 return Ok(Response::from_parts(parts, body));
188 },
189 Err(e) => {
190 tracing::error!("Interceptor error on_request: {}", e);
191 return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
192 }
193 }
194
195 let forward_req = match build_forward_request(&mut flow, current_body, target_addr, &policy, &loop_detector) {
196 Ok(req) => req,
197 Err(res) => return Ok(res),
198 };
199
200 let upstream_start = std::time::Instant::now();
202 let res = match tokio::time::timeout(std::time::Duration::from_millis(policy.request_timeout_ms), client.request(forward_req)).await {
203 Ok(Ok(res)) => res,
204 Ok(Err(e)) => {
205 tracing::error!("Upstream request failed: {}", e);
206 if let Layer::Http(http) = &mut flow.layer {
207 http.error = Some(format!("Upstream Error: {}", e));
208 }
209 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
210 tracing::error!("Failed to send flow update on upstream error: {}", e);
211 }
212 return Ok(create_error_response(StatusCode::BAD_GATEWAY, format!("Upstream Error: {}", e)));
213 },
214 Err(_) => {
215 tracing::error!("Upstream request timed out");
216 if let Layer::Http(http) = &mut flow.layer {
217 http.error = Some("Upstream Request Timed Out".to_string());
218 }
219 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
220 tracing::error!("Failed to send flow update on upstream timeout: {}", e);
221 }
222 return Ok(create_error_response(StatusCode::GATEWAY_TIMEOUT, "Upstream Request Timed Out"));
223 }
224 };
225
226 let (mut res_parts, res_body) = res.into_parts();
228
229 apply_quic_downgrade(&mut res_parts, &mut flow, &policy);
231
232 update_flow_with_response_headers(&mut flow, res_parts.status, res_parts.version, &res_parts.headers);
233
234 let ttfbs_ms = upstream_start.elapsed().as_millis() as u64;
235 if let Layer::Http(http) = &mut flow.layer
236 && let Some(response) = &mut http.response {
237 response.timing.time_to_first_byte = Some(ttfbs_ms);
238 }
239
240 match interceptor.on_response_headers(&mut flow).await {
241 InterceptionResult::Continue => {},
242 InterceptionResult::Drop => {
243 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
244 tracing::error!("Failed to send flow update on response drop: {}", e);
245 }
246 return Ok(create_error_response(StatusCode::FORBIDDEN, "Response dropped by policy"));
247 },
248 InterceptionResult::MockResponse(resp) => {
249 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
250 tracing::error!("Failed to send flow update on response mock: {}", e);
251 }
252 return Ok(mock_to_response(resp));
253 },
254 InterceptionResult::ModifiedResponse(resp) => {
255 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
256 tracing::error!("Failed to send flow update on response modification: {}", e);
257 }
258 return Ok(mock_to_response(resp));
259 },
260 _ => {}
261 }
262
263 let res_body: HttpBody = res_body.map_frame(|f| f.map_data(|d| d)).map_err(|e| e.into()).boxed();
265
266 let res_headers = if let Layer::Http(http) = &flow.layer {
268 http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default()
269 } else {
270 vec![]
271 };
272
273 let tap_res_body = TapBody::new(
274 res_body,
275 flow.id.to_string(),
276 on_flow.clone(),
277 Direction::ServerToClient,
278 policy.max_body_size,
279 res_headers,
280 );
281 let mut current_res_body = tap_res_body.boxed();
282
283 match interceptor.on_response(&mut flow, current_res_body).await {
284 Ok(ResponseAction::Continue(new_body)) => {
285 current_res_body = new_body;
286 },
287 Ok(ResponseAction::Drop) => {
288 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
289 tracing::error!("Failed to send flow update on response body drop: {}", e);
290 }
291 return Ok(create_error_response(StatusCode::FORBIDDEN, "Response dropped by interceptor"));
292 },
293 Ok(ResponseAction::ModifiedResponse(res)) => {
294 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
295 tracing::error!("Failed to send flow update on response body modification: {}", e);
296 }
297 let (parts, body) = res.into_parts();
298 return Ok(Response::from_parts(parts, body));
299 },
300 Err(e) => {
301 tracing::error!("Interceptor error on_response: {}", e);
302 return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
303 }
304 }
305
306 if let Err(e) = on_flow.send(FlowUpdate::Full(Box::new(flow.clone()))).await {
307 tracing::error!("Failed to send final flow update: {}", e);
308 }
309
310 if let Layer::Http(http) = &mut flow.layer
312 && let Some(response) = &mut http.response {
313 response.timing.time_to_last_byte = Some(upstream_start.elapsed().as_millis() as u64);
314 }
315
316 Ok(Response::from_parts(res_parts, current_res_body))
317}
318
319pub(crate) fn apply_quic_downgrade(parts: &mut hyper::http::response::Parts, flow: &mut relay_core_api::flow::Flow, policy: &ProxyPolicy) {
320 use relay_core_api::policy::QuicMode;
321 if policy.quic_mode == QuicMode::Downgrade {
322 if parts.headers.remove("Alt-Svc").is_some() {
323 flow.tags.push("quic-downgraded".to_string());
324 }
325 if policy.quic_downgrade_clear_cache {
326 parts.headers.insert("Clear-Site-Data", hyper::header::HeaderValue::from_static("\"cache\""));
327 }
328 }
329}
330
331#[cfg(test)]
332mod http_tests;