axum_extra/extract/
host.rs1use super::rejection::{FailedToResolveHost, HostRejection};
2use axum::{
3 extract::{FromRequestParts, OptionalFromRequestParts},
4 RequestPartsExt,
5};
6use http::{
7 header::{HeaderMap, FORWARDED},
8 request::Parts,
9 uri::Authority,
10};
11use std::convert::Infallible;
12
13const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
14
15#[derive(Debug, Clone)]
29pub struct Host(pub String);
30
31impl<S> FromRequestParts<S> for Host
32where
33 S: Send + Sync,
34{
35 type Rejection = HostRejection;
36
37 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
38 parts
39 .extract::<Option<Host>>()
40 .await
41 .ok()
42 .flatten()
43 .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
44 }
45}
46
47impl<S> OptionalFromRequestParts<S> for Host
48where
49 S: Send + Sync,
50{
51 type Rejection = Infallible;
52
53 async fn from_request_parts(
54 parts: &mut Parts,
55 _state: &S,
56 ) -> Result<Option<Self>, Self::Rejection> {
57 if let Some(host) = parse_forwarded(&parts.headers) {
58 return Ok(Some(Host(host.to_owned())));
59 }
60
61 if let Some(host) = parts
62 .headers
63 .get(X_FORWARDED_HOST_HEADER_KEY)
64 .and_then(|host| host.to_str().ok())
65 {
66 return Ok(Some(Host(host.to_owned())));
67 }
68
69 if let Some(host) = parts
70 .headers
71 .get(http::header::HOST)
72 .and_then(|host| host.to_str().ok())
73 {
74 return Ok(Some(Host(host.to_owned())));
75 }
76
77 if let Some(authority) = parts.uri.authority() {
78 return Ok(Some(Host(parse_authority(authority).to_owned())));
79 }
80
81 Ok(None)
82 }
83}
84
85#[allow(warnings)]
86fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
87 let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
89
90 let first_value = forwarded_values.split(',').nth(0)?;
92
93 first_value.split(';').find_map(|pair| {
95 let (key, value) = pair.split_once('=')?;
96 key.trim()
97 .eq_ignore_ascii_case("host")
98 .then(|| value.trim().trim_matches('"'))
99 })
100}
101
102fn parse_authority(auth: &Authority) -> &str {
103 auth.as_str()
104 .rsplit('@')
105 .next()
106 .expect("split always has at least 1 item")
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::test_helpers::TestClient;
113 use axum::{routing::get, Router};
114 use http::{header::HeaderName, Request};
115
116 fn test_client() -> TestClient {
117 async fn host_as_body(Host(host): Host) -> String {
118 host
119 }
120
121 TestClient::new(Router::new().route("/", get(host_as_body)))
122 }
123
124 #[crate::test]
125 async fn host_header() {
126 let original_host = "some-domain:123";
127 let host = test_client()
128 .get("/")
129 .header(http::header::HOST, original_host)
130 .await
131 .text()
132 .await;
133 assert_eq!(host, original_host);
134 }
135
136 #[crate::test]
137 async fn x_forwarded_host_header() {
138 let original_host = "some-domain:456";
139 let host = test_client()
140 .get("/")
141 .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
142 .await
143 .text()
144 .await;
145 assert_eq!(host, original_host);
146 }
147
148 #[crate::test]
149 async fn x_forwarded_host_precedence_over_host_header() {
150 let x_forwarded_host_header = "some-domain:456";
151 let host_header = "some-domain:123";
152 let host = test_client()
153 .get("/")
154 .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
155 .header(http::header::HOST, host_header)
156 .await
157 .text()
158 .await;
159 assert_eq!(host, x_forwarded_host_header);
160 }
161
162 #[crate::test]
163 async fn uri_host() {
164 let client = test_client();
165 let port = client.server_port();
166 let host = client.get("/").await.text().await;
167 assert_eq!(host, format!("127.0.0.1:{port}"));
168 }
169
170 #[crate::test]
171 async fn ip4_uri_host() {
172 let mut parts = Request::new(()).into_parts().0;
173 parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
174 let host = parts.extract::<Host>().await.unwrap();
175 assert_eq!(host.0, "127.0.0.1:1234");
176 }
177
178 #[crate::test]
179 async fn ip6_uri_host() {
180 let mut parts = Request::new(()).into_parts().0;
181 parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
182 let host = parts.extract::<Host>().await.unwrap();
183 assert_eq!(host.0, "[::1]:456");
184 }
185
186 #[crate::test]
187 async fn missing_host() {
188 let mut parts = Request::new(()).into_parts().0;
189 let host = parts.extract::<Host>().await.unwrap_err();
190 assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
191 }
192
193 #[crate::test]
194 async fn optional_extractor() {
195 let mut parts = Request::new(()).into_parts().0;
196 parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
197 let host = parts.extract::<Option<Host>>().await.unwrap();
198 assert!(host.is_some());
199 }
200
201 #[crate::test]
202 async fn optional_extractor_none() {
203 let mut parts = Request::new(()).into_parts().0;
204 let host = parts.extract::<Option<Host>>().await.unwrap();
205 assert!(host.is_none());
206 }
207
208 #[test]
209 fn forwarded_parsing() {
210 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
212 let value = parse_forwarded(&headers).unwrap();
213 assert_eq!(value, "192.0.2.60");
214
215 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
217 let value = parse_forwarded(&headers).unwrap();
218 assert_eq!(value, "192.0.2.60");
219
220 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
222 let value = parse_forwarded(&headers).unwrap();
223 assert_eq!(value, "[2001:db8:cafe::17]:4711");
224
225 let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
227 let value = parse_forwarded(&headers).unwrap();
228 assert_eq!(value, "192.0.2.60");
229
230 let headers = header_map(&[
232 (FORWARDED, "host=192.0.2.60"),
233 (FORWARDED, "host=127.0.0.1"),
234 ]);
235 let value = parse_forwarded(&headers).unwrap();
236 assert_eq!(value, "192.0.2.60");
237 }
238
239 fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
240 let mut headers = HeaderMap::new();
241 for (key, value) in values {
242 headers.append(key, value.parse().unwrap());
243 }
244 headers
245 }
246}