axum_extra/extract/
host.rs

1use super::rejection::{FailedToResolveHost, HostRejection};
2use axum::{
3    extract::{FromRequestParts, OptionalFromRequestParts},
4    RequestPartsExt,
5};
6use http::{
7    header::{HeaderMap, FORWARDED},
8    request::Parts,
9    uri::Authority,
10};
11use std::convert::Infallible;
12
13const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
14
15/// Extractor that resolves the host of the request.
16///
17/// Host is resolved through the following, in order:
18/// - `Forwarded` header
19/// - `X-Forwarded-Host` header
20/// - `Host` header
21/// - Authority of the request URI
22///
23/// See <https://www.rfc-editor.org/rfc/rfc9110.html#name-host-and-authority> for the definition of
24/// host.
25///
26/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
27/// sure to validate them to avoid security issues.
28#[derive(Debug, Clone)]
29pub struct Host(pub String);
30
31impl<S> FromRequestParts<S> for Host
32where
33    S: Send + Sync,
34{
35    type Rejection = HostRejection;
36
37    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
38        parts
39            .extract::<Option<Host>>()
40            .await
41            .ok()
42            .flatten()
43            .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
44    }
45}
46
47impl<S> OptionalFromRequestParts<S> for Host
48where
49    S: Send + Sync,
50{
51    type Rejection = Infallible;
52
53    async fn from_request_parts(
54        parts: &mut Parts,
55        _state: &S,
56    ) -> Result<Option<Self>, Self::Rejection> {
57        if let Some(host) = parse_forwarded(&parts.headers) {
58            return Ok(Some(Host(host.to_owned())));
59        }
60
61        if let Some(host) = parts
62            .headers
63            .get(X_FORWARDED_HOST_HEADER_KEY)
64            .and_then(|host| host.to_str().ok())
65        {
66            return Ok(Some(Host(host.to_owned())));
67        }
68
69        if let Some(host) = parts
70            .headers
71            .get(http::header::HOST)
72            .and_then(|host| host.to_str().ok())
73        {
74            return Ok(Some(Host(host.to_owned())));
75        }
76
77        if let Some(authority) = parts.uri.authority() {
78            return Ok(Some(Host(parse_authority(authority).to_owned())));
79        }
80
81        Ok(None)
82    }
83}
84
85#[allow(warnings)]
86fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
87    // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
88    let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
89
90    // get the first set of values
91    let first_value = forwarded_values.split(',').nth(0)?;
92
93    // find the value of the `host` field
94    first_value.split(';').find_map(|pair| {
95        let (key, value) = pair.split_once('=')?;
96        key.trim()
97            .eq_ignore_ascii_case("host")
98            .then(|| value.trim().trim_matches('"'))
99    })
100}
101
102fn parse_authority(auth: &Authority) -> &str {
103    auth.as_str()
104        .rsplit('@')
105        .next()
106        .expect("split always has at least 1 item")
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::test_helpers::TestClient;
113    use axum::{routing::get, Router};
114    use http::{header::HeaderName, Request};
115
116    fn test_client() -> TestClient {
117        async fn host_as_body(Host(host): Host) -> String {
118            host
119        }
120
121        TestClient::new(Router::new().route("/", get(host_as_body)))
122    }
123
124    #[crate::test]
125    async fn host_header() {
126        let original_host = "some-domain:123";
127        let host = test_client()
128            .get("/")
129            .header(http::header::HOST, original_host)
130            .await
131            .text()
132            .await;
133        assert_eq!(host, original_host);
134    }
135
136    #[crate::test]
137    async fn x_forwarded_host_header() {
138        let original_host = "some-domain:456";
139        let host = test_client()
140            .get("/")
141            .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
142            .await
143            .text()
144            .await;
145        assert_eq!(host, original_host);
146    }
147
148    #[crate::test]
149    async fn x_forwarded_host_precedence_over_host_header() {
150        let x_forwarded_host_header = "some-domain:456";
151        let host_header = "some-domain:123";
152        let host = test_client()
153            .get("/")
154            .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
155            .header(http::header::HOST, host_header)
156            .await
157            .text()
158            .await;
159        assert_eq!(host, x_forwarded_host_header);
160    }
161
162    #[crate::test]
163    async fn uri_host() {
164        let client = test_client();
165        let port = client.server_port();
166        let host = client.get("/").await.text().await;
167        assert_eq!(host, format!("127.0.0.1:{port}"));
168    }
169
170    #[crate::test]
171    async fn ip4_uri_host() {
172        let mut parts = Request::new(()).into_parts().0;
173        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
174        let host = parts.extract::<Host>().await.unwrap();
175        assert_eq!(host.0, "127.0.0.1:1234");
176    }
177
178    #[crate::test]
179    async fn ip6_uri_host() {
180        let mut parts = Request::new(()).into_parts().0;
181        parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
182        let host = parts.extract::<Host>().await.unwrap();
183        assert_eq!(host.0, "[::1]:456");
184    }
185
186    #[crate::test]
187    async fn missing_host() {
188        let mut parts = Request::new(()).into_parts().0;
189        let host = parts.extract::<Host>().await.unwrap_err();
190        assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
191    }
192
193    #[crate::test]
194    async fn optional_extractor() {
195        let mut parts = Request::new(()).into_parts().0;
196        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
197        let host = parts.extract::<Option<Host>>().await.unwrap();
198        assert!(host.is_some());
199    }
200
201    #[crate::test]
202    async fn optional_extractor_none() {
203        let mut parts = Request::new(()).into_parts().0;
204        let host = parts.extract::<Option<Host>>().await.unwrap();
205        assert!(host.is_none());
206    }
207
208    #[test]
209    fn forwarded_parsing() {
210        // the basic case
211        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
212        let value = parse_forwarded(&headers).unwrap();
213        assert_eq!(value, "192.0.2.60");
214
215        // is case insensitive
216        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
217        let value = parse_forwarded(&headers).unwrap();
218        assert_eq!(value, "192.0.2.60");
219
220        // ipv6
221        let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
222        let value = parse_forwarded(&headers).unwrap();
223        assert_eq!(value, "[2001:db8:cafe::17]:4711");
224
225        // multiple values in one header
226        let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
227        let value = parse_forwarded(&headers).unwrap();
228        assert_eq!(value, "192.0.2.60");
229
230        // multiple header values
231        let headers = header_map(&[
232            (FORWARDED, "host=192.0.2.60"),
233            (FORWARDED, "host=127.0.0.1"),
234        ]);
235        let value = parse_forwarded(&headers).unwrap();
236        assert_eq!(value, "192.0.2.60");
237    }
238
239    fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
240        let mut headers = HeaderMap::new();
241        for (key, value) in values {
242            headers.append(key, value.parse().unwrap());
243        }
244        headers
245    }
246}