1use axum::{
2 Router,
3 extract::Request,
4 http::{HeaderMap, StatusCode},
5 response::Response,
6};
7use std::{
8 collections::HashMap,
9 convert::Infallible,
10 future::Future,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14};
15use tower::util::ServiceExt;
16use tower::{Layer, Service};
17use tracing::{debug, trace};
18
19const KNOWN_TLDS: &[&str] = &[
20 "com", "net", "org", "tr", "edu", "gov", "io", "dev", "co", "uk", "info", "biz", "mil", "int",
21 "arpa", "name", "pro", "aero", "coop", "museum", "mobi", "asia", "tel", "cat", "jobs",
22 "travel", "us", "ca", "de", "fr", "au", "jp", "cn", "ru", "br", "it", "es", "nl", "se", "no",
23 "fi", "dk", "pl", "ch", "be", "at",
24];
25
26const HOST_HEADER: &str = "host";
27const X_FORWARDED_HOST_HEADER: &str = "x-forwarded-host";
28
29#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
30pub enum HostSource {
31 #[default]
32 HostOnly,
33 XForwardedHostFallback,
34}
35
36enum ResolvedHost {
37 Present(String),
38 Missing,
39 Malformed,
40}
41
42fn normalize_dns_like_value(value: &str) -> Option<String> {
43 let normalized = value.trim().trim_end_matches('.').to_ascii_lowercase();
44 if normalized.is_empty() {
45 return None;
46 }
47 Some(normalized)
48}
49
50fn extract_host_without_port(host: &str) -> Option<&str> {
51 if host.is_empty() {
52 return None;
53 }
54
55 if let Some(rest) = host.strip_prefix('[') {
56 let closing_index = rest.find(']')?;
57 let ipv6_host = &rest[..closing_index];
58 if ipv6_host.is_empty() {
59 return None;
60 }
61
62 let trailing = &rest[closing_index + 1..];
63 if trailing.is_empty() {
64 return Some(ipv6_host);
65 }
66
67 if let Some(port) = trailing.strip_prefix(':')
68 && !port.is_empty()
69 && port.as_bytes().iter().all(u8::is_ascii_digit)
70 {
71 return Some(ipv6_host);
72 }
73 return None;
74 }
75
76 if let Some((host_without_port, port)) = host.rsplit_once(':') {
77 if host_without_port.contains(':') {
78 return None;
79 }
80 if host_without_port.is_empty() || port.is_empty() {
81 return None;
82 }
83 if !port.as_bytes().iter().all(u8::is_ascii_digit) {
84 return None;
85 }
86 return Some(host_without_port);
87 }
88
89 if host.contains(':') {
90 return None;
91 }
92
93 Some(host)
94}
95
96fn parse_effective_host(value: &str, first_comma_only: bool) -> Option<String> {
97 let candidate = if first_comma_only {
98 value.split(',').next()?.trim()
99 } else {
100 value.trim()
101 };
102
103 let host = extract_host_without_port(candidate)?;
104 normalize_dns_like_value(host)
105}
106
107fn parse_host_header(
108 headers: &HeaderMap,
109 name: &str,
110 first_comma_only: bool,
111) -> Option<Result<String, ()>> {
112 let value = headers.get(name)?;
113 let text = match value.to_str() {
114 Ok(text) => text,
115 Err(_) => return Some(Err(())),
116 };
117 Some(parse_effective_host(text, first_comma_only).ok_or(()))
118}
119
120fn resolve_host(headers: &HeaderMap, source: HostSource) -> ResolvedHost {
121 let primary = parse_host_header(headers, HOST_HEADER, false);
122 match source {
123 HostSource::HostOnly => match primary {
124 Some(Ok(host)) => ResolvedHost::Present(host),
125 Some(Err(_)) => ResolvedHost::Malformed,
126 None => ResolvedHost::Missing,
127 },
128 HostSource::XForwardedHostFallback => {
129 if let Some(Ok(host)) = primary {
130 return ResolvedHost::Present(host);
131 }
132
133 let fallback = parse_host_header(headers, X_FORWARDED_HOST_HEADER, true);
134 match fallback {
135 Some(Ok(host)) => ResolvedHost::Present(host),
136 Some(Err(_)) => ResolvedHost::Malformed,
137 None => {
138 if matches!(primary, Some(Err(_))) {
139 ResolvedHost::Malformed
140 } else {
141 ResolvedHost::Missing
142 }
143 }
144 }
145 }
146 }
147}
148
149fn normalize_subdomain_key(subdomain: &str) -> String {
150 subdomain.trim().trim_matches('.').to_ascii_lowercase()
151}
152
153fn collapse_trailing_ipv4(host: &str) -> String {
154 let parts: Vec<&str> = host.split('.').collect();
155 if parts.len() < 4 {
156 return host.to_string();
157 }
158
159 let octets = &parts[parts.len() - 4..];
160 let is_ipv4_suffix = octets
161 .iter()
162 .all(|part| !part.is_empty() && part.parse::<u8>().is_ok());
163
164 if !is_ipv4_suffix {
165 return host.to_string();
166 }
167
168 let ipv4_collapsed = octets.join("_");
169 if parts.len() == 4 {
170 return ipv4_collapsed;
171 }
172
173 format!("{}.{}", parts[..parts.len() - 4].join("."), ipv4_collapsed)
174}
175
176fn not_found_response() -> Response {
177 Response::builder()
178 .status(StatusCode::NOT_FOUND)
179 .body(axum::body::Body::empty())
180 .unwrap()
181}
182
183#[derive(Clone)]
185pub struct SubdomainLayer {
186 routes: Arc<HashMap<String, Router>>,
187 strict: bool,
188 known_hosts: Arc<Vec<String>>,
189 auto_detect_domain: bool,
190 host_source: HostSource,
191}
192
193impl SubdomainLayer {
194 pub fn new() -> Self {
196 Self {
197 routes: Arc::new(HashMap::new()),
198 strict: false,
199 known_hosts: Arc::new(Vec::new()),
200 auto_detect_domain: true,
201 host_source: HostSource::default(),
202 }
203 }
204
205 pub fn register<S: ToString>(mut self, subdomain: S, router: Router) -> Self {
209 let subdomain_str = normalize_subdomain_key(&subdomain.to_string());
210 debug!(subdomain = %subdomain_str, "Registering router for subdomain");
211 let mut routes = (*self.routes).clone();
212 routes.insert(subdomain_str, router);
213 self.routes = Arc::new(routes);
214 self
215 }
216
217 pub fn strict(mut self, strict: bool) -> Self {
222 debug!(strict = %strict, "Setting strict mode");
223 self.strict = strict;
224 self
225 }
226
227 pub fn known_hosts(mut self, hosts: Vec<String>) -> Self {
231 let normalized_hosts: Vec<String> = hosts
232 .into_iter()
233 .filter_map(|host| normalize_dns_like_value(&host))
234 .collect();
235 debug!(hosts = ?normalized_hosts, "Setting known hosts");
236 self.known_hosts = Arc::new(normalized_hosts);
237 self
238 }
239
240 pub fn auto_detect_domain(mut self, enable: bool) -> Self {
244 debug!(auto_detect = %enable, "Setting auto-detect domain mode");
245 self.auto_detect_domain = enable;
246 self
247 }
248
249 pub fn host_source(mut self, host_source: HostSource) -> Self {
253 debug!(?host_source, "Setting host source mode");
254 self.host_source = host_source;
255 self
256 }
257}
258
259impl Default for SubdomainLayer {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265impl<S> Layer<S> for SubdomainLayer {
266 type Service = SubdomainService<S>;
267
268 fn layer(&self, inner: S) -> Self::Service {
269 SubdomainService {
270 inner,
271 routes: self.routes.clone(),
272 strict: self.strict,
273 auto_detect_domain: self.auto_detect_domain,
274 known_hosts: self.known_hosts.clone(),
275 host_source: self.host_source,
276 }
277 }
278}
279
280#[derive(Clone)]
282pub struct SubdomainService<S> {
283 inner: S,
284 routes: Arc<HashMap<String, Router>>,
285 strict: bool,
286 auto_detect_domain: bool,
287 known_hosts: Arc<Vec<String>>,
288 host_source: HostSource,
289}
290
291impl<S> Service<Request> for SubdomainService<S>
292where
293 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
294 S::Future: Send + 'static,
295{
296 type Response = Response;
297 type Error = Infallible;
298 type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
299
300 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301 self.inner.poll_ready(cx)
302 }
303
304 fn call(&mut self, req: Request) -> Self::Future {
305 let inner = self.inner.clone();
306 let routes = self.routes.clone();
307 let strict = self.strict;
308 let auto_detect_domain = self.auto_detect_domain;
309 let known_hosts = self.known_hosts.clone();
310 let host_source = self.host_source;
311
312 let resolved_host = resolve_host(req.headers(), host_source);
313
314 Box::pin(async move {
315 match resolved_host {
316 ResolvedHost::Malformed => {
317 trace!("Malformed effective host detected");
318 if strict {
319 return Ok(not_found_response());
320 }
321 trace!(
322 "Falling back to inner service due to malformed host in non-strict mode"
323 );
324 return inner.oneshot(req).await;
325 }
326 ResolvedHost::Missing => {
327 trace!("No effective host header found");
328 if strict {
329 return Ok(not_found_response());
330 }
331 trace!("Falling back to inner service due to missing host in non-strict mode");
332 return inner.oneshot(req).await;
333 }
334 ResolvedHost::Present(host) => {
335 debug!(host = %host, "Processing request for host");
336 let mut target_subdomain = None;
337
338 trace!(known_hosts = ?known_hosts.as_ref(), "Checking against known hosts");
340 for known in known_hosts.iter() {
341 if host.ends_with(known) {
342 trace!(known_host = %known, "Host matches known host");
343 let remainder_len = host.len() - known.len();
344 if remainder_len > 0 && host.as_bytes()[remainder_len - 1] == b'.' {
345 let subdomain = host[..remainder_len - 1].to_string();
346 debug!(subdomain = %subdomain, known_host = %known, "Extracted subdomain from known host");
347 target_subdomain = Some(subdomain);
348 break;
349 }
350 }
351 }
352
353 if target_subdomain.is_none() && auto_detect_domain {
354 trace!("Attempting auto-detection of subdomain");
355 let collapsed_host = collapse_trailing_ipv4(&host);
356 let mut parts: Vec<&str> = collapsed_host.split('.').collect();
357 trace!(parts = ?parts, "Split host into parts");
358 if !parts.is_empty() {
359 let last = *parts.last().unwrap();
360 if KNOWN_TLDS.contains(&last) {
361 trace!(tld = %last, "Detected known TLD");
362 parts.pop();
363 }
364 if parts.len() > 1 {
365 let subdomain = parts[..parts.len() - 1].to_vec().join(".");
366 debug!(subdomain = %subdomain, "Auto-detected subdomain");
367 target_subdomain = Some(subdomain);
368 }
369 }
370 }
371
372 if let Some(sub) = target_subdomain {
373 if let Some(router) = routes.get(&sub) {
374 debug!(subdomain = %sub, "Routing to registered subdomain router");
375 return router.clone().oneshot(req).await;
376 } else if strict {
377 debug!(subdomain = %sub, "Subdomain not found, returning 404 (strict mode)");
378 return Ok(not_found_response());
379 } else {
380 debug!(subdomain = %sub, "Subdomain not found, falling back to inner service");
381 }
382 } else {
383 trace!("No subdomain detected");
384 }
385 }
386 }
387 trace!("Falling back to inner service");
389 inner.oneshot(req).await
390 })
391 }
392}