rama_http/service/web/endpoint/extract/
host.rs1use super::FromRequestContextRefPair;
4use crate::utils::macros::define_http_rejection;
5use rama_core::Context;
6use rama_http_types::dep::http::request::Parts;
7use rama_net::address;
8use rama_net::http::RequestContext;
9use rama_utils::macros::impl_deref;
10
11#[derive(Debug, Clone)]
13pub struct Host(pub address::Host);
14
15impl_deref!(Host: address::Host);
16
17define_http_rejection! {
18 #[status = BAD_REQUEST]
19 #[body = "Failed to detect the Http host"]
20 pub struct MissingHost;
23}
24
25impl<S> FromRequestContextRefPair<S> for Host
26where
27 S: Clone + Send + Sync + 'static,
28{
29 type Rejection = MissingHost;
30
31 async fn from_request_context_ref_pair(
32 ctx: &Context<S>,
33 parts: &Parts,
34 ) -> Result<Self, Self::Rejection> {
35 Ok(Host(match ctx.get::<RequestContext>() {
36 Some(ctx) => ctx.authority.host().clone(),
37 None => RequestContext::try_from((ctx, parts))
38 .map_err(|_| MissingHost)?
39 .authority
40 .host()
41 .clone(),
42 }))
43 }
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49
50 use crate::StatusCode;
51 use crate::dep::http_body_util::BodyExt as _;
52 use crate::header::X_FORWARDED_HOST;
53 use crate::layer::forwarded::GetForwardedHeaderService;
54 use crate::service::web::WebService;
55 use crate::{Body, HeaderName, Request};
56 use rama_core::Service;
57
58 async fn test_host_from_request(uri: &str, host: &str, headers: Vec<(&HeaderName, &str)>) {
59 let svc = GetForwardedHeaderService::x_forwarded_host(
60 WebService::default().get("/", async |Host(host): Host| host.to_string()),
61 );
62
63 let mut builder = Request::builder().method("GET").uri(uri);
64 for (header, value) in headers {
65 builder = builder.header(header, value);
66 }
67 let req = builder.body(Body::empty()).unwrap();
68
69 let res = svc.serve(Context::default(), req).await.unwrap();
70 assert_eq!(res.status(), StatusCode::OK);
71 let body = res.into_body().collect().await.unwrap().to_bytes();
72 assert_eq!(body, host);
73 }
74
75 #[tokio::test]
76 async fn host_header() {
77 test_host_from_request(
78 "/",
79 "some-domain",
80 vec![(&rama_http_types::header::HOST, "some-domain:123")],
81 )
82 .await;
83 }
84
85 #[tokio::test]
86 async fn x_forwarded_host_header() {
87 test_host_from_request(
88 "/",
89 "some-domain",
90 vec![(&X_FORWARDED_HOST, "some-domain:456")],
91 )
92 .await;
93 }
94
95 #[tokio::test]
96 async fn x_forwarded_host_precedence_over_host_header() {
97 test_host_from_request(
98 "/",
99 "some-domain",
100 vec![
101 (&X_FORWARDED_HOST, "some-domain:456"),
102 (&rama_http_types::header::HOST, "some-domain:123"),
103 ],
104 )
105 .await;
106 }
107
108 #[tokio::test]
109 async fn uri_host() {
110 test_host_from_request("http://example.com", "example.com", vec![]).await;
111 }
112}