Skip to main content

axum_subdomain_routing/
lib.rs

1use axum::{
2    Router,
3    extract::Request,
4    http::{HeaderMap, StatusCode},
5    response::Response,
6};
7use std::{
8    collections::HashMap,
9    convert::Infallible,
10    future::Future,
11    pin::Pin,
12    sync::Arc,
13    task::{Context, Poll},
14};
15use tower::util::ServiceExt;
16use tower::{Layer, Service};
17use tracing::{debug, trace};
18
19const KNOWN_TLDS: &[&str] = &[
20    "com", "net", "org", "tr", "edu", "gov", "io", "dev", "co", "uk", "info", "biz", "mil", "int",
21    "arpa", "name", "pro", "aero", "coop", "museum", "mobi", "asia", "tel", "cat", "jobs",
22    "travel", "us", "ca", "de", "fr", "au", "jp", "cn", "ru", "br", "it", "es", "nl", "se", "no",
23    "fi", "dk", "pl", "ch", "be", "at",
24];
25
26const HOST_HEADER: &str = "host";
27const X_FORWARDED_HOST_HEADER: &str = "x-forwarded-host";
28
29#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
30pub enum HostSource {
31    #[default]
32    HostOnly,
33    XForwardedHostFallback,
34}
35
36enum ResolvedHost {
37    Present(String),
38    Missing,
39    Malformed,
40}
41
42fn normalize_dns_like_value(value: &str) -> Option<String> {
43    let normalized = value.trim().trim_end_matches('.').to_ascii_lowercase();
44    if normalized.is_empty() {
45        return None;
46    }
47    Some(normalized)
48}
49
50fn extract_host_without_port(host: &str) -> Option<&str> {
51    if host.is_empty() {
52        return None;
53    }
54
55    if let Some(rest) = host.strip_prefix('[') {
56        let closing_index = rest.find(']')?;
57        let ipv6_host = &rest[..closing_index];
58        if ipv6_host.is_empty() {
59            return None;
60        }
61
62        let trailing = &rest[closing_index + 1..];
63        if trailing.is_empty() {
64            return Some(ipv6_host);
65        }
66
67        if let Some(port) = trailing.strip_prefix(':')
68            && !port.is_empty()
69            && port.as_bytes().iter().all(u8::is_ascii_digit)
70        {
71            return Some(ipv6_host);
72        }
73        return None;
74    }
75
76    if let Some((host_without_port, port)) = host.rsplit_once(':') {
77        if host_without_port.contains(':') {
78            return None;
79        }
80        if host_without_port.is_empty() || port.is_empty() {
81            return None;
82        }
83        if !port.as_bytes().iter().all(u8::is_ascii_digit) {
84            return None;
85        }
86        return Some(host_without_port);
87    }
88
89    if host.contains(':') {
90        return None;
91    }
92
93    Some(host)
94}
95
96fn parse_effective_host(value: &str, first_comma_only: bool) -> Option<String> {
97    let candidate = if first_comma_only {
98        value.split(',').next()?.trim()
99    } else {
100        value.trim()
101    };
102
103    let host = extract_host_without_port(candidate)?;
104    normalize_dns_like_value(host)
105}
106
107fn parse_host_header(
108    headers: &HeaderMap,
109    name: &str,
110    first_comma_only: bool,
111) -> Option<Result<String, ()>> {
112    let value = headers.get(name)?;
113    let text = match value.to_str() {
114        Ok(text) => text,
115        Err(_) => return Some(Err(())),
116    };
117    Some(parse_effective_host(text, first_comma_only).ok_or(()))
118}
119
120fn resolve_host(headers: &HeaderMap, source: HostSource) -> ResolvedHost {
121    let primary = parse_host_header(headers, HOST_HEADER, false);
122    match source {
123        HostSource::HostOnly => match primary {
124            Some(Ok(host)) => ResolvedHost::Present(host),
125            Some(Err(_)) => ResolvedHost::Malformed,
126            None => ResolvedHost::Missing,
127        },
128        HostSource::XForwardedHostFallback => {
129            if let Some(Ok(host)) = primary {
130                return ResolvedHost::Present(host);
131            }
132
133            let fallback = parse_host_header(headers, X_FORWARDED_HOST_HEADER, true);
134            match fallback {
135                Some(Ok(host)) => ResolvedHost::Present(host),
136                Some(Err(_)) => ResolvedHost::Malformed,
137                None => {
138                    if matches!(primary, Some(Err(_))) {
139                        ResolvedHost::Malformed
140                    } else {
141                        ResolvedHost::Missing
142                    }
143                }
144            }
145        }
146    }
147}
148
149fn normalize_subdomain_key(subdomain: &str) -> String {
150    subdomain.trim().trim_matches('.').to_ascii_lowercase()
151}
152
153fn collapse_trailing_ipv4(host: &str) -> String {
154    let parts: Vec<&str> = host.split('.').collect();
155    if parts.len() < 4 {
156        return host.to_string();
157    }
158
159    let octets = &parts[parts.len() - 4..];
160    let is_ipv4_suffix = octets
161        .iter()
162        .all(|part| !part.is_empty() && part.parse::<u8>().is_ok());
163
164    if !is_ipv4_suffix {
165        return host.to_string();
166    }
167
168    let ipv4_collapsed = octets.join("_");
169    if parts.len() == 4 {
170        return ipv4_collapsed;
171    }
172
173    format!("{}.{}", parts[..parts.len() - 4].join("."), ipv4_collapsed)
174}
175
176fn not_found_response() -> Response {
177    Response::builder()
178        .status(StatusCode::NOT_FOUND)
179        .body(axum::body::Body::empty())
180        .unwrap()
181}
182
183/// A layer that routes requests based on the `Host` header (subdomain).
184#[derive(Clone)]
185pub struct SubdomainLayer {
186    routes: Arc<HashMap<String, Router>>,
187    strict: bool,
188    known_hosts: Arc<Vec<String>>,
189    auto_detect_domain: bool,
190    host_source: HostSource,
191}
192
193impl SubdomainLayer {
194    /// Create a new `SubdomainLayer`.
195    pub fn new() -> Self {
196        Self {
197            routes: Arc::new(HashMap::new()),
198            strict: false,
199            known_hosts: Arc::new(Vec::new()),
200            auto_detect_domain: true,
201            host_source: HostSource::default(),
202        }
203    }
204
205    /// Register a router for a specific subdomain.
206    ///
207    /// The `subdomain` argument is matched against the extracted subdomain from the `Host` header.
208    pub fn register<S: ToString>(mut self, subdomain: S, router: Router) -> Self {
209        let subdomain_str = normalize_subdomain_key(&subdomain.to_string());
210        debug!(subdomain = %subdomain_str, "Registering router for subdomain");
211        let mut routes = (*self.routes).clone();
212        routes.insert(subdomain_str, router);
213        self.routes = Arc::new(routes);
214        self
215    }
216
217    /// Enable or disable strict subdomain checking.
218    ///
219    /// When strict checking is enabled, requests to unknown subdomains will return a 404 response
220    /// instead of falling back to the main router.
221    pub fn strict(mut self, strict: bool) -> Self {
222        debug!(strict = %strict, "Setting strict mode");
223        self.strict = strict;
224        self
225    }
226
227    /// Set a list of known hosts.
228    ///
229    /// If the host ends with one of these known hosts, the suffix is removed to extract the subdomain.
230    pub fn known_hosts(mut self, hosts: Vec<String>) -> Self {
231        let normalized_hosts: Vec<String> = hosts
232            .into_iter()
233            .filter_map(|host| normalize_dns_like_value(&host))
234            .collect();
235        debug!(hosts = ?normalized_hosts, "Setting known hosts");
236        self.known_hosts = Arc::new(normalized_hosts);
237        self
238    }
239
240    /// Enable or disable automatic domain detection.
241    ///
242    /// When enabled, the layer will attempt to automatically detect and strip known TLDs.
243    pub fn auto_detect_domain(mut self, enable: bool) -> Self {
244        debug!(auto_detect = %enable, "Setting auto-detect domain mode");
245        self.auto_detect_domain = enable;
246        self
247    }
248
249    /// Configure how host should be resolved.
250    ///
251    /// By default, only the `Host` header is used.
252    pub fn host_source(mut self, host_source: HostSource) -> Self {
253        debug!(?host_source, "Setting host source mode");
254        self.host_source = host_source;
255        self
256    }
257}
258
259impl Default for SubdomainLayer {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265impl<S> Layer<S> for SubdomainLayer {
266    type Service = SubdomainService<S>;
267
268    fn layer(&self, inner: S) -> Self::Service {
269        SubdomainService {
270            inner,
271            routes: self.routes.clone(),
272            strict: self.strict,
273            auto_detect_domain: self.auto_detect_domain,
274            known_hosts: self.known_hosts.clone(),
275            host_source: self.host_source,
276        }
277    }
278}
279
280/// Service that handles subdomain routing.
281#[derive(Clone)]
282pub struct SubdomainService<S> {
283    inner: S,
284    routes: Arc<HashMap<String, Router>>,
285    strict: bool,
286    auto_detect_domain: bool,
287    known_hosts: Arc<Vec<String>>,
288    host_source: HostSource,
289}
290
291impl<S> Service<Request> for SubdomainService<S>
292where
293    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
294    S::Future: Send + 'static,
295{
296    type Response = Response;
297    type Error = Infallible;
298    type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
299
300    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301        self.inner.poll_ready(cx)
302    }
303
304    fn call(&mut self, req: Request) -> Self::Future {
305        let inner = self.inner.clone();
306        let routes = self.routes.clone();
307        let strict = self.strict;
308        let auto_detect_domain = self.auto_detect_domain;
309        let known_hosts = self.known_hosts.clone();
310        let host_source = self.host_source;
311
312        let resolved_host = resolve_host(req.headers(), host_source);
313
314        Box::pin(async move {
315            match resolved_host {
316                ResolvedHost::Malformed => {
317                    trace!("Malformed effective host detected");
318                    if strict {
319                        return Ok(not_found_response());
320                    }
321                    trace!(
322                        "Falling back to inner service due to malformed host in non-strict mode"
323                    );
324                    return inner.oneshot(req).await;
325                }
326                ResolvedHost::Missing => {
327                    trace!("No effective host header found");
328                    if strict {
329                        return Ok(not_found_response());
330                    }
331                    trace!("Falling back to inner service due to missing host in non-strict mode");
332                    return inner.oneshot(req).await;
333                }
334                ResolvedHost::Present(host) => {
335                    debug!(host = %host, "Processing request for host");
336                    let mut target_subdomain = None;
337
338                    // Try known hosts
339                    trace!(known_hosts = ?known_hosts.as_ref(), "Checking against known hosts");
340                    for known in known_hosts.iter() {
341                        if host.ends_with(known) {
342                            trace!(known_host = %known, "Host matches known host");
343                            let remainder_len = host.len() - known.len();
344                            if remainder_len > 0 && host.as_bytes()[remainder_len - 1] == b'.' {
345                                let subdomain = host[..remainder_len - 1].to_string();
346                                debug!(subdomain = %subdomain, known_host = %known, "Extracted subdomain from known host");
347                                target_subdomain = Some(subdomain);
348                                break;
349                            }
350                        }
351                    }
352
353                    if target_subdomain.is_none() && auto_detect_domain {
354                        trace!("Attempting auto-detection of subdomain");
355                        let collapsed_host = collapse_trailing_ipv4(&host);
356                        let mut parts: Vec<&str> = collapsed_host.split('.').collect();
357                        trace!(parts = ?parts, "Split host into parts");
358                        if !parts.is_empty() {
359                            let last = *parts.last().unwrap();
360                            if KNOWN_TLDS.contains(&last) {
361                                trace!(tld = %last, "Detected known TLD");
362                                parts.pop();
363                            }
364                            if parts.len() > 1 {
365                                let subdomain = parts[..parts.len() - 1].to_vec().join(".");
366                                debug!(subdomain = %subdomain, "Auto-detected subdomain");
367                                target_subdomain = Some(subdomain);
368                            }
369                        }
370                    }
371
372                    if let Some(sub) = target_subdomain {
373                        if let Some(router) = routes.get(&sub) {
374                            debug!(subdomain = %sub, "Routing to registered subdomain router");
375                            return router.clone().oneshot(req).await;
376                        } else if strict {
377                            debug!(subdomain = %sub, "Subdomain not found, returning 404 (strict mode)");
378                            return Ok(not_found_response());
379                        } else {
380                            debug!(subdomain = %sub, "Subdomain not found, falling back to inner service");
381                        }
382                    } else {
383                        trace!("No subdomain detected");
384                    }
385                }
386            }
387            // Fallback to inner service
388            trace!("Falling back to inner service");
389            inner.oneshot(req).await
390        })
391    }
392}