rama_http/service/web/endpoint/extract/
authority.rs

1//! Module in function of the [`Authority`] extractor.
2
3use super::FromRequestContextRefPair;
4use crate::utils::macros::define_http_rejection;
5use rama_core::Context;
6use rama_http_types::dep::http::request::Parts;
7use rama_net::address;
8use rama_net::http::RequestContext;
9use rama_utils::macros::impl_deref;
10
11/// Extractor that resolves the authority of the request.
12#[derive(Debug, Clone)]
13pub struct Authority(pub address::Authority);
14
15impl_deref!(Authority: address::Authority);
16
17define_http_rejection! {
18    #[status = BAD_REQUEST]
19    #[body = "Failed to detect the Http Authority"]
20    /// Rejection type used if the [`Authority`] extractor is unable to
21    /// determine the (http) Authority.
22    pub struct MissingAuthority;
23}
24
25impl<S> FromRequestContextRefPair<S> for Authority
26where
27    S: Clone + Send + Sync + 'static,
28{
29    type Rejection = MissingAuthority;
30
31    async fn from_request_context_ref_pair(
32        ctx: &Context<S>,
33        parts: &Parts,
34    ) -> Result<Self, Self::Rejection> {
35        Ok(Authority(match ctx.get::<RequestContext>() {
36            Some(ctx) => ctx.authority.clone(),
37            None => RequestContext::try_from((ctx, parts))
38                .map_err(|_| MissingAuthority)?
39                .authority
40                .clone(),
41        }))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    use crate::StatusCode;
50    use crate::dep::http_body_util::BodyExt as _;
51    use crate::header::X_FORWARDED_HOST;
52    use crate::layer::forwarded::GetForwardedHeaderService;
53    use crate::service::web::WebService;
54    use crate::{Body, HeaderName, Request};
55    use rama_core::Service;
56
57    async fn test_authority_from_request(
58        uri: &str,
59        authority: &str,
60        headers: Vec<(&HeaderName, &str)>,
61    ) {
62        let svc = GetForwardedHeaderService::x_forwarded_host(
63            WebService::default().get("/", async |Authority(authority): Authority| {
64                authority.to_string()
65            }),
66        );
67
68        let mut builder = Request::builder().method("GET").uri(uri);
69        for (header, value) in headers {
70            builder = builder.header(header, value);
71        }
72        let req = builder.body(Body::empty()).unwrap();
73
74        let res = svc.serve(Context::default(), req).await.unwrap();
75        assert_eq!(res.status(), StatusCode::OK);
76        let body = res.into_body().collect().await.unwrap().to_bytes();
77        assert_eq!(body, authority);
78    }
79
80    #[tokio::test]
81    async fn host_header() {
82        test_authority_from_request(
83            "/",
84            "some-domain:123",
85            vec![(&rama_http_types::header::HOST, "some-domain:123")],
86        )
87        .await;
88    }
89
90    #[tokio::test]
91    async fn x_forwarded_host_header() {
92        test_authority_from_request(
93            "/",
94            "some-domain:456",
95            vec![(&X_FORWARDED_HOST, "some-domain:456")],
96        )
97        .await;
98    }
99
100    #[tokio::test]
101    async fn x_forwarded_host_precedence_over_host_header() {
102        test_authority_from_request(
103            "/",
104            "some-domain:456",
105            vec![
106                (&X_FORWARDED_HOST, "some-domain:456"),
107                (&rama_http_types::header::HOST, "some-domain:123"),
108            ],
109        )
110        .await;
111    }
112
113    #[tokio::test]
114    async fn uri_host() {
115        test_authority_from_request("http://example.com", "example.com:80", vec![]).await;
116    }
117}