1use crate::capture::loop_detection::LoopDetector;
2use crate::interceptor::HttpBody;
3use crate::proxy::body_codec::process_body;
4use chrono::Utc;
5use cookie::Cookie as CookieCrate;
6use data_encoding::BASE64;
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper::header::{HeaderName, HeaderValue};
10use hyper::{Request, Response, StatusCode};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use relay_core_api::flow::{
15 BodyData, Cookie, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
16 TransportProtocol, WebSocketLayer,
17};
18use relay_core_api::policy::ProxyPolicy;
19use std::net::SocketAddr;
20use url::Url;
21use uuid::Uuid;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27 pub method: String,
28 pub url_str: String,
29 pub version: String,
30 pub headers: Vec<(String, String)>,
31 pub query: Vec<(String, String)>,
32 pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36 let method = req.method().to_string();
37 let mut url_str = req.uri().to_string();
38
39 if Url::parse(&url_str).is_err()
41 && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok())
42 {
43 let scheme = if is_mitm { "https" } else { "http" };
44 let new_url = format!("{}://{}{}", scheme, host, url_str);
45 if Url::parse(&new_url).is_ok() {
46 url_str = new_url;
47 }
48 }
49
50 let version = format!("{:?}", req.version());
51
52 let headers: Vec<(String, String)> = req
53 .headers()
54 .iter()
55 .map(|(k, v)| {
56 (
57 k.to_string(),
58 String::from_utf8_lossy(v.as_bytes()).to_string(),
59 )
60 })
61 .collect();
62
63 let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
64 parsed_url.query_pairs().into_owned().collect()
65 } else {
66 vec![]
67 };
68
69 let mut cookies = Vec::new();
70 if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
71 && let Ok(cookie_str) = cookie_header.to_str()
72 {
73 for c in CookieCrate::split_parse(cookie_str).flatten() {
74 cookies.push(Cookie {
75 name: c.name().to_string(),
76 value: c.value().to_string(),
77 path: None,
78 domain: None,
79 expires: None,
80 http_only: None,
81 secure: None,
82 });
83 }
84 }
85
86 RequestMeta {
87 method,
88 url_str,
89 version,
90 headers,
91 query,
92 cookies,
93 }
94}
95
96pub fn is_hop_by_hop(name: &str) -> bool {
97 name.eq_ignore_ascii_case("connection")
98 || name.eq_ignore_ascii_case("keep-alive")
99 || name.eq_ignore_ascii_case("proxy-authenticate")
100 || name.eq_ignore_ascii_case("proxy-authorization")
101 || name.eq_ignore_ascii_case("te")
102 || name.eq_ignore_ascii_case("trailers")
103 || name.eq_ignore_ascii_case("transfer-encoding")
104 || name.eq_ignore_ascii_case("upgrade")
105}
106
107pub fn create_initial_flow(
108 meta: RequestMeta,
109 req_body: Option<BodyData>,
110 client_addr: SocketAddr,
111 is_mitm: bool,
112 is_websocket: bool,
113) -> Flow {
114 let flow_id = Uuid::new_v4();
115 let start_time = Utc::now();
116
117 let network_info = NetworkInfo {
118 client_ip: client_addr.ip().to_string(),
119 client_port: client_addr.port(),
120 server_ip: "0.0.0.0".to_string(), server_port: 0, protocol: TransportProtocol::TCP,
123 tls: is_mitm,
124 tls_version: None,
125 sni: None,
126 };
127
128 let http_request = HttpRequest {
129 method: meta.method,
130 url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
131 version: meta.version,
132 headers: meta.headers,
133 cookies: meta.cookies,
134 query: meta.query,
135 body: req_body,
136 };
137
138 let mut flow = if is_websocket {
139 Flow {
140 id: flow_id,
141 start_time,
142 end_time: None,
143 network: network_info,
144 layer: Layer::WebSocket(WebSocketLayer {
145 handshake_request: http_request,
146 handshake_response: HttpResponse {
147 status: 0,
148 status_text: "".to_string(),
149 version: "".to_string(),
150 headers: vec![],
151 cookies: vec![],
152 body: None,
153 timing: relay_core_api::flow::ResponseTiming {
154 time_to_first_byte: None,
155 time_to_last_byte: None,
156 connect_time_ms: None,
157 ssl_time_ms: None,
158 },
159 },
160 messages: vec![],
161 closed: false,
162 }),
163 tags: vec!["websocket".to_string()],
164 meta: std::collections::HashMap::new(),
165 resilience_trace: None,
166 rule_variables: std::collections::HashMap::new(),
167 matched_rules: vec![],
168 }
169 } else {
170 Flow {
171 id: flow_id,
172 start_time,
173 end_time: None,
174 network: network_info,
175 layer: Layer::Http(HttpLayer {
176 request: http_request,
177 response: None,
178 error: None,
179 }),
180 tags: vec!["proxy".to_string()],
181 meta: std::collections::HashMap::new(),
182 resilience_trace: None,
183 rule_variables: std::collections::HashMap::new(),
184 matched_rules: vec![],
185 }
186 };
187
188 if is_mitm {
189 flow.tags.push("mitm".to_string());
190 }
191
192 flow
193}
194
195pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
196 Response::builder()
197 .status(status)
198 .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
199 .unwrap_or_else(|_| {
200 Response::new(
201 Full::new(Bytes::from("Internal Error"))
202 .map_err(|e| e.into())
203 .boxed(),
204 )
205 })
206}
207
208pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
209 let mut builder =
210 Response::builder().status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
211
212 for (k, v) in mock.headers {
213 if let (Ok(name), Ok(val)) = (
214 HeaderName::from_bytes(k.as_bytes()),
215 HeaderValue::from_str(&v),
216 ) {
217 builder = builder.header(name, val);
218 }
219 }
220
221 let body = if let Some(b) = mock.body {
222 Bytes::from(b.content)
223 } else {
224 Bytes::new()
225 };
226
227 builder
228 .body(Full::new(body).map_err(|e| e.into()).boxed())
229 .unwrap_or_else(|_| {
230 create_error_response(
231 StatusCode::INTERNAL_SERVER_ERROR,
232 "Failed to build mock response",
233 )
234 })
235}
236
237#[allow(clippy::result_large_err)]
238pub fn build_forward_request(
239 flow: &mut Flow,
240 body: HttpBody,
241 target_addr: Option<SocketAddr>,
242 policy: &ProxyPolicy,
243 loop_detector: &LoopDetector,
244) -> Result<Request<HttpBody>, Response<HttpBody>> {
245 let current_req = if let Layer::Http(http) = &flow.layer {
246 &http.request
247 } else {
248 return Err(create_error_response(
249 StatusCode::INTERNAL_SERVER_ERROR,
250 "Invalid Flow Layer State",
251 ));
252 };
253
254 let mut forward_req_builder = Request::builder().method(current_req.method.as_str());
255
256 let mut target_url = current_req.url.clone();
258
259 if policy.transparent_enabled
261 && let Some(addr) = target_addr
262 {
263 flow.tags.push("transparent".to_string());
264
265 flow.network.server_ip = addr.ip().to_string();
267 flow.network.server_port = addr.port();
268
269 if loop_detector.would_loop(addr) {
271 if let Layer::Http(http) = &mut flow.layer {
272 http.error = Some("Loop Detected".to_string());
273 }
274 return Err(create_error_response(
275 StatusCode::LOOP_DETECTED,
276 "Loop Detected",
277 ));
278 }
279
280 if target_url.set_ip_host(addr.ip()).is_ok() {
282 target_url.set_port(Some(addr.port())).ok();
283 }
284
285 if flow.network.tls && target_url.scheme() == "http" {
287 target_url.set_scheme("https").ok();
288 }
289 }
290
291 forward_req_builder = forward_req_builder.uri(target_url.as_str());
292
293 for (k, v) in ¤t_req.headers {
294 if is_hop_by_hop(k) {
296 continue;
297 }
298
299 if let (Ok(name), Ok(val)) = (
300 HeaderName::from_bytes(k.as_bytes()),
301 HeaderValue::from_str(v),
302 ) {
303 forward_req_builder = forward_req_builder.header(name, val);
304 }
305 }
306
307 match forward_req_builder.body(body) {
308 Ok(req) => Ok(req),
309 Err(e) => Err(create_error_response(
310 StatusCode::INTERNAL_SERVER_ERROR,
311 format!("Failed to build forward request: {}", e),
312 )),
313 }
314}
315
316pub fn update_flow_with_response_headers(
317 flow: &mut Flow,
318 status: StatusCode,
319 version: hyper::Version,
320 headers: &hyper::HeaderMap,
321) {
322 let mut response_cookies = Vec::new();
323 for (k, v) in headers.iter() {
324 if k == hyper::header::SET_COOKIE
325 && let Ok(v_str) = v.to_str()
326 && let Ok(c) = CookieCrate::parse(v_str)
327 {
328 response_cookies.push(Cookie {
329 name: c.name().to_string(),
330 value: c.value().to_string(),
331 path: c.path().map(|s| s.to_string()),
332 domain: c.domain().map(|s| s.to_string()),
333 expires: c.expires().map(|e| format!("{:?}", e)),
334 http_only: c.http_only(),
335 secure: c.secure(),
336 });
337 }
338 }
339
340 let resp_headers_vec: Vec<(String, String)> = headers
341 .iter()
342 .map(|(k, v)| {
343 (
344 k.to_string(),
345 String::from_utf8_lossy(v.as_bytes()).to_string(),
346 )
347 })
348 .collect();
349
350 let http_response = HttpResponse {
351 status: status.as_u16(),
352 status_text: status.to_string(),
353 version: format!("{:?}", version),
354 headers: resp_headers_vec,
355 cookies: response_cookies,
356 body: None,
357 timing: relay_core_api::flow::ResponseTiming {
358 time_to_first_byte: None,
359 time_to_last_byte: None,
360 connect_time_ms: None,
361 ssl_time_ms: None,
362 },
363 };
364
365 match &mut flow.layer {
366 Layer::Http(http) => {
367 http.response = Some(http_response);
368 }
369 Layer::WebSocket(ws) => {
370 ws.handshake_response = http_response;
371 }
372 _ => {}
373 }
374}
375
376pub fn update_flow_with_response_body(flow: &mut Flow, body_bytes: Bytes) {
377 let headers = match &flow.layer {
378 Layer::Http(http) => http
379 .response
380 .as_ref()
381 .map(|r| r.headers.clone())
382 .unwrap_or_default(),
383 Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
384 _ => Vec::new(),
385 };
386
387 let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
388
389 let body_data = BodyData {
390 encoding: resp_encoding,
391 content: resp_content,
392 size: body_bytes.len() as u64,
393 };
394
395 match &mut flow.layer {
396 Layer::Http(http) => {
397 if let Some(resp) = &mut http.response {
398 resp.body = Some(body_data);
399 }
400 }
401 Layer::WebSocket(ws) => {
402 ws.handshake_response.body = Some(body_data);
403 }
404 _ => {}
405 }
406}
407
408pub fn update_flow_with_response(
409 flow: &mut Flow,
410 status: StatusCode,
411 version: hyper::Version,
412 headers: &hyper::HeaderMap,
413 body_bytes: Bytes,
414) {
415 update_flow_with_response_headers(flow, status, version, headers);
416 update_flow_with_response_body(flow, body_bytes);
417}
418
419pub fn build_client_response_from_flow(
420 flow: &Flow,
421 default_version: hyper::Version,
422 strict_mode: bool,
423) -> Result<Response<Full<Bytes>>, String> {
424 if let Layer::Http(http) = &flow.layer {
425 if let Some(response) = &http.response {
426 let status = match StatusCode::from_u16(response.status) {
427 Ok(s) => s,
428 Err(_) => {
429 if strict_mode {
430 crate::metrics::inc_proxy_invalid_status();
431 return Err(format!("Invalid status code: {}", response.status));
432 }
433 StatusCode::OK
434 }
435 };
436
437 let mut builder = Response::builder().status(status).version(default_version); for (k, v) in &response.headers {
440 if k.eq_ignore_ascii_case("content-length")
442 || k.eq_ignore_ascii_case("transfer-encoding")
443 || k.eq_ignore_ascii_case("connection")
444 {
445 continue;
446 }
447
448 if let (Ok(name), Ok(val)) = (
449 HeaderName::from_bytes(k.as_bytes()),
450 HeaderValue::from_str(v),
451 ) {
452 builder = builder.header(name, val);
453 } else if strict_mode {
454 return Err(format!("Invalid header: {}: {}", k, v));
455 }
456 }
457
458 let body_bytes = if let Some(b) = &response.body {
459 if b.encoding == "base64" {
460 match BASE64.decode(b.content.as_bytes()) {
461 Ok(bytes) => Bytes::from(bytes),
462 Err(_e) => {
463 Bytes::from(b.content.clone())
465 }
466 }
467 } else {
468 Bytes::from(b.content.clone())
469 }
470 } else {
471 Bytes::new()
472 };
473
474 builder
475 .body(Full::new(body_bytes))
476 .map_err(|e| format!("Failed to build response: {}", e))
477 } else {
478 Err("No response in flow".to_string())
479 }
480 } else {
481 Err("Not HTTP layer".to_string())
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::{build_client_response_from_flow, parse_request_meta};
488 use chrono::Utc;
489 use http_body_util::BodyExt;
490 use hyper::{Request, StatusCode, Version};
491 use relay_core_api::flow::{
492 Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
493 TransportProtocol,
494 };
495 use std::collections::HashMap;
496 use url::Url;
497 use uuid::Uuid;
498
499 fn sample_flow_with_response(status: u16) -> Flow {
500 Flow {
501 id: Uuid::new_v4(),
502 start_time: Utc::now(),
503 end_time: None,
504 network: NetworkInfo {
505 client_ip: "127.0.0.1".to_string(),
506 client_port: 12345,
507 server_ip: "1.1.1.1".to_string(),
508 server_port: 80,
509 protocol: TransportProtocol::TCP,
510 tls: false,
511 tls_version: None,
512 sni: None,
513 },
514 layer: Layer::Http(HttpLayer {
515 request: HttpRequest {
516 method: "GET".to_string(),
517 url: Url::parse("http://example.com/a").expect("url"),
518 version: "HTTP/1.1".to_string(),
519 headers: vec![],
520 cookies: vec![],
521 query: vec![],
522 body: None,
523 },
524 response: Some(HttpResponse {
525 status,
526 status_text: "X".to_string(),
527 version: "HTTP/2.0".to_string(),
528 headers: vec![
529 ("X-Test".to_string(), "1".to_string()),
530 ("content-length".to_string(), "999".to_string()),
531 ("connection".to_string(), "keep-alive".to_string()),
532 ],
533 cookies: vec![],
534 body: None,
535 timing: ResponseTiming {
536 time_to_first_byte: None,
537 time_to_last_byte: None,
538 connect_time_ms: None,
539 ssl_time_ms: None,
540 },
541 }),
542 error: None,
543 }),
544 tags: vec![],
545 meta: HashMap::new(),
546 resilience_trace: None,
547 rule_variables: HashMap::new(),
548 matched_rules: vec![],
549 }
550 }
551
552 #[test]
553 fn test_parse_request_meta_relative_uri_uses_host_http() {
554 let req = Request::builder()
555 .uri("/api/v1?q=1")
556 .header("Host", "example.com:8080")
557 .body(())
558 .expect("request");
559 let meta = parse_request_meta(&req, false);
560 assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
561 assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
562 }
563
564 #[test]
565 fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
566 let req = Request::builder()
567 .uri("/secure")
568 .header("Host", "secure.example.com")
569 .body(())
570 .expect("request");
571 let meta = parse_request_meta(&req, true);
572 assert_eq!(meta.url_str, "https://secure.example.com/secure");
573 }
574
575 #[test]
576 fn test_build_client_response_from_flow_uses_default_version_currently() {
577 let flow = sample_flow_with_response(201);
578 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
579 .expect("response should build");
580 assert_eq!(resp.version(), Version::HTTP_11);
581 assert_eq!(resp.status(), StatusCode::CREATED);
582 assert_eq!(
583 resp.headers().get("x-test").and_then(|v| v.to_str().ok()),
584 Some("1")
585 );
586 assert!(
587 resp.headers().get("content-length").is_none(),
588 "content-length should be stripped from forwarded mock response"
589 );
590 assert!(resp.headers().get("connection").is_none());
591 }
592
593 #[test]
594 fn test_build_client_response_from_flow_invalid_status_strict_fails() {
595 let flow = sample_flow_with_response(1000);
596 let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
597 .expect_err("strict mode should reject invalid status");
598 assert!(err.contains("Invalid status code"));
599 }
600
601 #[tokio::test]
602 async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
603 let mut flow = sample_flow_with_response(1000);
604 if let Layer::Http(http) = &mut flow.layer {
605 if let Some(res) = &mut http.response {
606 res.body = Some(relay_core_api::flow::BodyData {
607 encoding: "utf-8".to_string(),
608 content: "hello".to_string(),
609 size: 5,
610 });
611 }
612 }
613
614 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
615 .expect("non-strict should fallback");
616 assert_eq!(resp.status(), StatusCode::OK);
617 let body = resp
618 .into_body()
619 .collect()
620 .await
621 .expect("collect body")
622 .to_bytes();
623 assert_eq!(body.as_ref(), b"hello");
624 }
625}