1use std::net::SocketAddr;
2use hyper::{Response, Request, StatusCode};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use hyper_util::client::legacy::Client;
7use hyper_util::client::legacy::connect::HttpConnector;
8use hyper_rustls::HttpsConnector;
9use relay_core_api::flow::{
10 BodyData, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, TransportProtocol,
11 WebSocketLayer, Cookie,
12};
13use relay_core_api::policy::ProxyPolicy;
14use crate::capture::loop_detection::LoopDetector;
15use uuid::Uuid;
16use chrono::Utc;
17use url::Url;
18use cookie::Cookie as CookieCrate;
19use data_encoding::BASE64;
20use crate::proxy::body_codec::process_body;
21use crate::interceptor::HttpBody;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27 pub method: String,
28 pub url_str: String,
29 pub version: String,
30 pub headers: Vec<(String, String)>,
31 pub query: Vec<(String, String)>,
32 pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36 let method = req.method().to_string();
37 let mut url_str = req.uri().to_string();
38
39 if Url::parse(&url_str).is_err()
41 && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok()) {
42 let scheme = if is_mitm { "https" } else { "http" };
43 let new_url = format!("{}://{}{}", scheme, host, url_str);
44 if Url::parse(&new_url).is_ok() {
45 url_str = new_url;
46 }
47 }
48
49 let version = format!("{:?}", req.version());
50
51 let headers: Vec<(String, String)> = req.headers()
52 .iter()
53 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
54 .collect();
55
56 let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
57 parsed_url.query_pairs().into_owned().collect()
58 } else {
59 vec![]
60 };
61
62 let mut cookies = Vec::new();
63 if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
64 && let Ok(cookie_str) = cookie_header.to_str() {
65 for c in CookieCrate::split_parse(cookie_str).flatten() {
66 cookies.push(Cookie {
67 name: c.name().to_string(),
68 value: c.value().to_string(),
69 path: None,
70 domain: None,
71 expires: None,
72 http_only: None,
73 secure: None,
74 });
75 }
76 }
77
78 RequestMeta {
79 method,
80 url_str,
81 version,
82 headers,
83 query,
84 cookies,
85 }
86}
87
88pub fn is_hop_by_hop(name: &str) -> bool {
89 name.eq_ignore_ascii_case("connection")
90 || name.eq_ignore_ascii_case("keep-alive")
91 || name.eq_ignore_ascii_case("proxy-authenticate")
92 || name.eq_ignore_ascii_case("proxy-authorization")
93 || name.eq_ignore_ascii_case("te")
94 || name.eq_ignore_ascii_case("trailers")
95 || name.eq_ignore_ascii_case("transfer-encoding")
96 || name.eq_ignore_ascii_case("upgrade")
97 || name.eq_ignore_ascii_case("content-length")
98}
99
100pub fn create_initial_flow(
101 meta: RequestMeta,
102 req_body: Option<BodyData>,
103 client_addr: SocketAddr,
104 is_mitm: bool,
105 is_websocket: bool,
106) -> Flow {
107 let flow_id = Uuid::new_v4();
108 let start_time = Utc::now();
109
110 let network_info = NetworkInfo {
111 client_ip: client_addr.ip().to_string(),
112 client_port: client_addr.port(),
113 server_ip: "0.0.0.0".to_string(), server_port: 0, protocol: TransportProtocol::TCP,
116 tls: is_mitm,
117 tls_version: None,
118 sni: None,
119 };
120
121 let http_request = HttpRequest {
122 method: meta.method,
123 url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
124 version: meta.version,
125 headers: meta.headers,
126 cookies: meta.cookies,
127 query: meta.query,
128 body: req_body,
129 };
130
131 let mut flow = if is_websocket {
132 Flow {
133 id: flow_id,
134 start_time,
135 end_time: None,
136 network: network_info,
137 layer: Layer::WebSocket(WebSocketLayer {
138 handshake_request: http_request,
139 handshake_response: HttpResponse {
140 status: 0,
141 status_text: "".to_string(),
142 version: "".to_string(),
143 headers: vec![],
144 cookies: vec![],
145 body: None,
146 timing: relay_core_api::flow::ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
147 connect_time_ms: None,
148 ssl_time_ms: None,
149 },
150 },
151 messages: vec![],
152 closed: false,
153 }),
154 tags: vec!["websocket".to_string()],
155 meta: std::collections::HashMap::new(),
156 }
157 } else {
158 Flow {
159 id: flow_id,
160 start_time,
161 end_time: None,
162 network: network_info,
163 layer: Layer::Http(HttpLayer {
164 request: http_request,
165 response: None,
166 error: None,
167 }),
168 tags: vec!["proxy".to_string()],
169 meta: std::collections::HashMap::new(),
170 }
171 };
172
173 if is_mitm {
174 flow.tags.push("mitm".to_string());
175 }
176
177 flow
178}
179
180pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
181 Response::builder()
182 .status(status)
183 .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
184 .unwrap_or_else(|_| Response::new(Full::new(Bytes::from("Internal Error")).map_err(|e| e.into()).boxed()))
185}
186
187pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
188 let mut builder = Response::builder()
189 .status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
190
191 for (k, v) in mock.headers {
192 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(&v)) {
193 builder = builder.header(name, val);
194 }
195 }
196
197 let body = if let Some(b) = mock.body {
198 Bytes::from(b.content)
199 } else {
200 Bytes::new()
201 };
202
203 builder.body(Full::new(body).map_err(|e| e.into()).boxed()).unwrap_or_else(|_| create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to build mock response"))
204}
205
206#[allow(clippy::result_large_err)]
207pub fn build_forward_request(
208 flow: &mut Flow,
209 body: HttpBody,
210 target_addr: Option<SocketAddr>,
211 policy: &ProxyPolicy,
212 loop_detector: &LoopDetector,
213) -> Result<Request<HttpBody>, Response<HttpBody>> {
214 let current_req = if let Layer::Http(http) = &flow.layer {
215 &http.request
216 } else {
217 return Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
218 };
219
220 let mut forward_req_builder = Request::builder()
221 .method(current_req.method.as_str());
222
223 let mut target_url = current_req.url.clone();
225
226 if policy.transparent_enabled
228 && let Some(addr) = target_addr {
229 flow.tags.push("transparent".to_string());
230
231 flow.network.server_ip = addr.ip().to_string();
233 flow.network.server_port = addr.port();
234
235 if loop_detector.would_loop(addr) {
237 if let Layer::Http(http) = &mut flow.layer {
238 http.error = Some("Loop Detected".to_string());
239 }
240 return Err(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
241 }
242
243 if target_url.set_ip_host(addr.ip()).is_ok() {
245 target_url.set_port(Some(addr.port())).ok();
246 }
247
248 if flow.network.tls && target_url.scheme() == "http" {
250 target_url.set_scheme("https").ok();
251 }
252 }
253
254 forward_req_builder = forward_req_builder.uri(target_url.as_str());
255
256 for (k, v) in ¤t_req.headers {
257 if is_hop_by_hop(k) {
259 continue;
260 }
261
262 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
263 forward_req_builder = forward_req_builder.header(name, val);
264 }
265 }
266
267 match forward_req_builder.body(body) {
268 Ok(req) => Ok(req),
269 Err(e) => Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
270 }
271}
272
273pub fn update_flow_with_response_headers(
274 flow: &mut Flow,
275 status: StatusCode,
276 version: hyper::Version,
277 headers: &hyper::HeaderMap,
278) {
279 let mut response_cookies = Vec::new();
280 for (k, v) in headers.iter() {
281 if k == hyper::header::SET_COOKIE
282 && let Ok(v_str) = v.to_str()
283 && let Ok(c) = CookieCrate::parse(v_str) {
284 response_cookies.push(Cookie {
285 name: c.name().to_string(),
286 value: c.value().to_string(),
287 path: c.path().map(|s| s.to_string()),
288 domain: c.domain().map(|s| s.to_string()),
289 expires: c.expires().map(|e| format!("{:?}", e)),
290 http_only: c.http_only(),
291 secure: c.secure(),
292 });
293 }
294 }
295
296 let resp_headers_vec: Vec<(String, String)> = headers.iter()
297 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
298 .collect();
299
300 let http_response = HttpResponse {
301 status: status.as_u16(),
302 status_text: status.to_string(),
303 version: format!("{:?}", version),
304 headers: resp_headers_vec,
305 cookies: response_cookies,
306 body: None,
307 timing: relay_core_api::flow::ResponseTiming {
308 time_to_first_byte: None,
309 time_to_last_byte: None,
310 connect_time_ms: None,
311 ssl_time_ms: None,
312 },
313 };
314
315 match &mut flow.layer {
316 Layer::Http(http) => {
317 http.response = Some(http_response);
318 },
319 Layer::WebSocket(ws) => {
320 ws.handshake_response = http_response;
321 },
322 _ => {}
323 }
324}
325
326pub fn update_flow_with_response_body(
327 flow: &mut Flow,
328 body_bytes: Bytes,
329) {
330 let headers = match &flow.layer {
331 Layer::Http(http) => http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default(),
332 Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
333 _ => Vec::new(),
334 };
335
336 let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
337
338 let body_data = BodyData {
339 encoding: resp_encoding,
340 content: resp_content,
341 size: body_bytes.len() as u64,
342 };
343
344 match &mut flow.layer {
345 Layer::Http(http) => {
346 if let Some(resp) = &mut http.response {
347 resp.body = Some(body_data);
348 }
349 },
350 Layer::WebSocket(ws) => {
351 ws.handshake_response.body = Some(body_data);
352 },
353 _ => {}
354 }
355}
356
357pub fn update_flow_with_response(
358 flow: &mut Flow,
359 status: StatusCode,
360 version: hyper::Version,
361 headers: &hyper::HeaderMap,
362 body_bytes: Bytes,
363) {
364 update_flow_with_response_headers(flow, status, version, headers);
365 update_flow_with_response_body(flow, body_bytes);
366}
367
368pub fn build_client_response_from_flow(flow: &Flow, default_version: hyper::Version, strict_mode: bool) -> Result<Response<Full<Bytes>>, String> {
369 if let Layer::Http(http) = &flow.layer {
370 if let Some(response) = &http.response {
371 let status = match StatusCode::from_u16(response.status) {
372 Ok(s) => s,
373 Err(_) => {
374 if strict_mode {
375 return Err(format!("Invalid status code: {}", response.status));
376 }
377 StatusCode::OK
378 }
379 };
380
381 let mut builder = Response::builder()
382 .status(status)
383 .version(default_version); for (k, v) in &response.headers {
386 if k.eq_ignore_ascii_case("content-length")
388 || k.eq_ignore_ascii_case("transfer-encoding")
389 || k.eq_ignore_ascii_case("connection") {
390 continue;
391 }
392
393 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
394 builder = builder.header(name, val);
395 } else if strict_mode {
396 return Err(format!("Invalid header: {}: {}", k, v));
397 }
398 }
399
400 let body_bytes = if let Some(b) = &response.body {
401 if b.encoding == "base64" {
402 match BASE64.decode(b.content.as_bytes()) {
403 Ok(bytes) => Bytes::from(bytes),
404 Err(_e) => {
405 Bytes::from(b.content.clone())
407 }
408 }
409 } else {
410 Bytes::from(b.content.clone())
411 }
412 } else {
413 Bytes::new()
414 };
415
416 builder.body(Full::new(body_bytes))
417 .map_err(|e| format!("Failed to build response: {}", e))
418 } else {
419 Err("No response in flow".to_string())
420 }
421 } else {
422 Err("Not HTTP layer".to_string())
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::{build_client_response_from_flow, parse_request_meta};
429 use chrono::Utc;
430 use http_body_util::BodyExt;
431 use hyper::{Request, StatusCode, Version};
432 use relay_core_api::flow::{
433 Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
434 TransportProtocol,
435 };
436 use std::collections::HashMap;
437 use url::Url;
438 use uuid::Uuid;
439
440 fn sample_flow_with_response(status: u16) -> Flow {
441 Flow {
442 id: Uuid::new_v4(),
443 start_time: Utc::now(),
444 end_time: None,
445 network: NetworkInfo {
446 client_ip: "127.0.0.1".to_string(),
447 client_port: 12345,
448 server_ip: "1.1.1.1".to_string(),
449 server_port: 80,
450 protocol: TransportProtocol::TCP,
451 tls: false,
452 tls_version: None,
453 sni: None,
454 },
455 layer: Layer::Http(HttpLayer {
456 request: HttpRequest {
457 method: "GET".to_string(),
458 url: Url::parse("http://example.com/a").expect("url"),
459 version: "HTTP/1.1".to_string(),
460 headers: vec![],
461 cookies: vec![],
462 query: vec![],
463 body: None,
464 },
465 response: Some(HttpResponse {
466 status,
467 status_text: "X".to_string(),
468 version: "HTTP/2.0".to_string(),
469 headers: vec![
470 ("X-Test".to_string(), "1".to_string()),
471 ("content-length".to_string(), "999".to_string()),
472 ("connection".to_string(), "keep-alive".to_string()),
473 ],
474 cookies: vec![],
475 body: None,
476 timing: ResponseTiming {
477 time_to_first_byte: None,
478 time_to_last_byte: None,
479 connect_time_ms: None,
480 ssl_time_ms: None,
481 },
482 }),
483 error: None,
484 }),
485 tags: vec![],
486 meta: HashMap::new(),
487 }
488 }
489
490 #[test]
491 fn test_parse_request_meta_relative_uri_uses_host_http() {
492 let req = Request::builder()
493 .uri("/api/v1?q=1")
494 .header("Host", "example.com:8080")
495 .body(())
496 .expect("request");
497 let meta = parse_request_meta(&req, false);
498 assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
499 assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
500 }
501
502 #[test]
503 fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
504 let req = Request::builder()
505 .uri("/secure")
506 .header("Host", "secure.example.com")
507 .body(())
508 .expect("request");
509 let meta = parse_request_meta(&req, true);
510 assert_eq!(meta.url_str, "https://secure.example.com/secure");
511 }
512
513 #[test]
514 fn test_build_client_response_from_flow_uses_default_version_currently() {
515 let flow = sample_flow_with_response(201);
516 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
517 .expect("response should build");
518 assert_eq!(resp.version(), Version::HTTP_11);
519 assert_eq!(resp.status(), StatusCode::CREATED);
520 assert_eq!(resp.headers().get("x-test").and_then(|v| v.to_str().ok()), Some("1"));
521 assert!(
522 resp.headers().get("content-length").is_none(),
523 "content-length should be stripped from forwarded mock response"
524 );
525 assert!(resp.headers().get("connection").is_none());
526 }
527
528 #[test]
529 fn test_build_client_response_from_flow_invalid_status_strict_fails() {
530 let flow = sample_flow_with_response(1000);
531 let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
532 .expect_err("strict mode should reject invalid status");
533 assert!(err.contains("Invalid status code"));
534 }
535
536 #[tokio::test]
537 async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
538 let mut flow = sample_flow_with_response(1000);
539 if let Layer::Http(http) = &mut flow.layer {
540 if let Some(res) = &mut http.response {
541 res.body = Some(relay_core_api::flow::BodyData {
542 encoding: "utf-8".to_string(),
543 content: "hello".to_string(),
544 size: 5,
545 });
546 }
547 }
548
549 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
550 .expect("non-strict should fallback");
551 assert_eq!(resp.status(), StatusCode::OK);
552 let body = resp
553 .into_body()
554 .collect()
555 .await
556 .expect("collect body")
557 .to_bytes();
558 assert_eq!(body.as_ref(), b"hello");
559 }
560}