1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::fmt;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{FromRequestParts, OptionalFromRequestParts};
11use axum::response::IntoResponse;
12use http::Request;
13use http::request::Parts;
14use tower::Service;
15
16use crate::Error;
17
18#[derive(Clone)]
56pub struct HostRouter {
57 inner: Arc<HostRouterInner>,
58}
59
60#[derive(Clone)]
61struct HostRouterInner {
62 exact: HashMap<String, Router>,
63 wildcard: HashMap<String, (String, Router)>,
67 fallback: Option<Router>,
68}
69
70enum Match<'a> {
71 Exact(&'a Router),
72 Wildcard {
73 router: &'a Router,
74 subdomain: String,
75 pattern: String,
76 },
77 Fallback(&'a Router),
78 NotFound,
79}
80
81impl fmt::Debug for HostRouter {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("HostRouter")
84 .field("exact_hosts", &self.inner.exact.keys().collect::<Vec<_>>())
85 .field(
86 "wildcard_hosts",
87 &self
88 .inner
89 .wildcard
90 .keys()
91 .map(|k| format!("*.{k}"))
92 .collect::<Vec<_>>(),
93 )
94 .field("has_fallback", &self.inner.fallback.is_some())
95 .finish()
96 }
97}
98
99impl Default for HostRouter {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl HostRouter {
106 pub fn new() -> Self {
108 Self {
109 inner: Arc::new(HostRouterInner {
110 exact: HashMap::new(),
111 wildcard: HashMap::new(),
112 fallback: None,
113 }),
114 }
115 }
116
117 pub fn host(mut self, pattern: &str, router: Router) -> Self {
133 let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::host called after clone");
134 let pattern = strip_port(pattern.trim()).to_lowercase();
135
136 if let Some(suffix) = pattern.strip_prefix("*.") {
137 assert!(
139 !suffix.is_empty(),
140 "invalid wildcard pattern \"{pattern}\": empty suffix"
141 );
142 assert!(
143 suffix.contains('.'),
144 "invalid wildcard pattern \"{pattern}\": suffix must contain at least one dot (e.g. \"*.example.com\")"
145 );
146 let full_pattern = format!("*.{suffix}");
147 let prev = inner
148 .wildcard
149 .insert(suffix.to_owned(), (full_pattern, router));
150 assert!(
151 prev.is_none(),
152 "duplicate wildcard suffix: \"*.{suffix}\" registered twice"
153 );
154 } else {
155 assert!(!pattern.is_empty(), "host pattern must not be empty");
156 assert!(
157 !pattern.starts_with('*'),
158 "invalid wildcard pattern \"{pattern}\": use \"*.domain.com\" format"
159 );
160 let prev = inner.exact.insert(pattern.clone(), router);
161 assert!(
162 prev.is_none(),
163 "duplicate exact host: \"{pattern}\" registered twice"
164 );
165 }
166
167 self
168 }
169
170 pub fn fallback(mut self, router: Router) -> Self {
179 let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::fallback called after clone");
180 inner.fallback = Some(router);
181 self
182 }
183}
184
185impl HostRouterInner {
186 fn match_host(&self, host: &str) -> Match<'_> {
187 if let Some(router) = self.exact.get(host) {
188 return Match::Exact(router);
189 }
190
191 if let Some(dot) = host.find('.') {
192 let subdomain = &host[..dot];
193 let suffix = &host[dot + 1..];
194 if let Some((pattern, router)) = self.wildcard.get(suffix) {
195 return Match::Wildcard {
196 router,
197 subdomain: subdomain.to_owned(),
198 pattern: pattern.clone(),
199 };
200 }
201 }
202
203 match &self.fallback {
204 Some(router) => Match::Fallback(router),
205 None => Match::NotFound,
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
228pub struct MatchedHost {
229 pub subdomain: String,
231 pub pattern: String,
233}
234
235impl<S> FromRequestParts<S> for MatchedHost
236where
237 S: Send + Sync,
238{
239 type Rejection = Error;
240
241 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
242 parts
243 .extensions
244 .get::<MatchedHost>()
245 .cloned()
246 .ok_or_else(|| Error::internal("internal routing error"))
247 }
248}
249
250impl<S> OptionalFromRequestParts<S> for MatchedHost
251where
252 S: Send + Sync,
253{
254 type Rejection = Error;
255
256 async fn from_request_parts(
257 parts: &mut Parts,
258 _state: &S,
259 ) -> Result<Option<Self>, Self::Rejection> {
260 Ok(parts.extensions.get::<MatchedHost>().cloned())
261 }
262}
263
264#[derive(Clone)]
268struct HostRouterService(Arc<HostRouterInner>);
269
270impl Service<Request<Body>> for HostRouterService {
271 type Response = http::Response<Body>;
272 type Error = Infallible;
273 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
274
275 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 Poll::Ready(Ok(()))
277 }
278
279 fn call(&mut self, req: Request<Body>) -> Self::Future {
280 let inner = Arc::clone(&self.0);
281
282 Box::pin(async move {
283 let (mut parts, body) = req.into_parts();
284
285 let host = match resolve_host(&parts) {
286 Ok(h) => h,
287 Err(e) => return Ok(e.into_response()),
288 };
289
290 let router = match inner.match_host(&host) {
291 Match::Exact(router) | Match::Fallback(router) => router,
292 Match::Wildcard {
293 router,
294 subdomain,
295 pattern,
296 } => {
297 parts.extensions.insert(MatchedHost { subdomain, pattern });
298 router
299 }
300 Match::NotFound => {
301 return Ok(Error::not_found("no route for host").into_response());
302 }
303 };
304
305 let req = Request::from_parts(parts, body);
306 Ok(router.clone().call(req).await.into_response())
307 })
308 }
309}
310
311impl From<HostRouter> for axum::Router {
312 fn from(host_router: HostRouter) -> axum::Router {
313 axum::Router::new().fallback_service(HostRouterService(host_router.inner))
314 }
315}
316
317fn resolve_host(parts: &Parts) -> Result<String, Error> {
326 const HOST_DIRECTIVE: &str = "host=";
327
328 if let Some(fwd) = parts.headers.get("forwarded")
329 && let Ok(fwd_str) = fwd.to_str()
330 {
331 let first_element = fwd_str.split(',').next().unwrap_or(fwd_str);
333 for directive in first_element.split(';') {
334 let directive = directive.trim();
335 if directive.len() >= HOST_DIRECTIVE.len()
337 && directive[..HOST_DIRECTIVE.len()].eq_ignore_ascii_case(HOST_DIRECTIVE)
338 {
339 let host = directive[HOST_DIRECTIVE.len()..].trim();
340 if !host.is_empty() {
341 return Ok(strip_port(host).to_lowercase());
342 }
343 }
344 }
345 }
346
347 if let Some(xfh) = parts.headers.get("x-forwarded-host")
348 && let Ok(host) = xfh.to_str()
349 {
350 let host = host.trim();
351 if !host.is_empty() {
352 return Ok(strip_port(host).to_lowercase());
353 }
354 }
355
356 if let Some(h) = parts.headers.get(http::header::HOST)
357 && let Ok(host) = h.to_str()
358 {
359 let host = host.trim();
360 if !host.is_empty() {
361 return Ok(strip_port(host).to_lowercase());
362 }
363 }
364
365 Err(Error::bad_request("missing or invalid Host header"))
366}
367
368fn strip_port(host: &str) -> &str {
374 match host.rfind(':') {
375 Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
376 _ => host,
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 fn parts_with_headers(headers: &[(&str, &str)]) -> Parts {
385 let mut builder = http::Request::builder();
386 for &(name, value) in headers {
387 builder = builder.header(name, value);
388 }
389 let (parts, _) = builder.body(()).unwrap().into_parts();
390 parts
391 }
392
393 #[test]
394 fn resolve_from_host_header() {
395 let parts = parts_with_headers(&[("host", "acme.com")]);
396 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
397 }
398
399 #[test]
400 fn resolve_strips_port() {
401 let parts = parts_with_headers(&[("host", "acme.com:8080")]);
402 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
403 }
404
405 #[test]
406 fn resolve_lowercases() {
407 let parts = parts_with_headers(&[("host", "ACME.COM")]);
408 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
409 }
410
411 #[test]
412 fn resolve_x_forwarded_host_over_host() {
413 let parts =
414 parts_with_headers(&[("host", "proxy.internal"), ("x-forwarded-host", "acme.com")]);
415 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
416 }
417
418 #[test]
419 fn resolve_forwarded_over_x_forwarded_host() {
420 let parts = parts_with_headers(&[
421 ("host", "proxy.internal"),
422 ("x-forwarded-host", "xfh.com"),
423 ("forwarded", "for=1.2.3.4; host=acme.com; proto=https"),
424 ]);
425 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
426 }
427
428 #[test]
429 fn resolve_forwarded_strips_port() {
430 let parts = parts_with_headers(&[("forwarded", "host=acme.com:443")]);
431 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
432 }
433
434 #[test]
435 fn resolve_x_forwarded_host_strips_port() {
436 let parts = parts_with_headers(&[("x-forwarded-host", "acme.com:8080")]);
437 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
438 }
439
440 #[test]
441 fn resolve_missing_all_headers_returns_400() {
442 let parts = parts_with_headers(&[]);
443 let err = resolve_host(&parts).unwrap_err();
444 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
445 }
446
447 #[test]
448 fn resolve_forwarded_case_insensitive_host_directive() {
449 let parts = parts_with_headers(&[("forwarded", "for=1.2.3.4; Host=acme.com")]);
450 assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
451 }
452
453 #[test]
454 fn resolve_forwarded_without_host_falls_through() {
455 let parts = parts_with_headers(&[
456 ("forwarded", "for=1.2.3.4; proto=https"),
457 ("host", "fallback.com"),
458 ]);
459 assert_eq!(resolve_host(&parts).unwrap(), "fallback.com");
460 }
461
462 fn router_with_body(body: &'static str) -> Router {
465 Router::new().route("/", axum::routing::get(move || async move { body }))
466 }
467
468 #[test]
469 fn match_exact() {
470 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
471 assert!(matches!(hr.inner.match_host("acme.com"), Match::Exact(_)));
472 }
473
474 #[test]
475 fn match_wildcard() {
476 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
477 match hr.inner.match_host("tenant1.acme.com") {
478 Match::Wildcard {
479 subdomain, pattern, ..
480 } => {
481 assert_eq!(subdomain, "tenant1");
482 assert_eq!(pattern, "*.acme.com");
483 }
484 other => panic!("expected Wildcard, got {}", match_name(&other)),
485 }
486 }
487
488 #[test]
489 fn exact_wins_over_wildcard() {
490 let hr = HostRouter::new()
491 .host("app.acme.com", router_with_body("admin"))
492 .host("*.acme.com", router_with_body("tenant"));
493 assert!(matches!(
494 hr.inner.match_host("app.acme.com"),
495 Match::Exact(_)
496 ));
497 }
498
499 #[test]
500 fn bare_domain_does_not_match_wildcard() {
501 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
502 assert!(matches!(hr.inner.match_host("acme.com"), Match::NotFound));
503 }
504
505 #[test]
506 fn multi_level_subdomain_does_not_match_wildcard() {
507 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
508 assert!(matches!(
510 hr.inner.match_host("a.b.acme.com"),
511 Match::NotFound
512 ));
513 }
514
515 #[test]
516 fn fallback_when_no_match() {
517 let hr = HostRouter::new()
518 .host("acme.com", router_with_body("landing"))
519 .fallback(router_with_body("fallback"));
520 assert!(matches!(
521 hr.inner.match_host("other.com"),
522 Match::Fallback(_)
523 ));
524 }
525
526 #[test]
527 fn not_found_when_no_match_and_no_fallback() {
528 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
529 assert!(matches!(hr.inner.match_host("other.com"), Match::NotFound));
530 }
531
532 fn match_name(m: &Match<'_>) -> &'static str {
533 match m {
534 Match::Exact(_) => "Exact",
535 Match::Wildcard { .. } => "Wildcard",
536 Match::Fallback(_) => "Fallback",
537 Match::NotFound => "NotFound",
538 }
539 }
540
541 #[test]
544 #[should_panic(expected = "duplicate exact host")]
545 fn panic_on_duplicate_exact() {
546 HostRouter::new()
547 .host("acme.com", router_with_body("a"))
548 .host("acme.com", router_with_body("b"));
549 }
550
551 #[test]
552 #[should_panic(expected = "duplicate wildcard suffix")]
553 fn panic_on_duplicate_wildcard() {
554 HostRouter::new()
555 .host("*.acme.com", router_with_body("a"))
556 .host("*.acme.com", router_with_body("b"));
557 }
558
559 #[test]
560 #[should_panic(expected = "suffix must contain at least one dot")]
561 fn panic_on_tld_wildcard() {
562 HostRouter::new().host("*.com", router_with_body("a"));
563 }
564
565 #[test]
566 #[should_panic(expected = "invalid wildcard pattern")]
567 fn panic_on_bare_star() {
568 HostRouter::new().host("*", router_with_body("a"));
569 }
570
571 #[test]
572 #[should_panic(expected = "empty suffix")]
573 fn panic_on_star_dot_only() {
574 HostRouter::new().host("*.", router_with_body("a"));
575 }
576
577 #[tokio::test]
580 async fn extract_matched_host_present() {
581 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
582 parts.extensions.insert(MatchedHost {
583 subdomain: "tenant1".into(),
584 pattern: "*.acme.com".into(),
585 });
586
587 let result =
588 <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
589 let matched = result.unwrap();
590 assert_eq!(matched.subdomain, "tenant1");
591 assert_eq!(matched.pattern, "*.acme.com");
592 }
593
594 #[tokio::test]
595 async fn extract_matched_host_missing_returns_500() {
596 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
597
598 let result =
599 <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
600 let err = result.unwrap_err();
601 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
602 assert_eq!(err.to_string(), "internal routing error");
603 }
604
605 #[tokio::test]
606 async fn optional_matched_host_none_when_missing() {
607 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
608
609 let result =
610 <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
611 .await;
612 assert!(result.unwrap().is_none());
613 }
614
615 #[tokio::test]
616 async fn optional_matched_host_some_when_present() {
617 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
618 parts.extensions.insert(MatchedHost {
619 subdomain: "t1".into(),
620 pattern: "*.acme.com".into(),
621 });
622
623 let result =
624 <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
625 .await;
626 let matched = result.unwrap().unwrap();
627 assert_eq!(matched.subdomain, "t1");
628 assert_eq!(matched.pattern, "*.acme.com");
629 }
630
631 use axum::body::Body;
634 use http::{Request, StatusCode};
635 use tower::ServiceExt;
636
637 async fn response_body(resp: http::Response<Body>) -> String {
638 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
639 .await
640 .unwrap();
641 String::from_utf8(bytes.to_vec()).unwrap()
642 }
643
644 #[tokio::test]
645 async fn dispatch_exact_match() {
646 let hr = HostRouter::new()
647 .host("acme.com", router_with_body("landing"))
648 .host("app.acme.com", router_with_body("admin"));
649
650 let router: axum::Router = hr.into();
651 let req = Request::builder()
652 .uri("/")
653 .header("host", "acme.com")
654 .body(Body::empty())
655 .unwrap();
656
657 let resp = router.oneshot(req).await.unwrap();
658 assert_eq!(resp.status(), StatusCode::OK);
659 assert_eq!(response_body(resp).await, "landing");
660 }
661
662 #[tokio::test]
663 async fn dispatch_wildcard_match() {
664 let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
665
666 let router: axum::Router = hr.into();
667 let req = Request::builder()
668 .uri("/")
669 .header("host", "tenant1.acme.com")
670 .body(Body::empty())
671 .unwrap();
672
673 let resp = router.oneshot(req).await.unwrap();
674 assert_eq!(resp.status(), StatusCode::OK);
675 assert_eq!(response_body(resp).await, "tenant");
676 }
677
678 #[tokio::test]
679 async fn dispatch_wildcard_injects_matched_host() {
680 let tenant_router = Router::new().route(
681 "/",
682 axum::routing::get(|matched: MatchedHost| async move {
683 format!("{}:{}", matched.subdomain, matched.pattern)
684 }),
685 );
686
687 let hr = HostRouter::new().host("*.acme.com", tenant_router);
688
689 let router: axum::Router = hr.into();
690 let req = Request::builder()
691 .uri("/")
692 .header("host", "tenant1.acme.com")
693 .body(Body::empty())
694 .unwrap();
695
696 let resp = router.oneshot(req).await.unwrap();
697 assert_eq!(resp.status(), StatusCode::OK);
698 assert_eq!(response_body(resp).await, "tenant1:*.acme.com");
699 }
700
701 #[tokio::test]
702 async fn dispatch_exact_wins_over_wildcard() {
703 let hr = HostRouter::new()
704 .host("app.acme.com", router_with_body("admin"))
705 .host("*.acme.com", router_with_body("tenant"));
706
707 let router: axum::Router = hr.into();
708 let req = Request::builder()
709 .uri("/")
710 .header("host", "app.acme.com")
711 .body(Body::empty())
712 .unwrap();
713
714 let resp = router.oneshot(req).await.unwrap();
715 assert_eq!(resp.status(), StatusCode::OK);
716 assert_eq!(response_body(resp).await, "admin");
717 }
718
719 #[tokio::test]
720 async fn dispatch_fallback() {
721 let hr = HostRouter::new()
722 .host("acme.com", router_with_body("landing"))
723 .fallback(router_with_body("fallback"));
724
725 let router: axum::Router = hr.into();
726 let req = Request::builder()
727 .uri("/")
728 .header("host", "unknown.com")
729 .body(Body::empty())
730 .unwrap();
731
732 let resp = router.oneshot(req).await.unwrap();
733 assert_eq!(resp.status(), StatusCode::OK);
734 assert_eq!(response_body(resp).await, "fallback");
735 }
736
737 #[tokio::test]
738 async fn dispatch_404_no_match_no_fallback() {
739 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
740
741 let router: axum::Router = hr.into();
742 let req = Request::builder()
743 .uri("/")
744 .header("host", "unknown.com")
745 .body(Body::empty())
746 .unwrap();
747
748 let resp = router.oneshot(req).await.unwrap();
749 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
750 }
751
752 #[tokio::test]
753 async fn dispatch_400_missing_host() {
754 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
755
756 let router: axum::Router = hr.into();
757 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
758
759 let resp = router.oneshot(req).await.unwrap();
760 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
761 }
762
763 #[tokio::test]
764 async fn dispatch_via_x_forwarded_host() {
765 let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
766
767 let router: axum::Router = hr.into();
768 let req = Request::builder()
769 .uri("/")
770 .header("host", "proxy.internal")
771 .header("x-forwarded-host", "acme.com")
772 .body(Body::empty())
773 .unwrap();
774
775 let resp = router.oneshot(req).await.unwrap();
776 assert_eq!(resp.status(), StatusCode::OK);
777 assert_eq!(response_body(resp).await, "landing");
778 }
779}