rama_http/service/web/endpoint/extract/
host.rs

1//! Module in function of the [`Host`] extractor.
2
3use super::FromRequestContextRefPair;
4use crate::utils::macros::define_http_rejection;
5use rama_core::Context;
6use rama_http_types::dep::http::request::Parts;
7use rama_net::address;
8use rama_net::http::RequestContext;
9use rama_utils::macros::impl_deref;
10
11/// Extractor that resolves the hostname of the request.
12#[derive(Debug, Clone)]
13pub struct Host(pub address::Host);
14
15impl_deref!(Host: address::Host);
16
17define_http_rejection! {
18    #[status = BAD_REQUEST]
19    #[body = "Failed to detect the Http host"]
20    /// Rejection type used if the [`Host`] extractor is unable to
21    /// determine the (http) Host.
22    pub struct MissingHost;
23}
24
25impl<S> FromRequestContextRefPair<S> for Host
26where
27    S: Clone + Send + Sync + 'static,
28{
29    type Rejection = MissingHost;
30
31    async fn from_request_context_ref_pair(
32        ctx: &Context<S>,
33        parts: &Parts,
34    ) -> Result<Self, Self::Rejection> {
35        Ok(Host(match ctx.get::<RequestContext>() {
36            Some(ctx) => ctx.authority.host().clone(),
37            None => RequestContext::try_from((ctx, parts))
38                .map_err(|_| MissingHost)?
39                .authority
40                .host()
41                .clone(),
42        }))
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    use crate::StatusCode;
51    use crate::dep::http_body_util::BodyExt as _;
52    use crate::header::X_FORWARDED_HOST;
53    use crate::layer::forwarded::GetForwardedHeaderService;
54    use crate::service::web::WebService;
55    use crate::{Body, HeaderName, Request};
56    use rama_core::Service;
57
58    async fn test_host_from_request(uri: &str, host: &str, headers: Vec<(&HeaderName, &str)>) {
59        let svc = GetForwardedHeaderService::x_forwarded_host(
60            WebService::default().get("/", async |Host(host): Host| host.to_string()),
61        );
62
63        let mut builder = Request::builder().method("GET").uri(uri);
64        for (header, value) in headers {
65            builder = builder.header(header, value);
66        }
67        let req = builder.body(Body::empty()).unwrap();
68
69        let res = svc.serve(Context::default(), req).await.unwrap();
70        assert_eq!(res.status(), StatusCode::OK);
71        let body = res.into_body().collect().await.unwrap().to_bytes();
72        assert_eq!(body, host);
73    }
74
75    #[tokio::test]
76    async fn host_header() {
77        test_host_from_request(
78            "/",
79            "some-domain",
80            vec![(&rama_http_types::header::HOST, "some-domain:123")],
81        )
82        .await;
83    }
84
85    #[tokio::test]
86    async fn x_forwarded_host_header() {
87        test_host_from_request(
88            "/",
89            "some-domain",
90            vec![(&X_FORWARDED_HOST, "some-domain:456")],
91        )
92        .await;
93    }
94
95    #[tokio::test]
96    async fn x_forwarded_host_precedence_over_host_header() {
97        test_host_from_request(
98            "/",
99            "some-domain",
100            vec![
101                (&X_FORWARDED_HOST, "some-domain:456"),
102                (&rama_http_types::header::HOST, "some-domain:123"),
103            ],
104        )
105        .await;
106    }
107
108    #[tokio::test]
109    async fn uri_host() {
110        test_host_from_request("http://example.com", "example.com", vec![]).await;
111    }
112}