Skip to main content

modo/tenant/
strategy.rs

1use std::future::Future;
2use std::pin::pin;
3use std::task::{Context, Poll, Wake};
4
5use crate::{Error, Result};
6
7use super::{TenantId, traits::TenantStrategy};
8
9/// Extract Host header value, strip port if present.
10fn host_from_parts(parts: &http::request::Parts) -> Result<String> {
11    let host = parts
12        .headers
13        .get(http::header::HOST)
14        .ok_or_else(|| Error::bad_request("missing Host header"))?
15        .to_str()
16        .map_err(|_| Error::bad_request("invalid Host header"))?;
17
18    // Strip port
19    let host = match host.rfind(':') {
20        Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
21        _ => host,
22    };
23
24    Ok(host.to_lowercase())
25}
26
27// ---------------------------------------------------------------------------
28// Strategy 1: Subdomain
29// ---------------------------------------------------------------------------
30
31/// Extracts tenant slug from a single-level subdomain relative to a base domain.
32///
33/// Created by [`subdomain()`]. Produces [`TenantId::Slug`].
34///
35/// Multi-level subdomains (e.g., `a.b.base.com`) and bare base-domain
36/// requests are rejected with 400 Bad Request.
37pub struct SubdomainStrategy {
38    base_domain: String,
39}
40
41impl SubdomainStrategy {
42    fn new(base_domain: &str) -> Self {
43        Self {
44            base_domain: base_domain.to_lowercase(),
45        }
46    }
47}
48
49impl TenantStrategy for SubdomainStrategy {
50    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
51        let host = host_from_parts(parts)?;
52        let suffix = format!(".{}", self.base_domain);
53
54        if !host.ends_with(&suffix) {
55            return Err(Error::bad_request("host is not a subdomain of base domain"));
56        }
57
58        let subdomain = &host[..host.len() - suffix.len()];
59
60        if subdomain.is_empty() {
61            return Err(Error::bad_request("no subdomain in host"));
62        }
63
64        // Only one level allowed
65        if subdomain.contains('.') {
66            return Err(Error::bad_request("multi-level subdomains not allowed"));
67        }
68
69        Ok(TenantId::Slug(subdomain.to_string()))
70    }
71}
72
73/// Returns a strategy that extracts the tenant slug from a subdomain of `base_domain`.
74pub fn subdomain(base_domain: &str) -> SubdomainStrategy {
75    SubdomainStrategy::new(base_domain)
76}
77
78// ---------------------------------------------------------------------------
79// Strategy 2: Domain
80// ---------------------------------------------------------------------------
81
82/// Extracts tenant identifier from the full domain name in the `Host` header.
83///
84/// Created by [`domain()`]. Produces [`TenantId::Domain`].
85pub struct DomainStrategy;
86
87impl TenantStrategy for DomainStrategy {
88    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
89        let host = host_from_parts(parts)?;
90        Ok(TenantId::Domain(host))
91    }
92}
93
94/// Returns a strategy that uses the full domain as the tenant identifier.
95pub fn domain() -> DomainStrategy {
96    DomainStrategy
97}
98
99// ---------------------------------------------------------------------------
100// Strategy 3: Subdomain or Domain
101// ---------------------------------------------------------------------------
102
103/// Extracts tenant from subdomain (as slug) or falls back to the full domain (as custom domain).
104///
105/// Created by [`subdomain_or_domain()`]. Produces [`TenantId::Slug`] or [`TenantId::Domain`].
106///
107/// - Single-level subdomain of base -> [`TenantId::Slug`]
108/// - Unrelated host -> [`TenantId::Domain`] (custom domain)
109/// - Base domain exactly -> 400 Bad Request
110/// - Multi-level subdomain -> 400 Bad Request
111pub struct SubdomainOrDomainStrategy {
112    base_domain: String,
113}
114
115impl SubdomainOrDomainStrategy {
116    fn new(base_domain: &str) -> Self {
117        Self {
118            base_domain: base_domain.to_lowercase(),
119        }
120    }
121}
122
123impl TenantStrategy for SubdomainOrDomainStrategy {
124    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
125        let host = host_from_parts(parts)?;
126        let suffix = format!(".{}", self.base_domain);
127
128        if host == self.base_domain {
129            return Err(Error::bad_request(
130                "base domain is not a valid tenant identifier",
131            ));
132        }
133
134        if host.ends_with(&suffix) {
135            let subdomain = &host[..host.len() - suffix.len()];
136            if subdomain.is_empty() {
137                return Err(Error::bad_request("no subdomain in host"));
138            }
139            if subdomain.contains('.') {
140                return Err(Error::bad_request("multi-level subdomains not allowed"));
141            }
142            Ok(TenantId::Slug(subdomain.to_string()))
143        } else {
144            Ok(TenantId::Domain(host))
145        }
146    }
147}
148
149/// Returns a strategy that extracts from a subdomain, falling back to the full domain.
150pub fn subdomain_or_domain(base_domain: &str) -> SubdomainOrDomainStrategy {
151    SubdomainOrDomainStrategy::new(base_domain)
152}
153
154// ---------------------------------------------------------------------------
155// Strategy 4: Header
156// ---------------------------------------------------------------------------
157
158/// Extracts tenant identifier from a named request header.
159///
160/// Created by [`header()`]. Produces [`TenantId::Id`].
161pub struct HeaderStrategy {
162    header_name: http::HeaderName,
163}
164
165impl HeaderStrategy {
166    fn new(name: &str) -> Self {
167        Self {
168            header_name: http::HeaderName::from_bytes(name.as_bytes())
169                .expect("invalid header name"),
170        }
171    }
172}
173
174impl TenantStrategy for HeaderStrategy {
175    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
176        let value = parts
177            .headers
178            .get(&self.header_name)
179            .ok_or_else(|| Error::bad_request(format!("missing {} header", self.header_name)))?
180            .to_str()
181            .map_err(|_| {
182                Error::bad_request(format!("invalid {} header value", self.header_name))
183            })?;
184        Ok(TenantId::Id(value.to_string()))
185    }
186}
187
188/// Returns a strategy that reads the tenant identifier from the given request header.
189///
190/// # Panics
191///
192/// Panics if `name` is not a valid HTTP header name.
193pub fn header(name: &str) -> HeaderStrategy {
194    HeaderStrategy::new(name)
195}
196
197// ---------------------------------------------------------------------------
198// Strategy 5: API Key Header
199// ---------------------------------------------------------------------------
200
201/// Extracts tenant API key from a named request header.
202///
203/// Created by [`api_key_header()`]. Produces [`TenantId::ApiKey`], which is
204/// **redacted** in `Display` and `Debug` output.
205pub struct ApiKeyHeaderStrategy {
206    header_name: http::HeaderName,
207}
208
209impl ApiKeyHeaderStrategy {
210    fn new(name: &str) -> Self {
211        Self {
212            header_name: http::HeaderName::from_bytes(name.as_bytes())
213                .expect("invalid header name"),
214        }
215    }
216}
217
218impl TenantStrategy for ApiKeyHeaderStrategy {
219    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
220        let value = parts
221            .headers
222            .get(&self.header_name)
223            .ok_or_else(|| Error::bad_request(format!("missing {} header", self.header_name)))?
224            .to_str()
225            .map_err(|_| {
226                Error::bad_request(format!("invalid {} header value", self.header_name))
227            })?;
228        Ok(TenantId::ApiKey(value.to_string()))
229    }
230}
231
232/// Returns a strategy that reads an API key from the given request header.
233///
234/// # Panics
235///
236/// Panics if `name` is not a valid HTTP header name.
237pub fn api_key_header(name: &str) -> ApiKeyHeaderStrategy {
238    ApiKeyHeaderStrategy::new(name)
239}
240
241// ---------------------------------------------------------------------------
242// Strategy 6: Path Prefix
243// ---------------------------------------------------------------------------
244
245/// Extracts tenant slug from a path prefix and rewrites the URI.
246///
247/// Created by [`path_prefix()`]. Produces [`TenantId::Slug`].
248///
249/// Strips the prefix and tenant slug from the URI before the request reaches
250/// handlers, preserving the query string. For example, with prefix `/org`,
251/// a request to `/org/acme/settings?tab=billing` becomes `/settings?tab=billing`
252/// and the slug `acme` is extracted.
253pub struct PathPrefixStrategy {
254    prefix: String,
255}
256
257impl PathPrefixStrategy {
258    fn new(prefix: &str) -> Self {
259        Self {
260            prefix: prefix.to_string(),
261        }
262    }
263}
264
265impl TenantStrategy for PathPrefixStrategy {
266    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
267        let path = parts.uri.path();
268
269        if !path.starts_with(&self.prefix) {
270            return Err(Error::bad_request(format!(
271                "path does not start with prefix '{}'",
272                self.prefix
273            )));
274        }
275
276        let after_prefix = &path[self.prefix.len()..];
277
278        // Must have /slug after prefix
279        let after_prefix = after_prefix
280            .strip_prefix('/')
281            .ok_or_else(|| Error::bad_request("no tenant segment after prefix"))?;
282
283        if after_prefix.is_empty() {
284            return Err(Error::bad_request("no tenant segment after prefix"));
285        }
286
287        // Split slug from remaining path
288        let (slug, remaining) = match after_prefix.find('/') {
289            Some(pos) => (&after_prefix[..pos], &after_prefix[pos..]),
290            None => (after_prefix, "/"),
291        };
292
293        if slug.is_empty() {
294            return Err(Error::bad_request("empty tenant slug in path"));
295        }
296
297        // Collect into owned values before reassigning parts.uri
298        let slug = slug.to_string();
299        let remaining = remaining.to_string();
300
301        // Rewrite URI -- preserve query string
302        let new_path_and_query = match parts.uri.query() {
303            Some(q) => format!("{remaining}?{q}"),
304            None => remaining,
305        };
306        let new_uri = http::Uri::builder()
307            .path_and_query(new_path_and_query)
308            .build()
309            .map_err(|e| Error::internal(format!("failed to rewrite URI: {e}")))?;
310        parts.uri = new_uri;
311
312        Ok(TenantId::Slug(slug))
313    }
314}
315
316/// Returns a strategy that extracts a tenant slug from a path prefix and rewrites the URI.
317pub fn path_prefix(prefix: &str) -> PathPrefixStrategy {
318    PathPrefixStrategy::new(prefix)
319}
320
321// ---------------------------------------------------------------------------
322// Strategy 7: Path Parameter
323// ---------------------------------------------------------------------------
324
325/// Extracts tenant slug from a named axum path parameter.
326///
327/// Created by [`path_param()`]. Produces [`TenantId::Slug`].
328///
329/// This strategy requires `.route_layer()` instead of `.layer()` because
330/// axum path parameters are only available after route matching.
331pub struct PathParamStrategy {
332    param_name: String,
333}
334
335impl PathParamStrategy {
336    fn new(name: &str) -> Self {
337        Self {
338            param_name: name.to_string(),
339        }
340    }
341}
342
343/// A no-op `Wake` implementation used to synchronously poll trivially-ready futures.
344struct NoopWaker;
345
346impl Wake for NoopWaker {
347    fn wake(self: std::sync::Arc<Self>) {}
348}
349
350impl TenantStrategy for PathParamStrategy {
351    fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
352        // `RawPathParams::from_request_parts` is async in signature but performs
353        // no actual I/O -- it reads from extensions synchronously. We poll it
354        // once with a noop waker; it is always immediately ready.
355        use axum::extract::FromRequestParts;
356        use axum::extract::RawPathParams;
357
358        let waker = std::sync::Arc::new(NoopWaker).into();
359        let mut cx = Context::from_waker(&waker);
360
361        let mut fut = pin!(RawPathParams::from_request_parts(parts, &()));
362
363        let raw_params = match fut.as_mut().poll(&mut cx) {
364            Poll::Ready(Ok(params)) => params,
365            Poll::Ready(Err(_)) => {
366                return Err(Error::internal(
367                    "path parameters not available (use route_layer instead of layer)",
368                ));
369            }
370            Poll::Pending => {
371                return Err(Error::internal(
372                    "unexpected pending state extracting path params",
373                ));
374            }
375        };
376
377        for (key, value) in &raw_params {
378            if key == self.param_name {
379                return Ok(TenantId::Slug(value.to_string()));
380            }
381        }
382
383        Err(Error::internal(format!(
384            "path parameter '{}' not found in route",
385            self.param_name
386        )))
387    }
388}
389
390/// Returns a strategy that reads the tenant slug from a named path parameter.
391pub fn path_param(name: &str) -> PathParamStrategy {
392    PathParamStrategy::new(name)
393}
394
395// ===========================================================================
396// Tests
397// ===========================================================================
398
399#[cfg(test)]
400mod tests {
401    use http::StatusCode;
402
403    use super::*;
404
405    fn make_parts(host: Option<&str>, uri: &str) -> http::request::Parts {
406        let mut builder = http::Request::builder().uri(uri);
407        if let Some(h) = host {
408            builder = builder.header("host", h);
409        }
410        let (parts, _) = builder.body(()).unwrap().into_parts();
411        parts
412    }
413
414    // -- host_from_parts ----------------------------------------------------
415
416    #[test]
417    fn host_strips_port() {
418        let parts = make_parts(Some("acme.com:8080"), "/");
419        let host = host_from_parts(&parts).unwrap();
420        assert_eq!(host, "acme.com");
421    }
422
423    #[test]
424    fn host_missing_returns_error() {
425        let parts = make_parts(None, "/");
426        let err = host_from_parts(&parts).unwrap_err();
427        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
428        assert!(err.message().contains("missing Host header"));
429    }
430
431    // -- SubdomainStrategy --------------------------------------------------
432
433    #[test]
434    fn subdomain_valid() {
435        let s = subdomain("acme.com");
436        let mut parts = make_parts(Some("tenant1.acme.com"), "/");
437        let id = s.extract(&mut parts).unwrap();
438        assert_eq!(id, TenantId::Slug("tenant1".into()));
439    }
440
441    #[test]
442    fn subdomain_case_insensitive() {
443        let s = subdomain("acme.com");
444        let mut parts = make_parts(Some("TENANT1.ACME.COM"), "/");
445        let id = s.extract(&mut parts).unwrap();
446        assert_eq!(id, TenantId::Slug("tenant1".into()));
447    }
448
449    #[test]
450    fn subdomain_bare_base_domain_error() {
451        let s = subdomain("acme.com");
452        let mut parts = make_parts(Some("acme.com"), "/");
453        let err = s.extract(&mut parts).unwrap_err();
454        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
455    }
456
457    #[test]
458    fn subdomain_multi_level_error() {
459        let s = subdomain("acme.com");
460        let mut parts = make_parts(Some("a.b.acme.com"), "/");
461        let err = s.extract(&mut parts).unwrap_err();
462        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
463        assert!(err.message().contains("multi-level"));
464    }
465
466    #[test]
467    fn subdomain_multi_level_base_domain() {
468        let s = subdomain("app.acme.com");
469        let mut parts = make_parts(Some("tenant1.app.acme.com"), "/");
470        let id = s.extract(&mut parts).unwrap();
471        assert_eq!(id, TenantId::Slug("tenant1".into()));
472    }
473
474    #[test]
475    fn subdomain_port_stripped() {
476        let s = subdomain("acme.com");
477        let mut parts = make_parts(Some("tenant1.acme.com:3000"), "/");
478        let id = s.extract(&mut parts).unwrap();
479        assert_eq!(id, TenantId::Slug("tenant1".into()));
480    }
481
482    #[test]
483    fn subdomain_missing_host() {
484        let s = subdomain("acme.com");
485        let mut parts = make_parts(None, "/");
486        let err = s.extract(&mut parts).unwrap_err();
487        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
488    }
489
490    // -- DomainStrategy -----------------------------------------------------
491
492    #[test]
493    fn domain_valid() {
494        let s = domain();
495        let mut parts = make_parts(Some("custom.example.com"), "/");
496        let id = s.extract(&mut parts).unwrap();
497        assert_eq!(id, TenantId::Domain("custom.example.com".into()));
498    }
499
500    #[test]
501    fn domain_strips_port() {
502        let s = domain();
503        let mut parts = make_parts(Some("custom.example.com:443"), "/");
504        let id = s.extract(&mut parts).unwrap();
505        assert_eq!(id, TenantId::Domain("custom.example.com".into()));
506    }
507
508    #[test]
509    fn domain_missing_host() {
510        let s = domain();
511        let mut parts = make_parts(None, "/");
512        let err = s.extract(&mut parts).unwrap_err();
513        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
514    }
515
516    // -- SubdomainOrDomainStrategy ------------------------------------------
517
518    #[test]
519    fn subdomain_or_domain_subdomain_branch() {
520        let s = subdomain_or_domain("acme.com");
521        let mut parts = make_parts(Some("tenant1.acme.com"), "/");
522        let id = s.extract(&mut parts).unwrap();
523        assert_eq!(id, TenantId::Slug("tenant1".into()));
524    }
525
526    #[test]
527    fn subdomain_or_domain_custom_domain_branch() {
528        let s = subdomain_or_domain("acme.com");
529        let mut parts = make_parts(Some("custom.example.org"), "/");
530        let id = s.extract(&mut parts).unwrap();
531        assert_eq!(id, TenantId::Domain("custom.example.org".into()));
532    }
533
534    #[test]
535    fn subdomain_or_domain_base_domain_error() {
536        let s = subdomain_or_domain("acme.com");
537        let mut parts = make_parts(Some("acme.com"), "/");
538        let err = s.extract(&mut parts).unwrap_err();
539        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
540        assert!(err.message().contains("base domain"));
541    }
542
543    #[test]
544    fn subdomain_or_domain_multi_level_error() {
545        let s = subdomain_or_domain("acme.com");
546        let mut parts = make_parts(Some("a.b.acme.com"), "/");
547        let err = s.extract(&mut parts).unwrap_err();
548        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
549        assert!(err.message().contains("multi-level"));
550    }
551
552    #[test]
553    fn subdomain_or_domain_missing_host() {
554        let s = subdomain_or_domain("acme.com");
555        let mut parts = make_parts(None, "/");
556        let err = s.extract(&mut parts).unwrap_err();
557        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
558    }
559
560    // -- HeaderStrategy -----------------------------------------------------
561
562    #[test]
563    fn header_valid() {
564        let s = header("x-tenant-id");
565        let mut parts = make_parts(Some("localhost"), "/");
566        parts
567            .headers
568            .insert("x-tenant-id", "abc123".parse().unwrap());
569        let id = s.extract(&mut parts).unwrap();
570        assert_eq!(id, TenantId::Id("abc123".into()));
571    }
572
573    #[test]
574    fn header_missing_error() {
575        let s = header("x-tenant-id");
576        let mut parts = make_parts(Some("localhost"), "/");
577        let err = s.extract(&mut parts).unwrap_err();
578        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
579        assert!(err.message().contains("missing"));
580    }
581
582    #[test]
583    fn header_non_utf8_error() {
584        let s = header("x-tenant-id");
585        let mut parts = make_parts(Some("localhost"), "/");
586        parts.headers.insert(
587            "x-tenant-id",
588            http::HeaderValue::from_bytes(&[0x80, 0x81]).unwrap(),
589        );
590        let err = s.extract(&mut parts).unwrap_err();
591        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
592        assert!(err.message().contains("invalid"));
593    }
594
595    // -- ApiKeyHeaderStrategy -----------------------------------------------
596
597    #[test]
598    fn api_key_header_valid() {
599        let s = api_key_header("x-api-key");
600        let mut parts = make_parts(Some("localhost"), "/");
601        parts
602            .headers
603            .insert("x-api-key", "sk_live_abc".parse().unwrap());
604        let id = s.extract(&mut parts).unwrap();
605        assert_eq!(id, TenantId::ApiKey("sk_live_abc".into()));
606    }
607
608    #[test]
609    fn api_key_header_missing_error() {
610        let s = api_key_header("x-api-key");
611        let mut parts = make_parts(Some("localhost"), "/");
612        let err = s.extract(&mut parts).unwrap_err();
613        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
614        assert!(err.message().contains("missing"));
615    }
616
617    #[test]
618    fn api_key_header_non_utf8_error() {
619        let s = api_key_header("x-api-key");
620        let mut parts = make_parts(Some("localhost"), "/");
621        parts.headers.insert(
622            "x-api-key",
623            http::HeaderValue::from_bytes(&[0x80, 0x81]).unwrap(),
624        );
625        let err = s.extract(&mut parts).unwrap_err();
626        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
627        assert!(err.message().contains("invalid"));
628    }
629
630    // -- PathPrefixStrategy -------------------------------------------------
631
632    #[test]
633    fn path_prefix_valid() {
634        let s = path_prefix("/org");
635        let mut parts = make_parts(Some("localhost"), "/org/acme/dashboard/settings");
636        let id = s.extract(&mut parts).unwrap();
637        assert_eq!(id, TenantId::Slug("acme".into()));
638        assert_eq!(parts.uri.path(), "/dashboard/settings");
639    }
640
641    #[test]
642    fn path_prefix_only_slug() {
643        let s = path_prefix("/org");
644        let mut parts = make_parts(Some("localhost"), "/org/acme");
645        let id = s.extract(&mut parts).unwrap();
646        assert_eq!(id, TenantId::Slug("acme".into()));
647        assert_eq!(parts.uri.path(), "/");
648    }
649
650    #[test]
651    fn path_prefix_wrong_prefix_error() {
652        let s = path_prefix("/org");
653        let mut parts = make_parts(Some("localhost"), "/api/v1");
654        let err = s.extract(&mut parts).unwrap_err();
655        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
656        assert!(err.message().contains("prefix"));
657    }
658
659    #[test]
660    fn path_prefix_no_segment_error() {
661        let s = path_prefix("/org");
662        let mut parts = make_parts(Some("localhost"), "/org");
663        let err = s.extract(&mut parts).unwrap_err();
664        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
665    }
666
667    #[test]
668    fn path_prefix_no_segment_trailing_slash_error() {
669        let s = path_prefix("/org");
670        let mut parts = make_parts(Some("localhost"), "/org/");
671        let err = s.extract(&mut parts).unwrap_err();
672        assert_eq!(err.status(), StatusCode::BAD_REQUEST);
673    }
674
675    #[test]
676    fn path_prefix_preserves_query_string() {
677        let s = path_prefix("/org");
678        let mut parts = make_parts(Some("localhost"), "/org/acme/page?foo=bar&baz=1");
679        let id = s.extract(&mut parts).unwrap();
680        assert_eq!(id, TenantId::Slug("acme".into()));
681        assert_eq!(parts.uri.path(), "/page");
682        assert_eq!(parts.uri.query(), Some("foo=bar&baz=1"));
683    }
684
685    #[test]
686    fn path_prefix_empty_prefix() {
687        let s = path_prefix("");
688        let mut parts = make_parts(Some("localhost"), "/acme/page");
689        let id = s.extract(&mut parts).unwrap();
690        assert_eq!(id, TenantId::Slug("acme".into()));
691        assert_eq!(parts.uri.path(), "/page");
692    }
693
694    // -- PathParamStrategy --------------------------------------------------
695
696    #[tokio::test]
697    async fn path_param_extracts_from_route() {
698        use axum::Router;
699        use axum::routing::get;
700        use tower::ServiceExt as _;
701
702        use super::super::middleware as tenant_middleware;
703        use super::super::traits::{HasTenantId, TenantResolver};
704
705        #[derive(Clone, Debug)]
706        struct TestTenant {
707            slug: String,
708        }
709
710        impl HasTenantId for TestTenant {
711            fn tenant_id(&self) -> &str {
712                &self.slug
713            }
714        }
715
716        struct SlugResolver;
717        impl TenantResolver for SlugResolver {
718            type Tenant = TestTenant;
719            async fn resolve(&self, id: &TenantId) -> crate::Result<TestTenant> {
720                Ok(TestTenant {
721                    slug: id.as_str().to_string(),
722                })
723            }
724        }
725
726        // Handler is module-level async fn to satisfy axum Handler bounds
727        async fn handler(tenant: super::super::Tenant<TestTenant>) -> String {
728            format!("tenant:{}", tenant.slug)
729        }
730
731        let layer = tenant_middleware(path_param("tenant"), SlugResolver);
732        let app = Router::new()
733            .route("/{tenant}/action", get(handler))
734            .route_layer(layer);
735
736        let req = http::Request::builder()
737            .uri("/acme/action")
738            .body(axum::body::Body::empty())
739            .unwrap();
740        let resp = app.oneshot(req).await.unwrap();
741        assert_eq!(resp.status(), http::StatusCode::OK);
742
743        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
744            .await
745            .unwrap();
746        assert_eq!(&body[..], b"tenant:acme");
747    }
748
749    #[test]
750    fn path_param_missing_returns_error() {
751        let s = path_param("tenant");
752        let mut parts = make_parts(Some("localhost"), "/whatever");
753        // No path params in extensions — should return 500
754        let err = s.extract(&mut parts).unwrap_err();
755        assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
756    }
757}