use crate::tenant::TenantContext;
use crate::Request;
use async_trait::async_trait;
use std::sync::Arc;
use super::lookup::TenantLookup;
#[async_trait]
pub trait TenantResolver: Send + Sync {
async fn resolve(&self, req: &Request) -> Option<TenantContext>;
}
pub struct SubdomainResolver {
base_domain_parts: usize,
tenant_lookup: Arc<dyn TenantLookup>,
}
impl SubdomainResolver {
pub fn new(base_domain_parts: usize, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
Self {
base_domain_parts,
tenant_lookup,
}
}
}
#[async_trait]
impl TenantResolver for SubdomainResolver {
async fn resolve(&self, req: &Request) -> Option<TenantContext> {
let host = req.header("host")?;
let host_no_port = host.split(':').next()?;
let parts: Vec<&str> = host_no_port.split('.').collect();
if parts.len() <= self.base_domain_parts {
return None;
}
let slug = parts[0];
self.tenant_lookup.find_by_slug(slug).await
}
}
pub struct HeaderResolver {
header_name: String,
tenant_lookup: Arc<dyn TenantLookup>,
}
impl HeaderResolver {
pub fn new(header_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
Self {
header_name: header_name.into(),
tenant_lookup,
}
}
}
#[async_trait]
impl TenantResolver for HeaderResolver {
async fn resolve(&self, req: &Request) -> Option<TenantContext> {
let value = req.header(&self.header_name)?;
self.tenant_lookup.find_by_slug(value).await
}
}
pub struct PathResolver {
param_name: String,
tenant_lookup: Arc<dyn TenantLookup>,
}
impl PathResolver {
pub fn new(param_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
Self {
param_name: param_name.into(),
tenant_lookup,
}
}
}
#[async_trait]
impl TenantResolver for PathResolver {
async fn resolve(&self, req: &Request) -> Option<TenantContext> {
let value = req.param(&self.param_name).ok()?;
self.tenant_lookup.find_by_slug(value).await
}
}
pub struct JwtClaimResolver {
claim_field: String,
tenant_lookup: Arc<dyn TenantLookup>,
}
impl JwtClaimResolver {
pub fn new(claim_field: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
Self {
claim_field: claim_field.into(),
tenant_lookup,
}
}
}
#[async_trait]
impl TenantResolver for JwtClaimResolver {
async fn resolve(&self, req: &Request) -> Option<TenantContext> {
let claims = req.get::<serde_json::Value>()?;
let id = claims[&self.claim_field].as_i64()?;
self.tenant_lookup.find_by_id(id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tenant::{TenantContext, TenantLookup};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Empty;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::oneshot;
fn make_tenant(slug: &str) -> TenantContext {
TenantContext {
id: 42,
slug: slug.to_string(),
name: "Test Corp".to_string(),
plan: None,
#[cfg(feature = "stripe")]
subscription: None,
}
}
struct MockLookup;
#[async_trait]
impl TenantLookup for MockLookup {
async fn find_by_slug(&self, slug: &str) -> Option<TenantContext> {
match slug {
"acme" | "beta" => Some(make_tenant(slug)),
_ => None,
}
}
async fn find_by_id(&self, id: i64) -> Option<TenantContext> {
if id == 42 {
Some(make_tenant("acme"))
} else {
None
}
}
}
async fn make_request_opts(
host: Option<&str>,
extra_header: Option<(&str, &str)>,
params: std::collections::HashMap<String, String>,
) -> Request {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
let tx_holder = Arc::new(Mutex::new(Some(tx)));
let params_clone = params.clone();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let tx_holder = tx_holder.clone();
let params_inner = params_clone.clone();
let service =
hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let tx_holder = tx_holder.clone();
let params_for_req = params_inner.clone();
async move {
if let Some(tx) = tx_holder.lock().unwrap().take() {
let _ = tx.send(Request::new(req).with_params(params_for_req));
}
Ok::<_, hyper::Error>(hyper::Response::new(Empty::<Bytes>::new()))
}
});
hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
.ok();
});
let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
conn.await.ok();
});
let mut builder = hyper::Request::builder().uri("/test");
if let Some(h) = host {
builder = builder.header("host", h);
}
if let Some((name, value)) = extra_header {
builder = builder.header(name, value);
}
let req = builder.body(Empty::<Bytes>::new()).unwrap();
let _ = sender.send_request(req).await;
rx.await.unwrap()
}
#[tokio::test]
async fn subdomain_resolver_extracts_slug_from_host() {
let lookup = Arc::new(MockLookup);
let resolver = SubdomainResolver::new(2, lookup);
let req = make_request_opts(Some("acme.yourapp.com"), None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_some());
assert_eq!(result.unwrap().slug, "acme");
}
#[tokio::test]
async fn subdomain_resolver_returns_none_for_no_subdomain() {
let lookup = Arc::new(MockLookup);
let resolver = SubdomainResolver::new(2, lookup);
let req = make_request_opts(Some("yourapp.com"), None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_none());
}
#[tokio::test]
async fn subdomain_resolver_strips_port() {
let lookup = Arc::new(MockLookup);
let resolver = SubdomainResolver::new(2, lookup);
let req = make_request_opts(Some("acme.yourapp.com:8080"), None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_some());
assert_eq!(result.unwrap().slug, "acme");
}
#[tokio::test]
async fn subdomain_resolver_calls_find_by_slug() {
let lookup = Arc::new(MockLookup);
let resolver = SubdomainResolver::new(2, lookup);
let req = make_request_opts(Some("unknown.yourapp.com"), None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_none()); }
#[tokio::test]
async fn header_resolver_extracts_from_header() {
let lookup = Arc::new(MockLookup);
let resolver = HeaderResolver::new("x-tenant-id", lookup);
let req = make_request_opts(None, Some(("x-tenant-id", "acme")), Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_some());
assert_eq!(result.unwrap().slug, "acme");
}
#[tokio::test]
async fn header_resolver_returns_none_when_absent() {
let lookup = Arc::new(MockLookup);
let resolver = HeaderResolver::new("x-tenant-id", lookup);
let req = make_request_opts(None, None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_none());
}
#[tokio::test]
async fn path_resolver_extracts_from_param() {
let lookup = Arc::new(MockLookup);
let resolver = PathResolver::new("tenant_slug", lookup);
let mut params = std::collections::HashMap::new();
params.insert("tenant_slug".to_string(), "beta".to_string());
let req = make_request_opts(None, None, params).await;
let result = resolver.resolve(&req).await;
assert!(result.is_some());
assert_eq!(result.unwrap().slug, "beta");
}
#[tokio::test]
async fn path_resolver_returns_none_when_param_absent() {
let lookup = Arc::new(MockLookup);
let resolver = PathResolver::new("tenant_slug", lookup);
let req = make_request_opts(None, None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_none());
}
#[tokio::test]
async fn jwt_claim_resolver_extracts_from_extensions() {
let lookup = Arc::new(MockLookup);
let resolver = JwtClaimResolver::new("tenant_id", lookup);
let mut req = make_request_opts(None, None, Default::default()).await;
req.insert::<serde_json::Value>(serde_json::json!({"tenant_id": 42, "sub": "user1"}));
let result = resolver.resolve(&req).await;
assert!(result.is_some());
assert_eq!(result.unwrap().slug, "acme"); }
#[tokio::test]
async fn jwt_claim_resolver_returns_none_without_claims() {
let lookup = Arc::new(MockLookup);
let resolver = JwtClaimResolver::new("tenant_id", lookup);
let req = make_request_opts(None, None, Default::default()).await;
let result = resolver.resolve(&req).await;
assert!(result.is_none());
}
#[test]
fn tenant_resolver_is_object_safe() {
let _: Box<dyn TenantResolver>;
}
}