rama_net/http/
request_context.rs

1use crate::forwarded::Forwarded;
2use crate::transport::{TransportContext, TransportProtocol, TryRefIntoTransportContext};
3use crate::{
4    Protocol,
5    address::{Authority, Host},
6};
7use rama_core::Context;
8use rama_core::error::OpaqueError;
9use rama_http_types::Method;
10use rama_http_types::{Request, Uri, Version, dep::http::request::Parts};
11
12#[cfg(feature = "tls")]
13use crate::tls::SecureTransport;
14
15#[cfg(feature = "tls")]
16fn try_get_host_from_secure_transport(t: &SecureTransport) -> Option<Host> {
17    use crate::tls::client::ClientHelloExtension;
18
19    t.client_hello().and_then(|h| {
20        h.extensions().iter().find_map(|e| match e {
21            ClientHelloExtension::ServerName(maybe_host) => maybe_host.clone(),
22            _ => None,
23        })
24    })
25}
26
27#[cfg(not(feature = "tls"))]
28#[derive(Debug, Clone)]
29#[non_exhaustive]
30struct SecureTransport;
31
32#[cfg(not(feature = "tls"))]
33fn try_get_host_from_secure_transport(_: &SecureTransport) -> Option<Host> {
34    None
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38/// The context of the [`Request`].
39pub struct RequestContext {
40    /// The HTTP Version.
41    pub http_version: Version,
42    /// The [`Protocol`] of the [`Request`].
43    pub protocol: Protocol,
44    /// The authority of the [`Request`].
45    ///
46    /// In http/1.1 this is typically defined by the `Host` header,
47    /// whereas for h2 and h3 this is found in the pseudo `:authority` header.
48    ///
49    /// This can be also manually set in case there is support for
50    /// forward headers (e.g. `Forwarded`, or `X-Forwarded-Host`)
51    /// or forward protocols (e.g. `HaProxy`).
52    pub authority: Authority,
53}
54
55impl RequestContext {
56    /// Check if [`Authority`] is using the default port for the [`Protocol`] set in this [`RequestContext`]
57    pub fn authority_has_default_port(&self) -> bool {
58        self.protocol.default_port() == Some(self.authority.port())
59    }
60}
61
62impl<Body, State> TryFrom<(&Context<State>, &Request<Body>)> for RequestContext {
63    type Error = OpaqueError;
64
65    fn try_from((ctx, req): (&Context<State>, &Request<Body>)) -> Result<Self, Self::Error> {
66        let uri = req.uri();
67
68        let protocol = protocol_from_uri_or_context(ctx, uri, req.method());
69        tracing::trace!(
70            uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
71            uri.scheme()
72        );
73
74        let default_port = uri
75            .port_u16()
76            .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
77        tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
78
79        let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
80            Some(h) => {
81                tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
82                (h, default_port).into()
83            },
84            None => uri
85                .host()
86                .and_then(|h| Host::try_from(h).ok().map(|h| {
87                    tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
88                    (h, default_port).into()
89                }))
90                .or_else(|| {
91                    ctx.get::<Forwarded>().and_then(|f| {
92                        f.client_host().map(|fauth| {
93                            let (host, port) = fauth.clone().into_parts();
94                            let port = port.unwrap_or(default_port);
95                            tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
96                            (host, port).into()
97                        })
98                    })
99                })
100                .or_else(|| {
101                    req.headers()
102                        .get(rama_http_types::header::HOST)
103                        .and_then(|host| {
104                            host.try_into() // try to consume as Authority, otherwise as Host
105                                .or_else(|_| Host::try_from(host).map(|h| {
106                                    tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
107                                    (h, default_port).into()
108                                }))
109                                .ok()
110                        })
111                })
112                .ok_or_else(|| {
113                    OpaqueError::from_display("RequestContext: no authourity found in http::Request")
114                })?
115        };
116
117        tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
118
119        let http_version = ctx
120            .get::<Forwarded>()
121            .and_then(|f| {
122                f.client_version().map(|v| match v {
123                    crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
124                    crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
125                    crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
126                    crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
127                    crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
128                })
129            })
130            .unwrap_or_else(|| req.version());
131        tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
132
133        Ok(RequestContext {
134            http_version,
135            protocol,
136            authority,
137        })
138    }
139}
140
141impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
142    type Error = OpaqueError;
143
144    fn try_from((ctx, parts): (&Context<State>, &Parts)) -> Result<Self, Self::Error> {
145        let uri = &parts.uri;
146
147        let protocol = protocol_from_uri_or_context(ctx, uri, &parts.method);
148        tracing::trace!(
149            uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
150            uri.scheme()
151        );
152
153        let default_port = uri
154            .port_u16()
155            .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
156        tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
157
158        let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
159            Some(h) => {
160                tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
161                (h, default_port).into()
162            }
163            None => {
164                uri
165                    .host()
166                    .and_then(|h| Host::try_from(h).ok().map(|h| {
167                        tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
168                        (h, default_port).into()
169                    }))
170                    .or_else(|| {
171                        ctx.get::<Forwarded>().and_then(|f| {
172                            f.client_host().map(|fauth| {
173                                let (host, port) = fauth.clone().into_parts();
174                                let port = port.unwrap_or(default_port);
175                                tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
176                                (host, port).into()
177                            })
178                        })
179                    })
180                    .or_else(|| {
181                        parts
182                            .headers
183                            .get(rama_http_types::header::HOST)
184                            .and_then(|host| {
185                                host.try_into() // try to consume as Authority, otherwise as Host
186                                    .or_else(|_| Host::try_from(host).map(|h| {
187                                        tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
188                                        (h, default_port).into()
189                                    }))
190                                    .ok()
191                            })
192                    })
193                    .ok_or_else(|| {
194                        OpaqueError::from_display(
195                            "RequestContext: no authourity found in http::request::Parts",
196                        )
197                    })?
198            }
199        };
200
201        tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
202
203        let http_version = ctx
204            .get::<Forwarded>()
205            .and_then(|f| {
206                f.client_version().map(|v| match v {
207                    crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
208                    crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
209                    crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
210                    crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
211                    crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
212                })
213            })
214            .unwrap_or(parts.version);
215        tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
216
217        Ok(RequestContext {
218            http_version,
219            protocol,
220            authority,
221        })
222    }
223}
224
225#[allow(clippy::unnecessary_lazy_evaluations)]
226fn protocol_from_uri_or_context<State>(
227    ctx: &Context<State>,
228    uri: &Uri,
229    method: &Method,
230) -> Protocol {
231    Protocol::maybe_from_uri_scheme_str_and_method(uri.scheme(), Some(method)).or_else(|| ctx.get::<Forwarded>()
232        .and_then(|f| f.client_proto().map(|p| {
233            tracing::trace!(uri = %uri, "request context: detected protocol from forwarded client proto");
234            p.into()
235        })))
236        .unwrap_or_else(|| {
237            if method == Method::CONNECT {
238                tracing::trace!(uri = %uri, method = %method, "request context: CONNECT: defaulting protocol to HTTPS");
239                Protocol::HTTPS
240            } else {
241                tracing::trace!(uri = %uri, method = %method, "request context: defaulting protocol to HTTP");
242                Protocol::HTTP
243            }
244        })
245}
246
247impl From<RequestContext> for TransportContext {
248    fn from(value: RequestContext) -> Self {
249        Self {
250            protocol: if value.http_version == Version::HTTP_3 {
251                TransportProtocol::Udp
252            } else {
253                TransportProtocol::Tcp
254            },
255            app_protocol: Some(value.protocol),
256            http_version: Some(value.http_version),
257            authority: value.authority,
258        }
259    }
260}
261
262impl From<&RequestContext> for TransportContext {
263    fn from(value: &RequestContext) -> Self {
264        Self {
265            protocol: if value.http_version == Version::HTTP_3 {
266                TransportProtocol::Udp
267            } else {
268                TransportProtocol::Tcp
269            },
270            app_protocol: Some(value.protocol.clone()),
271            http_version: Some(value.http_version),
272            authority: value.authority.clone(),
273        }
274    }
275}
276
277impl<State, Body> TryRefIntoTransportContext<State> for rama_http_types::Request<Body> {
278    type Error = OpaqueError;
279
280    fn try_ref_into_transport_ctx(
281        &self,
282        ctx: &Context<State>,
283    ) -> Result<TransportContext, Self::Error> {
284        (ctx, self).try_into()
285    }
286}
287
288impl<State> TryRefIntoTransportContext<State> for rama_http_types::dep::http::request::Parts {
289    type Error = OpaqueError;
290
291    fn try_ref_into_transport_ctx(
292        &self,
293        ctx: &Context<State>,
294    ) -> Result<TransportContext, Self::Error> {
295        (ctx, self).try_into()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::forwarded::{Forwarded, ForwardedElement, NodeId};
303    use rama_http_types::header::FORWARDED;
304
305    #[test]
306    fn test_request_context_from_request() {
307        let req = Request::builder()
308            .uri("http://example.com:8080")
309            .version(Version::HTTP_11)
310            .body(())
311            .unwrap();
312
313        let ctx = Context::default();
314
315        let req_ctx = RequestContext::try_from((&ctx, &req)).unwrap();
316
317        assert_eq!(req_ctx.http_version, Version::HTTP_11);
318        assert_eq!(req_ctx.protocol, Protocol::HTTP);
319        assert_eq!(req_ctx.authority.to_string(), "example.com:8080");
320    }
321
322    #[test]
323    fn test_request_context_from_parts() {
324        let req = Request::builder()
325            .uri("http://example.com:8080")
326            .version(Version::HTTP_11)
327            .body(())
328            .unwrap();
329
330        let (parts, _) = req.into_parts();
331
332        let ctx = Context::default();
333        let req_ctx = RequestContext::try_from((&ctx, &parts)).unwrap();
334
335        assert_eq!(req_ctx.http_version, Version::HTTP_11);
336        assert_eq!(req_ctx.protocol, Protocol::HTTP);
337        assert_eq!(
338            req_ctx.authority,
339            Authority::try_from("example.com:8080").unwrap()
340        );
341    }
342
343    #[test]
344    fn test_request_context_authority() {
345        let ctx = RequestContext {
346            http_version: Version::HTTP_11,
347            protocol: Protocol::HTTP,
348            authority: "example.com:8080".try_into().unwrap(),
349        };
350
351        assert_eq!(ctx.authority.to_string(), "example.com:8080");
352    }
353
354    #[test]
355    fn forwarded_parsing() {
356        for (forwarded_str_vec, expected) in [
357            // base
358            (
359                vec!["host=192.0.2.60;proto=http;by=203.0.113.43"],
360                RequestContext {
361                    http_version: Version::HTTP_11,
362                    protocol: Protocol::HTTP,
363                    authority: "192.0.2.60:80".parse().unwrap(),
364                },
365            ),
366            // ipv6
367            (
368                vec!["host=\"[2001:db8:cafe::17]:4711\""],
369                RequestContext {
370                    http_version: Version::HTTP_11,
371                    protocol: Protocol::HTTP,
372                    authority: "[2001:db8:cafe::17]:4711".parse().unwrap(),
373                },
374            ),
375            // multiple values in one header
376            (
377                vec!["host=192.0.2.60, host=127.0.0.1"],
378                RequestContext {
379                    http_version: Version::HTTP_11,
380                    protocol: Protocol::HTTP,
381                    authority: "192.0.2.60:80".parse().unwrap(),
382                },
383            ),
384            // multiple header values
385            (
386                vec!["host=192.0.2.60", "host=127.0.0.1"],
387                RequestContext {
388                    http_version: Version::HTTP_11,
389                    protocol: Protocol::HTTP,
390                    authority: "192.0.2.60:80".parse().unwrap(),
391                },
392            ),
393        ] {
394            let mut req_builder = Request::builder();
395            for header in forwarded_str_vec.clone() {
396                req_builder = req_builder.header(FORWARDED, header);
397            }
398
399            let req = req_builder.body(()).unwrap();
400            let mut ctx = Context::default();
401
402            let forwarded: Forwarded = req.headers().get(FORWARDED).unwrap().try_into().unwrap();
403            ctx.insert(forwarded);
404
405            let req_ctx = ctx
406                .get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
407                .unwrap()
408                .clone();
409
410            assert_eq!(req_ctx, expected, "Failed for {:?}", forwarded_str_vec);
411        }
412    }
413
414    #[test]
415    fn test_request_ctx_https_request_behind_haproxy_plain() {
416        let req = Request::builder()
417            .uri("/en/reservation/roomdetails")
418            .version(Version::HTTP_11)
419            .header("host", "echo.ramaproxy.org")
420            .header("user-agent", "curl/8.6.0")
421            .header("accept", "*/*")
422            .body(())
423            .unwrap();
424
425        let mut ctx = Context::default();
426        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
427            NodeId::try_from("127.0.0.1:61234").unwrap(),
428        )));
429
430        let req_ctx: &mut RequestContext = ctx
431            .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
432            .unwrap();
433
434        assert_eq!(req_ctx.http_version, Version::HTTP_11);
435        assert_eq!(req_ctx.protocol, "http");
436        assert_eq!(req_ctx.authority.to_string(), "echo.ramaproxy.org:80");
437    }
438
439    #[test]
440    fn test_request_ctx_connect_req_no_scheme() {
441        let test_cases = [
442            (80, Protocol::HTTPS),
443            (433, Protocol::HTTPS),
444            (8080, Protocol::HTTPS),
445        ];
446        for (port, expected_protocol) in test_cases {
447            let req = Request::builder()
448                .uri(format!("www.example.com:{port}"))
449                .version(Version::HTTP_11)
450                .method(Method::CONNECT)
451                .header("host", "www.example.com")
452                .header("user-agent", "test/42")
453                .body(())
454                .unwrap();
455
456            let mut ctx = Context::default();
457            let req_ctx: &mut RequestContext = ctx
458                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
459                .unwrap();
460
461            assert_eq!(req_ctx.http_version, Version::HTTP_11);
462            assert_eq!(req_ctx.protocol, expected_protocol);
463            assert_eq!(
464                req_ctx.authority.to_string(),
465                format!("www.example.com:{}", port)
466            );
467        }
468    }
469
470    #[test]
471    fn test_request_ctx_connect_req() {
472        let test_cases = [
473            ("http", Protocol::HTTPS),
474            ("https", Protocol::HTTPS),
475            ("ws", Protocol::WSS),
476            ("wss", Protocol::WSS),
477            ("ftp", Protocol::from_static("ftp")),
478        ];
479        for (scheme, expected_protocol) in test_cases {
480            let req = Request::builder()
481                .uri(format!("{scheme}://www.example.com"))
482                .version(Version::HTTP_11)
483                .method(Method::CONNECT)
484                .header("host", "www.example.com")
485                .header("user-agent", "test/42")
486                .body(())
487                .unwrap();
488
489            let mut ctx = Context::default();
490            let req_ctx: &mut RequestContext = ctx
491                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
492                .unwrap();
493
494            assert_eq!(req_ctx.http_version, Version::HTTP_11);
495            assert_eq!(req_ctx.protocol, expected_protocol);
496            assert_eq!(
497                req_ctx.authority.to_string(),
498                format!(
499                    "www.example.com:{}",
500                    expected_protocol.default_port().unwrap_or(80)
501                )
502            );
503        }
504    }
505}