Skip to main content

modo/server/
host_router.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::fmt;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{FromRequestParts, OptionalFromRequestParts};
11use axum::response::IntoResponse;
12use http::Request;
13use http::request::Parts;
14use tower::Service;
15
16use crate::Error;
17
18/// Routes requests to different axum [`Router`]s based on the `Host` header.
19///
20/// Supports exact host matches (`acme.com`, `app.acme.com`) and single-level
21/// wildcard subdomains (`*.acme.com`). Both use `HashMap` lookups for O(1)
22/// matching. The effective host is resolved from the `Forwarded` (RFC 7239),
23/// `X-Forwarded-Host`, or `Host` header, in that order; the value is
24/// lowercased and any trailing `:port` is stripped before matching.
25///
26/// `HostRouter` implements `Into<axum::Router>` and can therefore be passed
27/// directly to [`http()`](super::http()).
28///
29/// # Panics
30///
31/// The [`host`](Self::host) and [`fallback`](Self::fallback) methods panic if
32/// called after the `HostRouter` has been cloned or converted. Complete all
33/// route registration before passing the router to [`server::http()`](crate::server::http())
34/// or cloning it.
35///
36/// The [`host`](Self::host) method also panics on:
37/// - Empty host patterns
38/// - Duplicate exact host patterns
39/// - Duplicate wildcard suffixes
40/// - Malformed wildcard patterns (must be `*.suffix` where the suffix
41///   contains at least one dot; a bare `*` or `*.com` is rejected)
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use modo::server::HostRouter;
47/// use axum::Router;
48///
49/// let app = HostRouter::new()
50///     .host("acme.com", Router::new())
51///     .host("app.acme.com", Router::new())
52///     .host("*.acme.com", Router::new())
53///     .fallback(Router::new());
54/// ```
55#[derive(Clone)]
56pub struct HostRouter {
57    inner: Arc<HostRouterInner>,
58}
59
60#[derive(Clone)]
61struct HostRouterInner {
62    exact: HashMap<String, Router>,
63    /// Key is the suffix (e.g. `"acme.com"`), value is `(pattern, router)` where
64    /// pattern is the full wildcard string (e.g. `"*.acme.com"`), preformatted at
65    /// registration time to avoid per-request allocation.
66    wildcard: HashMap<String, (String, Router)>,
67    fallback: Option<Router>,
68}
69
70enum Match<'a> {
71    Exact(&'a Router),
72    Wildcard {
73        router: &'a Router,
74        subdomain: String,
75        pattern: String,
76    },
77    Fallback(&'a Router),
78    NotFound,
79}
80
81impl fmt::Debug for HostRouter {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("HostRouter")
84            .field("exact_hosts", &self.inner.exact.keys().collect::<Vec<_>>())
85            .field(
86                "wildcard_hosts",
87                &self
88                    .inner
89                    .wildcard
90                    .keys()
91                    .map(|k| format!("*.{k}"))
92                    .collect::<Vec<_>>(),
93            )
94            .field("has_fallback", &self.inner.fallback.is_some())
95            .finish()
96    }
97}
98
99impl Default for HostRouter {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl HostRouter {
106    /// Create a new empty `HostRouter` with no registered hosts and no fallback.
107    pub fn new() -> Self {
108        Self {
109            inner: Arc::new(HostRouterInner {
110                exact: HashMap::new(),
111                wildcard: HashMap::new(),
112                fallback: None,
113            }),
114        }
115    }
116
117    /// Register a host pattern with a router.
118    ///
119    /// Exact patterns (e.g. `"acme.com"`, `"app.acme.com"`) match the host
120    /// literally. Wildcard patterns (e.g. `"*.acme.com"`) match any single
121    /// subdomain level. The pattern is trimmed, lowercased, and stripped of
122    /// any `:port` suffix before registration.
123    ///
124    /// # Panics
125    ///
126    /// - If `self` has already been cloned or converted to an [`axum::Router`].
127    /// - If the pattern is empty after trimming.
128    /// - If a bare `*` or a leading `*` not followed by `.` is supplied.
129    /// - If an exact host is registered twice.
130    /// - If a wildcard suffix is registered twice.
131    /// - If a wildcard suffix is empty or contains no dot (e.g. `"*.com"`).
132    pub fn host(mut self, pattern: &str, router: Router) -> Self {
133        let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::host called after clone");
134        let pattern = strip_port(pattern.trim()).to_lowercase();
135
136        if let Some(suffix) = pattern.strip_prefix("*.") {
137            // Wildcard validation: suffix must be non-empty and contain at least one dot
138            assert!(
139                !suffix.is_empty(),
140                "invalid wildcard pattern \"{pattern}\": empty suffix"
141            );
142            assert!(
143                suffix.contains('.'),
144                "invalid wildcard pattern \"{pattern}\": suffix must contain at least one dot (e.g. \"*.example.com\")"
145            );
146            let full_pattern = format!("*.{suffix}");
147            let prev = inner
148                .wildcard
149                .insert(suffix.to_owned(), (full_pattern, router));
150            assert!(
151                prev.is_none(),
152                "duplicate wildcard suffix: \"*.{suffix}\" registered twice"
153            );
154        } else {
155            assert!(!pattern.is_empty(), "host pattern must not be empty");
156            assert!(
157                !pattern.starts_with('*'),
158                "invalid wildcard pattern \"{pattern}\": use \"*.domain.com\" format"
159            );
160            let prev = inner.exact.insert(pattern.clone(), router);
161            assert!(
162                prev.is_none(),
163                "duplicate exact host: \"{pattern}\" registered twice"
164            );
165        }
166
167        self
168    }
169
170    /// Set a fallback router for requests whose host doesn't match any pattern.
171    ///
172    /// If no fallback is set, unmatched hosts receive a 404 response.
173    ///
174    /// # Panics
175    ///
176    /// Panics if `self` has already been cloned or converted to an
177    /// [`axum::Router`].
178    pub fn fallback(mut self, router: Router) -> Self {
179        let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::fallback called after clone");
180        inner.fallback = Some(router);
181        self
182    }
183}
184
185impl HostRouterInner {
186    fn match_host(&self, host: &str) -> Match<'_> {
187        if let Some(router) = self.exact.get(host) {
188            return Match::Exact(router);
189        }
190
191        if let Some(dot) = host.find('.') {
192            let subdomain = &host[..dot];
193            let suffix = &host[dot + 1..];
194            if let Some((pattern, router)) = self.wildcard.get(suffix) {
195                return Match::Wildcard {
196                    router,
197                    subdomain: subdomain.to_owned(),
198                    pattern: pattern.clone(),
199                };
200            }
201        }
202
203        match &self.fallback {
204            Some(router) => Match::Fallback(router),
205            None => Match::NotFound,
206        }
207    }
208}
209
210/// Information about a wildcard host match.
211///
212/// Inserted into request extensions when a request matches a wildcard
213/// pattern (e.g. `*.acme.com`). Not present for exact or fallback matches.
214///
215/// Use `Option<MatchedHost>` for handlers that serve both exact and wildcard
216/// routes; the required [`OptionalFromRequestParts`] impl is provided.
217///
218/// # Example
219///
220/// ```rust,no_run
221/// use modo::server::MatchedHost;
222///
223/// async fn handler(matched: MatchedHost) -> String {
224///     format!("subdomain: {}", matched.subdomain)
225/// }
226/// ```
227#[derive(Debug, Clone)]
228pub struct MatchedHost {
229    /// The subdomain that matched (e.g. `"tenant1"` from `"tenant1.acme.com"`).
230    pub subdomain: String,
231    /// The wildcard pattern that matched (e.g. `"*.acme.com"`).
232    pub pattern: String,
233}
234
235impl<S> FromRequestParts<S> for MatchedHost
236where
237    S: Send + Sync,
238{
239    type Rejection = Error;
240
241    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
242        parts
243            .extensions
244            .get::<MatchedHost>()
245            .cloned()
246            .ok_or_else(|| Error::internal("internal routing error"))
247    }
248}
249
250impl<S> OptionalFromRequestParts<S> for MatchedHost
251where
252    S: Send + Sync,
253{
254    type Rejection = Error;
255
256    async fn from_request_parts(
257        parts: &mut Parts,
258        _state: &S,
259    ) -> Result<Option<Self>, Self::Rejection> {
260        Ok(parts.extensions.get::<MatchedHost>().cloned())
261    }
262}
263
264/// Newtype around `Arc<HostRouterInner>` so we can implement `tower::Service`
265/// without hitting orphan rules. Each `call()` does a cheap `Arc::clone`
266/// instead of cloning all `HashMap`s.
267#[derive(Clone)]
268struct HostRouterService(Arc<HostRouterInner>);
269
270impl Service<Request<Body>> for HostRouterService {
271    type Response = http::Response<Body>;
272    type Error = Infallible;
273    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
274
275    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276        Poll::Ready(Ok(()))
277    }
278
279    fn call(&mut self, req: Request<Body>) -> Self::Future {
280        let inner = Arc::clone(&self.0);
281
282        Box::pin(async move {
283            let (mut parts, body) = req.into_parts();
284
285            let host = match resolve_host(&parts) {
286                Ok(h) => h,
287                Err(e) => return Ok(e.into_response()),
288            };
289
290            let router = match inner.match_host(&host) {
291                Match::Exact(router) | Match::Fallback(router) => router,
292                Match::Wildcard {
293                    router,
294                    subdomain,
295                    pattern,
296                } => {
297                    parts.extensions.insert(MatchedHost { subdomain, pattern });
298                    router
299                }
300                Match::NotFound => {
301                    return Ok(Error::not_found("no route for host").into_response());
302                }
303            };
304
305            let req = Request::from_parts(parts, body);
306            Ok(router.clone().call(req).await.into_response())
307        })
308    }
309}
310
311impl From<HostRouter> for axum::Router {
312    fn from(host_router: HostRouter) -> axum::Router {
313        axum::Router::new().fallback_service(HostRouterService(host_router.inner))
314    }
315}
316
317/// Resolve the effective host from a request, checking proxy headers first.
318///
319/// Checks in order:
320/// 1. `Forwarded` header (RFC 7239) — `host=` directive
321/// 2. `X-Forwarded-Host` header
322/// 3. `Host` header
323///
324/// After extraction the value is lowercased and any trailing port is stripped.
325fn resolve_host(parts: &Parts) -> Result<String, Error> {
326    const HOST_DIRECTIVE: &str = "host=";
327
328    if let Some(fwd) = parts.headers.get("forwarded")
329        && let Ok(fwd_str) = fwd.to_str()
330    {
331        // Comma-separated entries represent multiple hops; only the first is relevant.
332        let first_element = fwd_str.split(',').next().unwrap_or(fwd_str);
333        for directive in first_element.split(';') {
334            let directive = directive.trim();
335            // RFC 7239: directive names are case-insensitive
336            if directive.len() >= HOST_DIRECTIVE.len()
337                && directive[..HOST_DIRECTIVE.len()].eq_ignore_ascii_case(HOST_DIRECTIVE)
338            {
339                let host = directive[HOST_DIRECTIVE.len()..].trim();
340                if !host.is_empty() {
341                    return Ok(strip_port(host).to_lowercase());
342                }
343            }
344        }
345    }
346
347    if let Some(xfh) = parts.headers.get("x-forwarded-host")
348        && let Ok(host) = xfh.to_str()
349    {
350        let host = host.trim();
351        if !host.is_empty() {
352            return Ok(strip_port(host).to_lowercase());
353        }
354    }
355
356    if let Some(h) = parts.headers.get(http::header::HOST)
357        && let Ok(host) = h.to_str()
358    {
359        let host = host.trim();
360        if !host.is_empty() {
361            return Ok(strip_port(host).to_lowercase());
362        }
363    }
364
365    Err(Error::bad_request("missing or invalid Host header"))
366}
367
368/// Strip an optional `:port` suffix from a host string.
369///
370/// Assumes RFC 7230 formatting: IPv6 addresses must be bracketed (e.g.
371/// `[::1]:8080`), so a bare `::1` would not be correctly handled. In
372/// practice, all valid Host header values follow this convention.
373fn strip_port(host: &str) -> &str {
374    match host.rfind(':') {
375        Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
376        _ => host,
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    fn parts_with_headers(headers: &[(&str, &str)]) -> Parts {
385        let mut builder = http::Request::builder();
386        for &(name, value) in headers {
387            builder = builder.header(name, value);
388        }
389        let (parts, _) = builder.body(()).unwrap().into_parts();
390        parts
391    }
392
393    #[test]
394    fn resolve_from_host_header() {
395        let parts = parts_with_headers(&[("host", "acme.com")]);
396        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
397    }
398
399    #[test]
400    fn resolve_strips_port() {
401        let parts = parts_with_headers(&[("host", "acme.com:8080")]);
402        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
403    }
404
405    #[test]
406    fn resolve_lowercases() {
407        let parts = parts_with_headers(&[("host", "ACME.COM")]);
408        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
409    }
410
411    #[test]
412    fn resolve_x_forwarded_host_over_host() {
413        let parts =
414            parts_with_headers(&[("host", "proxy.internal"), ("x-forwarded-host", "acme.com")]);
415        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
416    }
417
418    #[test]
419    fn resolve_forwarded_over_x_forwarded_host() {
420        let parts = parts_with_headers(&[
421            ("host", "proxy.internal"),
422            ("x-forwarded-host", "xfh.com"),
423            ("forwarded", "for=1.2.3.4; host=acme.com; proto=https"),
424        ]);
425        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
426    }
427
428    #[test]
429    fn resolve_forwarded_strips_port() {
430        let parts = parts_with_headers(&[("forwarded", "host=acme.com:443")]);
431        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
432    }
433
434    #[test]
435    fn resolve_x_forwarded_host_strips_port() {
436        let parts = parts_with_headers(&[("x-forwarded-host", "acme.com:8080")]);
437        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
438    }
439
440    #[test]
441    fn resolve_missing_all_headers_returns_400() {
442        let parts = parts_with_headers(&[]);
443        let err = resolve_host(&parts).unwrap_err();
444        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
445    }
446
447    #[test]
448    fn resolve_forwarded_case_insensitive_host_directive() {
449        let parts = parts_with_headers(&[("forwarded", "for=1.2.3.4; Host=acme.com")]);
450        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
451    }
452
453    #[test]
454    fn resolve_forwarded_without_host_falls_through() {
455        let parts = parts_with_headers(&[
456            ("forwarded", "for=1.2.3.4; proto=https"),
457            ("host", "fallback.com"),
458        ]);
459        assert_eq!(resolve_host(&parts).unwrap(), "fallback.com");
460    }
461
462    // ── Matching ──────────────────────────────────────────────
463
464    fn router_with_body(body: &'static str) -> Router {
465        Router::new().route("/", axum::routing::get(move || async move { body }))
466    }
467
468    #[test]
469    fn match_exact() {
470        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
471        assert!(matches!(hr.inner.match_host("acme.com"), Match::Exact(_)));
472    }
473
474    #[test]
475    fn match_wildcard() {
476        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
477        match hr.inner.match_host("tenant1.acme.com") {
478            Match::Wildcard {
479                subdomain, pattern, ..
480            } => {
481                assert_eq!(subdomain, "tenant1");
482                assert_eq!(pattern, "*.acme.com");
483            }
484            other => panic!("expected Wildcard, got {}", match_name(&other)),
485        }
486    }
487
488    #[test]
489    fn exact_wins_over_wildcard() {
490        let hr = HostRouter::new()
491            .host("app.acme.com", router_with_body("admin"))
492            .host("*.acme.com", router_with_body("tenant"));
493        assert!(matches!(
494            hr.inner.match_host("app.acme.com"),
495            Match::Exact(_)
496        ));
497    }
498
499    #[test]
500    fn bare_domain_does_not_match_wildcard() {
501        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
502        assert!(matches!(hr.inner.match_host("acme.com"), Match::NotFound));
503    }
504
505    #[test]
506    fn multi_level_subdomain_does_not_match_wildcard() {
507        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
508        // "a.b.acme.com" splits to subdomain="a", suffix="b.acme.com" — not in wildcard map
509        assert!(matches!(
510            hr.inner.match_host("a.b.acme.com"),
511            Match::NotFound
512        ));
513    }
514
515    #[test]
516    fn fallback_when_no_match() {
517        let hr = HostRouter::new()
518            .host("acme.com", router_with_body("landing"))
519            .fallback(router_with_body("fallback"));
520        assert!(matches!(
521            hr.inner.match_host("other.com"),
522            Match::Fallback(_)
523        ));
524    }
525
526    #[test]
527    fn not_found_when_no_match_and_no_fallback() {
528        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
529        assert!(matches!(hr.inner.match_host("other.com"), Match::NotFound));
530    }
531
532    fn match_name(m: &Match<'_>) -> &'static str {
533        match m {
534            Match::Exact(_) => "Exact",
535            Match::Wildcard { .. } => "Wildcard",
536            Match::Fallback(_) => "Fallback",
537            Match::NotFound => "NotFound",
538        }
539    }
540
541    // ── Construction panics ───────────────────────────────────
542
543    #[test]
544    #[should_panic(expected = "duplicate exact host")]
545    fn panic_on_duplicate_exact() {
546        HostRouter::new()
547            .host("acme.com", router_with_body("a"))
548            .host("acme.com", router_with_body("b"));
549    }
550
551    #[test]
552    #[should_panic(expected = "duplicate wildcard suffix")]
553    fn panic_on_duplicate_wildcard() {
554        HostRouter::new()
555            .host("*.acme.com", router_with_body("a"))
556            .host("*.acme.com", router_with_body("b"));
557    }
558
559    #[test]
560    #[should_panic(expected = "suffix must contain at least one dot")]
561    fn panic_on_tld_wildcard() {
562        HostRouter::new().host("*.com", router_with_body("a"));
563    }
564
565    #[test]
566    #[should_panic(expected = "invalid wildcard pattern")]
567    fn panic_on_bare_star() {
568        HostRouter::new().host("*", router_with_body("a"));
569    }
570
571    #[test]
572    #[should_panic(expected = "empty suffix")]
573    fn panic_on_star_dot_only() {
574        HostRouter::new().host("*.", router_with_body("a"));
575    }
576
577    // ── MatchedHost extractor ─────────────────────────────────
578
579    #[tokio::test]
580    async fn extract_matched_host_present() {
581        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
582        parts.extensions.insert(MatchedHost {
583            subdomain: "tenant1".into(),
584            pattern: "*.acme.com".into(),
585        });
586
587        let result =
588            <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
589        let matched = result.unwrap();
590        assert_eq!(matched.subdomain, "tenant1");
591        assert_eq!(matched.pattern, "*.acme.com");
592    }
593
594    #[tokio::test]
595    async fn extract_matched_host_missing_returns_500() {
596        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
597
598        let result =
599            <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
600        let err = result.unwrap_err();
601        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
602        assert_eq!(err.to_string(), "internal routing error");
603    }
604
605    #[tokio::test]
606    async fn optional_matched_host_none_when_missing() {
607        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
608
609        let result =
610            <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
611                .await;
612        assert!(result.unwrap().is_none());
613    }
614
615    #[tokio::test]
616    async fn optional_matched_host_some_when_present() {
617        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
618        parts.extensions.insert(MatchedHost {
619            subdomain: "t1".into(),
620            pattern: "*.acme.com".into(),
621        });
622
623        let result =
624            <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
625                .await;
626        let matched = result.unwrap().unwrap();
627        assert_eq!(matched.subdomain, "t1");
628        assert_eq!(matched.pattern, "*.acme.com");
629    }
630
631    // ── Full dispatch tests ──────────────────────────────────
632
633    use axum::body::Body;
634    use http::{Request, StatusCode};
635    use tower::ServiceExt;
636
637    async fn response_body(resp: http::Response<Body>) -> String {
638        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
639            .await
640            .unwrap();
641        String::from_utf8(bytes.to_vec()).unwrap()
642    }
643
644    #[tokio::test]
645    async fn dispatch_exact_match() {
646        let hr = HostRouter::new()
647            .host("acme.com", router_with_body("landing"))
648            .host("app.acme.com", router_with_body("admin"));
649
650        let router: axum::Router = hr.into();
651        let req = Request::builder()
652            .uri("/")
653            .header("host", "acme.com")
654            .body(Body::empty())
655            .unwrap();
656
657        let resp = router.oneshot(req).await.unwrap();
658        assert_eq!(resp.status(), StatusCode::OK);
659        assert_eq!(response_body(resp).await, "landing");
660    }
661
662    #[tokio::test]
663    async fn dispatch_wildcard_match() {
664        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
665
666        let router: axum::Router = hr.into();
667        let req = Request::builder()
668            .uri("/")
669            .header("host", "tenant1.acme.com")
670            .body(Body::empty())
671            .unwrap();
672
673        let resp = router.oneshot(req).await.unwrap();
674        assert_eq!(resp.status(), StatusCode::OK);
675        assert_eq!(response_body(resp).await, "tenant");
676    }
677
678    #[tokio::test]
679    async fn dispatch_wildcard_injects_matched_host() {
680        let tenant_router = Router::new().route(
681            "/",
682            axum::routing::get(|matched: MatchedHost| async move {
683                format!("{}:{}", matched.subdomain, matched.pattern)
684            }),
685        );
686
687        let hr = HostRouter::new().host("*.acme.com", tenant_router);
688
689        let router: axum::Router = hr.into();
690        let req = Request::builder()
691            .uri("/")
692            .header("host", "tenant1.acme.com")
693            .body(Body::empty())
694            .unwrap();
695
696        let resp = router.oneshot(req).await.unwrap();
697        assert_eq!(resp.status(), StatusCode::OK);
698        assert_eq!(response_body(resp).await, "tenant1:*.acme.com");
699    }
700
701    #[tokio::test]
702    async fn dispatch_exact_wins_over_wildcard() {
703        let hr = HostRouter::new()
704            .host("app.acme.com", router_with_body("admin"))
705            .host("*.acme.com", router_with_body("tenant"));
706
707        let router: axum::Router = hr.into();
708        let req = Request::builder()
709            .uri("/")
710            .header("host", "app.acme.com")
711            .body(Body::empty())
712            .unwrap();
713
714        let resp = router.oneshot(req).await.unwrap();
715        assert_eq!(resp.status(), StatusCode::OK);
716        assert_eq!(response_body(resp).await, "admin");
717    }
718
719    #[tokio::test]
720    async fn dispatch_fallback() {
721        let hr = HostRouter::new()
722            .host("acme.com", router_with_body("landing"))
723            .fallback(router_with_body("fallback"));
724
725        let router: axum::Router = hr.into();
726        let req = Request::builder()
727            .uri("/")
728            .header("host", "unknown.com")
729            .body(Body::empty())
730            .unwrap();
731
732        let resp = router.oneshot(req).await.unwrap();
733        assert_eq!(resp.status(), StatusCode::OK);
734        assert_eq!(response_body(resp).await, "fallback");
735    }
736
737    #[tokio::test]
738    async fn dispatch_404_no_match_no_fallback() {
739        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
740
741        let router: axum::Router = hr.into();
742        let req = Request::builder()
743            .uri("/")
744            .header("host", "unknown.com")
745            .body(Body::empty())
746            .unwrap();
747
748        let resp = router.oneshot(req).await.unwrap();
749        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
750    }
751
752    #[tokio::test]
753    async fn dispatch_400_missing_host() {
754        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
755
756        let router: axum::Router = hr.into();
757        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
758
759        let resp = router.oneshot(req).await.unwrap();
760        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
761    }
762
763    #[tokio::test]
764    async fn dispatch_via_x_forwarded_host() {
765        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
766
767        let router: axum::Router = hr.into();
768        let req = Request::builder()
769            .uri("/")
770            .header("host", "proxy.internal")
771            .header("x-forwarded-host", "acme.com")
772            .body(Body::empty())
773            .unwrap();
774
775        let resp = router.oneshot(req).await.unwrap();
776        assert_eq!(resp.status(), StatusCode::OK);
777        assert_eq!(response_body(resp).await, "landing");
778    }
779}