1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
//! Host-header allowlist middleware — Django parity for the
//! `ALLOWED_HOSTS` setting + the host-validation step
//! `SecurityMiddleware` runs implicitly before every view.
//!
//! Django gate: when `ALLOWED_HOSTS` is set, any request whose
//! `Host:` header isn't in the list is rejected with a 400. The
//! list supports exact-host entries (`example.com`), dot-prefix
//! subdomain wildcards (`.example.com` matches `api.example.com`
//! AND `example.com` itself), and the lone catch-all `*`.
//!
//! ## Quick start
//!
//! ```ignore
//! use rustango::host_validation::{AllowedHostsLayer, AllowedHostsRouterExt};
//!
//! let app = Router::new()
//! .route("/", get(home))
//! .allowed_hosts(AllowedHostsLayer::new([
//! "example.com",
//! ".example.com", // any subdomain
//! "localhost",
//! ]));
//! ```
//!
//! Requests with a missing or non-matching `Host` header receive a
//! `400 Bad Request` body that mirrors Django's `DisallowedHost`
//! message; the exact host echoes back to ease ops debugging without
//! leaking which hosts are allowed.
//!
//! ## Settings wiring
//!
//! `Settings.security.allowed_hosts: Vec<String>` (already parsed by
//! `env::list("ALLOWED_HOSTS")`) feeds the layer via
//! [`AllowedHostsLayer::from_settings_list`]. Empty list disables
//! validation — matches Django's "DEBUG=True allows all" behavior
//! by convention, but rustango doesn't have a DEBUG flag so the
//! operator must opt in explicitly.
use std::sync::Arc;
use axum::body::Body;
use axum::extract::Request;
use axum::http::{Response, StatusCode};
use axum::middleware::Next;
use axum::Router;
/// One allowed-host entry. Owns the comparison logic so the
/// matching loop stays cheap (no per-request allocation).
#[derive(Clone, Debug)]
enum Pattern {
/// Catch-all `*` — every host matches. Use sparingly.
Wildcard,
/// Exact match: `Host` header (lowercased) equals this string.
Exact(String),
/// Dot-prefix wildcard `.example.com` — matches `example.com`
/// itself plus any subdomain (`api.example.com`,
/// `a.b.example.com`). The stored string omits the leading dot.
Subdomain(String),
}
impl Pattern {
fn parse(entry: &str) -> Option<Self> {
let entry = entry.trim();
if entry.is_empty() {
return None;
}
if entry == "*" {
return Some(Self::Wildcard);
}
if let Some(rest) = entry.strip_prefix('.') {
if rest.is_empty() {
return None;
}
return Some(Self::Subdomain(rest.to_ascii_lowercase()));
}
Some(Self::Exact(entry.to_ascii_lowercase()))
}
fn matches(&self, host: &str) -> bool {
match self {
Self::Wildcard => true,
Self::Exact(h) => host == h,
Self::Subdomain(tail) => {
// `tail` does not include the leading dot. Match
// `tail` itself (the base domain) or any host whose
// suffix is `.<tail>` (avoids matching
// `eviltail.com` against `.tail.com`).
host == tail
|| host
.strip_suffix(tail)
.is_some_and(|prefix| prefix.ends_with('.'))
}
}
}
}
/// Tower-layer-equivalent configuration. Holds the parsed pattern
/// list; applied via [`AllowedHostsRouterExt::allowed_hosts`].
#[derive(Clone)]
pub struct AllowedHostsLayer {
patterns: Arc<Vec<Pattern>>,
}
impl AllowedHostsLayer {
/// Build a layer from a list of allowed-host entries. Entries
/// support exact-match hostnames, `.example.com` subdomain
/// wildcards, and the lone catch-all `*`. Empty / whitespace
/// entries are silently dropped.
#[must_use]
pub fn new<I, S>(entries: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let patterns: Vec<Pattern> = entries
.into_iter()
.filter_map(|s| Pattern::parse(s.as_ref()))
.collect();
Self {
patterns: Arc::new(patterns),
}
}
/// Convenience: wire from `Settings.security.allowed_hosts`. An
/// empty list disables the layer (every host passes) — matches
/// the "no ALLOWED_HOSTS configured → no enforcement" shape
/// Django uses with `DEBUG=True`.
#[must_use]
pub fn from_settings_list<I, S>(entries: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self::new(entries)
}
/// `true` when the configured list permits this host header.
/// Empty list passes every host (operator opted out of
/// validation by leaving the setting empty).
#[must_use]
pub fn permits(&self, host: &str) -> bool {
if self.patterns.is_empty() {
return true;
}
let host = strip_port(host).to_ascii_lowercase();
self.patterns.iter().any(|p| p.matches(&host))
}
}
/// Strip a trailing `:<port>` from a Host-header value so the
/// allowlist comparison ignores it. Returns the input untouched if
/// no port is present. Handles bracketed IPv6 literals too.
fn strip_port(host: &str) -> &str {
if let Some(rest) = host.strip_prefix('[') {
// IPv6 literal: `[::1]:8080` → strip from the closing bracket.
if let Some(end) = rest.find(']') {
return &host[..end + 2.min(host.len())];
}
return host;
}
match host.rfind(':') {
Some(i) => &host[..i],
None => host,
}
}
/// Router extension trait — `.allowed_hosts(layer)`.
pub trait AllowedHostsRouterExt {
#[must_use]
fn allowed_hosts(self, layer: AllowedHostsLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> AllowedHostsRouterExt for Router<S> {
fn allowed_hosts(self, layer: AllowedHostsLayer) -> Self {
let cfg = Arc::new(layer);
self.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move { handle(cfg, req, next).await }
},
))
}
}
async fn handle(cfg: Arc<AllowedHostsLayer>, req: Request<Body>, next: Next) -> Response<Body> {
let host = req
.headers()
.get(axum::http::header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
if cfg.permits(host) {
next.run(req).await
} else {
let msg = format!(
"DisallowedHost: rejected Host header {host:?} — \
add it to Settings.security.allowed_hosts to allow"
);
let mut resp = Response::new(Body::from(msg));
*resp.status_mut() = StatusCode::BAD_REQUEST;
resp
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_list_passes_every_host() {
let layer = AllowedHostsLayer::new(Vec::<String>::new());
assert!(layer.permits("anywhere.example.com"));
assert!(layer.permits(""));
}
#[test]
fn exact_match_is_case_insensitive() {
let layer = AllowedHostsLayer::new(["Example.COM"]);
assert!(layer.permits("example.com"));
assert!(layer.permits("EXAMPLE.com"));
assert!(!layer.permits("api.example.com"));
}
#[test]
fn dot_prefix_wildcard_matches_subdomains_plus_base() {
let layer = AllowedHostsLayer::new([".example.com"]);
assert!(layer.permits("example.com"));
assert!(layer.permits("api.example.com"));
assert!(layer.permits("a.b.example.com"));
// Tricky case Django gets right: an unrelated host that
// *ends with* "example.com" but isn't a subdomain shouldn't
// match. `evilexample.com` ends with `example.com` but the
// boundary char isn't a dot.
assert!(!layer.permits("evilexample.com"));
}
#[test]
fn star_is_catchall() {
let layer = AllowedHostsLayer::new(["*"]);
assert!(layer.permits("anything"));
assert!(layer.permits("attacker.com"));
}
#[test]
fn port_is_stripped_before_comparison() {
let layer = AllowedHostsLayer::new(["example.com"]);
assert!(layer.permits("example.com:8080"));
assert!(layer.permits("example.com:443"));
}
#[test]
fn ipv6_with_port_is_handled() {
let layer = AllowedHostsLayer::new(["[::1]"]);
assert!(layer.permits("[::1]:8080"));
}
#[test]
fn whitespace_entries_are_ignored() {
let layer = AllowedHostsLayer::new(["", " ", "example.com"]);
// Only the real entry counts; non-matching hosts still get
// rejected.
assert!(layer.permits("example.com"));
assert!(!layer.permits("attacker.com"));
}
#[test]
fn rejected_host_does_not_match_other_patterns_in_list() {
let layer = AllowedHostsLayer::new(["a.com", ".b.com", "c.com"]);
assert!(layer.permits("a.com"));
assert!(layer.permits("foo.b.com"));
assert!(layer.permits("c.com"));
assert!(!layer.permits("d.com"));
assert!(!layer.permits("malicious.a.com"));
}
}