use http::header::HOST;
use crate::error::{Error, Result};
use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
use crate::response::Response;
use crate::router::BoxFuture;
enum HostPattern {
Exact(String),
Suffix(String),
}
pub struct TrustedHost {
patterns: Vec<HostPattern>,
}
impl TrustedHost {
pub fn new<I, S>(hosts: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let patterns = hosts
.into_iter()
.map(|host| {
let host = host.as_ref();
match host.strip_prefix("*.") {
Some(rest) => HostPattern::Suffix(format!(".{}", rest.to_ascii_lowercase())),
None => HostPattern::Exact(host.to_ascii_lowercase()),
}
})
.collect();
Self { patterns }
}
fn allows(&self, host: &str) -> bool {
let host = host.to_ascii_lowercase();
self.patterns.iter().any(|pattern| match pattern {
HostPattern::Exact(exact) => *exact == host,
HostPattern::Suffix(suffix) => host.ends_with(suffix.as_str()),
})
}
}
impl Middleware for TrustedHost {
fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
let host = request
.headers()
.get(HOST)
.and_then(|value| value.to_str().ok())
.map(strip_port);
let allowed = matches!(host, Some(host) if self.allows(host));
if !allowed {
return Box::pin(async { Err(Error::bad_request("invalid host header")) });
}
next.run(request)
}
fn name(&self) -> &'static str {
"TrustedHost"
}
fn duplicate_policy(&self) -> DuplicatePolicy {
DuplicatePolicy::Reject
}
}
fn strip_port(value: &str) -> &str {
if value.starts_with('[') {
match value.find(']') {
Some(end) => &value[..=end],
None => value,
}
} else {
value.split(':').next().unwrap_or(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_exact_and_wildcard_suffix_matches() {
let hosts = TrustedHost::new(["example.com", "*.example.com"]);
assert!(hosts.allows("example.com"));
assert!(hosts.allows("api.example.com"));
assert!(hosts.allows("API.EXAMPLE.COM"));
assert!(!hosts.allows("evil.com"));
assert!(!hosts.allows("example.co"));
}
#[test]
fn strip_port_handles_names_ipv4_and_bracketed_ipv6() {
assert_eq!(strip_port("example.com"), "example.com");
assert_eq!(strip_port("example.com:8080"), "example.com");
assert_eq!(strip_port("127.0.0.1:8080"), "127.0.0.1");
assert_eq!(strip_port("[::1]:8080"), "[::1]");
assert_eq!(strip_port("[2001:db8::1]"), "[2001:db8::1]");
}
}