Skip to main content

relay_core_lib/proxy/
http_utils.rs

1use std::net::SocketAddr;
2use hyper::{Response, Request, StatusCode};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use hyper_util::client::legacy::Client;
7use hyper_util::client::legacy::connect::HttpConnector;
8use hyper_rustls::HttpsConnector;
9use relay_core_api::flow::{
10    BodyData, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, TransportProtocol,
11    WebSocketLayer, Cookie,
12};
13use relay_core_api::policy::ProxyPolicy;
14use crate::capture::loop_detection::LoopDetector;
15use uuid::Uuid;
16use chrono::Utc;
17use url::Url;
18use cookie::Cookie as CookieCrate;
19use data_encoding::BASE64;
20use crate::proxy::body_codec::process_body;
21use crate::interceptor::HttpBody;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27    pub method: String,
28    pub url_str: String,
29    pub version: String,
30    pub headers: Vec<(String, String)>,
31    pub query: Vec<(String, String)>,
32    pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36    let method = req.method().to_string();
37    let mut url_str = req.uri().to_string();
38    
39    // Attempt to construct absolute URL if relative
40    if Url::parse(&url_str).is_err()
41        && let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok()) {
42             let scheme = if is_mitm { "https" } else { "http" };
43             let new_url = format!("{}://{}{}", scheme, host, url_str);
44             if Url::parse(&new_url).is_ok() {
45                 url_str = new_url;
46             }
47        }
48
49    let version = format!("{:?}", req.version());
50    
51    let headers: Vec<(String, String)> = req.headers()
52        .iter()
53        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
54        .collect();
55
56    let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
57        parsed_url.query_pairs().into_owned().collect()
58    } else {
59        vec![]
60    };
61
62    let mut cookies = Vec::new();
63    if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE)
64        && let Ok(cookie_str) = cookie_header.to_str() {
65             for c in CookieCrate::split_parse(cookie_str).flatten() {
66                 cookies.push(Cookie {
67                     name: c.name().to_string(),
68                     value: c.value().to_string(),
69                     path: None,
70                     domain: None,
71                     expires: None,
72                     http_only: None,
73                     secure: None,
74                 });
75             }
76        }
77
78    RequestMeta {
79        method,
80        url_str,
81        version,
82        headers,
83        query,
84        cookies,
85    }
86}
87
88pub fn is_hop_by_hop(name: &str) -> bool {
89    name.eq_ignore_ascii_case("connection")
90        || name.eq_ignore_ascii_case("keep-alive")
91        || name.eq_ignore_ascii_case("proxy-authenticate")
92        || name.eq_ignore_ascii_case("proxy-authorization")
93        || name.eq_ignore_ascii_case("te")
94        || name.eq_ignore_ascii_case("trailers")
95        || name.eq_ignore_ascii_case("transfer-encoding")
96        || name.eq_ignore_ascii_case("upgrade")
97        || name.eq_ignore_ascii_case("content-length")
98}
99
100pub fn create_initial_flow(
101    meta: RequestMeta,
102    req_body: Option<BodyData>,
103    client_addr: SocketAddr,
104    is_mitm: bool,
105    is_websocket: bool,
106) -> Flow {
107    let flow_id = Uuid::new_v4();
108    let start_time = Utc::now();
109    
110    let network_info = NetworkInfo {
111        client_ip: client_addr.ip().to_string(),
112        client_port: client_addr.port(),
113        server_ip: "0.0.0.0".to_string(), // Placeholder
114        server_port: 0,                   // Placeholder
115        protocol: TransportProtocol::TCP,
116        tls: is_mitm,
117        tls_version: None,
118        sni: None,
119    };
120
121    let http_request = HttpRequest {
122        method: meta.method,
123        url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
124        version: meta.version,
125        headers: meta.headers,
126        cookies: meta.cookies,
127        query: meta.query,
128        body: req_body, 
129    };
130
131    let mut flow = if is_websocket {
132        Flow {
133            id: flow_id,
134            start_time,
135            end_time: None,
136            network: network_info,
137            layer: Layer::WebSocket(WebSocketLayer {
138                handshake_request: http_request,
139                handshake_response: HttpResponse {
140                    status: 0,
141                    status_text: "".to_string(),
142                    version: "".to_string(),
143                    headers: vec![],
144                    cookies: vec![],
145                    body: None,
146                    timing: relay_core_api::flow::ResponseTiming { time_to_first_byte: None, time_to_last_byte: None,
147                connect_time_ms: None,
148                ssl_time_ms: None,
149            },
150                },
151                messages: vec![],
152                closed: false,
153            }),
154            tags: vec!["websocket".to_string()],
155            meta: std::collections::HashMap::new(),
156        }
157    } else {
158        Flow {
159            id: flow_id,
160            start_time,
161            end_time: None,
162            network: network_info,
163            layer: Layer::Http(HttpLayer {
164                request: http_request,
165                response: None,
166                error: None,
167            }),
168            tags: vec!["proxy".to_string()],
169            meta: std::collections::HashMap::new(),
170        }
171    };
172
173    if is_mitm {
174        flow.tags.push("mitm".to_string());
175    }
176
177    flow
178}
179
180pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
181    Response::builder()
182        .status(status)
183        .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
184        .unwrap_or_else(|_| Response::new(Full::new(Bytes::from("Internal Error")).map_err(|e| e.into()).boxed()))
185}
186
187pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
188     let mut builder = Response::builder()
189        .status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
190        
191    for (k, v) in mock.headers {
192        if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(&v)) {
193            builder = builder.header(name, val);
194        }
195    }
196    
197    let body = if let Some(b) = mock.body {
198        Bytes::from(b.content)
199    } else {
200        Bytes::new()
201    };
202    
203    builder.body(Full::new(body).map_err(|e| e.into()).boxed()).unwrap_or_else(|_| create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to build mock response"))
204}
205
206#[allow(clippy::result_large_err)]
207pub fn build_forward_request(
208    flow: &mut Flow,
209    body: HttpBody,
210    target_addr: Option<SocketAddr>,
211    policy: &ProxyPolicy,
212    loop_detector: &LoopDetector,
213) -> Result<Request<HttpBody>, Response<HttpBody>> {
214    let current_req = if let Layer::Http(http) = &flow.layer {
215        &http.request
216    } else {
217        return Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
218    };
219
220    let mut forward_req_builder = Request::builder()
221        .method(current_req.method.as_str());
222
223    // Determine upstream URI
224    let mut target_url = current_req.url.clone();
225    
226    // Transparent Proxy Routing Logic
227    if policy.transparent_enabled
228        && let Some(addr) = target_addr {
229            flow.tags.push("transparent".to_string());
230            
231            // Update Flow Network Info
232            flow.network.server_ip = addr.ip().to_string();
233            flow.network.server_port = addr.port();
234
235            // Loop Detection
236            if loop_detector.would_loop(addr) {
237                if let Layer::Http(http) = &mut flow.layer {
238                    http.error = Some("Loop Detected".to_string());
239                }
240                return Err(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
241            }
242
243            // Rewrite URI to use target IP
244            if target_url.set_ip_host(addr.ip()).is_ok() {
245                 target_url.set_port(Some(addr.port())).ok();
246            }
247
248            // Update scheme if MITM
249            if flow.network.tls && target_url.scheme() == "http" {
250                target_url.set_scheme("https").ok();
251            }
252        }
253
254    forward_req_builder = forward_req_builder.uri(target_url.as_str());
255
256    for (k, v) in &current_req.headers {
257        // Filter out hop-by-hop headers to allow connection pooling
258        if is_hop_by_hop(k) {
259            continue;
260        }
261
262        if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
263            forward_req_builder = forward_req_builder.header(name, val);
264        }
265    }
266
267    match forward_req_builder.body(body) {
268        Ok(req) => Ok(req),
269        Err(e) => Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
270    }
271}
272
273pub fn update_flow_with_response_headers(
274    flow: &mut Flow,
275    status: StatusCode,
276    version: hyper::Version,
277    headers: &hyper::HeaderMap,
278) {
279    let mut response_cookies = Vec::new();
280    for (k, v) in headers.iter() {
281        if k == hyper::header::SET_COOKIE
282            && let Ok(v_str) = v.to_str()
283                && let Ok(c) = CookieCrate::parse(v_str) {
284                    response_cookies.push(Cookie {
285                        name: c.name().to_string(),
286                        value: c.value().to_string(),
287                        path: c.path().map(|s| s.to_string()),
288                        domain: c.domain().map(|s| s.to_string()),
289                        expires: c.expires().map(|e| format!("{:?}", e)),
290                        http_only: c.http_only(),
291                        secure: c.secure(),
292                    });
293                }
294    }
295
296    let resp_headers_vec: Vec<(String, String)> = headers.iter()
297        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
298        .collect();
299
300    let http_response = HttpResponse {
301        status: status.as_u16(),
302        status_text: status.to_string(),
303        version: format!("{:?}", version),
304        headers: resp_headers_vec,
305        cookies: response_cookies,
306        body: None,
307        timing: relay_core_api::flow::ResponseTiming {
308            time_to_first_byte: None,
309            time_to_last_byte: None,
310            connect_time_ms: None,
311            ssl_time_ms: None,
312        },
313    };
314
315    match &mut flow.layer {
316        Layer::Http(http) => {
317            http.response = Some(http_response);
318        },
319        Layer::WebSocket(ws) => {
320            ws.handshake_response = http_response;
321        },
322        _ => {}
323    }
324}
325
326pub fn update_flow_with_response_body(
327    flow: &mut Flow,
328    body_bytes: Bytes,
329) {
330    let headers = match &flow.layer {
331        Layer::Http(http) => http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default(),
332        Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
333        _ => Vec::new(),
334    };
335
336    let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
337    
338    let body_data = BodyData {
339        encoding: resp_encoding,
340        content: resp_content,
341        size: body_bytes.len() as u64,
342    };
343
344    match &mut flow.layer {
345        Layer::Http(http) => {
346            if let Some(resp) = &mut http.response {
347                resp.body = Some(body_data);
348            }
349        },
350        Layer::WebSocket(ws) => {
351            ws.handshake_response.body = Some(body_data);
352        },
353        _ => {}
354    }
355}
356
357pub fn update_flow_with_response(
358    flow: &mut Flow,
359    status: StatusCode,
360    version: hyper::Version,
361    headers: &hyper::HeaderMap,
362    body_bytes: Bytes,
363) {
364    update_flow_with_response_headers(flow, status, version, headers);
365    update_flow_with_response_body(flow, body_bytes);
366}
367
368pub fn build_client_response_from_flow(flow: &Flow, default_version: hyper::Version, strict_mode: bool) -> Result<Response<Full<Bytes>>, String> {
369    if let Layer::Http(http) = &flow.layer {
370        if let Some(response) = &http.response {
371            let status = match StatusCode::from_u16(response.status) {
372                Ok(s) => s,
373                Err(_) => {
374                    if strict_mode {
375                        return Err(format!("Invalid status code: {}", response.status));
376                    }
377                    StatusCode::OK
378                }
379            };
380
381            let mut builder = Response::builder()
382                .status(status)
383                .version(default_version); // TODO: Parse version from flow string if needed
384
385            for (k, v) in &response.headers {
386                // Filter out transport-level headers that might conflict with the new body
387                if k.eq_ignore_ascii_case("content-length") 
388                    || k.eq_ignore_ascii_case("transfer-encoding") 
389                    || k.eq_ignore_ascii_case("connection") {
390                    continue;
391                }
392
393                if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
394                    builder = builder.header(name, val);
395                } else if strict_mode {
396                    return Err(format!("Invalid header: {}: {}", k, v));
397                }
398            }
399
400            let body_bytes = if let Some(b) = &response.body {
401                 if b.encoding == "base64" {
402                    match BASE64.decode(b.content.as_bytes()) {
403                        Ok(bytes) => Bytes::from(bytes),
404                        Err(_e) => {
405                            // Fallback
406                            Bytes::from(b.content.clone()) 
407                        }
408                    }
409                 } else {
410                    Bytes::from(b.content.clone())
411                 }
412            } else {
413                Bytes::new()
414            };
415
416            builder.body(Full::new(body_bytes))
417                .map_err(|e| format!("Failed to build response: {}", e))
418        } else {
419             Err("No response in flow".to_string())
420        }
421    } else {
422        Err("Not HTTP layer".to_string())
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::{build_client_response_from_flow, parse_request_meta};
429    use chrono::Utc;
430    use http_body_util::BodyExt;
431    use hyper::{Request, StatusCode, Version};
432    use relay_core_api::flow::{
433        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
434        TransportProtocol,
435    };
436    use std::collections::HashMap;
437    use url::Url;
438    use uuid::Uuid;
439
440    fn sample_flow_with_response(status: u16) -> Flow {
441        Flow {
442            id: Uuid::new_v4(),
443            start_time: Utc::now(),
444            end_time: None,
445            network: NetworkInfo {
446                client_ip: "127.0.0.1".to_string(),
447                client_port: 12345,
448                server_ip: "1.1.1.1".to_string(),
449                server_port: 80,
450                protocol: TransportProtocol::TCP,
451                tls: false,
452                tls_version: None,
453                sni: None,
454            },
455            layer: Layer::Http(HttpLayer {
456                request: HttpRequest {
457                    method: "GET".to_string(),
458                    url: Url::parse("http://example.com/a").expect("url"),
459                    version: "HTTP/1.1".to_string(),
460                    headers: vec![],
461                    cookies: vec![],
462                    query: vec![],
463                    body: None,
464                },
465                response: Some(HttpResponse {
466                    status,
467                    status_text: "X".to_string(),
468                    version: "HTTP/2.0".to_string(),
469                    headers: vec![
470                        ("X-Test".to_string(), "1".to_string()),
471                        ("content-length".to_string(), "999".to_string()),
472                        ("connection".to_string(), "keep-alive".to_string()),
473                    ],
474                    cookies: vec![],
475                    body: None,
476                    timing: ResponseTiming {
477                        time_to_first_byte: None,
478                        time_to_last_byte: None,
479            connect_time_ms: None,
480            ssl_time_ms: None,
481                    },
482                }),
483                error: None,
484            }),
485            tags: vec![],
486            meta: HashMap::new(),
487        }
488    }
489
490    #[test]
491    fn test_parse_request_meta_relative_uri_uses_host_http() {
492        let req = Request::builder()
493            .uri("/api/v1?q=1")
494            .header("Host", "example.com:8080")
495            .body(())
496            .expect("request");
497        let meta = parse_request_meta(&req, false);
498        assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
499        assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
500    }
501
502    #[test]
503    fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
504        let req = Request::builder()
505            .uri("/secure")
506            .header("Host", "secure.example.com")
507            .body(())
508            .expect("request");
509        let meta = parse_request_meta(&req, true);
510        assert_eq!(meta.url_str, "https://secure.example.com/secure");
511    }
512
513    #[test]
514    fn test_build_client_response_from_flow_uses_default_version_currently() {
515        let flow = sample_flow_with_response(201);
516        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
517            .expect("response should build");
518        assert_eq!(resp.version(), Version::HTTP_11);
519        assert_eq!(resp.status(), StatusCode::CREATED);
520        assert_eq!(resp.headers().get("x-test").and_then(|v| v.to_str().ok()), Some("1"));
521        assert!(
522            resp.headers().get("content-length").is_none(),
523            "content-length should be stripped from forwarded mock response"
524        );
525        assert!(resp.headers().get("connection").is_none());
526    }
527
528    #[test]
529    fn test_build_client_response_from_flow_invalid_status_strict_fails() {
530        let flow = sample_flow_with_response(1000);
531        let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
532            .expect_err("strict mode should reject invalid status");
533        assert!(err.contains("Invalid status code"));
534    }
535
536    #[tokio::test]
537    async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
538        let mut flow = sample_flow_with_response(1000);
539        if let Layer::Http(http) = &mut flow.layer {
540            if let Some(res) = &mut http.response {
541                res.body = Some(relay_core_api::flow::BodyData {
542                    encoding: "utf-8".to_string(),
543                    content: "hello".to_string(),
544                    size: 5,
545                });
546            }
547        }
548
549        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
550            .expect("non-strict should fallback");
551        assert_eq!(resp.status(), StatusCode::OK);
552        let body = resp
553            .into_body()
554            .collect()
555            .await
556            .expect("collect body")
557            .to_bytes();
558        assert_eq!(body.as_ref(), b"hello");
559    }
560}