1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::fmt;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{FromRequestParts, OptionalFromRequestParts};
11use axum::response::IntoResponse;
12use http::Request;
13use http::request::Parts;
14use tower::Service;
15
16use crate::Error;
17
18#[derive(Clone)]
49pub struct HostRouter {
50 inner: Arc<HostRouterInner>,
51}
52
53#[derive(Clone)]
54struct HostRouterInner {
55 exact: HashMap<String, Router>,
56 wildcard: HashMap<String, (String, Router)>,
60 fallback: Option<Router>,
61}
62
63enum Match<'a> {
64 Exact(&'a Router),
65 Wildcard {
66 router: &'a Router,
67 subdomain: String,
68 pattern: String,
69 },
70 Fallback(&'a Router),
71 NotFound,
72}
73
74impl fmt::Debug for HostRouter {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("HostRouter")
77 .field("exact_hosts", &self.inner.exact.keys().collect::<Vec<_>>())
78 .field(
79 "wildcard_hosts",
80 &self
81 .inner
82 .wildcard
83 .keys()
84 .map(|k| format!("*.{k}"))
85 .collect::<Vec<_>>(),
86 )
87 .field("has_fallback", &self.inner.fallback.is_some())
88 .finish()
89 }
90}
91
92impl Default for HostRouter {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl HostRouter {
99 pub fn new() -> Self {
101 Self {
102 inner: Arc::new(HostRouterInner {
103 exact: HashMap::new(),
104 wildcard: HashMap::new(),
105 fallback: None,
106 }),
107 }
108 }
109
110 pub fn host(mut self, pattern: &str, router: Router) -> Self {
122 let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::host called after clone");
123 let pattern = strip_port(pattern.trim()).to_lowercase();
124
125 if let Some(suffix) = pattern.strip_prefix("*.") {
126 assert!(
128 !suffix.is_empty(),
129 "invalid wildcard pattern \"{pattern}\": empty suffix"
130 );
131 assert!(
132 suffix.contains('.'),
133 "invalid wildcard pattern \"{pattern}\": suffix must contain at least one dot (e.g. \"*.example.com\")"
134 );
135 let full_pattern = format!("*.{suffix}");
136 let prev = inner
137 .wildcard
138 .insert(suffix.to_owned(), (full_pattern, router));
139 assert!(
140 prev.is_none(),
141 "duplicate wildcard suffix: \"*.{suffix}\" registered twice"
142 );
143 } else {
144 assert!(!pattern.is_empty(), "host pattern must not be empty");
145 assert!(
146 !pattern.starts_with('*'),
147 "invalid wildcard pattern \"{pattern}\": use \"*.domain.com\" format"
148 );
149 let prev = inner.exact.insert(pattern.clone(), router);
150 assert!(
151 prev.is_none(),
152 "duplicate exact host: \"{pattern}\" registered twice"
153 );
154 }
155
156 self
157 }
158
159 pub fn fallback(mut self, router: Router) -> Self {
163 let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::fallback called after clone");
164 inner.fallback = Some(router);
165 self
166 }
167}
168
169impl HostRouterInner {
170 fn match_host(&self, host: &str) -> Match<'_> {
171 if let Some(router) = self.exact.get(host) {
172 return Match::Exact(router);
173 }
174
175 if let Some(dot) = host.find('.') {
176 let subdomain = &host[..dot];
177 let suffix = &host[dot + 1..];
178 if let Some((pattern, router)) = self.wildcard.get(suffix) {
179 return Match::Wildcard {
180 router,
181 subdomain: subdomain.to_owned(),
182 pattern: pattern.clone(),
183 };
184 }
185 }
186
187 match &self.fallback {
188 Some(router) => Match::Fallback(router),
189 None => Match::NotFound,
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
210pub struct MatchedHost {
211 pub subdomain: String,
213 pub pattern: String,
215}
216
217impl<S> FromRequestParts<S> for MatchedHost
218where
219 S: Send + Sync,
220{
221 type Rejection = Error;
222
223 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
224 parts
225 .extensions
226 .get::<MatchedHost>()
227 .cloned()
228 .ok_or_else(|| Error::internal("internal routing error"))
229 }
230}
231
232impl<S> OptionalFromRequestParts<S> for MatchedHost
233where
234 S: Send + Sync,
235{
236 type Rejection = Error;
237
238 async fn from_request_parts(
239 parts: &mut Parts,
240 _state: &S,
241 ) -> Result<Option<Self>, Self::Rejection> {
242 Ok(parts.extensions.get::<MatchedHost>().cloned())
243 }
244}
245
246#[derive(Clone)]
250struct HostRouterService(Arc<HostRouterInner>);
251
252impl Service<Request<Body>> for HostRouterService {
253 type Response = http::Response<Body>;
254 type Error = Infallible;
255 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
256
257 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
258 Poll::Ready(Ok(()))
259 }
260
261 fn call(&mut self, req: Request<Body>) -> Self::Future {
262 let inner = Arc::clone(&self.0);
263
264 Box::pin(async move {
265 let (mut parts, body) = req.into_parts();
266
267 let host = match resolve_host(&parts) {
268 Ok(h) => h,
269 Err(e) => return Ok(e.into_response()),
270 };
271
272 let router = match inner.match_host(&host) {
273 Match::Exact(router) | Match::Fallback(router) => router,
274 Match::Wildcard {
275 router,
276 subdomain,
277 pattern,
278 } => {
279 parts.extensions.insert(MatchedHost { subdomain, pattern });
280 router
281 }
282 Match::NotFound => {
283 return Ok(Error::not_found("no route for host").into_response());
284 }
285 };
286
287 let req = Request::from_parts(parts, body);
288 Ok(router.clone().call(req).await.into_response())
289 })
290 }
291}
292
293impl From<HostRouter> for axum::Router {
294 fn from(host_router: HostRouter) -> axum::Router {
295 axum::Router::new().fallback_service(HostRouterService(host_router.inner))
296 }
297}
298
299fn resolve_host(parts: &Parts) -> Result<String, Error> {
308 const HOST_DIRECTIVE: &str = "host=";
309
310 if let Some(fwd) = parts.headers.get("forwarded")
311 && let Ok(fwd_str) = fwd.to_str()
312 {
313 let first_element = fwd_str.split(',').next().unwrap_or(fwd_str);
315 for directive in first_element.split(';') {
316 let directive = directive.trim();
317 if directive.len() >= HOST_DIRECTIVE.len()
319 && directive[..HOST_DIRECTIVE.len()].eq_ignore_ascii_case(HOST_DIRECTIVE)
320 {
321 let host = directive[HOST_DIRECTIVE.len()..].trim();
322 if !host.is_empty() {
323 return Ok(strip_port(host).to_lowercase());
324 }
325 }
326 }
327 }
328
329 if let Some(xfh) = parts.headers.get("x-forwarded-host")
330 && let Ok(host) = xfh.to_str()
331 {
332 let host = host.trim();
333 if !host.is_empty() {
334 return Ok(strip_port(host).to_lowercase());
335 }
336 }
337
338 if let Some(h) = parts.headers.get(http::header::HOST)
339 && let Ok(host) = h.to_str()
340 {
341 let host = host.trim();
342 if !host.is_empty() {
343 return Ok(strip_port(host).to_lowercase());
344 }
345 }
346
347 Err(Error::bad_request("missing or invalid Host header"))
348}
349
350fn strip_port(host: &str) -> &str {
356 match host.rfind(':') {
357 Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
358 _ => host,
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 fn parts_with_headers(headers: &[(&str, &str)]) -> Parts {
367 let mut builder = http::Request::builder();
368 for &(name, value) in headers {
369 builder = builder.header(name, value);
370 }
371 let (parts, _) = builder.body(()).unwrap().into_parts();
372 parts
373 }
374
375 #[test]
376 fn resolve_from_host_header() {
377 let parts = parts_with_headers(&[("host", "acme.com")]);
378 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
379 }
380
381 #[test]
382 fn resolve_strips_port() {
383 let parts = parts_with_headers(&[("host", "acme.com:8080")]);
384 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
385 }
386
387 #[test]
388 fn resolve_lowercases() {
389 let parts = parts_with_headers(&[("host", "ACME.COM")]);
390 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
391 }
392
393 #[test]
394 fn resolve_x_forwarded_host_over_host() {
395 let parts =
396 parts_with_headers(&[("host", "proxy.internal"), ("x-forwarded-host", "acme.com")]);
397 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
398 }
399
400 #[test]
401 fn resolve_forwarded_over_x_forwarded_host() {
402 let parts = parts_with_headers(&[
403 ("host", "proxy.internal"),
404 ("x-forwarded-host", "xfh.com"),
405 ("forwarded", "for=1.2.3.4; host=acme.com; proto=https"),
406 ]);
407 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
408 }
409
410 #[test]
411 fn resolve_forwarded_strips_port() {
412 let parts = parts_with_headers(&[("forwarded", "host=acme.com:443")]);
413 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
414 }
415
416 #[test]
417 fn resolve_x_forwarded_host_strips_port() {
418 let parts = parts_with_headers(&[("x-forwarded-host", "acme.com:8080")]);
419 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
420 }
421
422 #[test]
423 fn resolve_missing_all_headers_returns_400() {
424 let parts = parts_with_headers(&[]);
425 let err = resolve_host(&parts).unwrap_err();
426 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
427 }
428
429 #[test]
430 fn resolve_forwarded_case_insensitive_host_directive() {
431 let parts = parts_with_headers(&[("forwarded", "for=1.2.3.4; Host=acme.com")]);
432 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
433 }
434
435 #[test]
436 fn resolve_forwarded_without_host_falls_through() {
437 let parts = parts_with_headers(&[
438 ("forwarded", "for=1.2.3.4; proto=https"),
439 ("host", "fallback.com"),
440 ]);
441 assert_eq!(resolve_host(&parts).unwrap(), "fallback.com");
442 }
443
444 fn router_with_body(body: &'static str) -> Router {
447 Router::new().route("/", axum::routing::get(move || async move { body }))
448 }
449
450 #[test]
451 fn match_exact() {
452 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
453 assert!(matches!(hr.inner.match_host("acme.com"), Match::Exact(_)));
454 }
455
456 #[test]
457 fn match_wildcard() {
458 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
459 match hr.inner.match_host("tenant1.acme.com") {
460 Match::Wildcard {
461 subdomain, pattern, ..
462 } => {
463 assert_eq!(subdomain, "tenant1");
464 assert_eq!(pattern, "*.acme.com");
465 }
466 other => panic!("expected Wildcard, got {}", match_name(&other)),
467 }
468 }
469
470 #[test]
471 fn exact_wins_over_wildcard() {
472 let hr = HostRouter::new()
473 .host("app.acme.com", router_with_body("admin"))
474 .host("*.acme.com", router_with_body("tenant"));
475 assert!(matches!(
476 hr.inner.match_host("app.acme.com"),
477 Match::Exact(_)
478 ));
479 }
480
481 #[test]
482 fn bare_domain_does_not_match_wildcard() {
483 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
484 assert!(matches!(hr.inner.match_host("acme.com"), Match::NotFound));
485 }
486
487 #[test]
488 fn multi_level_subdomain_does_not_match_wildcard() {
489 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
490 assert!(matches!(
492 hr.inner.match_host("a.b.acme.com"),
493 Match::NotFound
494 ));
495 }
496
497 #[test]
498 fn fallback_when_no_match() {
499 let hr = HostRouter::new()
500 .host("acme.com", router_with_body("landing"))
501 .fallback(router_with_body("fallback"));
502 assert!(matches!(
503 hr.inner.match_host("other.com"),
504 Match::Fallback(_)
505 ));
506 }
507
508 #[test]
509 fn not_found_when_no_match_and_no_fallback() {
510 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
511 assert!(matches!(hr.inner.match_host("other.com"), Match::NotFound));
512 }
513
514 fn match_name(m: &Match<'_>) -> &'static str {
515 match m {
516 Match::Exact(_) => "Exact",
517 Match::Wildcard { .. } => "Wildcard",
518 Match::Fallback(_) => "Fallback",
519 Match::NotFound => "NotFound",
520 }
521 }
522
523 #[test]
526 #[should_panic(expected = "duplicate exact host")]
527 fn panic_on_duplicate_exact() {
528 HostRouter::new()
529 .host("acme.com", router_with_body("a"))
530 .host("acme.com", router_with_body("b"));
531 }
532
533 #[test]
534 #[should_panic(expected = "duplicate wildcard suffix")]
535 fn panic_on_duplicate_wildcard() {
536 HostRouter::new()
537 .host("*.acme.com", router_with_body("a"))
538 .host("*.acme.com", router_with_body("b"));
539 }
540
541 #[test]
542 #[should_panic(expected = "suffix must contain at least one dot")]
543 fn panic_on_tld_wildcard() {
544 HostRouter::new().host("*.com", router_with_body("a"));
545 }
546
547 #[test]
548 #[should_panic(expected = "invalid wildcard pattern")]
549 fn panic_on_bare_star() {
550 HostRouter::new().host("*", router_with_body("a"));
551 }
552
553 #[test]
554 #[should_panic(expected = "empty suffix")]
555 fn panic_on_star_dot_only() {
556 HostRouter::new().host("*.", router_with_body("a"));
557 }
558
559 #[tokio::test]
562 async fn extract_matched_host_present() {
563 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
564 parts.extensions.insert(MatchedHost {
565 subdomain: "tenant1".into(),
566 pattern: "*.acme.com".into(),
567 });
568
569 let result =
570 <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
571 let matched = result.unwrap();
572 assert_eq!(matched.subdomain, "tenant1");
573 assert_eq!(matched.pattern, "*.acme.com");
574 }
575
576 #[tokio::test]
577 async fn extract_matched_host_missing_returns_500() {
578 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
579
580 let result =
581 <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
582 let err = result.unwrap_err();
583 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
584 assert_eq!(err.to_string(), "internal routing error");
585 }
586
587 #[tokio::test]
588 async fn optional_matched_host_none_when_missing() {
589 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
590
591 let result =
592 <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
593 .await;
594 assert!(result.unwrap().is_none());
595 }
596
597 #[tokio::test]
598 async fn optional_matched_host_some_when_present() {
599 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
600 parts.extensions.insert(MatchedHost {
601 subdomain: "t1".into(),
602 pattern: "*.acme.com".into(),
603 });
604
605 let result =
606 <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
607 .await;
608 let matched = result.unwrap().unwrap();
609 assert_eq!(matched.subdomain, "t1");
610 assert_eq!(matched.pattern, "*.acme.com");
611 }
612
613 use axum::body::Body;
616 use http::{Request, StatusCode};
617 use tower::ServiceExt;
618
619 async fn response_body(resp: http::Response<Body>) -> String {
620 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
621 .await
622 .unwrap();
623 String::from_utf8(bytes.to_vec()).unwrap()
624 }
625
626 #[tokio::test]
627 async fn dispatch_exact_match() {
628 let hr = HostRouter::new()
629 .host("acme.com", router_with_body("landing"))
630 .host("app.acme.com", router_with_body("admin"));
631
632 let router: axum::Router = hr.into();
633 let req = Request::builder()
634 .uri("/")
635 .header("host", "acme.com")
636 .body(Body::empty())
637 .unwrap();
638
639 let resp = router.oneshot(req).await.unwrap();
640 assert_eq!(resp.status(), StatusCode::OK);
641 assert_eq!(response_body(resp).await, "landing");
642 }
643
644 #[tokio::test]
645 async fn dispatch_wildcard_match() {
646 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
647
648 let router: axum::Router = hr.into();
649 let req = Request::builder()
650 .uri("/")
651 .header("host", "tenant1.acme.com")
652 .body(Body::empty())
653 .unwrap();
654
655 let resp = router.oneshot(req).await.unwrap();
656 assert_eq!(resp.status(), StatusCode::OK);
657 assert_eq!(response_body(resp).await, "tenant");
658 }
659
660 #[tokio::test]
661 async fn dispatch_wildcard_injects_matched_host() {
662 let tenant_router = Router::new().route(
663 "/",
664 axum::routing::get(|matched: MatchedHost| async move {
665 format!("{}:{}", matched.subdomain, matched.pattern)
666 }),
667 );
668
669 let hr = HostRouter::new().host("*.acme.com", tenant_router);
670
671 let router: axum::Router = hr.into();
672 let req = Request::builder()
673 .uri("/")
674 .header("host", "tenant1.acme.com")
675 .body(Body::empty())
676 .unwrap();
677
678 let resp = router.oneshot(req).await.unwrap();
679 assert_eq!(resp.status(), StatusCode::OK);
680 assert_eq!(response_body(resp).await, "tenant1:*.acme.com");
681 }
682
683 #[tokio::test]
684 async fn dispatch_exact_wins_over_wildcard() {
685 let hr = HostRouter::new()
686 .host("app.acme.com", router_with_body("admin"))
687 .host("*.acme.com", router_with_body("tenant"));
688
689 let router: axum::Router = hr.into();
690 let req = Request::builder()
691 .uri("/")
692 .header("host", "app.acme.com")
693 .body(Body::empty())
694 .unwrap();
695
696 let resp = router.oneshot(req).await.unwrap();
697 assert_eq!(resp.status(), StatusCode::OK);
698 assert_eq!(response_body(resp).await, "admin");
699 }
700
701 #[tokio::test]
702 async fn dispatch_fallback() {
703 let hr = HostRouter::new()
704 .host("acme.com", router_with_body("landing"))
705 .fallback(router_with_body("fallback"));
706
707 let router: axum::Router = hr.into();
708 let req = Request::builder()
709 .uri("/")
710 .header("host", "unknown.com")
711 .body(Body::empty())
712 .unwrap();
713
714 let resp = router.oneshot(req).await.unwrap();
715 assert_eq!(resp.status(), StatusCode::OK);
716 assert_eq!(response_body(resp).await, "fallback");
717 }
718
719 #[tokio::test]
720 async fn dispatch_404_no_match_no_fallback() {
721 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
722
723 let router: axum::Router = hr.into();
724 let req = Request::builder()
725 .uri("/")
726 .header("host", "unknown.com")
727 .body(Body::empty())
728 .unwrap();
729
730 let resp = router.oneshot(req).await.unwrap();
731 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
732 }
733
734 #[tokio::test]
735 async fn dispatch_400_missing_host() {
736 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
737
738 let router: axum::Router = hr.into();
739 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
740
741 let resp = router.oneshot(req).await.unwrap();
742 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
743 }
744
745 #[tokio::test]
746 async fn dispatch_via_x_forwarded_host() {
747 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
748
749 let router: axum::Router = hr.into();
750 let req = Request::builder()
751 .uri("/")
752 .header("host", "proxy.internal")
753 .header("x-forwarded-host", "acme.com")
754 .body(Body::empty())
755 .unwrap();
756
757 let resp = router.oneshot(req).await.unwrap();
758 assert_eq!(resp.status(), StatusCode::OK);
759 assert_eq!(response_body(resp).await, "landing");
760 }
761}