Skip to main content

relay_core_lib/proxy/
http_utils.rs

1use crate::capture::loop_detection::LoopDetector;
2use crate::interceptor::HttpBody;
3use crate::proxy::body_codec::process_body;
4use chrono::Utc;
5use cookie::Cookie as CookieCrate;
6use data_encoding::BASE64;
7use http_body_util::{BodyExt, Full};
8use hyper::body::Bytes;
9use hyper::header::{HeaderName, HeaderValue};
10use hyper::{Request, Response, StatusCode};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use relay_core_api::flow::{
15    BodyData, Cookie, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo,
16    TransportProtocol, WebSocketLayer,
17};
18use relay_core_api::policy::ProxyPolicy;
19use std::net::SocketAddr;
20use url::Url;
21use uuid::Uuid;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27    pub method: String,
28    pub url_str: String,
29    pub version: String,
30    pub headers: Vec<(String, String)>,
31    pub query: Vec<(String, String)>,
32    pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36    let method = req.method().to_string();
37    let mut url_str = req.uri().to_string();
38
39    // Attempt to construct absolute URL if relative
40    if Url::parse(&url_str).is_err()
41        && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok())
42    {
43        let scheme = if is_mitm { "https" } else { "http" };
44        let new_url = format!("{}://{}{}", scheme, host, url_str);
45        if Url::parse(&new_url).is_ok() {
46            url_str = new_url;
47        }
48    }
49
50    let version = format!("{:?}", req.version());
51
52    let headers: Vec<(String, String)> = req
53        .headers()
54        .iter()
55        .map(|(k, v)| {
56            (
57                k.to_string(),
58                String::from_utf8_lossy(v.as_bytes()).to_string(),
59            )
60        })
61        .collect();
62
63    let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
64        parsed_url.query_pairs().into_owned().collect()
65    } else {
66        vec![]
67    };
68
69    let mut cookies = Vec::new();
70    if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
71        && let Ok(cookie_str) = cookie_header.to_str()
72    {
73        for c in CookieCrate::split_parse(cookie_str).flatten() {
74            cookies.push(Cookie {
75                name: c.name().to_string(),
76                value: c.value().to_string(),
77                path: None,
78                domain: None,
79                expires: None,
80                http_only: None,
81                secure: None,
82            });
83        }
84    }
85
86    RequestMeta {
87        method,
88        url_str,
89        version,
90        headers,
91        query,
92        cookies,
93    }
94}
95
96pub fn is_hop_by_hop(name: &str) -> bool {
97    name.eq_ignore_ascii_case("connection")
98        || name.eq_ignore_ascii_case("keep-alive")
99        || name.eq_ignore_ascii_case("proxy-authenticate")
100        || name.eq_ignore_ascii_case("proxy-authorization")
101        || name.eq_ignore_ascii_case("te")
102        || name.eq_ignore_ascii_case("trailers")
103        || name.eq_ignore_ascii_case("transfer-encoding")
104        || name.eq_ignore_ascii_case("upgrade")
105}
106
107pub fn create_initial_flow(
108    meta: RequestMeta,
109    req_body: Option<BodyData>,
110    client_addr: SocketAddr,
111    is_mitm: bool,
112    is_websocket: bool,
113) -> Flow {
114    let flow_id = Uuid::new_v4();
115    let start_time = Utc::now();
116
117    let network_info = NetworkInfo {
118        client_ip: client_addr.ip().to_string(),
119        client_port: client_addr.port(),
120        server_ip: "0.0.0.0".to_string(), // Placeholder
121        server_port: 0,                   // Placeholder
122        protocol: TransportProtocol::TCP,
123        tls: is_mitm,
124        tls_version: None,
125        sni: None,
126    };
127
128    let http_request = HttpRequest {
129        method: meta.method,
130        url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
131        version: meta.version,
132        headers: meta.headers,
133        cookies: meta.cookies,
134        query: meta.query,
135        body: req_body,
136    };
137
138    let mut flow = if is_websocket {
139        Flow {
140            id: flow_id,
141            start_time,
142            end_time: None,
143            network: network_info,
144            layer: Layer::WebSocket(WebSocketLayer {
145                handshake_request: http_request,
146                handshake_response: HttpResponse {
147                    status: 0,
148                    status_text: "".to_string(),
149                    version: "".to_string(),
150                    headers: vec![],
151                    cookies: vec![],
152                    body: None,
153                    timing: relay_core_api::flow::ResponseTiming {
154                        time_to_first_byte: None,
155                        time_to_last_byte: None,
156                        connect_time_ms: None,
157                        ssl_time_ms: None,
158                    },
159                },
160                messages: vec![],
161                closed: false,
162            }),
163            tags: vec!["websocket".to_string()],
164            meta: std::collections::HashMap::new(),
165            resilience_trace: None,
166            rule_variables: std::collections::HashMap::new(),
167            matched_rules: vec![],
168        }
169    } else {
170        Flow {
171            id: flow_id,
172            start_time,
173            end_time: None,
174            network: network_info,
175            layer: Layer::Http(HttpLayer {
176                request: http_request,
177                response: None,
178                error: None,
179            }),
180            tags: vec!["proxy".to_string()],
181            meta: std::collections::HashMap::new(),
182            resilience_trace: None,
183            rule_variables: std::collections::HashMap::new(),
184            matched_rules: vec![],
185        }
186    };
187
188    if is_mitm {
189        flow.tags.push("mitm".to_string());
190    }
191
192    flow
193}
194
195pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
196    Response::builder()
197        .status(status)
198        .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
199        .unwrap_or_else(|_| {
200            Response::new(
201                Full::new(Bytes::from("Internal Error"))
202                    .map_err(|e| e.into())
203                    .boxed(),
204            )
205        })
206}
207
208pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
209    let mut builder =
210        Response::builder().status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
211
212    for (k, v) in mock.headers {
213        if let (Ok(name), Ok(val)) = (
214            HeaderName::from_bytes(k.as_bytes()),
215            HeaderValue::from_str(&v),
216        ) {
217            builder = builder.header(name, val);
218        }
219    }
220
221    let body = if let Some(b) = mock.body {
222        Bytes::from(b.content)
223    } else {
224        Bytes::new()
225    };
226
227    builder
228        .body(Full::new(body).map_err(|e| e.into()).boxed())
229        .unwrap_or_else(|_| {
230            create_error_response(
231                StatusCode::INTERNAL_SERVER_ERROR,
232                "Failed to build mock response",
233            )
234        })
235}
236
237#[allow(clippy::result_large_err)]
238pub fn build_forward_request(
239    flow: &mut Flow,
240    body: HttpBody,
241    target_addr: Option<SocketAddr>,
242    policy: &ProxyPolicy,
243    loop_detector: &LoopDetector,
244) -> Result<Request<HttpBody>, Response<HttpBody>> {
245    let current_req = if let Layer::Http(http) = &flow.layer {
246        &http.request
247    } else {
248        return Err(create_error_response(
249            StatusCode::INTERNAL_SERVER_ERROR,
250            "Invalid Flow Layer State",
251        ));
252    };
253
254    let mut forward_req_builder = Request::builder().method(current_req.method.as_str());
255
256    // Determine upstream URI
257    let mut target_url = current_req.url.clone();
258
259    // Transparent Proxy Routing Logic
260    if policy.transparent_enabled
261        && let Some(addr) = target_addr
262    {
263        flow.tags.push("transparent".to_string());
264
265        // Update Flow Network Info
266        flow.network.server_ip = addr.ip().to_string();
267        flow.network.server_port = addr.port();
268
269        // Loop Detection
270        if loop_detector.would_loop(addr) {
271            if let Layer::Http(http) = &mut flow.layer {
272                http.error = Some("Loop Detected".to_string());
273            }
274            return Err(create_error_response(
275                StatusCode::LOOP_DETECTED,
276                "Loop Detected",
277            ));
278        }
279
280        // Rewrite URI to use target IP
281        if target_url.set_ip_host(addr.ip()).is_ok() {
282            target_url.set_port(Some(addr.port())).ok();
283        }
284
285        // Update scheme if MITM
286        if flow.network.tls && target_url.scheme() == "http" {
287            target_url.set_scheme("https").ok();
288        }
289    }
290
291    forward_req_builder = forward_req_builder.uri(target_url.as_str());
292
293    for (k, v) in &current_req.headers {
294        // Filter out hop-by-hop headers to allow connection pooling
295        if is_hop_by_hop(k) {
296            continue;
297        }
298
299        if let (Ok(name), Ok(val)) = (
300            HeaderName::from_bytes(k.as_bytes()),
301            HeaderValue::from_str(v),
302        ) {
303            forward_req_builder = forward_req_builder.header(name, val);
304        }
305    }
306
307    match forward_req_builder.body(body) {
308        Ok(req) => Ok(req),
309        Err(e) => Err(create_error_response(
310            StatusCode::INTERNAL_SERVER_ERROR,
311            format!("Failed to build forward request: {}", e),
312        )),
313    }
314}
315
316pub fn update_flow_with_response_headers(
317    flow: &mut Flow,
318    status: StatusCode,
319    version: hyper::Version,
320    headers: &hyper::HeaderMap,
321) {
322    let mut response_cookies = Vec::new();
323    for (k, v) in headers.iter() {
324        if k == hyper::header::SET_COOKIE
325            && let Ok(v_str) = v.to_str()
326            && let Ok(c) = CookieCrate::parse(v_str)
327        {
328            response_cookies.push(Cookie {
329                name: c.name().to_string(),
330                value: c.value().to_string(),
331                path: c.path().map(|s| s.to_string()),
332                domain: c.domain().map(|s| s.to_string()),
333                expires: c.expires().map(|e| format!("{:?}", e)),
334                http_only: c.http_only(),
335                secure: c.secure(),
336            });
337        }
338    }
339
340    let resp_headers_vec: Vec<(String, String)> = headers
341        .iter()
342        .map(|(k, v)| {
343            (
344                k.to_string(),
345                String::from_utf8_lossy(v.as_bytes()).to_string(),
346            )
347        })
348        .collect();
349
350    let http_response = HttpResponse {
351        status: status.as_u16(),
352        status_text: status.to_string(),
353        version: format!("{:?}", version),
354        headers: resp_headers_vec,
355        cookies: response_cookies,
356        body: None,
357        timing: relay_core_api::flow::ResponseTiming {
358            time_to_first_byte: None,
359            time_to_last_byte: None,
360            connect_time_ms: None,
361            ssl_time_ms: None,
362        },
363    };
364
365    match &mut flow.layer {
366        Layer::Http(http) => {
367            http.response = Some(http_response);
368        }
369        Layer::WebSocket(ws) => {
370            ws.handshake_response = http_response;
371        }
372        _ => {}
373    }
374}
375
376pub fn update_flow_with_response_body(flow: &mut Flow, body_bytes: Bytes) {
377    let headers = match &flow.layer {
378        Layer::Http(http) => http
379            .response
380            .as_ref()
381            .map(|r| r.headers.clone())
382            .unwrap_or_default(),
383        Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
384        _ => Vec::new(),
385    };
386
387    let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
388
389    let body_data = BodyData {
390        encoding: resp_encoding,
391        content: resp_content,
392        size: body_bytes.len() as u64,
393    };
394
395    match &mut flow.layer {
396        Layer::Http(http) => {
397            if let Some(resp) = &mut http.response {
398                resp.body = Some(body_data);
399            }
400        }
401        Layer::WebSocket(ws) => {
402            ws.handshake_response.body = Some(body_data);
403        }
404        _ => {}
405    }
406}
407
408pub fn update_flow_with_response(
409    flow: &mut Flow,
410    status: StatusCode,
411    version: hyper::Version,
412    headers: &hyper::HeaderMap,
413    body_bytes: Bytes,
414) {
415    update_flow_with_response_headers(flow, status, version, headers);
416    update_flow_with_response_body(flow, body_bytes);
417}
418
419pub fn build_client_response_from_flow(
420    flow: &Flow,
421    default_version: hyper::Version,
422    strict_mode: bool,
423) -> Result<Response<Full<Bytes>>, String> {
424    if let Layer::Http(http) = &flow.layer {
425        if let Some(response) = &http.response {
426            let status = match StatusCode::from_u16(response.status) {
427                Ok(s) => s,
428                Err(_) => {
429                    if strict_mode {
430                        crate::metrics::inc_proxy_invalid_status();
431                        return Err(format!("Invalid status code: {}", response.status));
432                    }
433                    StatusCode::OK
434                }
435            };
436
437            let mut builder = Response::builder().status(status).version(default_version); // TODO: Parse version from flow string if needed
438
439            for (k, v) in &response.headers {
440                // Filter out transport-level headers that might conflict with the new body
441                if k.eq_ignore_ascii_case("content-length")
442                    || k.eq_ignore_ascii_case("transfer-encoding")
443                    || k.eq_ignore_ascii_case("connection")
444                {
445                    continue;
446                }
447
448                if let (Ok(name), Ok(val)) = (
449                    HeaderName::from_bytes(k.as_bytes()),
450                    HeaderValue::from_str(v),
451                ) {
452                    builder = builder.header(name, val);
453                } else if strict_mode {
454                    return Err(format!("Invalid header: {}: {}", k, v));
455                }
456            }
457
458            let body_bytes = if let Some(b) = &response.body {
459                if b.encoding == "base64" {
460                    match BASE64.decode(b.content.as_bytes()) {
461                        Ok(bytes) => Bytes::from(bytes),
462                        Err(_e) => {
463                            // Fallback
464                            Bytes::from(b.content.clone())
465                        }
466                    }
467                } else {
468                    Bytes::from(b.content.clone())
469                }
470            } else {
471                Bytes::new()
472            };
473
474            builder
475                .body(Full::new(body_bytes))
476                .map_err(|e| format!("Failed to build response: {}", e))
477        } else {
478            Err("No response in flow".to_string())
479        }
480    } else {
481        Err("Not HTTP layer".to_string())
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::{build_client_response_from_flow, parse_request_meta};
488    use chrono::Utc;
489    use http_body_util::BodyExt;
490    use hyper::{Request, StatusCode, Version};
491    use relay_core_api::flow::{
492        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
493        TransportProtocol,
494    };
495    use std::collections::HashMap;
496    use url::Url;
497    use uuid::Uuid;
498
499    fn sample_flow_with_response(status: u16) -> Flow {
500        Flow {
501            id: Uuid::new_v4(),
502            start_time: Utc::now(),
503            end_time: None,
504            network: NetworkInfo {
505                client_ip: "127.0.0.1".to_string(),
506                client_port: 12345,
507                server_ip: "1.1.1.1".to_string(),
508                server_port: 80,
509                protocol: TransportProtocol::TCP,
510                tls: false,
511                tls_version: None,
512                sni: None,
513            },
514            layer: Layer::Http(HttpLayer {
515                request: HttpRequest {
516                    method: "GET".to_string(),
517                    url: Url::parse("http://example.com/a").expect("url"),
518                    version: "HTTP/1.1".to_string(),
519                    headers: vec![],
520                    cookies: vec![],
521                    query: vec![],
522                    body: None,
523                },
524                response: Some(HttpResponse {
525                    status,
526                    status_text: "X".to_string(),
527                    version: "HTTP/2.0".to_string(),
528                    headers: vec![
529                        ("X-Test".to_string(), "1".to_string()),
530                        ("content-length".to_string(), "999".to_string()),
531                        ("connection".to_string(), "keep-alive".to_string()),
532                    ],
533                    cookies: vec![],
534                    body: None,
535                    timing: ResponseTiming {
536                        time_to_first_byte: None,
537                        time_to_last_byte: None,
538                        connect_time_ms: None,
539                        ssl_time_ms: None,
540                    },
541                }),
542                error: None,
543            }),
544            tags: vec![],
545            meta: HashMap::new(),
546            resilience_trace: None,
547            rule_variables: HashMap::new(),
548            matched_rules: vec![],
549        }
550    }
551
552    #[test]
553    fn test_parse_request_meta_relative_uri_uses_host_http() {
554        let req = Request::builder()
555            .uri("/api/v1?q=1")
556            .header("Host", "example.com:8080")
557            .body(())
558            .expect("request");
559        let meta = parse_request_meta(&req, false);
560        assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
561        assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
562    }
563
564    #[test]
565    fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
566        let req = Request::builder()
567            .uri("/secure")
568            .header("Host", "secure.example.com")
569            .body(())
570            .expect("request");
571        let meta = parse_request_meta(&req, true);
572        assert_eq!(meta.url_str, "https://secure.example.com/secure");
573    }
574
575    #[test]
576    fn test_build_client_response_from_flow_uses_default_version_currently() {
577        let flow = sample_flow_with_response(201);
578        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
579            .expect("response should build");
580        assert_eq!(resp.version(), Version::HTTP_11);
581        assert_eq!(resp.status(), StatusCode::CREATED);
582        assert_eq!(
583            resp.headers().get("x-test").and_then(|v| v.to_str().ok()),
584            Some("1")
585        );
586        assert!(
587            resp.headers().get("content-length").is_none(),
588            "content-length should be stripped from forwarded mock response"
589        );
590        assert!(resp.headers().get("connection").is_none());
591    }
592
593    #[test]
594    fn test_build_client_response_from_flow_invalid_status_strict_fails() {
595        let flow = sample_flow_with_response(1000);
596        let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
597            .expect_err("strict mode should reject invalid status");
598        assert!(err.contains("Invalid status code"));
599    }
600
601    #[tokio::test]
602    async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
603        let mut flow = sample_flow_with_response(1000);
604        if let Layer::Http(http) = &mut flow.layer {
605            if let Some(res) = &mut http.response {
606                res.body = Some(relay_core_api::flow::BodyData {
607                    encoding: "utf-8".to_string(),
608                    content: "hello".to_string(),
609                    size: 5,
610                });
611            }
612        }
613
614        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
615            .expect("non-strict should fallback");
616        assert_eq!(resp.status(), StatusCode::OK);
617        let body = resp
618            .into_body()
619            .collect()
620            .await
621            .expect("collect body")
622            .to_bytes();
623        assert_eq!(body.as_ref(), b"hello");
624    }
625}