plane_dynamic_proxy/
request.rs

1use crate::body::{to_simple_body, SimpleBody};
2use bytes::Bytes;
3use http::{
4    request::Parts,
5    uri::{Authority, Scheme},
6    HeaderMap, HeaderName, HeaderValue, Request, Uri,
7};
8use http_body::Body;
9use std::{net::SocketAddr, str::FromStr};
10
11/// Represents an HTTP request (from hyper) with helpers for mutating it.
12pub struct MutableRequest<T>
13where
14    T: Body<Data = Bytes> + Send + 'static,
15    T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
16{
17    pub parts: Parts,
18    pub body: T,
19}
20
21impl<T> MutableRequest<T>
22where
23    T: Body<Data = Bytes> + Send + 'static,
24    T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
25{
26    pub fn from_request(request: Request<T>) -> Self {
27        let (parts, body) = request.into_parts();
28        Self { parts, body }
29    }
30
31    pub fn into_request(self) -> Request<T> {
32        Request::from_parts(self.parts, self.body)
33    }
34
35    pub fn into_request_with_simple_body(self) -> Request<SimpleBody> {
36        Request::from_parts(self.parts, to_simple_body(self.body))
37    }
38
39    /// Rewrite the request so that it points to the given upstream address.
40    pub fn set_upstream_address(&mut self, address: SocketAddr) {
41        let uri = std::mem::take(&mut self.parts.uri);
42        let mut uri_parts = uri.into_parts();
43        uri_parts.scheme = Some(Scheme::HTTP);
44        uri_parts.authority = Some(
45            Authority::try_from(address.to_string())
46                .expect("SocketAddr should always be a valid authority."),
47        );
48        self.parts.uri = Uri::from_parts(uri_parts).expect("URI should always be valid.");
49    }
50
51    /// Add a header to the request.
52    ///
53    /// If the header is invalid, it will be ignored and logged.
54    pub fn add_header(&mut self, key: &str, value: &str) {
55        let Ok(key) = HeaderName::from_str(key) else {
56            tracing::error!("Attempted to set invalid header name: {}", key);
57            return;
58        };
59        let Ok(value) = HeaderValue::from_str(value) else {
60            // Not logging the value, which could be sensitive.
61            tracing::error!("Attempted to set invalid header value with key: {}", key);
62            return;
63        };
64        self.parts.headers.append(key, value);
65    }
66
67    pub fn headers_mut(&mut self) -> &mut HeaderMap {
68        &mut self.parts.headers
69    }
70}
71
72pub fn should_upgrade<T>(request: &Request<T>) -> bool {
73    let Some(conn_header) = request.headers().get("connection") else {
74        return false;
75    };
76
77    let Ok(conn_header) = conn_header.to_str() else {
78        return false;
79    };
80
81    conn_header
82        .to_lowercase()
83        .split(',')
84        .any(|s| s.trim() == "upgrade")
85}