use std::collections::HashSet;
use std::sync::Arc;
use axum::extract::Request;
use axum::http::{StatusCode, header::HOST};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use crate::config::RoutingConfig;
#[derive(Debug, Clone)]
pub struct HostMap {
hosts: Arc<HashSet<String>>,
strict: bool,
}
impl HostMap {
pub fn from_routing(routing: &RoutingConfig) -> Self {
let mut hosts = HashSet::new();
for surface in [&routing.api, &routing.admin_ui, &routing.website] {
if let Some(h) = surface.host.as_deref() {
hosts.insert(h.to_ascii_lowercase());
}
}
Self {
hosts: Arc::new(hosts),
strict: routing.subdomain_mode_strict,
}
}
pub fn is_path_mode(&self) -> bool {
self.hosts.is_empty()
}
fn matches(&self, host_header: &str) -> bool {
if self.is_path_mode() {
return true;
}
self.hosts.contains(&host_header.to_ascii_lowercase())
}
}
pub async fn enforce(
axum::extract::State(map): axum::extract::State<HostMap>,
request: Request,
next: Next,
) -> Response {
if map.is_path_mode() {
return next.run(request).await;
}
let host = request
.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let allowed = match host.as_deref() {
Some(h) => map.matches(h),
None => false,
};
if allowed || !map.strict {
return next.run(request).await;
}
let body = json!({
"error": "HostNotRecognised",
"host": host,
});
(StatusCode::NOT_FOUND, axum::Json(body)).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::MountConfig;
fn cfg(
api_host: Option<&str>,
admin_host: Option<&str>,
web_host: Option<&str>,
) -> RoutingConfig {
RoutingConfig {
api: MountConfig {
mount: "/v1".into(),
host: api_host.map(String::from),
},
admin_ui: MountConfig {
mount: "/admin".into(),
host: admin_host.map(String::from),
},
website: MountConfig {
mount: "/".into(),
host: web_host.map(String::from),
},
subdomain_mode_strict: true,
}
}
#[test]
fn path_mode_when_no_hosts_set() {
let map = HostMap::from_routing(&cfg(None, None, None));
assert!(map.is_path_mode());
assert!(map.matches("anything.example.com"));
}
#[test]
fn matches_recognises_configured_host() {
let map = HostMap::from_routing(&cfg(
Some("api.example.com"),
Some("admin.example.com"),
None,
));
assert!(!map.is_path_mode());
assert!(map.matches("api.example.com"));
assert!(map.matches("ADMIN.example.com")); assert!(!map.matches("other.example.com"));
}
}