axum_extra/extract/
host.rs1#![allow(deprecated)]
2
3use super::rejection::{FailedToResolveHost, HostRejection};
4use axum_core::{
5 extract::{FromRequestParts, OptionalFromRequestParts},
6 RequestPartsExt,
7};
8use http::{
9 header::{HeaderMap, FORWARDED},
10 request::Parts,
11 uri::Authority,
12};
13use std::convert::Infallible;
14
15const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
16
17#[deprecated = "will be removed in the next version; see https://github.com/tokio-rs/axum/issues/3442"]
31#[derive(Debug, Clone)]
32pub struct Host(pub String);
33
34impl<S> FromRequestParts<S> for Host
35where
36 S: Send + Sync,
37{
38 type Rejection = HostRejection;
39
40 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
41 parts
42 .extract::<Option<Host>>()
43 .await
44 .ok()
45 .flatten()
46 .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
47 }
48}
49
50impl<S> OptionalFromRequestParts<S> for Host
51where
52 S: Send + Sync,
53{
54 type Rejection = Infallible;
55
56 async fn from_request_parts(
57 parts: &mut Parts,
58 _state: &S,
59 ) -> Result<Option<Self>, Self::Rejection> {
60 if let Some(host) = parse_forwarded(&parts.headers) {
61 return Ok(Some(Host(host.to_owned())));
62 }
63
64 if let Some(host) = parts
65 .headers
66 .get(X_FORWARDED_HOST_HEADER_KEY)
67 .and_then(|host| host.to_str().ok())
68 {
69 return Ok(Some(Host(host.to_owned())));
70 }
71
72 if let Some(host) = parts
73 .headers
74 .get(http::header::HOST)
75 .and_then(|host| host.to_str().ok())
76 {
77 return Ok(Some(Host(host.to_owned())));
78 }
79
80 if let Some(authority) = parts.uri.authority() {
81 return Ok(Some(Host(parse_authority(authority).to_owned())));
82 }
83
84 Ok(None)
85 }
86}
87
88#[allow(warnings)]
89fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
90 let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
92
93 let first_value = forwarded_values.split(',').nth(0)?;
95
96 first_value.split(';').find_map(|pair| {
98 let (key, value) = pair.split_once('=')?;
99 key.trim()
100 .eq_ignore_ascii_case("host")
101 .then(|| value.trim().trim_matches('"'))
102 })
103}
104
105fn parse_authority(auth: &Authority) -> &str {
106 auth.as_str()
107 .rsplit('@')
108 .next()
109 .expect("split always has at least 1 item")
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use crate::test_helpers::TestClient;
116 use axum::{routing::get, Router};
117 use http::{header::HeaderName, Request};
118
119 fn test_client() -> TestClient {
120 async fn host_as_body(Host(host): Host) -> String {
121 host
122 }
123
124 TestClient::new(Router::new().route("/", get(host_as_body)))
125 }
126
127 #[crate::test]
128 async fn host_header() {
129 let original_host = "some-domain:123";
130 let host = test_client()
131 .get("/")
132 .header(http::header::HOST, original_host)
133 .await
134 .text()
135 .await;
136 assert_eq!(host, original_host);
137 }
138
139 #[crate::test]
140 async fn x_forwarded_host_header() {
141 let original_host = "some-domain:456";
142 let host = test_client()
143 .get("/")
144 .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
145 .await
146 .text()
147 .await;
148 assert_eq!(host, original_host);
149 }
150
151 #[crate::test]
152 async fn x_forwarded_host_precedence_over_host_header() {
153 let x_forwarded_host_header = "some-domain:456";
154 let host_header = "some-domain:123";
155 let host = test_client()
156 .get("/")
157 .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
158 .header(http::header::HOST, host_header)
159 .await
160 .text()
161 .await;
162 assert_eq!(host, x_forwarded_host_header);
163 }
164
165 #[crate::test]
166 async fn uri_host() {
167 let client = test_client();
168 let port = client.server_port();
169 let host = client.get("/").await.text().await;
170 assert_eq!(host, format!("127.0.0.1:{port}"));
171 }
172
173 #[crate::test]
174 async fn ip4_uri_host() {
175 let mut parts = Request::new(()).into_parts().0;
176 parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
177 let host = parts.extract::<Host>().await.unwrap();
178 assert_eq!(host.0, "127.0.0.1:1234");
179 }
180
181 #[crate::test]
182 async fn ip6_uri_host() {
183 let mut parts = Request::new(()).into_parts().0;
184 parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
185 let host = parts.extract::<Host>().await.unwrap();
186 assert_eq!(host.0, "[::1]:456");
187 }
188
189 #[crate::test]
190 async fn missing_host() {
191 let mut parts = Request::new(()).into_parts().0;
192 let host = parts.extract::<Host>().await.unwrap_err();
193 assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
194 }
195
196 #[crate::test]
197 async fn optional_extractor() {
198 let mut parts = Request::new(()).into_parts().0;
199 parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
200 let host = parts.extract::<Option<Host>>().await.unwrap();
201 assert!(host.is_some());
202 }
203
204 #[crate::test]
205 async fn optional_extractor_none() {
206 let mut parts = Request::new(()).into_parts().0;
207 let host = parts.extract::<Option<Host>>().await.unwrap();
208 assert!(host.is_none());
209 }
210
211 #[test]
212 fn forwarded_parsing() {
213 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
215 let value = parse_forwarded(&headers).unwrap();
216 assert_eq!(value, "192.0.2.60");
217
218 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
220 let value = parse_forwarded(&headers).unwrap();
221 assert_eq!(value, "192.0.2.60");
222
223 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
225 let value = parse_forwarded(&headers).unwrap();
226 assert_eq!(value, "[2001:db8:cafe::17]:4711");
227
228 let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
230 let value = parse_forwarded(&headers).unwrap();
231 assert_eq!(value, "192.0.2.60");
232
233 let headers = header_map(&[
235 (FORWARDED, "host=192.0.2.60"),
236 (FORWARDED, "host=127.0.0.1"),
237 ]);
238 let value = parse_forwarded(&headers).unwrap();
239 assert_eq!(value, "192.0.2.60");
240 }
241
242 fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
243 let mut headers = HeaderMap::new();
244 for (key, value) in values {
245 headers.append(key, value.parse().unwrap());
246 }
247 headers
248 }
249}