use std::collections::HashSet;
use std::sync::Arc;
use axum::extract::Request;
use axum::http::{StatusCode, header::HOST};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use crate::config::RoutingConfig;
const INFRA_PATHS: &[&str] = &["/health", "/openapi.json", "/.well-known/did.jsonl"];
#[derive(Debug, Clone)]
struct Surface {
mount: String,
host: Option<String>,
priority: u8,
}
impl Surface {
fn match_len(&self, path: &str) -> Option<usize> {
if self.mount == "/" {
return Some(1);
}
if path == self.mount || path.starts_with(&format!("{}/", self.mount)) {
Some(self.mount.len())
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct HostMap {
hosts: Arc<HashSet<String>>,
surfaces: Arc<Vec<Surface>>,
strict: bool,
}
impl HostMap {
pub fn from_routing(routing: &RoutingConfig) -> Self {
let mut hosts = HashSet::new();
let mut surfaces = Vec::with_capacity(3);
for (priority, surface) in [&routing.api, &routing.admin_ui, &routing.website]
.into_iter()
.enumerate()
{
let host = surface.host.as_deref().map(str::to_ascii_lowercase);
if let Some(h) = host.as_deref() {
hosts.insert(h.to_string());
}
surfaces.push(Surface {
mount: surface.mount.clone(),
host,
priority: priority as u8,
});
}
Self {
hosts: Arc::new(hosts),
surfaces: Arc::new(surfaces),
strict: routing.subdomain_mode_strict,
}
}
pub fn is_path_mode(&self) -> bool {
self.hosts.is_empty()
}
fn matches(&self, host_header: &str) -> bool {
if self.is_path_mode() {
return true;
}
self.hosts.contains(&host_header.to_ascii_lowercase())
}
fn target_surface(&self, path: &str, req_host: &str) -> Option<&Surface> {
let best = self
.surfaces
.iter()
.filter_map(|s| s.match_len(path).map(|len| (s, len)))
.max_by(|(a, alen), (b, blen)| {
alen.cmp(blen)
.then_with(|| {
let a_owns = a.host.as_deref() == Some(req_host);
let b_owns = b.host.as_deref() == Some(req_host);
a_owns.cmp(&b_owns)
})
.then_with(|| b.priority.cmp(&a.priority))
});
best.map(|(s, _)| s)
}
fn surface_allowed(&self, path: &str, req_host: &str) -> bool {
if INFRA_PATHS.contains(&path) {
return true;
}
match self.target_surface(path, req_host) {
Some(s) => match s.host.as_deref() {
Some(h) => h == req_host,
None => true,
},
None => true,
}
}
}
pub async fn enforce(
axum::extract::State(map): axum::extract::State<HostMap>,
request: Request,
next: Next,
) -> Response {
if map.is_path_mode() {
return next.run(request).await;
}
let host = request
.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let recognised = match host.as_deref() {
Some(h) => map.matches(h),
None => false,
};
if !recognised {
if map.strict {
return not_recognised(host);
}
return next.run(request).await;
}
if !map.strict {
return next.run(request).await;
}
let req_host = host.as_deref().unwrap_or_default().to_ascii_lowercase();
let path = request.uri().path();
if map.surface_allowed(path, &req_host) {
return next.run(request).await;
}
let body = json!({
"error": "SurfaceNotOnHost",
"host": host,
"path": request.uri().path(),
});
(StatusCode::NOT_FOUND, axum::Json(body)).into_response()
}
fn not_recognised(host: Option<String>) -> Response {
let body = json!({
"error": "HostNotRecognised",
"host": host,
});
(StatusCode::NOT_FOUND, axum::Json(body)).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::MountConfig;
fn cfg(
api_host: Option<&str>,
admin_host: Option<&str>,
web_host: Option<&str>,
) -> RoutingConfig {
RoutingConfig {
api: MountConfig {
mount: "/v1".into(),
host: api_host.map(String::from),
},
admin_ui: MountConfig {
mount: "/admin".into(),
host: admin_host.map(String::from),
},
website: MountConfig {
mount: "/".into(),
host: web_host.map(String::from),
},
subdomain_mode_strict: true,
}
}
#[test]
fn path_mode_when_no_hosts_set() {
let map = HostMap::from_routing(&cfg(None, None, None));
assert!(map.is_path_mode());
assert!(map.matches("anything.example.com"));
}
#[test]
fn matches_recognises_configured_host() {
let map = HostMap::from_routing(&cfg(
Some("api.example.com"),
Some("admin.example.com"),
None,
));
assert!(!map.is_path_mode());
assert!(map.matches("api.example.com"));
assert!(map.matches("ADMIN.example.com")); assert!(!map.matches("other.example.com"));
}
#[test]
fn surface_isolation_routes_path_to_its_own_host() {
let map = HostMap::from_routing(&cfg(
Some("api.example.com"),
Some("admin.example.com"),
Some("example.com"),
));
assert!(map.surface_allowed("/v1/acl", "api.example.com"));
assert!(map.surface_allowed("/admin/users", "admin.example.com"));
assert!(map.surface_allowed("/index.html", "example.com"));
assert!(!map.surface_allowed("/v1/acl", "admin.example.com"));
assert!(!map.surface_allowed("/v1/acl", "example.com"));
assert!(!map.surface_allowed("/admin/users", "api.example.com"));
assert!(!map.surface_allowed("/index.html", "api.example.com"));
}
#[test]
fn infra_paths_answer_on_every_recognised_host() {
let map = HostMap::from_routing(&cfg(
Some("api.example.com"),
Some("admin.example.com"),
Some("example.com"),
));
for host in ["api.example.com", "admin.example.com", "example.com"] {
assert!(map.surface_allowed("/health", host), "host {host}");
assert!(map.surface_allowed("/.well-known/did.jsonl", host));
assert!(map.surface_allowed("/openapi.json", host));
}
}
#[test]
fn admin_at_root_in_host_mode_resolves_by_request_host() {
let routing = RoutingConfig {
api: MountConfig {
mount: "/v1".into(),
host: Some("api.example.com".into()),
},
admin_ui: MountConfig {
mount: "/".into(),
host: Some("admin.example.com".into()),
},
website: MountConfig {
mount: "/".into(),
host: Some("example.com".into()),
},
subdomain_mode_strict: true,
};
let map = HostMap::from_routing(&routing);
assert!(map.surface_allowed("/dashboard", "admin.example.com"));
assert!(map.surface_allowed("/dashboard", "example.com"));
assert!(!map.surface_allowed("/v1/acl", "example.com"));
}
}