axum_extra/extract/
host.rs

1#![allow(deprecated)]
2
3use super::rejection::{FailedToResolveHost, HostRejection};
4use axum_core::{
5    extract::{FromRequestParts, OptionalFromRequestParts},
6    RequestPartsExt,
7};
8use http::{
9    header::{HeaderMap, FORWARDED},
10    request::Parts,
11    uri::Authority,
12};
13use std::convert::Infallible;
14
15const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
16
17/// Extractor that resolves the host of the request.
18///
19/// Host is resolved through the following, in order:
20/// - `Forwarded` header
21/// - `X-Forwarded-Host` header
22/// - `Host` header
23/// - Authority of the request URI
24///
25/// See <https://www.rfc-editor.org/rfc/rfc9110.html#name-host-and-authority> for the definition of
26/// host.
27///
28/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
29/// sure to validate them to avoid security issues.
30#[deprecated = "will be removed in the next version; see https://github.com/tokio-rs/axum/issues/3442"]
31#[derive(Debug, Clone)]
32pub struct Host(pub String);
33
34impl<S> FromRequestParts<S> for Host
35where
36    S: Send + Sync,
37{
38    type Rejection = HostRejection;
39
40    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
41        parts
42            .extract::<Option<Host>>()
43            .await
44            .ok()
45            .flatten()
46            .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
47    }
48}
49
50impl<S> OptionalFromRequestParts<S> for Host
51where
52    S: Send + Sync,
53{
54    type Rejection = Infallible;
55
56    async fn from_request_parts(
57        parts: &mut Parts,
58        _state: &S,
59    ) -> Result<Option<Self>, Self::Rejection> {
60        if let Some(host) = parse_forwarded(&parts.headers) {
61            return Ok(Some(Host(host.to_owned())));
62        }
63
64        if let Some(host) = parts
65            .headers
66            .get(X_FORWARDED_HOST_HEADER_KEY)
67            .and_then(|host| host.to_str().ok())
68        {
69            return Ok(Some(Host(host.to_owned())));
70        }
71
72        if let Some(host) = parts
73            .headers
74            .get(http::header::HOST)
75            .and_then(|host| host.to_str().ok())
76        {
77            return Ok(Some(Host(host.to_owned())));
78        }
79
80        if let Some(authority) = parts.uri.authority() {
81            return Ok(Some(Host(parse_authority(authority).to_owned())));
82        }
83
84        Ok(None)
85    }
86}
87
88#[allow(warnings)]
89fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
90    // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
91    let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
92
93    // get the first set of values
94    let first_value = forwarded_values.split(',').nth(0)?;
95
96    // find the value of the `host` field
97    first_value.split(';').find_map(|pair| {
98        let (key, value) = pair.split_once('=')?;
99        key.trim()
100            .eq_ignore_ascii_case("host")
101            .then(|| value.trim().trim_matches('"'))
102    })
103}
104
105fn parse_authority(auth: &Authority) -> &str {
106    auth.as_str()
107        .rsplit('@')
108        .next()
109        .expect("split always has at least 1 item")
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::test_helpers::TestClient;
116    use axum::{routing::get, Router};
117    use http::{header::HeaderName, Request};
118
119    fn test_client() -> TestClient {
120        async fn host_as_body(Host(host): Host) -> String {
121            host
122        }
123
124        TestClient::new(Router::new().route("/", get(host_as_body)))
125    }
126
127    #[crate::test]
128    async fn host_header() {
129        let original_host = "some-domain:123";
130        let host = test_client()
131            .get("/")
132            .header(http::header::HOST, original_host)
133            .await
134            .text()
135            .await;
136        assert_eq!(host, original_host);
137    }
138
139    #[crate::test]
140    async fn x_forwarded_host_header() {
141        let original_host = "some-domain:456";
142        let host = test_client()
143            .get("/")
144            .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
145            .await
146            .text()
147            .await;
148        assert_eq!(host, original_host);
149    }
150
151    #[crate::test]
152    async fn x_forwarded_host_precedence_over_host_header() {
153        let x_forwarded_host_header = "some-domain:456";
154        let host_header = "some-domain:123";
155        let host = test_client()
156            .get("/")
157            .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
158            .header(http::header::HOST, host_header)
159            .await
160            .text()
161            .await;
162        assert_eq!(host, x_forwarded_host_header);
163    }
164
165    #[crate::test]
166    async fn uri_host() {
167        let client = test_client();
168        let port = client.server_port();
169        let host = client.get("/").await.text().await;
170        assert_eq!(host, format!("127.0.0.1:{port}"));
171    }
172
173    #[crate::test]
174    async fn ip4_uri_host() {
175        let mut parts = Request::new(()).into_parts().0;
176        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
177        let host = parts.extract::<Host>().await.unwrap();
178        assert_eq!(host.0, "127.0.0.1:1234");
179    }
180
181    #[crate::test]
182    async fn ip6_uri_host() {
183        let mut parts = Request::new(()).into_parts().0;
184        parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
185        let host = parts.extract::<Host>().await.unwrap();
186        assert_eq!(host.0, "[::1]:456");
187    }
188
189    #[crate::test]
190    async fn missing_host() {
191        let mut parts = Request::new(()).into_parts().0;
192        let host = parts.extract::<Host>().await.unwrap_err();
193        assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
194    }
195
196    #[crate::test]
197    async fn optional_extractor() {
198        let mut parts = Request::new(()).into_parts().0;
199        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
200        let host = parts.extract::<Option<Host>>().await.unwrap();
201        assert!(host.is_some());
202    }
203
204    #[crate::test]
205    async fn optional_extractor_none() {
206        let mut parts = Request::new(()).into_parts().0;
207        let host = parts.extract::<Option<Host>>().await.unwrap();
208        assert!(host.is_none());
209    }
210
211    #[test]
212    fn forwarded_parsing() {
213        // the basic case
214        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
215        let value = parse_forwarded(&headers).unwrap();
216        assert_eq!(value, "192.0.2.60");
217
218        // is case insensitive
219        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
220        let value = parse_forwarded(&headers).unwrap();
221        assert_eq!(value, "192.0.2.60");
222
223        // ipv6
224        let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
225        let value = parse_forwarded(&headers).unwrap();
226        assert_eq!(value, "[2001:db8:cafe::17]:4711");
227
228        // multiple values in one header
229        let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
230        let value = parse_forwarded(&headers).unwrap();
231        assert_eq!(value, "192.0.2.60");
232
233        // multiple header values
234        let headers = header_map(&[
235            (FORWARDED, "host=192.0.2.60"),
236            (FORWARDED, "host=127.0.0.1"),
237        ]);
238        let value = parse_forwarded(&headers).unwrap();
239        assert_eq!(value, "192.0.2.60");
240    }
241
242    fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
243        let mut headers = HeaderMap::new();
244        for (key, value) in values {
245            headers.append(key, value.parse().unwrap());
246        }
247        headers
248    }
249}