Skip to main content

relay_core_lib/proxy/
http_utils.rs

1use std::net::SocketAddr;
2use hyper::{Response, Request, StatusCode};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use hyper_util::client::legacy::Client;
7use hyper_util::client::legacy::connect::HttpConnector;
8use hyper_rustls::HttpsConnector;
9use relay_core_api::flow::{
10    BodyData, Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, TransportProtocol,
11    WebSocketLayer, Cookie,
12};
13use relay_core_api::policy::ProxyPolicy;
14use crate::capture::loop_detection::LoopDetector;
15use uuid::Uuid;
16use chrono::Utc;
17use url::Url;
18use cookie::Cookie as CookieCrate;
19use data_encoding::BASE64;
20use crate::proxy::body_codec::process_body;
21use crate::interceptor::HttpBody;
22
23pub type HttpsClient = Client<HttpsConnector<HttpConnector>, HttpBody>;
24
25#[derive(Clone, Debug)]
26pub struct RequestMeta {
27    pub method: String,
28    pub url_str: String,
29    pub version: String,
30    pub headers: Vec<(String, String)>,
31    pub query: Vec<(String, String)>,
32    pub cookies: Vec<Cookie>,
33}
34
35pub fn parse_request_meta<B>(req: &Request<B>, is_mitm: bool) -> RequestMeta {
36    let method = req.method().to_string();
37    let mut url_str = req.uri().to_string();
38    
39    // Attempt to construct absolute URL if relative
40    if Url::parse(&url_str).is_err() {
41        if let Some(host) = req.headers().get("Host").and_then(|v| v.to_str().ok()) {
42             let scheme = if is_mitm { "https" } else { "http" };
43             let new_url = format!("{}://{}{}", scheme, host, url_str);
44             if Url::parse(&new_url).is_ok() {
45                 url_str = new_url;
46             }
47        }
48    }
49
50    let version = format!("{:?}", req.version());
51    
52    let headers: Vec<(String, String)> = req.headers()
53        .iter()
54        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
55        .collect();
56
57    let query: Vec<(String, String)> = if let Ok(parsed_url) = Url::parse(&url_str) {
58        parsed_url.query_pairs().into_owned().collect()
59    } else {
60        vec![]
61    };
62
63    let mut cookies = Vec::new();
64    if let Some(cookie_header) = req.headers().get(hyper::header::COOKIE) {
65        if let Ok(cookie_str) = cookie_header.to_str() {
66             for c in CookieCrate::split_parse(cookie_str).flatten() {
67                 cookies.push(Cookie {
68                     name: c.name().to_string(),
69                     value: c.value().to_string(),
70                     path: None,
71                     domain: None,
72                     expires: None,
73                     http_only: None,
74                     secure: None,
75                 });
76             }
77        }
78    }
79
80    RequestMeta {
81        method,
82        url_str,
83        version,
84        headers,
85        query,
86        cookies,
87    }
88}
89
90pub fn is_hop_by_hop(name: &str) -> bool {
91    name.eq_ignore_ascii_case("connection")
92        || name.eq_ignore_ascii_case("keep-alive")
93        || name.eq_ignore_ascii_case("proxy-authenticate")
94        || name.eq_ignore_ascii_case("proxy-authorization")
95        || name.eq_ignore_ascii_case("te")
96        || name.eq_ignore_ascii_case("trailers")
97        || name.eq_ignore_ascii_case("transfer-encoding")
98        || name.eq_ignore_ascii_case("upgrade")
99        || name.eq_ignore_ascii_case("content-length")
100}
101
102pub fn create_initial_flow(
103    meta: RequestMeta,
104    req_body: Option<BodyData>,
105    client_addr: SocketAddr,
106    is_mitm: bool,
107    is_websocket: bool,
108) -> Flow {
109    let flow_id = Uuid::new_v4();
110    let start_time = Utc::now();
111    
112    let network_info = NetworkInfo {
113        client_ip: client_addr.ip().to_string(),
114        client_port: client_addr.port(),
115        server_ip: "0.0.0.0".to_string(), // Placeholder
116        server_port: 0,                   // Placeholder
117        protocol: TransportProtocol::TCP,
118        tls: is_mitm,
119        tls_version: None,
120        sni: None,
121    };
122
123    let http_request = HttpRequest {
124        method: meta.method,
125        url: Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown").unwrap()),
126        version: meta.version,
127        headers: meta.headers,
128        cookies: meta.cookies,
129        query: meta.query,
130        body: req_body, 
131    };
132
133    let mut flow = if is_websocket {
134        Flow {
135            id: flow_id,
136            start_time,
137            end_time: None,
138            network: network_info,
139            layer: Layer::WebSocket(WebSocketLayer {
140                handshake_request: http_request,
141                handshake_response: HttpResponse {
142                    status: 0,
143                    status_text: "".to_string(),
144                    version: "".to_string(),
145                    headers: vec![],
146                    cookies: vec![],
147                    body: None,
148                    timing: relay_core_api::flow::ResponseTiming { time_to_first_byte: None, time_to_last_byte: None },
149                },
150                messages: vec![],
151                closed: false,
152            }),
153            tags: vec!["websocket".to_string()],
154            meta: std::collections::HashMap::new(),
155        }
156    } else {
157        Flow {
158            id: flow_id,
159            start_time,
160            end_time: None,
161            network: network_info,
162            layer: Layer::Http(HttpLayer {
163                request: http_request,
164                response: None,
165                error: None,
166            }),
167            tags: vec!["proxy".to_string()],
168            meta: std::collections::HashMap::new(),
169        }
170    };
171
172    if is_mitm {
173        flow.tags.push("mitm".to_string());
174    }
175
176    flow
177}
178
179pub fn create_error_response(status: StatusCode, message: impl Into<Bytes>) -> Response<HttpBody> {
180    Response::builder()
181        .status(status)
182        .body(Full::new(message.into()).map_err(|e| e.into()).boxed())
183        .unwrap_or_else(|_| Response::new(Full::new(Bytes::from("Internal Error")).map_err(|e| e.into()).boxed()))
184}
185
186pub fn mock_to_response(mock: HttpResponse) -> Response<HttpBody> {
187     let mut builder = Response::builder()
188        .status(StatusCode::from_u16(mock.status).unwrap_or(StatusCode::OK));
189        
190    for (k, v) in mock.headers {
191        if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(&v)) {
192            builder = builder.header(name, val);
193        }
194    }
195    
196    let body = if let Some(b) = mock.body {
197        Bytes::from(b.content)
198    } else {
199        Bytes::new()
200    };
201    
202    builder.body(Full::new(body).map_err(|e| e.into()).boxed()).unwrap_or_else(|_| create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to build mock response"))
203}
204
205#[allow(clippy::result_large_err)]
206pub fn build_forward_request(
207    flow: &mut Flow,
208    body: HttpBody,
209    req_parts: &http::request::Parts,
210    target_addr: Option<SocketAddr>,
211    policy: &ProxyPolicy,
212    loop_detector: &LoopDetector,
213) -> Result<Request<HttpBody>, Response<HttpBody>> {
214    let current_req = if let Layer::Http(http) = &flow.layer {
215        &http.request
216    } else {
217        return Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
218    };
219
220    let mut forward_req_builder = Request::builder()
221        .method(current_req.method.as_str())
222        .version(req_parts.version);
223
224    // Determine upstream URI
225    let mut target_url = current_req.url.clone();
226    
227    // Transparent Proxy Routing Logic
228    if policy.transparent_enabled {
229        if let Some(addr) = target_addr {
230            flow.tags.push("transparent".to_string());
231            
232            // Update Flow Network Info
233            flow.network.server_ip = addr.ip().to_string();
234            flow.network.server_port = addr.port();
235
236            // Loop Detection
237            if loop_detector.would_loop(addr) {
238                if let Layer::Http(http) = &mut flow.layer {
239                    http.error = Some("Loop Detected".to_string());
240                }
241                return Err(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
242            }
243
244            // Rewrite URI to use target IP
245            if target_url.set_ip_host(addr.ip()).is_ok() {
246                 target_url.set_port(Some(addr.port())).ok();
247            }
248
249            // Update scheme if MITM
250            if flow.network.tls && target_url.scheme() == "http" {
251                target_url.set_scheme("https").ok();
252            }
253        }
254    }
255
256    forward_req_builder = forward_req_builder.uri(target_url.as_str());
257
258    for (k, v) in &current_req.headers {
259        // Filter out hop-by-hop headers to allow connection pooling
260        if is_hop_by_hop(k) {
261            continue;
262        }
263
264        if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
265            forward_req_builder = forward_req_builder.header(name, val);
266        }
267    }
268
269    match forward_req_builder.body(body) {
270        Ok(req) => Ok(req),
271        Err(e) => Err(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
272    }
273}
274
275pub fn update_flow_with_response_headers(
276    flow: &mut Flow,
277    status: StatusCode,
278    version: hyper::Version,
279    headers: &hyper::HeaderMap,
280) {
281    let mut response_cookies = Vec::new();
282    for (k, v) in headers.iter() {
283        if k == hyper::header::SET_COOKIE {
284            if let Ok(v_str) = v.to_str() {
285                if let Ok(c) = CookieCrate::parse(v_str) {
286                    response_cookies.push(Cookie {
287                        name: c.name().to_string(),
288                        value: c.value().to_string(),
289                        path: c.path().map(|s| s.to_string()),
290                        domain: c.domain().map(|s| s.to_string()),
291                        expires: c.expires().map(|e| format!("{:?}", e)),
292                        http_only: c.http_only(),
293                        secure: c.secure(),
294                    });
295                }
296            }
297        }
298    }
299
300    let resp_headers_vec: Vec<(String, String)> = headers.iter()
301        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
302        .collect();
303
304    let http_response = HttpResponse {
305        status: status.as_u16(),
306        status_text: status.to_string(),
307        version: format!("{:?}", version),
308        headers: resp_headers_vec,
309        cookies: response_cookies,
310        body: None,
311        timing: relay_core_api::flow::ResponseTiming {
312            time_to_first_byte: None,
313            time_to_last_byte: None,
314        },
315    };
316
317    match &mut flow.layer {
318        Layer::Http(http) => {
319            http.response = Some(http_response);
320        },
321        Layer::WebSocket(ws) => {
322            ws.handshake_response = http_response;
323        },
324        _ => {}
325    }
326}
327
328pub fn update_flow_with_response_body(
329    flow: &mut Flow,
330    body_bytes: Bytes,
331) {
332    let headers = match &flow.layer {
333        Layer::Http(http) => http.response.as_ref().map(|r| r.headers.clone()).unwrap_or_default(),
334        Layer::WebSocket(ws) => ws.handshake_response.headers.clone(),
335        _ => Vec::new(),
336    };
337
338    let (resp_encoding, resp_content) = process_body(&body_bytes, &headers);
339    
340    let body_data = BodyData {
341        encoding: resp_encoding,
342        content: resp_content,
343        size: body_bytes.len() as u64,
344    };
345
346    match &mut flow.layer {
347        Layer::Http(http) => {
348            if let Some(resp) = &mut http.response {
349                resp.body = Some(body_data);
350            }
351        },
352        Layer::WebSocket(ws) => {
353            ws.handshake_response.body = Some(body_data);
354        },
355        _ => {}
356    }
357}
358
359pub fn update_flow_with_response(
360    flow: &mut Flow,
361    status: StatusCode,
362    version: hyper::Version,
363    headers: &hyper::HeaderMap,
364    body_bytes: Bytes,
365) {
366    update_flow_with_response_headers(flow, status, version, headers);
367    update_flow_with_response_body(flow, body_bytes);
368}
369
370pub fn build_client_response_from_flow(flow: &Flow, default_version: hyper::Version, strict_mode: bool) -> Result<Response<Full<Bytes>>, String> {
371    if let Layer::Http(http) = &flow.layer {
372        if let Some(response) = &http.response {
373            let status = match StatusCode::from_u16(response.status) {
374                Ok(s) => s,
375                Err(_) => {
376                    if strict_mode {
377                        return Err(format!("Invalid status code: {}", response.status));
378                    }
379                    StatusCode::OK
380                }
381            };
382
383            let mut builder = Response::builder()
384                .status(status)
385                .version(default_version); // TODO: Parse version from flow string if needed
386
387            for (k, v) in &response.headers {
388                // Filter out transport-level headers that might conflict with the new body
389                if k.eq_ignore_ascii_case("content-length") 
390                    || k.eq_ignore_ascii_case("transfer-encoding") 
391                    || k.eq_ignore_ascii_case("connection") {
392                    continue;
393                }
394
395                if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
396                    builder = builder.header(name, val);
397                } else if strict_mode {
398                    return Err(format!("Invalid header: {}: {}", k, v));
399                }
400            }
401
402            let body_bytes = if let Some(b) = &response.body {
403                 if b.encoding == "base64" {
404                    match BASE64.decode(b.content.as_bytes()) {
405                        Ok(bytes) => Bytes::from(bytes),
406                        Err(_e) => {
407                            // Fallback
408                            Bytes::from(b.content.clone()) 
409                        }
410                    }
411                 } else {
412                    Bytes::from(b.content.clone())
413                 }
414            } else {
415                Bytes::new()
416            };
417
418            builder.body(Full::new(body_bytes))
419                .map_err(|e| format!("Failed to build response: {}", e))
420        } else {
421             Err("No response in flow".to_string())
422        }
423    } else {
424        Err("Not HTTP layer".to_string())
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::{build_client_response_from_flow, parse_request_meta};
431    use chrono::Utc;
432    use http_body_util::BodyExt;
433    use hyper::{Request, StatusCode, Version};
434    use relay_core_api::flow::{
435        Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
436        TransportProtocol,
437    };
438    use std::collections::HashMap;
439    use url::Url;
440    use uuid::Uuid;
441
442    fn sample_flow_with_response(status: u16) -> Flow {
443        Flow {
444            id: Uuid::new_v4(),
445            start_time: Utc::now(),
446            end_time: None,
447            network: NetworkInfo {
448                client_ip: "127.0.0.1".to_string(),
449                client_port: 12345,
450                server_ip: "1.1.1.1".to_string(),
451                server_port: 80,
452                protocol: TransportProtocol::TCP,
453                tls: false,
454                tls_version: None,
455                sni: None,
456            },
457            layer: Layer::Http(HttpLayer {
458                request: HttpRequest {
459                    method: "GET".to_string(),
460                    url: Url::parse("http://example.com/a").expect("url"),
461                    version: "HTTP/1.1".to_string(),
462                    headers: vec![],
463                    cookies: vec![],
464                    query: vec![],
465                    body: None,
466                },
467                response: Some(HttpResponse {
468                    status,
469                    status_text: "X".to_string(),
470                    version: "HTTP/2.0".to_string(),
471                    headers: vec![
472                        ("X-Test".to_string(), "1".to_string()),
473                        ("content-length".to_string(), "999".to_string()),
474                        ("connection".to_string(), "keep-alive".to_string()),
475                    ],
476                    cookies: vec![],
477                    body: None,
478                    timing: ResponseTiming {
479                        time_to_first_byte: None,
480                        time_to_last_byte: None,
481                    },
482                }),
483                error: None,
484            }),
485            tags: vec![],
486            meta: HashMap::new(),
487        }
488    }
489
490    #[test]
491    fn test_parse_request_meta_relative_uri_uses_host_http() {
492        let req = Request::builder()
493            .uri("/api/v1?q=1")
494            .header("Host", "example.com:8080")
495            .body(())
496            .expect("request");
497        let meta = parse_request_meta(&req, false);
498        assert_eq!(meta.url_str, "http://example.com:8080/api/v1?q=1");
499        assert_eq!(meta.query, vec![("q".to_string(), "1".to_string())]);
500    }
501
502    #[test]
503    fn test_parse_request_meta_relative_uri_uses_host_https_in_mitm() {
504        let req = Request::builder()
505            .uri("/secure")
506            .header("Host", "secure.example.com")
507            .body(())
508            .expect("request");
509        let meta = parse_request_meta(&req, true);
510        assert_eq!(meta.url_str, "https://secure.example.com/secure");
511    }
512
513    #[test]
514    fn test_build_client_response_from_flow_uses_default_version_currently() {
515        let flow = sample_flow_with_response(201);
516        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, true)
517            .expect("response should build");
518        assert_eq!(resp.version(), Version::HTTP_11);
519        assert_eq!(resp.status(), StatusCode::CREATED);
520        assert_eq!(resp.headers().get("x-test").and_then(|v| v.to_str().ok()), Some("1"));
521        assert!(
522            resp.headers().get("content-length").is_none(),
523            "content-length should be stripped from forwarded mock response"
524        );
525        assert!(resp.headers().get("connection").is_none());
526    }
527
528    #[test]
529    fn test_build_client_response_from_flow_invalid_status_strict_fails() {
530        let flow = sample_flow_with_response(1000);
531        let err = build_client_response_from_flow(&flow, Version::HTTP_11, true)
532            .expect_err("strict mode should reject invalid status");
533        assert!(err.contains("Invalid status code"));
534    }
535
536    #[tokio::test]
537    async fn test_build_client_response_from_flow_invalid_status_non_strict_fallback_ok() {
538        let mut flow = sample_flow_with_response(1000);
539        if let Layer::Http(http) = &mut flow.layer {
540            if let Some(res) = &mut http.response {
541                res.body = Some(relay_core_api::flow::BodyData {
542                    encoding: "utf-8".to_string(),
543                    content: "hello".to_string(),
544                    size: 5,
545                });
546            }
547        }
548
549        let resp = build_client_response_from_flow(&flow, Version::HTTP_11, false)
550            .expect("non-strict should fallback");
551        assert_eq!(resp.status(), StatusCode::OK);
552        let body = resp
553            .into_body()
554            .collect()
555            .await
556            .expect("collect body")
557            .to_bytes();
558        assert_eq!(body.as_ref(), b"hello");
559    }
560}