1use crate::capture::loop_detection::LoopDetector;
2use crate::interceptor::HttpBody;
3use crate::proxy::body_codec::process_body;
4use chrono::Utc;
5use cookie::Cookie as CookieCrate;
6use data_encoding::BASE64;
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper::header::{HeaderName, HeaderValue};
10use hyper::{Request, Response, StatusCode};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use relay_core_api::flow::{
15 BodyData, Cookie, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
16 TransportProtocol, WebSocketLayer,
17};
18use relay_core_api::policy::ProxyPolicy;
19use std::net::SocketAddr;
20use url::Url;
21use uuid::Uuid;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27 pub method: String,
28 pub url_str: String,
29 pub version: String,
30 pub headers: Vec<(String, String)>,
31 pub query: Vec<(String, String)>,
32 pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36 let method = req.method().to_string();
37 let mut url_str = req.uri().to_string();
38
39 if Url::parse(&url_str).is_err()
41 && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok())
42 {
43 let scheme = if is_mitm { "https" } else { "http" };
44 let new_url = format!("{}://{}{}", scheme, host, url_str);
45 if Url::parse(&new_url).is_ok() {
46 url_str = new_url;
47 }
48 }
49
50 let version = format!("{:?}", req.version());
51
52 let headers: Vec<(String, String)> = req
53 .headers()
54 .iter()
55 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
56 .collect();
57
58 let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
59 parsed_url.query_pairs().into_owned().collect()
60 } else {
61 vec![]
62 };
63
64 let mut cookies = Vec::new();
65 if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
66 && let Ok(cookie_str) = cookie_header.to_str()
67 {
68 for c in CookieCrate::split_parse(cookie_str).flatten() {
69 cookies.push(Cookie {
70 name: c.name().to_string(),
71 value: c.value().to_string(),
72 path: None,
73 domain: None,
74 expires: None,
75 http_only: None,
76 secure: None,
77 });
78 }
79 }
80
81 RequestMeta {
82 method,
83 url_str,
84 version,
85 headers,
86 query,
87 cookies,
88 }
89}
90
91pub fn is_hop_by_hop(name: &str) -> bool {
92 name.eq_ignore_ascii_case("connection")
93 || name.eq_ignore_ascii_case("keep-alive")
94 || name.eq_ignore_ascii_case("proxy-authenticate")
95 || name.eq_ignore_ascii_case("proxy-authorization")
96 || name.eq_ignore_ascii_case("te")
97 || name.eq_ignore_ascii_case("trailers")
98 || name.eq_ignore_ascii_case("transfer-encoding")
99 || name.eq_ignore_ascii_case("upgrade")
100}
101
102pub fn create_initial_flow(
103 meta: RequestMeta,
104 req_body: Option<BodyData>,
105 client_addr: SocketAddr,
106 is_mitm: bool,
107 is_websocket: bool,
108) -> Flow {
109 let flow_id = Uuid::new_v4();
110 let start_time = Utc::now();
111
112 let network_info = NetworkInfo {
113 client_ip: client_addr.ip().to_string(),
114 client_port: client_addr.port(),
115 server_ip: "0.0.0.0".to_string(), server_port: 0, protocol: TransportProtocol::TCP,
118 tls: is_mitm,
119 tls_version: None,
120 sni: None,
121 };
122
123 let http_request = HttpRequest {
124 method: meta.method,
125 url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
126 version: meta.version,
127 headers: meta.headers,
128 cookies: meta.cookies,
129 query: meta.query,
130 body: req_body,
131 };
132
133 let mut flow = if is_websocket {
134 Flow {
135 id: flow_id,
136 start_time,
137 end_time: None,
138 network: network_info,
139 layer: Layer::WebSocket(WebSocketLayer {
140 handshake_request: http_request,
141 handshake_response: HttpResponse {
142 status: 0,
143 status_text: "".to_string(),
144 version: "".to_string(),
145 headers: vec![],
146 cookies: vec![],
147 body: None,
148 timing: relay_core_api::flow::ResponseTiming {
149 time_to_first_byte: None,
150 time_to_last_byte: None,
151 connect_time_ms: None,
152 ssl_time_ms: None,
153 },
154 },
155 messages: vec![],
156 closed: false,
157 }),
158 tags: vec!["websocket".to_string()],
159 meta: std::collections::HashMap::new(),
160 }
161 } else {
162 Flow {
163 id: flow_id,
164 start_time,
165 end_time: None,
166 network: network_info,
167 layer: Layer::Http(HttpLayer {
168 request: http_request,
169 response: None,
170 error: None,
171 }),
172 tags: vec!["proxy".to_string()],
173 meta: std::collections::HashMap::new(),
174 }
175 };
176
177 if is_mitm {
178 flow.tags.push("mitm".to_string());
179 }
180
181 flow
182}
183
184pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
185 Response::builder()
186 .status(status)
187 .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
188 .unwrap_or_else(|_| {
189 Response::new(
190 Full::new(Bytes::from("Internal Error"))
191 .map_err(|e| e.into())
192 .boxed(),
193 )
194 })
195}
196
197pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
198 let mut builder =
199 Response::builder().status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
200
201 for (k, v) in mock.headers {
202 if let (Ok(name), Ok(val)) = (
203 HeaderName::from_bytes(k.as_bytes()),
204 HeaderValue::from_str(&v),
205 ) {
206 builder = builder.header(name, val);
207 }
208 }
209
210 let body = if let Some(b) = mock.body {
211 Bytes::from(b.content)
212 } else {
213 Bytes::new()
214 };
215
216 builder
217 .body(Full::new(body).map_err(|e| e.into()).boxed())
218 .unwrap_or_else(|_| {
219 create_error_response(
220 StatusCode::INTERNAL_SERVER_ERROR,
221 "Failed to build mock response",
222 )
223 })
224}
225
226#[allow(clippy::result_large_err)]
227pub fn build_forward_request(
228 flow: &mut Flow,
229 body: HttpBody,
230 target_addr: Option<SocketAddr>,
231 policy: &ProxyPolicy,
232 loop_detector: &LoopDetector,
233) -> Result<Request<HttpBody>, Response<HttpBody>> {
234 let current_req = if let Layer::Http(http) = &flow.layer {
235 &http.request
236 } else {
237 return Err(create_error_response(
238 StatusCode::INTERNAL_SERVER_ERROR,
239 "Invalid Flow Layer State",
240 ));
241 };
242
243 let mut forward_req_builder = Request::builder().method(current_req.method.as_str());
244
245 let mut target_url = current_req.url.clone();
247
248 if policy.transparent_enabled
250 && let Some(addr) = target_addr
251 {
252 flow.tags.push("transparent".to_string());
253
254 flow.network.server_ip = addr.ip().to_string();
256 flow.network.server_port = addr.port();
257
258 if loop_detector.would_loop(addr) {
260 if let Layer::Http(http) = &mut flow.layer {
261 http.error = Some("Loop Detected".to_string());
262 }
263 return Err(create_error_response(
264 StatusCode::LOOP_DETECTED,
265 "Loop Detected",
266 ));
267 }
268
269 if target_url.set_ip_host(addr.ip()).is_ok() {
271 target_url.set_port(Some(addr.port())).ok();
272 }
273
274 if flow.network.tls && target_url.scheme() == "http" {
276 target_url.set_scheme("https").ok();
277 }
278 }
279
280 forward_req_builder = forward_req_builder.uri(target_url.as_str());
281
282 for (k, v) in ¤t_req.headers {
283 if is_hop_by_hop(k) {
285 continue;
286 }
287
288 if let (Ok(name), Ok(val)) = (
289 HeaderName::from_bytes(k.as_bytes()),
290 HeaderValue::from_str(v),
291 ) {
292 forward_req_builder = forward_req_builder.header(name, val);
293 }
294 }
295
296 match forward_req_builder.body(body) {
297 Ok(req) => Ok(req),
298 Err(e) => Err(create_error_response(
299 StatusCode::INTERNAL_SERVER_ERROR,
300 format!("Failed to build forward request: {}", e),
301 )),
302 }
303}
304
305pub fn update_flow_with_response_headers(
306 flow: &mut Flow,
307 status: StatusCode,
308 version: hyper::Version,
309 headers: &hyper::HeaderMap,
310) {
311 let mut response_cookies = Vec::new();
312 for (k, v) in headers.iter() {
313 if k == hyper::header::SET_COOKIE
314 && let Ok(v_str) = v.to_str()
315 && let Ok(c) = CookieCrate::parse(v_str)
316 {
317 response_cookies.push(Cookie {
318 name: c.name().to_string(),
319 value: c.value().to_string(),
320 path: c.path().map(|s| s.to_string()),
321 domain: c.domain().map(|s| s.to_string()),
322 expires: c.expires().map(|e| format!("{:?}", e)),
323 http_only: c.http_only(),
324 secure: c.secure(),
325 });
326 }
327 }
328
329 let resp_headers_vec: Vec<(String, String)> = headers
330 .iter()
331 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
332 .collect();
333
334 let http_response = HttpResponse {
335 status: status.as_u16(),
336 status_text: status.to_string(),
337 version: format!("{:?}", version),
338 headers: resp_headers_vec,
339 cookies: response_cookies,
340 body: None,
341 timing: relay_core_api::flow::ResponseTiming {
342 time_to_first_byte: None,
343 time_to_last_byte: None,
344 connect_time_ms: None,
345 ssl_time_ms: None,
346 },
347 };
348
349 match &mut flow.layer {
350 Layer::Http(http) => {
351 http.response = Some(http_response);
352 }
353 Layer::WebSocket(ws) => {
354 ws.handshake_response = http_response;
355 }
356 _ => {}
357 }
358}
359
360pub fn update_flow_with_response_body(flow: &mut Flow, body_bytes: Bytes) {
361 let headers = match &flow.layer {
362 Layer::Http(http) => http
363 .response
364 .as_ref()
365 .map(|r| r.headers.clone())
366 .unwrap_or_default(),
367 Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
368 _ => Vec::new(),
369 };
370
371 let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
372
373 let body_data = BodyData {
374 encoding: resp_encoding,
375 content: resp_content,
376 size: body_bytes.len() as u64,
377 };
378
379 match &mut flow.layer {
380 Layer::Http(http) => {
381 if let Some(resp) = &mut http.response {
382 resp.body = Some(body_data);
383 }
384 }
385 Layer::WebSocket(ws) => {
386 ws.handshake_response.body = Some(body_data);
387 }
388 _ => {}
389 }
390}
391
392pub fn update_flow_with_response(
393 flow: &mut Flow,
394 status: StatusCode,
395 version: hyper::Version,
396 headers: &hyper::HeaderMap,
397 body_bytes: Bytes,
398) {
399 update_flow_with_response_headers(flow, status, version, headers);
400 update_flow_with_response_body(flow, body_bytes);
401}
402
403pub fn build_client_response_from_flow(
404 flow: &Flow,
405 default_version: hyper::Version,
406 strict_mode: bool,
407) -> Result<Response<Full<Bytes>>, String> {
408 if let Layer::Http(http) = &flow.layer {
409 if let Some(response) = &http.response {
410 let status = match StatusCode::from_u16(response.status) {
411 Ok(s) => s,
412 Err(_) => {
413 if strict_mode {
414 return Err(format!("Invalid status code: {}", response.status));
415 }
416 StatusCode::OK
417 }
418 };
419
420 let mut builder = Response::builder().status(status).version(default_version); for (k, v) in &response.headers {
423 if k.eq_ignore_ascii_case("content-length")
425 || k.eq_ignore_ascii_case("transfer-encoding")
426 || k.eq_ignore_ascii_case("connection")
427 {
428 continue;
429 }
430
431 if let (Ok(name), Ok(val)) = (
432 HeaderName::from_bytes(k.as_bytes()),
433 HeaderValue::from_str(v),
434 ) {
435 builder = builder.header(name, val);
436 } else if strict_mode {
437 return Err(format!("Invalid header: {}: {}", k, v));
438 }
439 }
440
441 let body_bytes = if let Some(b) = &response.body {
442 if b.encoding == "base64" {
443 match BASE64.decode(b.content.as_bytes()) {
444 Ok(bytes) => Bytes::from(bytes),
445 Err(_e) => {
446 Bytes::from(b.content.clone())
448 }
449 }
450 } else {
451 Bytes::from(b.content.clone())
452 }
453 } else {
454 Bytes::new()
455 };
456
457 builder
458 .body(Full::new(body_bytes))
459 .map_err(|e| format!("Failed to build response: {}", e))
460 } else {
461 Err("No response in flow".to_string())
462 }
463 } else {
464 Err("Not HTTP layer".to_string())
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::{build_client_response_from_flow, parse_request_meta};
471 use chrono::Utc;
472 use http_body_util::BodyExt;
473 use hyper::{Request, StatusCode, Version};
474 use relay_core_api::flow::{
475 Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
476 TransportProtocol,
477 };
478 use std::collections::HashMap;
479 use url::Url;
480 use uuid::Uuid;
481
482 fn sample_flow_with_response(status: u16) -> Flow {
483 Flow {
484 id: Uuid::new_v4(),
485 start_time: Utc::now(),
486 end_time: None,
487 network: NetworkInfo {
488 client_ip: "127.0.0.1".to_string(),
489 client_port: 12345,
490 server_ip: "1.1.1.1".to_string(),
491 server_port: 80,
492 protocol: TransportProtocol::TCP,
493 tls: false,
494 tls_version: None,
495 sni: None,
496 },
497 layer: Layer::Http(HttpLayer {
498 request: HttpRequest {
499 method: "GET".to_string(),
500 url: Url::parse("http://example.com/a").expect("url"),
501 version: "HTTP/1.1".to_string(),
502 headers: vec![],
503 cookies: vec![],
504 query: vec![],
505 body: None,
506 },
507 response: Some(HttpResponse {
508 status,
509 status_text: "X".to_string(),
510 version: "HTTP/2.0".to_string(),
511 headers: vec![
512 ("X-Test".to_string(), "1".to_string()),
513 ("content-length".to_string(), "999".to_string()),
514 ("connection".to_string(), "keep-alive".to_string()),
515 ],
516 cookies: vec![],
517 body: None,
518 timing: ResponseTiming {
519 time_to_first_byte: None,
520 time_to_last_byte: None,
521 connect_time_ms: None,
522 ssl_time_ms: None,
523 },
524 }),
525 error: None,
526 }),
527 tags: vec![],
528 meta: HashMap::new(),
529 }
530 }
531
532 #[test]
533 fn test_parse_request_meta_relative_uri_uses_host_http() {
534 let req = Request::builder()
535 .uri("/api/v1?q=1")
536 .header("Host", "example.com:8080")
537 .body(())
538 .expect("request");
539 let meta = parse_request_meta(&req, false);
540 assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
541 assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
542 }
543
544 #[test]
545 fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
546 let req = Request::builder()
547 .uri("/secure")
548 .header("Host", "secure.example.com")
549 .body(())
550 .expect("request");
551 let meta = parse_request_meta(&req, true);
552 assert_eq!(meta.url_str, "https://secure.example.com/secure");
553 }
554
555 #[test]
556 fn test_build_client_response_from_flow_uses_default_version_currently() {
557 let flow = sample_flow_with_response(201);
558 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
559 .expect("response should build");
560 assert_eq!(resp.version(), Version::HTTP_11);
561 assert_eq!(resp.status(), StatusCode::CREATED);
562 assert_eq!(
563 resp.headers().get("x-test").and_then(|v| v.to_str().ok()),
564 Some("1")
565 );
566 assert!(
567 resp.headers().get("content-length").is_none(),
568 "content-length should be stripped from forwarded mock response"
569 );
570 assert!(resp.headers().get("connection").is_none());
571 }
572
573 #[test]
574 fn test_build_client_response_from_flow_invalid_status_strict_fails() {
575 let flow = sample_flow_with_response(1000);
576 let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
577 .expect_err("strict mode should reject invalid status");
578 assert!(err.contains("Invalid status code"));
579 }
580
581 #[tokio::test]
582 async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
583 let mut flow = sample_flow_with_response(1000);
584 if let Layer::Http(http) = &mut flow.layer {
585 if let Some(res) = &mut http.response {
586 res.body = Some(relay_core_api::flow::BodyData {
587 encoding: "utf-8".to_string(),
588 content: "hello".to_string(),
589 size: 5,
590 });
591 }
592 }
593
594 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
595 .expect("non-strict should fallback");
596 assert_eq!(resp.status(), StatusCode::OK);
597 let body = resp
598 .into_body()
599 .collect()
600 .await
601 .expect("collect body")
602 .to_bytes();
603 assert_eq!(body.as_ref(), b"hello");
604 }
605}