Skip to main content

relay_core_lib/proxy/
http_utils.rs

1use crate::capture::loop_detection::LoopDetector;
2use crate::interceptor::HttpBody;
3use crate::proxy::body_codec::process_body;
4use chrono::Utc;
5use cookie::Cookie as CookieCrate;
6use data_encoding::BASE64;
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper::header::{HeaderName, HeaderValue};
10use hyper::{Request, Response, StatusCode};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use relay_core_api::flow::{
15    BodyData, Cookie, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
16    TransportProtocol, WebSocketLayer,
17};
18use relay_core_api::policy::ProxyPolicy;
19use std::net::SocketAddr;
20use url::Url;
21use uuid::Uuid;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27    pub method: String,
28    pub url_str: String,
29    pub version: String,
30    pub headers: Vec<(String, String)>,
31    pub query: Vec<(String, String)>,
32    pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36    let method = req.method().to_string();
37    let mut url_str = req.uri().to_string();
38
39    // Attempt to construct absolute URL if relative
40    if Url::parse(&url_str).is_err()
41        && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok())
42    {
43        let scheme = if is_mitm { "https" } else { "http" };
44        let new_url = format!("{}://{}{}", scheme, host, url_str);
45        if Url::parse(&new_url).is_ok() {
46            url_str = new_url;
47        }
48    }
49
50    let version = format!("{:?}", req.version());
51
52    let headers: Vec<(String, String)> = req
53        .headers()
54        .iter()
55        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
56        .collect();
57
58    let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
59        parsed_url.query_pairs().into_owned().collect()
60    } else {
61        vec![]
62    };
63
64    let mut cookies = Vec::new();
65    if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
66        && let Ok(cookie_str) = cookie_header.to_str()
67    {
68        for c in CookieCrate::split_parse(cookie_str).flatten() {
69            cookies.push(Cookie {
70                name: c.name().to_string(),
71                value: c.value().to_string(),
72                path: None,
73                domain: None,
74                expires: None,
75                http_only: None,
76                secure: None,
77            });
78        }
79    }
80
81    RequestMeta {
82        method,
83        url_str,
84        version,
85        headers,
86        query,
87        cookies,
88    }
89}
90
91pub fn is_hop_by_hop(name: &str) -> bool {
92    name.eq_ignore_ascii_case("connection")
93        || name.eq_ignore_ascii_case("keep-alive")
94        || name.eq_ignore_ascii_case("proxy-authenticate")
95        || name.eq_ignore_ascii_case("proxy-authorization")
96        || name.eq_ignore_ascii_case("te")
97        || name.eq_ignore_ascii_case("trailers")
98        || name.eq_ignore_ascii_case("transfer-encoding")
99        || name.eq_ignore_ascii_case("upgrade")
100        || name.eq_ignore_ascii_case("content-length")
101}
102
103pub fn create_initial_flow(
104    meta: RequestMeta,
105    req_body: Option<BodyData>,
106    client_addr: SocketAddr,
107    is_mitm: bool,
108    is_websocket: bool,
109) -> Flow {
110    let flow_id = Uuid::new_v4();
111    let start_time = Utc::now();
112
113    let network_info = NetworkInfo {
114        client_ip: client_addr.ip().to_string(),
115        client_port: client_addr.port(),
116        server_ip: "0.0.0.0".to_string(), // Placeholder
117        server_port: 0,                   // Placeholder
118        protocol: TransportProtocol::TCP,
119        tls: is_mitm,
120        tls_version: None,
121        sni: None,
122    };
123
124    let http_request = HttpRequest {
125        method: meta.method,
126        url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
127        version: meta.version,
128        headers: meta.headers,
129        cookies: meta.cookies,
130        query: meta.query,
131        body: req_body,
132    };
133
134    let mut flow = if is_websocket {
135        Flow {
136            id: flow_id,
137            start_time,
138            end_time: None,
139            network: network_info,
140            layer: Layer::WebSocket(WebSocketLayer {
141                handshake_request: http_request,
142                handshake_response: HttpResponse {
143                    status: 0,
144                    status_text: "".to_string(),
145                    version: "".to_string(),
146                    headers: vec![],
147                    cookies: vec![],
148                    body: None,
149                    timing: relay_core_api::flow::ResponseTiming {
150                        time_to_first_byte: None,
151                        time_to_last_byte: None,
152                        connect_time_ms: None,
153                        ssl_time_ms: None,
154                    },
155                },
156                messages: vec![],
157                closed: false,
158            }),
159            tags: vec!["websocket".to_string()],
160            meta: std::collections::HashMap::new(),
161        }
162    } else {
163        Flow {
164            id: flow_id,
165            start_time,
166            end_time: None,
167            network: network_info,
168            layer: Layer::Http(HttpLayer {
169                request: http_request,
170                response: None,
171                error: None,
172            }),
173            tags: vec!["proxy".to_string()],
174            meta: std::collections::HashMap::new(),
175        }
176    };
177
178    if is_mitm {
179        flow.tags.push("mitm".to_string());
180    }
181
182    flow
183}
184
185pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
186    Response::builder()
187        .status(status)
188        .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
189        .unwrap_or_else(|_| {
190            Response::new(
191                Full::new(Bytes::from("Internal Error"))
192                    .map_err(|e| e.into())
193                    .boxed(),
194            )
195        })
196}
197
198pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
199    let mut builder =
200        Response::builder().status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
201
202    for (k, v) in mock.headers {
203        if let (Ok(name), Ok(val)) = (
204            HeaderName::from_bytes(k.as_bytes()),
205            HeaderValue::from_str(&v),
206        ) {
207            builder = builder.header(name, val);
208        }
209    }
210
211    let body = if let Some(b) = mock.body {
212        Bytes::from(b.content)
213    } else {
214        Bytes::new()
215    };
216
217    builder
218        .body(Full::new(body).map_err(|e| e.into()).boxed())
219        .unwrap_or_else(|_| {
220            create_error_response(
221                StatusCode::INTERNAL_SERVER_ERROR,
222                "Failed to build mock response",
223            )
224        })
225}
226
227#[allow(clippy::result_large_err)]
228pub fn build_forward_request(
229    flow: &mut Flow,
230    body: HttpBody,
231    target_addr: Option<SocketAddr>,
232    policy: &ProxyPolicy,
233    loop_detector: &LoopDetector,
234) -> Result<Request<HttpBody>, Response<HttpBody>> {
235    let current_req = if let Layer::Http(http) = &flow.layer {
236        &http.request
237    } else {
238        return Err(create_error_response(
239            StatusCode::INTERNAL_SERVER_ERROR,
240            "Invalid Flow Layer State",
241        ));
242    };
243
244    let mut forward_req_builder = Request::builder().method(current_req.method.as_str());
245
246    // Determine upstream URI
247    let mut target_url = current_req.url.clone();
248
249    // Transparent Proxy Routing Logic
250    if policy.transparent_enabled
251        && let Some(addr) = target_addr
252    {
253        flow.tags.push("transparent".to_string());
254
255        // Update Flow Network Info
256        flow.network.server_ip = addr.ip().to_string();
257        flow.network.server_port = addr.port();
258
259        // Loop Detection
260        if loop_detector.would_loop(addr) {
261            if let Layer::Http(http) = &mut flow.layer {
262                http.error = Some("Loop Detected".to_string());
263            }
264            return Err(create_error_response(
265                StatusCode::LOOP_DETECTED,
266                "Loop Detected",
267            ));
268        }
269
270        // Rewrite URI to use target IP
271        if target_url.set_ip_host(addr.ip()).is_ok() {
272            target_url.set_port(Some(addr.port())).ok();
273        }
274
275        // Update scheme if MITM
276        if flow.network.tls && target_url.scheme() == "http" {
277            target_url.set_scheme("https").ok();
278        }
279    }
280
281    forward_req_builder = forward_req_builder.uri(target_url.as_str());
282
283    for (k, v) in &current_req.headers {
284        // Filter out hop-by-hop headers to allow connection pooling
285        if is_hop_by_hop(k) {
286            continue;
287        }
288
289        if let (Ok(name), Ok(val)) = (
290            HeaderName::from_bytes(k.as_bytes()),
291            HeaderValue::from_str(v),
292        ) {
293            forward_req_builder = forward_req_builder.header(name, val);
294        }
295    }
296
297    match forward_req_builder.body(body) {
298        Ok(req) => Ok(req),
299        Err(e) => Err(create_error_response(
300            StatusCode::INTERNAL_SERVER_ERROR,
301            format!("Failed to build forward request: {}", e),
302        )),
303    }
304}
305
306pub fn update_flow_with_response_headers(
307    flow: &mut Flow,
308    status: StatusCode,
309    version: hyper::Version,
310    headers: &hyper::HeaderMap,
311) {
312    let mut response_cookies = Vec::new();
313    for (k, v) in headers.iter() {
314        if k == hyper::header::SET_COOKIE
315            && let Ok(v_str) = v.to_str()
316            && let Ok(c) = CookieCrate::parse(v_str)
317        {
318            response_cookies.push(Cookie {
319                name: c.name().to_string(),
320                value: c.value().to_string(),
321                path: c.path().map(|s| s.to_string()),
322                domain: c.domain().map(|s| s.to_string()),
323                expires: c.expires().map(|e| format!("{:?}", e)),
324                http_only: c.http_only(),
325                secure: c.secure(),
326            });
327        }
328    }
329
330    let resp_headers_vec: Vec<(String, String)> = headers
331        .iter()
332        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
333        .collect();
334
335    let http_response = HttpResponse {
336        status: status.as_u16(),
337        status_text: status.to_string(),
338        version: format!("{:?}", version),
339        headers: resp_headers_vec,
340        cookies: response_cookies,
341        body: None,
342        timing: relay_core_api::flow::ResponseTiming {
343            time_to_first_byte: None,
344            time_to_last_byte: None,
345            connect_time_ms: None,
346            ssl_time_ms: None,
347        },
348    };
349
350    match &mut flow.layer {
351        Layer::Http(http) => {
352            http.response = Some(http_response);
353        }
354        Layer::WebSocket(ws) => {
355            ws.handshake_response = http_response;
356        }
357        _ => {}
358    }
359}
360
361pub fn update_flow_with_response_body(flow: &mut Flow, body_bytes: Bytes) {
362    let headers = match &flow.layer {
363        Layer::Http(http) => http
364            .response
365            .as_ref()
366            .map(|r| r.headers.clone())
367            .unwrap_or_default(),
368        Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
369        _ => Vec::new(),
370    };
371
372    let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
373
374    let body_data = BodyData {
375        encoding: resp_encoding,
376        content: resp_content,
377        size: body_bytes.len() as u64,
378    };
379
380    match &mut flow.layer {
381        Layer::Http(http) => {
382            if let Some(resp) = &mut http.response {
383                resp.body = Some(body_data);
384            }
385        }
386        Layer::WebSocket(ws) => {
387            ws.handshake_response.body = Some(body_data);
388        }
389        _ => {}
390    }
391}
392
393pub fn update_flow_with_response(
394    flow: &mut Flow,
395    status: StatusCode,
396    version: hyper::Version,
397    headers: &hyper::HeaderMap,
398    body_bytes: Bytes,
399) {
400    update_flow_with_response_headers(flow, status, version, headers);
401    update_flow_with_response_body(flow, body_bytes);
402}
403
404pub fn build_client_response_from_flow(
405    flow: &Flow,
406    default_version: hyper::Version,
407    strict_mode: bool,
408) -> Result<Response<Full<Bytes>>, String> {
409    if let Layer::Http(http) = &flow.layer {
410        if let Some(response) = &http.response {
411            let status = match StatusCode::from_u16(response.status) {
412                Ok(s) => s,
413                Err(_) => {
414                    if strict_mode {
415                        return Err(format!("Invalid status code: {}", response.status));
416                    }
417                    StatusCode::OK
418                }
419            };
420
421            let mut builder = Response::builder().status(status).version(default_version); // TODO: Parse version from flow string if needed
422
423            for (k, v) in &response.headers {
424                // Filter out transport-level headers that might conflict with the new body
425                if k.eq_ignore_ascii_case("content-length")
426                    || k.eq_ignore_ascii_case("transfer-encoding")
427                    || k.eq_ignore_ascii_case("connection")
428                {
429                    continue;
430                }
431
432                if let (Ok(name), Ok(val)) = (
433                    HeaderName::from_bytes(k.as_bytes()),
434                    HeaderValue::from_str(v),
435                ) {
436                    builder = builder.header(name, val);
437                } else if strict_mode {
438                    return Err(format!("Invalid header: {}: {}", k, v));
439                }
440            }
441
442            let body_bytes = if let Some(b) = &response.body {
443                if b.encoding == "base64" {
444                    match BASE64.decode(b.content.as_bytes()) {
445                        Ok(bytes) => Bytes::from(bytes),
446                        Err(_e) => {
447                            // Fallback
448                            Bytes::from(b.content.clone())
449                        }
450                    }
451                } else {
452                    Bytes::from(b.content.clone())
453                }
454            } else {
455                Bytes::new()
456            };
457
458            builder
459                .body(Full::new(body_bytes))
460                .map_err(|e| format!("Failed to build response: {}", e))
461        } else {
462            Err("No response in flow".to_string())
463        }
464    } else {
465        Err("Not HTTP layer".to_string())
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::{build_client_response_from_flow, parse_request_meta};
472    use chrono::Utc;
473    use http_body_util::BodyExt;
474    use hyper::{Request, StatusCode, Version};
475    use relay_core_api::flow::{
476        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
477        TransportProtocol,
478    };
479    use std::collections::HashMap;
480    use url::Url;
481    use uuid::Uuid;
482
483    fn sample_flow_with_response(status: u16) -> Flow {
484        Flow {
485            id: Uuid::new_v4(),
486            start_time: Utc::now(),
487            end_time: None,
488            network: NetworkInfo {
489                client_ip: "127.0.0.1".to_string(),
490                client_port: 12345,
491                server_ip: "1.1.1.1".to_string(),
492                server_port: 80,
493                protocol: TransportProtocol::TCP,
494                tls: false,
495                tls_version: None,
496                sni: None,
497            },
498            layer: Layer::Http(HttpLayer {
499                request: HttpRequest {
500                    method: "GET".to_string(),
501                    url: Url::parse("http://example.com/a").expect("url"),
502                    version: "HTTP/1.1".to_string(),
503                    headers: vec![],
504                    cookies: vec![],
505                    query: vec![],
506                    body: None,
507                },
508                response: Some(HttpResponse {
509                    status,
510                    status_text: "X".to_string(),
511                    version: "HTTP/2.0".to_string(),
512                    headers: vec![
513                        ("X-Test".to_string(), "1".to_string()),
514                        ("content-length".to_string(), "999".to_string()),
515                        ("connection".to_string(), "keep-alive".to_string()),
516                    ],
517                    cookies: vec![],
518                    body: None,
519                    timing: ResponseTiming {
520                        time_to_first_byte: None,
521                        time_to_last_byte: None,
522                        connect_time_ms: None,
523                        ssl_time_ms: None,
524                    },
525                }),
526                error: None,
527            }),
528            tags: vec![],
529            meta: HashMap::new(),
530        }
531    }
532
533    #[test]
534    fn test_parse_request_meta_relative_uri_uses_host_http() {
535        let req = Request::builder()
536            .uri("/api/v1?q=1")
537            .header("Host", "example.com:8080")
538            .body(())
539            .expect("request");
540        let meta = parse_request_meta(&req, false);
541        assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
542        assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
543    }
544
545    #[test]
546    fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
547        let req = Request::builder()
548            .uri("/secure")
549            .header("Host", "secure.example.com")
550            .body(())
551            .expect("request");
552        let meta = parse_request_meta(&req, true);
553        assert_eq!(meta.url_str, "https://secure.example.com/secure");
554    }
555
556    #[test]
557    fn test_build_client_response_from_flow_uses_default_version_currently() {
558        let flow = sample_flow_with_response(201);
559        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
560            .expect("response should build");
561        assert_eq!(resp.version(), Version::HTTP_11);
562        assert_eq!(resp.status(), StatusCode::CREATED);
563        assert_eq!(
564            resp.headers().get("x-test").and_then(|v| v.to_str().ok()),
565            Some("1")
566        );
567        assert!(
568            resp.headers().get("content-length").is_none(),
569            "content-length should be stripped from forwarded mock response"
570        );
571        assert!(resp.headers().get("connection").is_none());
572    }
573
574    #[test]
575    fn test_build_client_response_from_flow_invalid_status_strict_fails() {
576        let flow = sample_flow_with_response(1000);
577        let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
578            .expect_err("strict mode should reject invalid status");
579        assert!(err.contains("Invalid status code"));
580    }
581
582    #[tokio::test]
583    async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
584        let mut flow = sample_flow_with_response(1000);
585        if let Layer::Http(http) = &mut flow.layer {
586            if let Some(res) = &mut http.response {
587                res.body = Some(relay_core_api::flow::BodyData {
588                    encoding: "utf-8".to_string(),
589                    content: "hello".to_string(),
590                    size: 5,
591                });
592            }
593        }
594
595        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
596            .expect("non-strict should fallback");
597        assert_eq!(resp.status(), StatusCode::OK);
598        let body = resp
599            .into_body()
600            .collect()
601            .await
602            .expect("collect body")
603            .to_bytes();
604        assert_eq!(body.as_ref(), b"hello");
605    }
606}