rama_net/http/
request_context.rs

1use crate::forwarded::Forwarded;
2use crate::transport::{TransportContext, TransportProtocol, TryRefIntoTransportContext};
3use crate::{
4    Protocol,
5    address::{Authority, Host},
6};
7use rama_core::Context;
8use rama_core::error::OpaqueError;
9use rama_http_types::Method;
10use rama_http_types::{Request, Uri, Version, dep::http::request::Parts};
11
12#[cfg(feature = "tls")]
13use crate::tls::SecureTransport;
14
15#[cfg(feature = "tls")]
16fn try_get_host_from_secure_transport(t: &SecureTransport) -> Option<Host> {
17    use crate::tls::client::ClientHelloExtension;
18
19    t.client_hello().and_then(|h| {
20        h.extensions().iter().find_map(|e| match e {
21            ClientHelloExtension::ServerName(maybe_host) => maybe_host.clone(),
22            _ => None,
23        })
24    })
25}
26
27#[cfg(not(feature = "tls"))]
28#[derive(Debug, Clone)]
29#[non_exhaustive]
30struct SecureTransport;
31
32#[cfg(not(feature = "tls"))]
33fn try_get_host_from_secure_transport(_: &SecureTransport) -> Option<Host> {
34    None
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38/// The context of the [`Request`].
39pub struct RequestContext {
40    /// The HTTP Version.
41    pub http_version: Version,
42    /// The [`Protocol`] of the [`Request`].
43    pub protocol: Protocol,
44    /// The authority of the [`Request`].
45    ///
46    /// In http/1.1 this is typically defined by the `Host` header,
47    /// whereas for h2 and h3 this is found in the pseudo `:authority` header.
48    ///
49    /// This can be also manually set in case there is support for
50    /// forward headers (e.g. `Forwarded`, or `X-Forwarded-Host`)
51    /// or forward protocols (e.g. `HaProxy`).
52    pub authority: Authority,
53}
54
55impl RequestContext {
56    /// Check if [`Authority`] is using the default port for the [`Protocol`] set in this [`RequestContext`]
57    pub fn authority_has_default_port(&self) -> bool {
58        self.protocol.default_port() == Some(self.authority.port())
59    }
60}
61
62impl<Body, State> TryFrom<(&Context<State>, &Request<Body>)> for RequestContext {
63    type Error = OpaqueError;
64
65    fn try_from((ctx, req): (&Context<State>, &Request<Body>)) -> Result<Self, Self::Error> {
66        let uri = req.uri();
67
68        let protocol = protocol_from_uri_or_context(ctx, uri, req.method());
69        tracing::trace!(
70            uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
71            uri.scheme()
72        );
73
74        let default_port = uri
75            .port_u16()
76            .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
77        tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
78
79        let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
80            Some(h) => {
81                tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
82                (h, default_port).into()
83            },
84            None => uri
85                .host()
86                .and_then(|h| Host::try_from(h).ok().map(|h| {
87                    tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
88                    (h, default_port).into()
89                }))
90                .or_else(|| {
91                    ctx.get::<Forwarded>().and_then(|f| {
92                        f.client_host().map(|fauth| {
93                            let (host, port) = fauth.clone().into_parts();
94                            let port = port.unwrap_or(default_port);
95                            tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
96                            (host, port).into()
97                        })
98                    })
99                })
100                .or_else(|| {
101                    req.headers()
102                        .get(rama_http_types::header::HOST)
103                        .and_then(|host| {
104                            host.try_into() // try to consume as Authority, otherwise as Host
105                                .or_else(|_| Host::try_from(host).map(|h| {
106                                    tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
107                                    (h, default_port).into()
108                                }))
109                                .ok()
110                        })
111                })
112                .ok_or_else(|| {
113                    OpaqueError::from_display("RequestContext: no authourity found in http::Request")
114                })?
115        };
116
117        tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
118
119        let http_version = ctx
120            .get::<Forwarded>()
121            .and_then(|f| {
122                f.client_version().map(|v| match v {
123                    crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
124                    crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
125                    crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
126                    crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
127                    crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
128                })
129            })
130            .unwrap_or_else(|| req.version());
131        tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
132
133        Ok(RequestContext {
134            http_version,
135            protocol,
136            authority,
137        })
138    }
139}
140
141impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
142    type Error = OpaqueError;
143
144    fn try_from((ctx, parts): (&Context<State>, &Parts)) -> Result<Self, Self::Error> {
145        let uri = &parts.uri;
146
147        let protocol = protocol_from_uri_or_context(ctx, uri, &parts.method);
148        tracing::trace!(
149            uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
150            uri.scheme()
151        );
152
153        let default_port = uri
154            .port_u16()
155            .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
156        tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
157
158        let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
159            Some(h) => {
160                tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
161                (h, default_port).into()
162            }
163            None => {
164                uri
165                    .host()
166                    .and_then(|h| Host::try_from(h).ok().map(|h| {
167                        tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
168                        (h, default_port).into()
169                    }))
170                    .or_else(|| {
171                        ctx.get::<Forwarded>().and_then(|f| {
172                            f.client_host().map(|fauth| {
173                                let (host, port) = fauth.clone().into_parts();
174                                let port = port.unwrap_or(default_port);
175                                tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
176                                (host, port).into()
177                            })
178                        })
179                    })
180                    .or_else(|| {
181                        parts
182                            .headers
183                            .get(rama_http_types::header::HOST)
184                            .and_then(|host| {
185                                host.try_into() // try to consume as Authority, otherwise as Host
186                                    .or_else(|_| Host::try_from(host).map(|h| {
187                                        tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
188                                        (h, default_port).into()
189                                    }))
190                                    .ok()
191                            })
192                    })
193                    .ok_or_else(|| {
194                        OpaqueError::from_display(
195                            "RequestContext: no authourity found in http::request::Parts",
196                        )
197                    })?
198            }
199        };
200
201        tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
202
203        let http_version = ctx
204            .get::<Forwarded>()
205            .and_then(|f| {
206                f.client_version().map(|v| match v {
207                    crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
208                    crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
209                    crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
210                    crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
211                    crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
212                })
213            })
214            .unwrap_or(parts.version);
215        tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
216
217        Ok(RequestContext {
218            http_version,
219            protocol,
220            authority,
221        })
222    }
223}
224
225#[allow(clippy::unnecessary_lazy_evaluations)]
226fn protocol_from_uri_or_context<State>(
227    ctx: &Context<State>,
228    uri: &Uri,
229    method: &Method,
230) -> Protocol {
231    Protocol::maybe_from_uri_scheme_str_and_method(uri.scheme(), Some(method)).or_else(|| ctx.get::<Forwarded>()
232        .and_then(|f| f.client_proto().map(|p| {
233            tracing::trace!(uri = %uri, "request context: detected protocol from forwarded client proto");
234            p.into()
235        })))
236        .unwrap_or_else(|| {
237            if method == Method::CONNECT {
238                tracing::trace!(uri = %uri, method = %method, "request context: CONNECT: defaulting protocol to HTTPS");
239                Protocol::HTTPS
240            } else {
241                tracing::trace!(uri = %uri, method = %method, "request context: defaulting protocol to HTTP");
242                Protocol::HTTP
243            }
244        })
245}
246
247impl From<RequestContext> for TransportContext {
248    fn from(value: RequestContext) -> Self {
249        Self {
250            protocol: if value.http_version == Version::HTTP_3 {
251                TransportProtocol::Udp
252            } else {
253                TransportProtocol::Tcp
254            },
255            app_protocol: Some(value.protocol),
256            http_version: Some(value.http_version),
257            authority: value.authority,
258        }
259    }
260}
261
262impl From<&RequestContext> for TransportContext {
263    fn from(value: &RequestContext) -> Self {
264        Self {
265            protocol: if value.http_version == Version::HTTP_3 {
266                TransportProtocol::Udp
267            } else {
268                TransportProtocol::Tcp
269            },
270            app_protocol: Some(value.protocol.clone()),
271            http_version: Some(value.http_version),
272            authority: value.authority.clone(),
273        }
274    }
275}
276
277impl<State, Body> TryRefIntoTransportContext<State> for rama_http_types::Request<Body> {
278    type Error = OpaqueError;
279
280    fn try_ref_into_transport_ctx(
281        &self,
282        ctx: &Context<State>,
283    ) -> Result<TransportContext, Self::Error> {
284        (ctx, self).try_into()
285    }
286}
287
288impl<State> TryRefIntoTransportContext<State> for rama_http_types::dep::http::request::Parts {
289    type Error = OpaqueError;
290
291    fn try_ref_into_transport_ctx(
292        &self,
293        ctx: &Context<State>,
294    ) -> Result<TransportContext, Self::Error> {
295        (ctx, self).try_into()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::forwarded::{Forwarded, ForwardedElement, NodeId};
303    use rama_http_types::header::FORWARDED;
304    use rama_http_types::headers::HeaderMapExt;
305
306    #[test]
307    fn test_request_context_from_request() {
308        let req = Request::builder()
309            .uri("http://example.com:8080")
310            .version(Version::HTTP_11)
311            .body(())
312            .unwrap();
313
314        let ctx = Context::default();
315
316        let req_ctx = RequestContext::try_from((&ctx, &req)).unwrap();
317
318        assert_eq!(req_ctx.http_version, Version::HTTP_11);
319        assert_eq!(req_ctx.protocol, Protocol::HTTP);
320        assert_eq!(req_ctx.authority.to_string(), "example.com:8080");
321    }
322
323    #[test]
324    fn test_request_context_from_parts() {
325        let req = Request::builder()
326            .uri("http://example.com:8080")
327            .version(Version::HTTP_11)
328            .body(())
329            .unwrap();
330
331        let (parts, _) = req.into_parts();
332
333        let ctx = Context::default();
334        let req_ctx = RequestContext::try_from((&ctx, &parts)).unwrap();
335
336        assert_eq!(req_ctx.http_version, Version::HTTP_11);
337        assert_eq!(req_ctx.protocol, Protocol::HTTP);
338        assert_eq!(
339            req_ctx.authority,
340            Authority::try_from("example.com:8080").unwrap()
341        );
342    }
343
344    #[test]
345    fn test_request_context_authority() {
346        let ctx = RequestContext {
347            http_version: Version::HTTP_11,
348            protocol: Protocol::HTTP,
349            authority: "example.com:8080".try_into().unwrap(),
350        };
351
352        assert_eq!(ctx.authority.to_string(), "example.com:8080");
353    }
354
355    #[test]
356    fn forwarded_parsing() {
357        for (forwarded_str_vec, expected) in [
358            // base
359            (
360                vec!["host=192.0.2.60;proto=http;by=203.0.113.43"],
361                RequestContext {
362                    http_version: Version::HTTP_11,
363                    protocol: Protocol::HTTP,
364                    authority: "192.0.2.60:80".parse().unwrap(),
365                },
366            ),
367            // ipv6
368            (
369                vec!["host=\"[2001:db8:cafe::17]:4711\""],
370                RequestContext {
371                    http_version: Version::HTTP_11,
372                    protocol: Protocol::HTTP,
373                    authority: "[2001:db8:cafe::17]:4711".parse().unwrap(),
374                },
375            ),
376            // multiple values in one header
377            (
378                vec!["host=192.0.2.60, host=127.0.0.1"],
379                RequestContext {
380                    http_version: Version::HTTP_11,
381                    protocol: Protocol::HTTP,
382                    authority: "192.0.2.60:80".parse().unwrap(),
383                },
384            ),
385            // multiple header values
386            (
387                vec!["host=192.0.2.60", "host=127.0.0.1"],
388                RequestContext {
389                    http_version: Version::HTTP_11,
390                    protocol: Protocol::HTTP,
391                    authority: "192.0.2.60:80".parse().unwrap(),
392                },
393            ),
394        ] {
395            let mut req_builder = Request::builder();
396            for header in forwarded_str_vec.clone() {
397                req_builder = req_builder.header(FORWARDED, header);
398            }
399
400            let req = req_builder.body(()).unwrap();
401            let mut ctx = Context::default();
402
403            let forwarded = req.headers().typed_get::<Forwarded>().unwrap();
404            ctx.insert(forwarded);
405
406            let req_ctx = ctx
407                .get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
408                .unwrap()
409                .clone();
410
411            assert_eq!(req_ctx, expected, "Failed for {:?}", forwarded_str_vec);
412        }
413    }
414
415    #[test]
416    fn test_request_ctx_https_request_behind_haproxy_plain() {
417        let req = Request::builder()
418            .uri("/en/reservation/roomdetails")
419            .version(Version::HTTP_11)
420            .header("host", "echo.ramaproxy.org")
421            .header("user-agent", "curl/8.6.0")
422            .header("accept", "*/*")
423            .body(())
424            .unwrap();
425
426        let mut ctx = Context::default();
427        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
428            NodeId::try_from("127.0.0.1:61234").unwrap(),
429        )));
430
431        let req_ctx: &mut RequestContext = ctx
432            .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
433            .unwrap();
434
435        assert_eq!(req_ctx.http_version, Version::HTTP_11);
436        assert_eq!(req_ctx.protocol, "http");
437        assert_eq!(req_ctx.authority.to_string(), "echo.ramaproxy.org:80");
438    }
439
440    #[test]
441    fn test_request_ctx_connect_req_no_scheme() {
442        let test_cases = [
443            (80, Protocol::HTTPS),
444            (433, Protocol::HTTPS),
445            (8080, Protocol::HTTPS),
446        ];
447        for (port, expected_protocol) in test_cases {
448            let req = Request::builder()
449                .uri(format!("www.example.com:{port}"))
450                .version(Version::HTTP_11)
451                .method(Method::CONNECT)
452                .header("host", "www.example.com")
453                .header("user-agent", "test/42")
454                .body(())
455                .unwrap();
456
457            let mut ctx = Context::default();
458            let req_ctx: &mut RequestContext = ctx
459                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
460                .unwrap();
461
462            assert_eq!(req_ctx.http_version, Version::HTTP_11);
463            assert_eq!(req_ctx.protocol, expected_protocol);
464            assert_eq!(
465                req_ctx.authority.to_string(),
466                format!("www.example.com:{}", port)
467            );
468        }
469    }
470
471    #[test]
472    fn test_request_ctx_connect_req() {
473        let test_cases = [
474            ("http", Protocol::HTTPS),
475            ("https", Protocol::HTTPS),
476            ("ws", Protocol::WSS),
477            ("wss", Protocol::WSS),
478            ("ftp", Protocol::from_static("ftp")),
479        ];
480        for (scheme, expected_protocol) in test_cases {
481            let req = Request::builder()
482                .uri(format!("{scheme}://www.example.com"))
483                .version(Version::HTTP_11)
484                .method(Method::CONNECT)
485                .header("host", "www.example.com")
486                .header("user-agent", "test/42")
487                .body(())
488                .unwrap();
489
490            let mut ctx = Context::default();
491            let req_ctx: &mut RequestContext = ctx
492                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
493                .unwrap();
494
495            assert_eq!(req_ctx.http_version, Version::HTTP_11);
496            assert_eq!(req_ctx.protocol, expected_protocol);
497            assert_eq!(
498                req_ctx.authority.to_string(),
499                format!(
500                    "www.example.com:{}",
501                    expected_protocol.default_port().unwrap_or(80)
502                )
503            );
504        }
505    }
506}