1use crate::capture::loop_detection::LoopDetector;
2use crate::interceptor::HttpBody;
3use crate::proxy::body_codec::process_body;
4use chrono::Utc;
5use cookie::Cookie as CookieCrate;
6use data_encoding::BASE64;
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper::header::{HeaderName, HeaderValue};
10use hyper::{Request, Response, StatusCode};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use relay_core_api::flow::{
15 BodyData, Cookie, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
16 TransportProtocol, WebSocketLayer,
17};
18use relay_core_api::policy::ProxyPolicy;
19use std::net::SocketAddr;
20use url::Url;
21use uuid::Uuid;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27 pub method: String,
28 pub url_str: String,
29 pub version: String,
30 pub headers: Vec<(String, String)>,
31 pub query: Vec<(String, String)>,
32 pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36 let method = req.method().to_string();
37 let mut url_str = req.uri().to_string();
38
39 if Url::parse(&url_str).is_err()
41 && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok())
42 {
43 let scheme = if is_mitm { "https" } else { "http" };
44 let new_url = format!("{}://{}{}", scheme, host, url_str);
45 if Url::parse(&new_url).is_ok() {
46 url_str = new_url;
47 }
48 }
49
50 let version = format!("{:?}", req.version());
51
52 let headers: Vec<(String, String)> = req
53 .headers()
54 .iter()
55 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
56 .collect();
57
58 let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
59 parsed_url.query_pairs().into_owned().collect()
60 } else {
61 vec![]
62 };
63
64 let mut cookies = Vec::new();
65 if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
66 && let Ok(cookie_str) = cookie_header.to_str()
67 {
68 for c in CookieCrate::split_parse(cookie_str).flatten() {
69 cookies.push(Cookie {
70 name: c.name().to_string(),
71 value: c.value().to_string(),
72 path: None,
73 domain: None,
74 expires: None,
75 http_only: None,
76 secure: None,
77 });
78 }
79 }
80
81 RequestMeta {
82 method,
83 url_str,
84 version,
85 headers,
86 query,
87 cookies,
88 }
89}
90
91pub fn is_hop_by_hop(name: &str) -> bool {
92 name.eq_ignore_ascii_case("connection")
93 || name.eq_ignore_ascii_case("keep-alive")
94 || name.eq_ignore_ascii_case("proxy-authenticate")
95 || name.eq_ignore_ascii_case("proxy-authorization")
96 || name.eq_ignore_ascii_case("te")
97 || name.eq_ignore_ascii_case("trailers")
98 || name.eq_ignore_ascii_case("transfer-encoding")
99 || name.eq_ignore_ascii_case("upgrade")
100 || name.eq_ignore_ascii_case("content-length")
101}
102
103pub fn create_initial_flow(
104 meta: RequestMeta,
105 req_body: Option<BodyData>,
106 client_addr: SocketAddr,
107 is_mitm: bool,
108 is_websocket: bool,
109) -> Flow {
110 let flow_id = Uuid::new_v4();
111 let start_time = Utc::now();
112
113 let network_info = NetworkInfo {
114 client_ip: client_addr.ip().to_string(),
115 client_port: client_addr.port(),
116 server_ip: "0.0.0.0".to_string(), server_port: 0, protocol: TransportProtocol::TCP,
119 tls: is_mitm,
120 tls_version: None,
121 sni: None,
122 };
123
124 let http_request = HttpRequest {
125 method: meta.method,
126 url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
127 version: meta.version,
128 headers: meta.headers,
129 cookies: meta.cookies,
130 query: meta.query,
131 body: req_body,
132 };
133
134 let mut flow = if is_websocket {
135 Flow {
136 id: flow_id,
137 start_time,
138 end_time: None,
139 network: network_info,
140 layer: Layer::WebSocket(WebSocketLayer {
141 handshake_request: http_request,
142 handshake_response: HttpResponse {
143 status: 0,
144 status_text: "".to_string(),
145 version: "".to_string(),
146 headers: vec![],
147 cookies: vec![],
148 body: None,
149 timing: relay_core_api::flow::ResponseTiming {
150 time_to_first_byte: None,
151 time_to_last_byte: None,
152 connect_time_ms: None,
153 ssl_time_ms: None,
154 },
155 },
156 messages: vec![],
157 closed: false,
158 }),
159 tags: vec!["websocket".to_string()],
160 meta: std::collections::HashMap::new(),
161 }
162 } else {
163 Flow {
164 id: flow_id,
165 start_time,
166 end_time: None,
167 network: network_info,
168 layer: Layer::Http(HttpLayer {
169 request: http_request,
170 response: None,
171 error: None,
172 }),
173 tags: vec!["proxy".to_string()],
174 meta: std::collections::HashMap::new(),
175 }
176 };
177
178 if is_mitm {
179 flow.tags.push("mitm".to_string());
180 }
181
182 flow
183}
184
185pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
186 Response::builder()
187 .status(status)
188 .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
189 .unwrap_or_else(|_| {
190 Response::new(
191 Full::new(Bytes::from("Internal Error"))
192 .map_err(|e| e.into())
193 .boxed(),
194 )
195 })
196}
197
198pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
199 let mut builder =
200 Response::builder().status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
201
202 for (k, v) in mock.headers {
203 if let (Ok(name), Ok(val)) = (
204 HeaderName::from_bytes(k.as_bytes()),
205 HeaderValue::from_str(&v),
206 ) {
207 builder = builder.header(name, val);
208 }
209 }
210
211 let body = if let Some(b) = mock.body {
212 Bytes::from(b.content)
213 } else {
214 Bytes::new()
215 };
216
217 builder
218 .body(Full::new(body).map_err(|e| e.into()).boxed())
219 .unwrap_or_else(|_| {
220 create_error_response(
221 StatusCode::INTERNAL_SERVER_ERROR,
222 "Failed to build mock response",
223 )
224 })
225}
226
227#[allow(clippy::result_large_err)]
228pub fn build_forward_request(
229 flow: &mut Flow,
230 body: HttpBody,
231 target_addr: Option<SocketAddr>,
232 policy: &ProxyPolicy,
233 loop_detector: &LoopDetector,
234) -> Result<Request<HttpBody>, Response<HttpBody>> {
235 let current_req = if let Layer::Http(http) = &flow.layer {
236 &http.request
237 } else {
238 return Err(create_error_response(
239 StatusCode::INTERNAL_SERVER_ERROR,
240 "Invalid Flow Layer State",
241 ));
242 };
243
244 let mut forward_req_builder = Request::builder().method(current_req.method.as_str());
245
246 let mut target_url = current_req.url.clone();
248
249 if policy.transparent_enabled
251 && let Some(addr) = target_addr
252 {
253 flow.tags.push("transparent".to_string());
254
255 flow.network.server_ip = addr.ip().to_string();
257 flow.network.server_port = addr.port();
258
259 if loop_detector.would_loop(addr) {
261 if let Layer::Http(http) = &mut flow.layer {
262 http.error = Some("Loop Detected".to_string());
263 }
264 return Err(create_error_response(
265 StatusCode::LOOP_DETECTED,
266 "Loop Detected",
267 ));
268 }
269
270 if target_url.set_ip_host(addr.ip()).is_ok() {
272 target_url.set_port(Some(addr.port())).ok();
273 }
274
275 if flow.network.tls && target_url.scheme() == "http" {
277 target_url.set_scheme("https").ok();
278 }
279 }
280
281 forward_req_builder = forward_req_builder.uri(target_url.as_str());
282
283 for (k, v) in ¤t_req.headers {
284 if is_hop_by_hop(k) {
286 continue;
287 }
288
289 if let (Ok(name), Ok(val)) = (
290 HeaderName::from_bytes(k.as_bytes()),
291 HeaderValue::from_str(v),
292 ) {
293 forward_req_builder = forward_req_builder.header(name, val);
294 }
295 }
296
297 match forward_req_builder.body(body) {
298 Ok(req) => Ok(req),
299 Err(e) => Err(create_error_response(
300 StatusCode::INTERNAL_SERVER_ERROR,
301 format!("Failed to build forward request: {}", e),
302 )),
303 }
304}
305
306pub fn update_flow_with_response_headers(
307 flow: &mut Flow,
308 status: StatusCode,
309 version: hyper::Version,
310 headers: &hyper::HeaderMap,
311) {
312 let mut response_cookies = Vec::new();
313 for (k, v) in headers.iter() {
314 if k == hyper::header::SET_COOKIE
315 && let Ok(v_str) = v.to_str()
316 && let Ok(c) = CookieCrate::parse(v_str)
317 {
318 response_cookies.push(Cookie {
319 name: c.name().to_string(),
320 value: c.value().to_string(),
321 path: c.path().map(|s| s.to_string()),
322 domain: c.domain().map(|s| s.to_string()),
323 expires: c.expires().map(|e| format!("{:?}", e)),
324 http_only: c.http_only(),
325 secure: c.secure(),
326 });
327 }
328 }
329
330 let resp_headers_vec: Vec<(String, String)> = headers
331 .iter()
332 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
333 .collect();
334
335 let http_response = HttpResponse {
336 status: status.as_u16(),
337 status_text: status.to_string(),
338 version: format!("{:?}", version),
339 headers: resp_headers_vec,
340 cookies: response_cookies,
341 body: None,
342 timing: relay_core_api::flow::ResponseTiming {
343 time_to_first_byte: None,
344 time_to_last_byte: None,
345 connect_time_ms: None,
346 ssl_time_ms: None,
347 },
348 };
349
350 match &mut flow.layer {
351 Layer::Http(http) => {
352 http.response = Some(http_response);
353 }
354 Layer::WebSocket(ws) => {
355 ws.handshake_response = http_response;
356 }
357 _ => {}
358 }
359}
360
361pub fn update_flow_with_response_body(flow: &mut Flow, body_bytes: Bytes) {
362 let headers = match &flow.layer {
363 Layer::Http(http) => http
364 .response
365 .as_ref()
366 .map(|r| r.headers.clone())
367 .unwrap_or_default(),
368 Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
369 _ => Vec::new(),
370 };
371
372 let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
373
374 let body_data = BodyData {
375 encoding: resp_encoding,
376 content: resp_content,
377 size: body_bytes.len() as u64,
378 };
379
380 match &mut flow.layer {
381 Layer::Http(http) => {
382 if let Some(resp) = &mut http.response {
383 resp.body = Some(body_data);
384 }
385 }
386 Layer::WebSocket(ws) => {
387 ws.handshake_response.body = Some(body_data);
388 }
389 _ => {}
390 }
391}
392
393pub fn update_flow_with_response(
394 flow: &mut Flow,
395 status: StatusCode,
396 version: hyper::Version,
397 headers: &hyper::HeaderMap,
398 body_bytes: Bytes,
399) {
400 update_flow_with_response_headers(flow, status, version, headers);
401 update_flow_with_response_body(flow, body_bytes);
402}
403
404pub fn build_client_response_from_flow(
405 flow: &Flow,
406 default_version: hyper::Version,
407 strict_mode: bool,
408) -> Result<Response<Full<Bytes>>, String> {
409 if let Layer::Http(http) = &flow.layer {
410 if let Some(response) = &http.response {
411 let status = match StatusCode::from_u16(response.status) {
412 Ok(s) => s,
413 Err(_) => {
414 if strict_mode {
415 return Err(format!("Invalid status code: {}", response.status));
416 }
417 StatusCode::OK
418 }
419 };
420
421 let mut builder = Response::builder().status(status).version(default_version); for (k, v) in &response.headers {
424 if k.eq_ignore_ascii_case("content-length")
426 || k.eq_ignore_ascii_case("transfer-encoding")
427 || k.eq_ignore_ascii_case("connection")
428 {
429 continue;
430 }
431
432 if let (Ok(name), Ok(val)) = (
433 HeaderName::from_bytes(k.as_bytes()),
434 HeaderValue::from_str(v),
435 ) {
436 builder = builder.header(name, val);
437 } else if strict_mode {
438 return Err(format!("Invalid header: {}: {}", k, v));
439 }
440 }
441
442 let body_bytes = if let Some(b) = &response.body {
443 if b.encoding == "base64" {
444 match BASE64.decode(b.content.as_bytes()) {
445 Ok(bytes) => Bytes::from(bytes),
446 Err(_e) => {
447 Bytes::from(b.content.clone())
449 }
450 }
451 } else {
452 Bytes::from(b.content.clone())
453 }
454 } else {
455 Bytes::new()
456 };
457
458 builder
459 .body(Full::new(body_bytes))
460 .map_err(|e| format!("Failed to build response: {}", e))
461 } else {
462 Err("No response in flow".to_string())
463 }
464 } else {
465 Err("Not HTTP layer".to_string())
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::{build_client_response_from_flow, parse_request_meta};
472 use chrono::Utc;
473 use http_body_util::BodyExt;
474 use hyper::{Request, StatusCode, Version};
475 use relay_core_api::flow::{
476 Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
477 TransportProtocol,
478 };
479 use std::collections::HashMap;
480 use url::Url;
481 use uuid::Uuid;
482
483 fn sample_flow_with_response(status: u16) -> Flow {
484 Flow {
485 id: Uuid::new_v4(),
486 start_time: Utc::now(),
487 end_time: None,
488 network: NetworkInfo {
489 client_ip: "127.0.0.1".to_string(),
490 client_port: 12345,
491 server_ip: "1.1.1.1".to_string(),
492 server_port: 80,
493 protocol: TransportProtocol::TCP,
494 tls: false,
495 tls_version: None,
496 sni: None,
497 },
498 layer: Layer::Http(HttpLayer {
499 request: HttpRequest {
500 method: "GET".to_string(),
501 url: Url::parse("http://example.com/a").expect("url"),
502 version: "HTTP/1.1".to_string(),
503 headers: vec![],
504 cookies: vec![],
505 query: vec![],
506 body: None,
507 },
508 response: Some(HttpResponse {
509 status,
510 status_text: "X".to_string(),
511 version: "HTTP/2.0".to_string(),
512 headers: vec![
513 ("X-Test".to_string(), "1".to_string()),
514 ("content-length".to_string(), "999".to_string()),
515 ("connection".to_string(), "keep-alive".to_string()),
516 ],
517 cookies: vec![],
518 body: None,
519 timing: ResponseTiming {
520 time_to_first_byte: None,
521 time_to_last_byte: None,
522 connect_time_ms: None,
523 ssl_time_ms: None,
524 },
525 }),
526 error: None,
527 }),
528 tags: vec![],
529 meta: HashMap::new(),
530 }
531 }
532
533 #[test]
534 fn test_parse_request_meta_relative_uri_uses_host_http() {
535 let req = Request::builder()
536 .uri("/api/v1?q=1")
537 .header("Host", "example.com:8080")
538 .body(())
539 .expect("request");
540 let meta = parse_request_meta(&req, false);
541 assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
542 assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
543 }
544
545 #[test]
546 fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
547 let req = Request::builder()
548 .uri("/secure")
549 .header("Host", "secure.example.com")
550 .body(())
551 .expect("request");
552 let meta = parse_request_meta(&req, true);
553 assert_eq!(meta.url_str, "https://secure.example.com/secure");
554 }
555
556 #[test]
557 fn test_build_client_response_from_flow_uses_default_version_currently() {
558 let flow = sample_flow_with_response(201);
559 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
560 .expect("response should build");
561 assert_eq!(resp.version(), Version::HTTP_11);
562 assert_eq!(resp.status(), StatusCode::CREATED);
563 assert_eq!(
564 resp.headers().get("x-test").and_then(|v| v.to_str().ok()),
565 Some("1")
566 );
567 assert!(
568 resp.headers().get("content-length").is_none(),
569 "content-length should be stripped from forwarded mock response"
570 );
571 assert!(resp.headers().get("connection").is_none());
572 }
573
574 #[test]
575 fn test_build_client_response_from_flow_invalid_status_strict_fails() {
576 let flow = sample_flow_with_response(1000);
577 let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
578 .expect_err("strict mode should reject invalid status");
579 assert!(err.contains("Invalid status code"));
580 }
581
582 #[tokio::test]
583 async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
584 let mut flow = sample_flow_with_response(1000);
585 if let Layer::Http(http) = &mut flow.layer {
586 if let Some(res) = &mut http.response {
587 res.body = Some(relay_core_api::flow::BodyData {
588 encoding: "utf-8".to_string(),
589 content: "hello".to_string(),
590 size: 5,
591 });
592 }
593 }
594
595 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
596 .expect("non-strict should fallback");
597 assert_eq!(resp.status(), StatusCode::OK);
598 let body = resp
599 .into_body()
600 .collect()
601 .await
602 .expect("collect body")
603 .to_bytes();
604 assert_eq!(body.as_ref(), b"hello");
605 }
606}