1use std::net::SocketAddr;
2use hyper::{Response, Request, StatusCode};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use hyper_util::client::legacy::Client;
7use hyper_util::client::legacy::connect::HttpConnector;
8use hyper_rustls::HttpsConnector;
9use relay_core_api::flow::{
10 BodyData, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, TransportProtocol,
11 WebSocketLayer, Cookie,
12};
13use relay_core_api::policy::ProxyPolicy;
14use crate::capture::loop_detection::LoopDetector;
15use uuid::Uuid;
16use chrono::Utc;
17use url::Url;
18use cookie::Cookie as CookieCrate;
19use data_encoding::BASE64;
20use crate::proxy::body_codec::process_body;
21use crate::interceptor::HttpBody;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27 pub method: String,
28 pub url_str: String,
29 pub version: String,
30 pub headers: Vec<(String, String)>,
31 pub query: Vec<(String, String)>,
32 pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36 let method = req.method().to_string();
37 let mut url_str = req.uri().to_string();
38
39 if Url::parse(&url_str).is_err() {
41 if let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok()) {
42 let scheme = if is_mitm { "https" } else { "http" };
43 let new_url = format!("{}://{}{}", scheme, host, url_str);
44 if Url::parse(&new_url).is_ok() {
45 url_str = new_url;
46 }
47 }
48 }
49
50 let version = format!("{:?}", req.version());
51
52 let headers: Vec<(String, String)> = req.headers()
53 .iter()
54 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
55 .collect();
56
57 let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
58 parsed_url.query_pairs().into_owned().collect()
59 } else {
60 vec![]
61 };
62
63 let mut cookies = Vec::new();
64 if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE) {
65 if let Ok(cookie_str) = cookie_header.to_str() {
66 for c in CookieCrate::split_parse(cookie_str).flatten() {
67 cookies.push(Cookie {
68 name: c.name().to_string(),
69 value: c.value().to_string(),
70 path: None,
71 domain: None,
72 expires: None,
73 http_only: None,
74 secure: None,
75 });
76 }
77 }
78 }
79
80 RequestMeta {
81 method,
82 url_str,
83 version,
84 headers,
85 query,
86 cookies,
87 }
88}
89
90pub fn is_hop_by_hop(name: &str) -> bool {
91 name.eq_ignore_ascii_case("connection")
92 || name.eq_ignore_ascii_case("keep-alive")
93 || name.eq_ignore_ascii_case("proxy-authenticate")
94 || name.eq_ignore_ascii_case("proxy-authorization")
95 || name.eq_ignore_ascii_case("te")
96 || name.eq_ignore_ascii_case("trailers")
97 || name.eq_ignore_ascii_case("transfer-encoding")
98 || name.eq_ignore_ascii_case("upgrade")
99 || name.eq_ignore_ascii_case("content-length")
100}
101
102pub fn create_initial_flow(
103 meta: RequestMeta,
104 req_body: Option<BodyData>,
105 client_addr: SocketAddr,
106 is_mitm: bool,
107 is_websocket: bool,
108) -> Flow {
109 let flow_id = Uuid::new_v4();
110 let start_time = Utc::now();
111
112 let network_info = NetworkInfo {
113 client_ip: client_addr.ip().to_string(),
114 client_port: client_addr.port(),
115 server_ip: "0.0.0.0".to_string(), server_port: 0, protocol: TransportProtocol::TCP,
118 tls: is_mitm,
119 tls_version: None,
120 sni: None,
121 };
122
123 let http_request = HttpRequest {
124 method: meta.method,
125 url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
126 version: meta.version,
127 headers: meta.headers,
128 cookies: meta.cookies,
129 query: meta.query,
130 body: req_body,
131 };
132
133 let mut flow = if is_websocket {
134 Flow {
135 id: flow_id,
136 start_time,
137 end_time: None,
138 network: network_info,
139 layer: Layer::WebSocket(WebSocketLayer {
140 handshake_request: http_request,
141 handshake_response: HttpResponse {
142 status: 0,
143 status_text: "".to_string(),
144 version: "".to_string(),
145 headers: vec![],
146 cookies: vec![],
147 body: None,
148 timing: relay_core_api::flow::ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
149 },
150 messages: vec![],
151 closed: false,
152 }),
153 tags: vec!["websocket".to_string()],
154 meta: std::collections::HashMap::new(),
155 }
156 } else {
157 Flow {
158 id: flow_id,
159 start_time,
160 end_time: None,
161 network: network_info,
162 layer: Layer::Http(HttpLayer {
163 request: http_request,
164 response: None,
165 error: None,
166 }),
167 tags: vec!["proxy".to_string()],
168 meta: std::collections::HashMap::new(),
169 }
170 };
171
172 if is_mitm {
173 flow.tags.push("mitm".to_string());
174 }
175
176 flow
177}
178
179pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
180 Response::builder()
181 .status(status)
182 .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
183 .unwrap_or_else(|_| Response::new(Full::new(Bytes::from("Internal Error")).map_err(|e| e.into()).boxed()))
184}
185
186pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
187 let mut builder = Response::builder()
188 .status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
189
190 for (k, v) in mock.headers {
191 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(&v)) {
192 builder = builder.header(name, val);
193 }
194 }
195
196 let body = if let Some(b) = mock.body {
197 Bytes::from(b.content)
198 } else {
199 Bytes::new()
200 };
201
202 builder.body(Full::new(body).map_err(|e| e.into()).boxed()).unwrap_or_else(|_| create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to build mock response"))
203}
204
205#[allow(clippy::result_large_err)]
206pub fn build_forward_request(
207 flow: &mut Flow,
208 body: HttpBody,
209 req_parts: &http::request::Parts,
210 target_addr: Option<SocketAddr>,
211 policy: &ProxyPolicy,
212 loop_detector: &LoopDetector,
213) -> Result<Request<HttpBody>, Response<HttpBody>> {
214 let current_req = if let Layer::Http(http) = &flow.layer {
215 &http.request
216 } else {
217 return Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
218 };
219
220 let mut forward_req_builder = Request::builder()
221 .method(current_req.method.as_str())
222 .version(req_parts.version);
223
224 let mut target_url = current_req.url.clone();
226
227 if policy.transparent_enabled {
229 if let Some(addr) = target_addr {
230 flow.tags.push("transparent".to_string());
231
232 flow.network.server_ip = addr.ip().to_string();
234 flow.network.server_port = addr.port();
235
236 if loop_detector.would_loop(addr) {
238 if let Layer::Http(http) = &mut flow.layer {
239 http.error = Some("Loop Detected".to_string());
240 }
241 return Err(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
242 }
243
244 if target_url.set_ip_host(addr.ip()).is_ok() {
246 target_url.set_port(Some(addr.port())).ok();
247 }
248
249 if flow.network.tls && target_url.scheme() == "http" {
251 target_url.set_scheme("https").ok();
252 }
253 }
254 }
255
256 forward_req_builder = forward_req_builder.uri(target_url.as_str());
257
258 for (k, v) in ¤t_req.headers {
259 if is_hop_by_hop(k) {
261 continue;
262 }
263
264 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
265 forward_req_builder = forward_req_builder.header(name, val);
266 }
267 }
268
269 match forward_req_builder.body(body) {
270 Ok(req) => Ok(req),
271 Err(e) => Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
272 }
273}
274
275pub fn update_flow_with_response_headers(
276 flow: &mut Flow,
277 status: StatusCode,
278 version: hyper::Version,
279 headers: &hyper::HeaderMap,
280) {
281 let mut response_cookies = Vec::new();
282 for (k, v) in headers.iter() {
283 if k == hyper::header::SET_COOKIE {
284 if let Ok(v_str) = v.to_str() {
285 if let Ok(c) = CookieCrate::parse(v_str) {
286 response_cookies.push(Cookie {
287 name: c.name().to_string(),
288 value: c.value().to_string(),
289 path: c.path().map(|s| s.to_string()),
290 domain: c.domain().map(|s| s.to_string()),
291 expires: c.expires().map(|e| format!("{:?}", e)),
292 http_only: c.http_only(),
293 secure: c.secure(),
294 });
295 }
296 }
297 }
298 }
299
300 let resp_headers_vec: Vec<(String, String)> = headers.iter()
301 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
302 .collect();
303
304 let http_response = HttpResponse {
305 status: status.as_u16(),
306 status_text: status.to_string(),
307 version: format!("{:?}", version),
308 headers: resp_headers_vec,
309 cookies: response_cookies,
310 body: None,
311 timing: relay_core_api::flow::ResponseTiming {
312 time_to_first_byte: None,
313 time_to_last_byte: None,
314 },
315 };
316
317 match &mut flow.layer {
318 Layer::Http(http) => {
319 http.response = Some(http_response);
320 },
321 Layer::WebSocket(ws) => {
322 ws.handshake_response = http_response;
323 },
324 _ => {}
325 }
326}
327
328pub fn update_flow_with_response_body(
329 flow: &mut Flow,
330 body_bytes: Bytes,
331) {
332 let headers = match &flow.layer {
333 Layer::Http(http) => http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default(),
334 Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
335 _ => Vec::new(),
336 };
337
338 let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
339
340 let body_data = BodyData {
341 encoding: resp_encoding,
342 content: resp_content,
343 size: body_bytes.len() as u64,
344 };
345
346 match &mut flow.layer {
347 Layer::Http(http) => {
348 if let Some(resp) = &mut http.response {
349 resp.body = Some(body_data);
350 }
351 },
352 Layer::WebSocket(ws) => {
353 ws.handshake_response.body = Some(body_data);
354 },
355 _ => {}
356 }
357}
358
359pub fn update_flow_with_response(
360 flow: &mut Flow,
361 status: StatusCode,
362 version: hyper::Version,
363 headers: &hyper::HeaderMap,
364 body_bytes: Bytes,
365) {
366 update_flow_with_response_headers(flow, status, version, headers);
367 update_flow_with_response_body(flow, body_bytes);
368}
369
370pub fn build_client_response_from_flow(flow: &Flow, default_version: hyper::Version, strict_mode: bool) -> Result<Response<Full<Bytes>>, String> {
371 if let Layer::Http(http) = &flow.layer {
372 if let Some(response) = &http.response {
373 let status = match StatusCode::from_u16(response.status) {
374 Ok(s) => s,
375 Err(_) => {
376 if strict_mode {
377 return Err(format!("Invalid status code: {}", response.status));
378 }
379 StatusCode::OK
380 }
381 };
382
383 let mut builder = Response::builder()
384 .status(status)
385 .version(default_version); for (k, v) in &response.headers {
388 if k.eq_ignore_ascii_case("content-length")
390 || k.eq_ignore_ascii_case("transfer-encoding")
391 || k.eq_ignore_ascii_case("connection") {
392 continue;
393 }
394
395 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
396 builder = builder.header(name, val);
397 } else if strict_mode {
398 return Err(format!("Invalid header: {}: {}", k, v));
399 }
400 }
401
402 let body_bytes = if let Some(b) = &response.body {
403 if b.encoding == "base64" {
404 match BASE64.decode(b.content.as_bytes()) {
405 Ok(bytes) => Bytes::from(bytes),
406 Err(_e) => {
407 Bytes::from(b.content.clone())
409 }
410 }
411 } else {
412 Bytes::from(b.content.clone())
413 }
414 } else {
415 Bytes::new()
416 };
417
418 builder.body(Full::new(body_bytes))
419 .map_err(|e| format!("Failed to build response: {}", e))
420 } else {
421 Err("No response in flow".to_string())
422 }
423 } else {
424 Err("Not HTTP layer".to_string())
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::{build_client_response_from_flow, parse_request_meta};
431 use chrono::Utc;
432 use http_body_util::BodyExt;
433 use hyper::{Request, StatusCode, Version};
434 use relay_core_api::flow::{
435 Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
436 TransportProtocol,
437 };
438 use std::collections::HashMap;
439 use url::Url;
440 use uuid::Uuid;
441
442 fn sample_flow_with_response(status: u16) -> Flow {
443 Flow {
444 id: Uuid::new_v4(),
445 start_time: Utc::now(),
446 end_time: None,
447 network: NetworkInfo {
448 client_ip: "127.0.0.1".to_string(),
449 client_port: 12345,
450 server_ip: "1.1.1.1".to_string(),
451 server_port: 80,
452 protocol: TransportProtocol::TCP,
453 tls: false,
454 tls_version: None,
455 sni: None,
456 },
457 layer: Layer::Http(HttpLayer {
458 request: HttpRequest {
459 method: "GET".to_string(),
460 url: Url::parse("http://example.com/a").expect("url"),
461 version: "HTTP/1.1".to_string(),
462 headers: vec![],
463 cookies: vec![],
464 query: vec![],
465 body: None,
466 },
467 response: Some(HttpResponse {
468 status,
469 status_text: "X".to_string(),
470 version: "HTTP/2.0".to_string(),
471 headers: vec![
472 ("X-Test".to_string(), "1".to_string()),
473 ("content-length".to_string(), "999".to_string()),
474 ("connection".to_string(), "keep-alive".to_string()),
475 ],
476 cookies: vec![],
477 body: None,
478 timing: ResponseTiming {
479 time_to_first_byte: None,
480 time_to_last_byte: None,
481 },
482 }),
483 error: None,
484 }),
485 tags: vec![],
486 meta: HashMap::new(),
487 }
488 }
489
490 #[test]
491 fn test_parse_request_meta_relative_uri_uses_host_http() {
492 let req = Request::builder()
493 .uri("/api/v1?q=1")
494 .header("Host", "example.com:8080")
495 .body(())
496 .expect("request");
497 let meta = parse_request_meta(&req, false);
498 assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
499 assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
500 }
501
502 #[test]
503 fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
504 let req = Request::builder()
505 .uri("/secure")
506 .header("Host", "secure.example.com")
507 .body(())
508 .expect("request");
509 let meta = parse_request_meta(&req, true);
510 assert_eq!(meta.url_str, "https://secure.example.com/secure");
511 }
512
513 #[test]
514 fn test_build_client_response_from_flow_uses_default_version_currently() {
515 let flow = sample_flow_with_response(201);
516 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
517 .expect("response should build");
518 assert_eq!(resp.version(), Version::HTTP_11);
519 assert_eq!(resp.status(), StatusCode::CREATED);
520 assert_eq!(resp.headers().get("x-test").and_then(|v| v.to_str().ok()), Some("1"));
521 assert!(
522 resp.headers().get("content-length").is_none(),
523 "content-length should be stripped from forwarded mock response"
524 );
525 assert!(resp.headers().get("connection").is_none());
526 }
527
528 #[test]
529 fn test_build_client_response_from_flow_invalid_status_strict_fails() {
530 let flow = sample_flow_with_response(1000);
531 let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
532 .expect_err("strict mode should reject invalid status");
533 assert!(err.contains("Invalid status code"));
534 }
535
536 #[tokio::test]
537 async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
538 let mut flow = sample_flow_with_response(1000);
539 if let Layer::Http(http) = &mut flow.layer {
540 if let Some(res) = &mut http.response {
541 res.body = Some(relay_core_api::flow::BodyData {
542 encoding: "utf-8".to_string(),
543 content: "hello".to_string(),
544 size: 5,
545 });
546 }
547 }
548
549 let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
550 .expect("non-strict should fallback");
551 assert_eq!(resp.status(), StatusCode::OK);
552 let body = resp
553 .into_body()
554 .collect()
555 .await
556 .expect("collect body")
557 .to_bytes();
558 assert_eq!(body.as_ref(), b"hello");
559 }
560}