1use std::future::Future;
2use std::pin::pin;
3use std::task::{Context, Poll, Wake};
4
5use crate::{Error, Result};
6
7use super::{TenantId, traits::TenantStrategy};
8
9fn host_from_parts(parts: &http::request::Parts) -> Result<String> {
11 let host = parts
12 .headers
13 .get(http::header::HOST)
14 .ok_or_else(|| Error::bad_request("missing Host header"))?
15 .to_str()
16 .map_err(|_| Error::bad_request("invalid Host header"))?;
17
18 let host = match host.rfind(':') {
20 Some(pos) if host[pos + 1..].bytes().all(|b| b.is_ascii_digit()) => &host[..pos],
21 _ => host,
22 };
23
24 Ok(host.to_lowercase())
25}
26
27pub struct SubdomainStrategy {
38 base_domain: String,
39}
40
41impl SubdomainStrategy {
42 fn new(base_domain: &str) -> Self {
43 Self {
44 base_domain: base_domain.to_lowercase(),
45 }
46 }
47}
48
49impl TenantStrategy for SubdomainStrategy {
50 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
51 let host = host_from_parts(parts)?;
52 let suffix = format!(".{}", self.base_domain);
53
54 if !host.ends_with(&suffix) {
55 return Err(Error::bad_request("host is not a subdomain of base domain"));
56 }
57
58 let subdomain = &host[..host.len() - suffix.len()];
59
60 if subdomain.is_empty() {
61 return Err(Error::bad_request("no subdomain in host"));
62 }
63
64 if subdomain.contains('.') {
66 return Err(Error::bad_request("multi-level subdomains not allowed"));
67 }
68
69 Ok(TenantId::Slug(subdomain.to_string()))
70 }
71}
72
73pub fn subdomain(base_domain: &str) -> SubdomainStrategy {
75 SubdomainStrategy::new(base_domain)
76}
77
78pub struct DomainStrategy;
86
87impl TenantStrategy for DomainStrategy {
88 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
89 let host = host_from_parts(parts)?;
90 Ok(TenantId::Domain(host))
91 }
92}
93
94pub fn domain() -> DomainStrategy {
96 DomainStrategy
97}
98
99pub struct SubdomainOrDomainStrategy {
112 base_domain: String,
113}
114
115impl SubdomainOrDomainStrategy {
116 fn new(base_domain: &str) -> Self {
117 Self {
118 base_domain: base_domain.to_lowercase(),
119 }
120 }
121}
122
123impl TenantStrategy for SubdomainOrDomainStrategy {
124 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
125 let host = host_from_parts(parts)?;
126 let suffix = format!(".{}", self.base_domain);
127
128 if host == self.base_domain {
129 return Err(Error::bad_request(
130 "base domain is not a valid tenant identifier",
131 ));
132 }
133
134 if host.ends_with(&suffix) {
135 let subdomain = &host[..host.len() - suffix.len()];
136 if subdomain.is_empty() {
137 return Err(Error::bad_request("no subdomain in host"));
138 }
139 if subdomain.contains('.') {
140 return Err(Error::bad_request("multi-level subdomains not allowed"));
141 }
142 Ok(TenantId::Slug(subdomain.to_string()))
143 } else {
144 Ok(TenantId::Domain(host))
145 }
146 }
147}
148
149pub fn subdomain_or_domain(base_domain: &str) -> SubdomainOrDomainStrategy {
151 SubdomainOrDomainStrategy::new(base_domain)
152}
153
154pub struct HeaderStrategy {
162 header_name: http::HeaderName,
163}
164
165impl HeaderStrategy {
166 fn new(name: &str) -> Self {
167 Self {
168 header_name: http::HeaderName::from_bytes(name.as_bytes())
169 .expect("invalid header name"),
170 }
171 }
172}
173
174impl TenantStrategy for HeaderStrategy {
175 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
176 let value = parts
177 .headers
178 .get(&self.header_name)
179 .ok_or_else(|| Error::bad_request(format!("missing {} header", self.header_name)))?
180 .to_str()
181 .map_err(|_| {
182 Error::bad_request(format!("invalid {} header value", self.header_name))
183 })?;
184 Ok(TenantId::Id(value.to_string()))
185 }
186}
187
188pub fn header(name: &str) -> HeaderStrategy {
194 HeaderStrategy::new(name)
195}
196
197pub struct ApiKeyHeaderStrategy {
206 header_name: http::HeaderName,
207}
208
209impl ApiKeyHeaderStrategy {
210 fn new(name: &str) -> Self {
211 Self {
212 header_name: http::HeaderName::from_bytes(name.as_bytes())
213 .expect("invalid header name"),
214 }
215 }
216}
217
218impl TenantStrategy for ApiKeyHeaderStrategy {
219 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
220 let value = parts
221 .headers
222 .get(&self.header_name)
223 .ok_or_else(|| Error::bad_request(format!("missing {} header", self.header_name)))?
224 .to_str()
225 .map_err(|_| {
226 Error::bad_request(format!("invalid {} header value", self.header_name))
227 })?;
228 Ok(TenantId::ApiKey(value.to_string()))
229 }
230}
231
232pub fn api_key_header(name: &str) -> ApiKeyHeaderStrategy {
238 ApiKeyHeaderStrategy::new(name)
239}
240
241pub struct PathPrefixStrategy {
254 prefix: String,
255}
256
257impl PathPrefixStrategy {
258 fn new(prefix: &str) -> Self {
259 Self {
260 prefix: prefix.to_string(),
261 }
262 }
263}
264
265impl TenantStrategy for PathPrefixStrategy {
266 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
267 let path = parts.uri.path();
268
269 if !path.starts_with(&self.prefix) {
270 return Err(Error::bad_request(format!(
271 "path does not start with prefix '{}'",
272 self.prefix
273 )));
274 }
275
276 let after_prefix = &path[self.prefix.len()..];
277
278 let after_prefix = after_prefix
280 .strip_prefix('/')
281 .ok_or_else(|| Error::bad_request("no tenant segment after prefix"))?;
282
283 if after_prefix.is_empty() {
284 return Err(Error::bad_request("no tenant segment after prefix"));
285 }
286
287 let (slug, remaining) = match after_prefix.find('/') {
289 Some(pos) => (&after_prefix[..pos], &after_prefix[pos..]),
290 None => (after_prefix, "/"),
291 };
292
293 if slug.is_empty() {
294 return Err(Error::bad_request("empty tenant slug in path"));
295 }
296
297 let slug = slug.to_string();
299 let remaining = remaining.to_string();
300
301 let new_path_and_query = match parts.uri.query() {
303 Some(q) => format!("{remaining}?{q}"),
304 None => remaining,
305 };
306 let new_uri = http::Uri::builder()
307 .path_and_query(new_path_and_query)
308 .build()
309 .map_err(|e| Error::internal(format!("failed to rewrite URI: {e}")))?;
310 parts.uri = new_uri;
311
312 Ok(TenantId::Slug(slug))
313 }
314}
315
316pub fn path_prefix(prefix: &str) -> PathPrefixStrategy {
318 PathPrefixStrategy::new(prefix)
319}
320
321pub struct PathParamStrategy {
332 param_name: String,
333}
334
335impl PathParamStrategy {
336 fn new(name: &str) -> Self {
337 Self {
338 param_name: name.to_string(),
339 }
340 }
341}
342
343struct NoopWaker;
345
346impl Wake for NoopWaker {
347 fn wake(self: std::sync::Arc<Self>) {}
348}
349
350impl TenantStrategy for PathParamStrategy {
351 fn extract(&self, parts: &mut http::request::Parts) -> Result<TenantId> {
352 use axum::extract::FromRequestParts;
356 use axum::extract::RawPathParams;
357
358 let waker = std::sync::Arc::new(NoopWaker).into();
359 let mut cx = Context::from_waker(&waker);
360
361 let mut fut = pin!(RawPathParams::from_request_parts(parts, &()));
362
363 let raw_params = match fut.as_mut().poll(&mut cx) {
364 Poll::Ready(Ok(params)) => params,
365 Poll::Ready(Err(_)) => {
366 return Err(Error::internal(
367 "path parameters not available (use route_layer instead of layer)",
368 ));
369 }
370 Poll::Pending => {
371 return Err(Error::internal(
372 "unexpected pending state extracting path params",
373 ));
374 }
375 };
376
377 for (key, value) in &raw_params {
378 if key == self.param_name {
379 return Ok(TenantId::Slug(value.to_string()));
380 }
381 }
382
383 Err(Error::internal(format!(
384 "path parameter '{}' not found in route",
385 self.param_name
386 )))
387 }
388}
389
390pub fn path_param(name: &str) -> PathParamStrategy {
392 PathParamStrategy::new(name)
393}
394
395#[cfg(test)]
400mod tests {
401 use http::StatusCode;
402
403 use super::*;
404
405 fn make_parts(host: Option<&str>, uri: &str) -> http::request::Parts {
406 let mut builder = http::Request::builder().uri(uri);
407 if let Some(h) = host {
408 builder = builder.header("host", h);
409 }
410 let (parts, _) = builder.body(()).unwrap().into_parts();
411 parts
412 }
413
414 #[test]
417 fn host_strips_port() {
418 let parts = make_parts(Some("acme.com:8080"), "/");
419 let host = host_from_parts(&parts).unwrap();
420 assert_eq!(host, "acme.com");
421 }
422
423 #[test]
424 fn host_missing_returns_error() {
425 let parts = make_parts(None, "/");
426 let err = host_from_parts(&parts).unwrap_err();
427 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
428 assert!(err.message().contains("missing Host header"));
429 }
430
431 #[test]
434 fn subdomain_valid() {
435 let s = subdomain("acme.com");
436 let mut parts = make_parts(Some("tenant1.acme.com"), "/");
437 let id = s.extract(&mut parts).unwrap();
438 assert_eq!(id, TenantId::Slug("tenant1".into()));
439 }
440
441 #[test]
442 fn subdomain_case_insensitive() {
443 let s = subdomain("acme.com");
444 let mut parts = make_parts(Some("TENANT1.ACME.COM"), "/");
445 let id = s.extract(&mut parts).unwrap();
446 assert_eq!(id, TenantId::Slug("tenant1".into()));
447 }
448
449 #[test]
450 fn subdomain_bare_base_domain_error() {
451 let s = subdomain("acme.com");
452 let mut parts = make_parts(Some("acme.com"), "/");
453 let err = s.extract(&mut parts).unwrap_err();
454 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
455 }
456
457 #[test]
458 fn subdomain_multi_level_error() {
459 let s = subdomain("acme.com");
460 let mut parts = make_parts(Some("a.b.acme.com"), "/");
461 let err = s.extract(&mut parts).unwrap_err();
462 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
463 assert!(err.message().contains("multi-level"));
464 }
465
466 #[test]
467 fn subdomain_multi_level_base_domain() {
468 let s = subdomain("app.acme.com");
469 let mut parts = make_parts(Some("tenant1.app.acme.com"), "/");
470 let id = s.extract(&mut parts).unwrap();
471 assert_eq!(id, TenantId::Slug("tenant1".into()));
472 }
473
474 #[test]
475 fn subdomain_port_stripped() {
476 let s = subdomain("acme.com");
477 let mut parts = make_parts(Some("tenant1.acme.com:3000"), "/");
478 let id = s.extract(&mut parts).unwrap();
479 assert_eq!(id, TenantId::Slug("tenant1".into()));
480 }
481
482 #[test]
483 fn subdomain_missing_host() {
484 let s = subdomain("acme.com");
485 let mut parts = make_parts(None, "/");
486 let err = s.extract(&mut parts).unwrap_err();
487 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
488 }
489
490 #[test]
493 fn domain_valid() {
494 let s = domain();
495 let mut parts = make_parts(Some("custom.example.com"), "/");
496 let id = s.extract(&mut parts).unwrap();
497 assert_eq!(id, TenantId::Domain("custom.example.com".into()));
498 }
499
500 #[test]
501 fn domain_strips_port() {
502 let s = domain();
503 let mut parts = make_parts(Some("custom.example.com:443"), "/");
504 let id = s.extract(&mut parts).unwrap();
505 assert_eq!(id, TenantId::Domain("custom.example.com".into()));
506 }
507
508 #[test]
509 fn domain_missing_host() {
510 let s = domain();
511 let mut parts = make_parts(None, "/");
512 let err = s.extract(&mut parts).unwrap_err();
513 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
514 }
515
516 #[test]
519 fn subdomain_or_domain_subdomain_branch() {
520 let s = subdomain_or_domain("acme.com");
521 let mut parts = make_parts(Some("tenant1.acme.com"), "/");
522 let id = s.extract(&mut parts).unwrap();
523 assert_eq!(id, TenantId::Slug("tenant1".into()));
524 }
525
526 #[test]
527 fn subdomain_or_domain_custom_domain_branch() {
528 let s = subdomain_or_domain("acme.com");
529 let mut parts = make_parts(Some("custom.example.org"), "/");
530 let id = s.extract(&mut parts).unwrap();
531 assert_eq!(id, TenantId::Domain("custom.example.org".into()));
532 }
533
534 #[test]
535 fn subdomain_or_domain_base_domain_error() {
536 let s = subdomain_or_domain("acme.com");
537 let mut parts = make_parts(Some("acme.com"), "/");
538 let err = s.extract(&mut parts).unwrap_err();
539 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
540 assert!(err.message().contains("base domain"));
541 }
542
543 #[test]
544 fn subdomain_or_domain_multi_level_error() {
545 let s = subdomain_or_domain("acme.com");
546 let mut parts = make_parts(Some("a.b.acme.com"), "/");
547 let err = s.extract(&mut parts).unwrap_err();
548 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
549 assert!(err.message().contains("multi-level"));
550 }
551
552 #[test]
553 fn subdomain_or_domain_missing_host() {
554 let s = subdomain_or_domain("acme.com");
555 let mut parts = make_parts(None, "/");
556 let err = s.extract(&mut parts).unwrap_err();
557 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
558 }
559
560 #[test]
563 fn header_valid() {
564 let s = header("x-tenant-id");
565 let mut parts = make_parts(Some("localhost"), "/");
566 parts
567 .headers
568 .insert("x-tenant-id", "abc123".parse().unwrap());
569 let id = s.extract(&mut parts).unwrap();
570 assert_eq!(id, TenantId::Id("abc123".into()));
571 }
572
573 #[test]
574 fn header_missing_error() {
575 let s = header("x-tenant-id");
576 let mut parts = make_parts(Some("localhost"), "/");
577 let err = s.extract(&mut parts).unwrap_err();
578 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
579 assert!(err.message().contains("missing"));
580 }
581
582 #[test]
583 fn header_non_utf8_error() {
584 let s = header("x-tenant-id");
585 let mut parts = make_parts(Some("localhost"), "/");
586 parts.headers.insert(
587 "x-tenant-id",
588 http::HeaderValue::from_bytes(&[0x80, 0x81]).unwrap(),
589 );
590 let err = s.extract(&mut parts).unwrap_err();
591 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
592 assert!(err.message().contains("invalid"));
593 }
594
595 #[test]
598 fn api_key_header_valid() {
599 let s = api_key_header("x-api-key");
600 let mut parts = make_parts(Some("localhost"), "/");
601 parts
602 .headers
603 .insert("x-api-key", "sk_live_abc".parse().unwrap());
604 let id = s.extract(&mut parts).unwrap();
605 assert_eq!(id, TenantId::ApiKey("sk_live_abc".into()));
606 }
607
608 #[test]
609 fn api_key_header_missing_error() {
610 let s = api_key_header("x-api-key");
611 let mut parts = make_parts(Some("localhost"), "/");
612 let err = s.extract(&mut parts).unwrap_err();
613 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
614 assert!(err.message().contains("missing"));
615 }
616
617 #[test]
618 fn api_key_header_non_utf8_error() {
619 let s = api_key_header("x-api-key");
620 let mut parts = make_parts(Some("localhost"), "/");
621 parts.headers.insert(
622 "x-api-key",
623 http::HeaderValue::from_bytes(&[0x80, 0x81]).unwrap(),
624 );
625 let err = s.extract(&mut parts).unwrap_err();
626 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
627 assert!(err.message().contains("invalid"));
628 }
629
630 #[test]
633 fn path_prefix_valid() {
634 let s = path_prefix("/org");
635 let mut parts = make_parts(Some("localhost"), "/org/acme/dashboard/settings");
636 let id = s.extract(&mut parts).unwrap();
637 assert_eq!(id, TenantId::Slug("acme".into()));
638 assert_eq!(parts.uri.path(), "/dashboard/settings");
639 }
640
641 #[test]
642 fn path_prefix_only_slug() {
643 let s = path_prefix("/org");
644 let mut parts = make_parts(Some("localhost"), "/org/acme");
645 let id = s.extract(&mut parts).unwrap();
646 assert_eq!(id, TenantId::Slug("acme".into()));
647 assert_eq!(parts.uri.path(), "/");
648 }
649
650 #[test]
651 fn path_prefix_wrong_prefix_error() {
652 let s = path_prefix("/org");
653 let mut parts = make_parts(Some("localhost"), "/api/v1");
654 let err = s.extract(&mut parts).unwrap_err();
655 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
656 assert!(err.message().contains("prefix"));
657 }
658
659 #[test]
660 fn path_prefix_no_segment_error() {
661 let s = path_prefix("/org");
662 let mut parts = make_parts(Some("localhost"), "/org");
663 let err = s.extract(&mut parts).unwrap_err();
664 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
665 }
666
667 #[test]
668 fn path_prefix_no_segment_trailing_slash_error() {
669 let s = path_prefix("/org");
670 let mut parts = make_parts(Some("localhost"), "/org/");
671 let err = s.extract(&mut parts).unwrap_err();
672 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
673 }
674
675 #[test]
676 fn path_prefix_preserves_query_string() {
677 let s = path_prefix("/org");
678 let mut parts = make_parts(Some("localhost"), "/org/acme/page?foo=bar&baz=1");
679 let id = s.extract(&mut parts).unwrap();
680 assert_eq!(id, TenantId::Slug("acme".into()));
681 assert_eq!(parts.uri.path(), "/page");
682 assert_eq!(parts.uri.query(), Some("foo=bar&baz=1"));
683 }
684
685 #[test]
686 fn path_prefix_empty_prefix() {
687 let s = path_prefix("");
688 let mut parts = make_parts(Some("localhost"), "/acme/page");
689 let id = s.extract(&mut parts).unwrap();
690 assert_eq!(id, TenantId::Slug("acme".into()));
691 assert_eq!(parts.uri.path(), "/page");
692 }
693
694 #[tokio::test]
697 async fn path_param_extracts_from_route() {
698 use axum::Router;
699 use axum::routing::get;
700 use tower::ServiceExt as _;
701
702 use super::super::middleware as tenant_middleware;
703 use super::super::traits::{HasTenantId, TenantResolver};
704
705 #[derive(Clone, Debug)]
706 struct TestTenant {
707 slug: String,
708 }
709
710 impl HasTenantId for TestTenant {
711 fn tenant_id(&self) -> &str {
712 &self.slug
713 }
714 }
715
716 struct SlugResolver;
717 impl TenantResolver for SlugResolver {
718 type Tenant = TestTenant;
719 async fn resolve(&self, id: &TenantId) -> crate::Result<TestTenant> {
720 Ok(TestTenant {
721 slug: id.as_str().to_string(),
722 })
723 }
724 }
725
726 async fn handler(tenant: super::super::Tenant<TestTenant>) -> String {
728 format!("tenant:{}", tenant.slug)
729 }
730
731 let layer = tenant_middleware(path_param("tenant"), SlugResolver);
732 let app = Router::new()
733 .route("/{tenant}/action", get(handler))
734 .route_layer(layer);
735
736 let req = http::Request::builder()
737 .uri("/acme/action")
738 .body(axum::body::Body::empty())
739 .unwrap();
740 let resp = app.oneshot(req).await.unwrap();
741 assert_eq!(resp.status(), http::StatusCode::OK);
742
743 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
744 .await
745 .unwrap();
746 assert_eq!(&body[..], b"tenant:acme");
747 }
748
749 #[test]
750 fn path_param_missing_returns_error() {
751 let s = path_param("tenant");
752 let mut parts = make_parts(Some("localhost"), "/whatever");
753 let err = s.extract(&mut parts).unwrap_err();
755 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
756 }
757}