Skip to main content

modo/server/
host_router.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::fmt;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{FromRequestParts, OptionalFromRequestParts};
11use axum::response::IntoResponse;
12use http::Request;
13use http::request::Parts;
14use tower::Service;
15
16use crate::Error;
17
18/// Routes requests to different axum [`Router`]s based on the `Host` header.
19///
20/// Supports exact host matches (`acme.com`, `app.acme.com`) and single-level
21/// wildcard subdomains (`*.acme.com`). Both use `HashMap` lookups for O(1)
22/// matching.
23///
24/// # Panics
25///
26/// The [`host`](Self::host) and [`fallback`](Self::fallback) methods panic if
27/// called after the `HostRouter` has been cloned or converted. Complete all
28/// route registration before passing the router to [`server::http()`](crate::server::http())
29/// or cloning it.
30///
31/// The [`host`](Self::host) method also panics on:
32/// - Duplicate exact host patterns
33/// - Duplicate wildcard suffixes
34/// - Invalid wildcard patterns (suffix must contain at least one dot)
35///
36/// # Example
37///
38/// ```rust,no_run
39/// use modo::server::HostRouter;
40/// use axum::Router;
41///
42/// let app = HostRouter::new()
43///     .host("acme.com", Router::new())
44///     .host("app.acme.com", Router::new())
45///     .host("*.acme.com", Router::new())
46///     .fallback(Router::new());
47/// ```
48#[derive(Clone)]
49pub struct HostRouter {
50    inner: Arc<HostRouterInner>,
51}
52
53#[derive(Clone)]
54struct HostRouterInner {
55    exact: HashMap<String, Router>,
56    /// Key is the suffix (e.g. `"acme.com"`), value is `(pattern, router)` where
57    /// pattern is the full wildcard string (e.g. `"*.acme.com"`), preformatted at
58    /// registration time to avoid per-request allocation.
59    wildcard: HashMap<String, (String, Router)>,
60    fallback: Option<Router>,
61}
62
63enum Match<'a> {
64    Exact(&'a Router),
65    Wildcard {
66        router: &'a Router,
67        subdomain: String,
68        pattern: String,
69    },
70    Fallback(&'a Router),
71    NotFound,
72}
73
74impl fmt::Debug for HostRouter {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("HostRouter")
77            .field("exact_hosts", &self.inner.exact.keys().collect::<Vec<_>>())
78            .field(
79                "wildcard_hosts",
80                &self
81                    .inner
82                    .wildcard
83                    .keys()
84                    .map(|k| format!("*.{k}"))
85                    .collect::<Vec<_>>(),
86            )
87            .field("has_fallback", &self.inner.fallback.is_some())
88            .finish()
89    }
90}
91
92impl Default for HostRouter {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl HostRouter {
99    /// Create a new empty `HostRouter` with no registered hosts and no fallback.
100    pub fn new() -> Self {
101        Self {
102            inner: Arc::new(HostRouterInner {
103                exact: HashMap::new(),
104                wildcard: HashMap::new(),
105                fallback: None,
106            }),
107        }
108    }
109
110    /// Register a host pattern with a router.
111    ///
112    /// Exact patterns (e.g. `"acme.com"`, `"app.acme.com"`) match the host
113    /// literally. Wildcard patterns (e.g. `"*.acme.com"`) match any single
114    /// subdomain level.
115    ///
116    /// # Panics
117    ///
118    /// - If an exact host is registered twice.
119    /// - If a wildcard suffix is registered twice.
120    /// - If a wildcard suffix contains no dot (e.g. `"*.com"`).
121    pub fn host(mut self, pattern: &str, router: Router) -> Self {
122        let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::host called after clone");
123        let pattern = strip_port(pattern.trim()).to_lowercase();
124
125        if let Some(suffix) = pattern.strip_prefix("*.") {
126            // Wildcard validation: suffix must be non-empty and contain at least one dot
127            assert!(
128                !suffix.is_empty(),
129                "invalid wildcard pattern \"{pattern}\": empty suffix"
130            );
131            assert!(
132                suffix.contains('.'),
133                "invalid wildcard pattern \"{pattern}\": suffix must contain at least one dot (e.g. \"*.example.com\")"
134            );
135            let full_pattern = format!("*.{suffix}");
136            let prev = inner
137                .wildcard
138                .insert(suffix.to_owned(), (full_pattern, router));
139            assert!(
140                prev.is_none(),
141                "duplicate wildcard suffix: \"*.{suffix}\" registered twice"
142            );
143        } else {
144            assert!(!pattern.is_empty(), "host pattern must not be empty");
145            assert!(
146                !pattern.starts_with('*'),
147                "invalid wildcard pattern \"{pattern}\": use \"*.domain.com\" format"
148            );
149            let prev = inner.exact.insert(pattern.clone(), router);
150            assert!(
151                prev.is_none(),
152                "duplicate exact host: \"{pattern}\" registered twice"
153            );
154        }
155
156        self
157    }
158
159    /// Set a fallback router for requests whose host doesn't match any pattern.
160    ///
161    /// If no fallback is set, unmatched hosts receive a 404 response.
162    pub fn fallback(mut self, router: Router) -> Self {
163        let inner = Arc::get_mut(&mut self.inner).expect("HostRouter::fallback called after clone");
164        inner.fallback = Some(router);
165        self
166    }
167}
168
169impl HostRouterInner {
170    fn match_host(&self, host: &str) -> Match<'_> {
171        if let Some(router) = self.exact.get(host) {
172            return Match::Exact(router);
173        }
174
175        if let Some(dot) = host.find('.') {
176            let subdomain = &host[..dot];
177            let suffix = &host[dot + 1..];
178            if let Some((pattern, router)) = self.wildcard.get(suffix) {
179                return Match::Wildcard {
180                    router,
181                    subdomain: subdomain.to_owned(),
182                    pattern: pattern.clone(),
183                };
184            }
185        }
186
187        match &self.fallback {
188            Some(router) => Match::Fallback(router),
189            None => Match::NotFound,
190        }
191    }
192}
193
194/// Information about a wildcard host match.
195///
196/// Inserted into request extensions when a request matches a wildcard
197/// pattern (e.g. `*.acme.com`). Not present for exact or fallback matches.
198///
199/// Use `Option<MatchedHost>` for handlers that serve both exact and wildcard
200/// routes.
201///
202/// # Example
203///
204/// ```rust,ignore
205/// async fn handler(matched: MatchedHost) -> impl IntoResponse {
206///     format!("subdomain: {}", matched.subdomain)
207/// }
208/// ```
209#[derive(Debug, Clone)]
210pub struct MatchedHost {
211    /// The subdomain that matched (e.g. `"tenant1"` from `"tenant1.acme.com"`).
212    pub subdomain: String,
213    /// The wildcard pattern that matched (e.g. `"*.acme.com"`).
214    pub pattern: String,
215}
216
217impl<S> FromRequestParts<S> for MatchedHost
218where
219    S: Send + Sync,
220{
221    type Rejection = Error;
222
223    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
224        parts
225            .extensions
226            .get::<MatchedHost>()
227            .cloned()
228            .ok_or_else(|| Error::internal("internal routing error"))
229    }
230}
231
232impl<S> OptionalFromRequestParts<S> for MatchedHost
233where
234    S: Send + Sync,
235{
236    type Rejection = Error;
237
238    async fn from_request_parts(
239        parts: &mut Parts,
240        _state: &S,
241    ) -> Result<Option<Self>, Self::Rejection> {
242        Ok(parts.extensions.get::<MatchedHost>().cloned())
243    }
244}
245
246/// Newtype around `Arc<HostRouterInner>` so we can implement `tower::Service`
247/// without hitting orphan rules. Each `call()` does a cheap `Arc::clone`
248/// instead of cloning all `HashMap`s.
249#[derive(Clone)]
250struct HostRouterService(Arc<HostRouterInner>);
251
252impl Service<Request<Body>> for HostRouterService {
253    type Response = http::Response<Body>;
254    type Error = Infallible;
255    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
256
257    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
258        Poll::Ready(Ok(()))
259    }
260
261    fn call(&mut self, req: Request<Body>) -> Self::Future {
262        let inner = Arc::clone(&self.0);
263
264        Box::pin(async move {
265            let (mut parts, body) = req.into_parts();
266
267            let host = match resolve_host(&parts) {
268                Ok(h) => h,
269                Err(e) => return Ok(e.into_response()),
270            };
271
272            let router = match inner.match_host(&host) {
273                Match::Exact(router) | Match::Fallback(router) => router,
274                Match::Wildcard {
275                    router,
276                    subdomain,
277                    pattern,
278                } => {
279                    parts.extensions.insert(MatchedHost { subdomain, pattern });
280                    router
281                }
282                Match::NotFound => {
283                    return Ok(Error::not_found("no route for host").into_response());
284                }
285            };
286
287            let req = Request::from_parts(parts, body);
288            Ok(router.clone().call(req).await.into_response())
289        })
290    }
291}
292
293impl From<HostRouter> for axum::Router {
294    fn from(host_router: HostRouter) -> axum::Router {
295        axum::Router::new().fallback_service(HostRouterService(host_router.inner))
296    }
297}
298
299/// Resolve the effective host from a request, checking proxy headers first.
300///
301/// Checks in order:
302/// 1. `Forwarded` header (RFC 7239) — `host=` directive
303/// 2. `X-Forwarded-Host` header
304/// 3. `Host` header
305///
306/// After extraction the value is lowercased and any trailing port is stripped.
307fn resolve_host(parts: &Parts) -> Result<String, Error> {
308    const HOST_DIRECTIVE: &str = "host=";
309
310    if let Some(fwd) = parts.headers.get("forwarded")
311        && let Ok(fwd_str) = fwd.to_str()
312    {
313        // Comma-separated entries represent multiple hops; only the first is relevant.
314        let first_element = fwd_str.split(',').next().unwrap_or(fwd_str);
315        for directive in first_element.split(';') {
316            let directive = directive.trim();
317            // RFC 7239: directive names are case-insensitive
318            if directive.len() >= HOST_DIRECTIVE.len()
319                && directive[..HOST_DIRECTIVE.len()].eq_ignore_ascii_case(HOST_DIRECTIVE)
320            {
321                let host = directive[HOST_DIRECTIVE.len()..].trim();
322                if !host.is_empty() {
323                    return Ok(strip_port(host).to_lowercase());
324                }
325            }
326        }
327    }
328
329    if let Some(xfh) = parts.headers.get("x-forwarded-host")
330        && let Ok(host) = xfh.to_str()
331    {
332        let host = host.trim();
333        if !host.is_empty() {
334            return Ok(strip_port(host).to_lowercase());
335        }
336    }
337
338    if let Some(h) = parts.headers.get(http::header::HOST)
339        && let Ok(host) = h.to_str()
340    {
341        let host = host.trim();
342        if !host.is_empty() {
343            return Ok(strip_port(host).to_lowercase());
344        }
345    }
346
347    Err(Error::bad_request("missing or invalid Host header"))
348}
349
350/// Strip an optional `:port` suffix from a host string.
351///
352/// Assumes RFC 7230 formatting: IPv6 addresses must be bracketed (e.g.
353/// `[::1]:8080`), so a bare `::1` would not be correctly handled. In
354/// practice, all valid Host header values follow this convention.
355fn strip_port(host: &str) -> &str {
356    match host.rfind(':') {
357        Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
358        _ => host,
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    fn parts_with_headers(headers: &[(&str, &str)]) -> Parts {
367        let mut builder = http::Request::builder();
368        for &(name, value) in headers {
369            builder = builder.header(name, value);
370        }
371        let (parts, _) = builder.body(()).unwrap().into_parts();
372        parts
373    }
374
375    #[test]
376    fn resolve_from_host_header() {
377        let parts = parts_with_headers(&[("host", "acme.com")]);
378        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
379    }
380
381    #[test]
382    fn resolve_strips_port() {
383        let parts = parts_with_headers(&[("host", "acme.com:8080")]);
384        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
385    }
386
387    #[test]
388    fn resolve_lowercases() {
389        let parts = parts_with_headers(&[("host", "ACME.COM")]);
390        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
391    }
392
393    #[test]
394    fn resolve_x_forwarded_host_over_host() {
395        let parts =
396            parts_with_headers(&[("host", "proxy.internal"), ("x-forwarded-host", "acme.com")]);
397        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
398    }
399
400    #[test]
401    fn resolve_forwarded_over_x_forwarded_host() {
402        let parts = parts_with_headers(&[
403            ("host", "proxy.internal"),
404            ("x-forwarded-host", "xfh.com"),
405            ("forwarded", "for=1.2.3.4; host=acme.com; proto=https"),
406        ]);
407        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
408    }
409
410    #[test]
411    fn resolve_forwarded_strips_port() {
412        let parts = parts_with_headers(&[("forwarded", "host=acme.com:443")]);
413        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
414    }
415
416    #[test]
417    fn resolve_x_forwarded_host_strips_port() {
418        let parts = parts_with_headers(&[("x-forwarded-host", "acme.com:8080")]);
419        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
420    }
421
422    #[test]
423    fn resolve_missing_all_headers_returns_400() {
424        let parts = parts_with_headers(&[]);
425        let err = resolve_host(&parts).unwrap_err();
426        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
427    }
428
429    #[test]
430    fn resolve_forwarded_case_insensitive_host_directive() {
431        let parts = parts_with_headers(&[("forwarded", "for=1.2.3.4; Host=acme.com")]);
432        assert_eq!(resolve_host(&parts).unwrap(), "acme.com");
433    }
434
435    #[test]
436    fn resolve_forwarded_without_host_falls_through() {
437        let parts = parts_with_headers(&[
438            ("forwarded", "for=1.2.3.4; proto=https"),
439            ("host", "fallback.com"),
440        ]);
441        assert_eq!(resolve_host(&parts).unwrap(), "fallback.com");
442    }
443
444    // ── Matching ──────────────────────────────────────────────
445
446    fn router_with_body(body: &'static str) -> Router {
447        Router::new().route("/", axum::routing::get(move || async move { body }))
448    }
449
450    #[test]
451    fn match_exact() {
452        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
453        assert!(matches!(hr.inner.match_host("acme.com"), Match::Exact(_)));
454    }
455
456    #[test]
457    fn match_wildcard() {
458        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
459        match hr.inner.match_host("tenant1.acme.com") {
460            Match::Wildcard {
461                subdomain, pattern, ..
462            } => {
463                assert_eq!(subdomain, "tenant1");
464                assert_eq!(pattern, "*.acme.com");
465            }
466            other => panic!("expected Wildcard, got {}", match_name(&other)),
467        }
468    }
469
470    #[test]
471    fn exact_wins_over_wildcard() {
472        let hr = HostRouter::new()
473            .host("app.acme.com", router_with_body("admin"))
474            .host("*.acme.com", router_with_body("tenant"));
475        assert!(matches!(
476            hr.inner.match_host("app.acme.com"),
477            Match::Exact(_)
478        ));
479    }
480
481    #[test]
482    fn bare_domain_does_not_match_wildcard() {
483        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
484        assert!(matches!(hr.inner.match_host("acme.com"), Match::NotFound));
485    }
486
487    #[test]
488    fn multi_level_subdomain_does_not_match_wildcard() {
489        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
490        // "a.b.acme.com" splits to subdomain="a", suffix="b.acme.com" — not in wildcard map
491        assert!(matches!(
492            hr.inner.match_host("a.b.acme.com"),
493            Match::NotFound
494        ));
495    }
496
497    #[test]
498    fn fallback_when_no_match() {
499        let hr = HostRouter::new()
500            .host("acme.com", router_with_body("landing"))
501            .fallback(router_with_body("fallback"));
502        assert!(matches!(
503            hr.inner.match_host("other.com"),
504            Match::Fallback(_)
505        ));
506    }
507
508    #[test]
509    fn not_found_when_no_match_and_no_fallback() {
510        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
511        assert!(matches!(hr.inner.match_host("other.com"), Match::NotFound));
512    }
513
514    fn match_name(m: &Match<'_>) -> &'static str {
515        match m {
516            Match::Exact(_) => "Exact",
517            Match::Wildcard { .. } => "Wildcard",
518            Match::Fallback(_) => "Fallback",
519            Match::NotFound => "NotFound",
520        }
521    }
522
523    // ── Construction panics ───────────────────────────────────
524
525    #[test]
526    #[should_panic(expected = "duplicate exact host")]
527    fn panic_on_duplicate_exact() {
528        HostRouter::new()
529            .host("acme.com", router_with_body("a"))
530            .host("acme.com", router_with_body("b"));
531    }
532
533    #[test]
534    #[should_panic(expected = "duplicate wildcard suffix")]
535    fn panic_on_duplicate_wildcard() {
536        HostRouter::new()
537            .host("*.acme.com", router_with_body("a"))
538            .host("*.acme.com", router_with_body("b"));
539    }
540
541    #[test]
542    #[should_panic(expected = "suffix must contain at least one dot")]
543    fn panic_on_tld_wildcard() {
544        HostRouter::new().host("*.com", router_with_body("a"));
545    }
546
547    #[test]
548    #[should_panic(expected = "invalid wildcard pattern")]
549    fn panic_on_bare_star() {
550        HostRouter::new().host("*", router_with_body("a"));
551    }
552
553    #[test]
554    #[should_panic(expected = "empty suffix")]
555    fn panic_on_star_dot_only() {
556        HostRouter::new().host("*.", router_with_body("a"));
557    }
558
559    // ── MatchedHost extractor ─────────────────────────────────
560
561    #[tokio::test]
562    async fn extract_matched_host_present() {
563        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
564        parts.extensions.insert(MatchedHost {
565            subdomain: "tenant1".into(),
566            pattern: "*.acme.com".into(),
567        });
568
569        let result =
570            <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
571        let matched = result.unwrap();
572        assert_eq!(matched.subdomain, "tenant1");
573        assert_eq!(matched.pattern, "*.acme.com");
574    }
575
576    #[tokio::test]
577    async fn extract_matched_host_missing_returns_500() {
578        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
579
580        let result =
581            <MatchedHost as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
582        let err = result.unwrap_err();
583        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
584        assert_eq!(err.to_string(), "internal routing error");
585    }
586
587    #[tokio::test]
588    async fn optional_matched_host_none_when_missing() {
589        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
590
591        let result =
592            <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
593                .await;
594        assert!(result.unwrap().is_none());
595    }
596
597    #[tokio::test]
598    async fn optional_matched_host_some_when_present() {
599        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
600        parts.extensions.insert(MatchedHost {
601            subdomain: "t1".into(),
602            pattern: "*.acme.com".into(),
603        });
604
605        let result =
606            <MatchedHost as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &())
607                .await;
608        let matched = result.unwrap().unwrap();
609        assert_eq!(matched.subdomain, "t1");
610        assert_eq!(matched.pattern, "*.acme.com");
611    }
612
613    // ── Full dispatch tests ──────────────────────────────────
614
615    use axum::body::Body;
616    use http::{Request, StatusCode};
617    use tower::ServiceExt;
618
619    async fn response_body(resp: http::Response<Body>) -> String {
620        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
621            .await
622            .unwrap();
623        String::from_utf8(bytes.to_vec()).unwrap()
624    }
625
626    #[tokio::test]
627    async fn dispatch_exact_match() {
628        let hr = HostRouter::new()
629            .host("acme.com", router_with_body("landing"))
630            .host("app.acme.com", router_with_body("admin"));
631
632        let router: axum::Router = hr.into();
633        let req = Request::builder()
634            .uri("/")
635            .header("host", "acme.com")
636            .body(Body::empty())
637            .unwrap();
638
639        let resp = router.oneshot(req).await.unwrap();
640        assert_eq!(resp.status(), StatusCode::OK);
641        assert_eq!(response_body(resp).await, "landing");
642    }
643
644    #[tokio::test]
645    async fn dispatch_wildcard_match() {
646        let hr = HostRouter::new().host("*.acme.com", router_with_body("tenant"));
647
648        let router: axum::Router = hr.into();
649        let req = Request::builder()
650            .uri("/")
651            .header("host", "tenant1.acme.com")
652            .body(Body::empty())
653            .unwrap();
654
655        let resp = router.oneshot(req).await.unwrap();
656        assert_eq!(resp.status(), StatusCode::OK);
657        assert_eq!(response_body(resp).await, "tenant");
658    }
659
660    #[tokio::test]
661    async fn dispatch_wildcard_injects_matched_host() {
662        let tenant_router = Router::new().route(
663            "/",
664            axum::routing::get(|matched: MatchedHost| async move {
665                format!("{}:{}", matched.subdomain, matched.pattern)
666            }),
667        );
668
669        let hr = HostRouter::new().host("*.acme.com", tenant_router);
670
671        let router: axum::Router = hr.into();
672        let req = Request::builder()
673            .uri("/")
674            .header("host", "tenant1.acme.com")
675            .body(Body::empty())
676            .unwrap();
677
678        let resp = router.oneshot(req).await.unwrap();
679        assert_eq!(resp.status(), StatusCode::OK);
680        assert_eq!(response_body(resp).await, "tenant1:*.acme.com");
681    }
682
683    #[tokio::test]
684    async fn dispatch_exact_wins_over_wildcard() {
685        let hr = HostRouter::new()
686            .host("app.acme.com", router_with_body("admin"))
687            .host("*.acme.com", router_with_body("tenant"));
688
689        let router: axum::Router = hr.into();
690        let req = Request::builder()
691            .uri("/")
692            .header("host", "app.acme.com")
693            .body(Body::empty())
694            .unwrap();
695
696        let resp = router.oneshot(req).await.unwrap();
697        assert_eq!(resp.status(), StatusCode::OK);
698        assert_eq!(response_body(resp).await, "admin");
699    }
700
701    #[tokio::test]
702    async fn dispatch_fallback() {
703        let hr = HostRouter::new()
704            .host("acme.com", router_with_body("landing"))
705            .fallback(router_with_body("fallback"));
706
707        let router: axum::Router = hr.into();
708        let req = Request::builder()
709            .uri("/")
710            .header("host", "unknown.com")
711            .body(Body::empty())
712            .unwrap();
713
714        let resp = router.oneshot(req).await.unwrap();
715        assert_eq!(resp.status(), StatusCode::OK);
716        assert_eq!(response_body(resp).await, "fallback");
717    }
718
719    #[tokio::test]
720    async fn dispatch_404_no_match_no_fallback() {
721        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
722
723        let router: axum::Router = hr.into();
724        let req = Request::builder()
725            .uri("/")
726            .header("host", "unknown.com")
727            .body(Body::empty())
728            .unwrap();
729
730        let resp = router.oneshot(req).await.unwrap();
731        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
732    }
733
734    #[tokio::test]
735    async fn dispatch_400_missing_host() {
736        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
737
738        let router: axum::Router = hr.into();
739        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
740
741        let resp = router.oneshot(req).await.unwrap();
742        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
743    }
744
745    #[tokio::test]
746    async fn dispatch_via_x_forwarded_host() {
747        let hr = HostRouter::new().host("acme.com", router_with_body("landing"));
748
749        let router: axum::Router = hr.into();
750        let req = Request::builder()
751            .uri("/")
752            .header("host", "proxy.internal")
753            .header("x-forwarded-host", "acme.com")
754            .body(Body::empty())
755            .unwrap();
756
757        let resp = router.oneshot(req).await.unwrap();
758        assert_eq!(resp.status(), StatusCode::OK);
759        assert_eq!(response_body(resp).await, "landing");
760    }
761}