use crate::auth::AuthContext;
use crate::tenant::types::{Tenant, TenantError};
use async_trait::async_trait;
#[async_trait]
pub trait TenantResolver: Send + Sync + 'static {
async fn resolve(&self, auth: &AuthContext) -> Result<Tenant, TenantError>;
}
pub(crate) fn mint_tenant_from_str(s: impl Into<String>) -> Result<Tenant, TenantError> {
Tenant::try_new(s)
}
#[derive(Debug, Clone)]
pub struct ClaimTenantResolver {
pub claim_key: String,
pub single_user_fallback: bool,
}
impl ClaimTenantResolver {
pub fn new() -> Self {
Self {
claim_key: "tenant_id".into(),
single_user_fallback: true,
}
}
pub fn with_claim_key(claim_key: impl Into<String>) -> Self {
Self {
claim_key: claim_key.into(),
single_user_fallback: true,
}
}
}
impl Default for ClaimTenantResolver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TenantResolver for ClaimTenantResolver {
async fn resolve(&self, auth: &AuthContext) -> Result<Tenant, TenantError> {
if let Some(claim) = auth.get_metadata_string(&self.claim_key) {
return mint_tenant_from_str(claim);
}
if self.single_user_fallback && auth.is_authenticated() {
return mint_tenant_from_str(auth.user_id.clone());
}
Err(TenantError::UnresolvedFromAuthContext)
}
}
#[derive(Debug, Clone)]
pub struct SingleTenantResolver {
fixed: Tenant,
}
impl SingleTenantResolver {
pub fn new() -> Self {
Self {
fixed: Tenant::try_new("default")
.expect("the literal \"default\" satisfies Tenant::try_new validation"),
}
}
pub fn with_fixed(value: impl Into<String>) -> Result<Self, TenantError> {
Ok(Self {
fixed: Tenant::try_new(value)?,
})
}
}
impl Default for SingleTenantResolver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TenantResolver for SingleTenantResolver {
async fn resolve(&self, _auth: &AuthContext) -> Result<Tenant, TenantError> {
Ok(self.fixed.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn ctx_with_metadata(user_id: &str, metadata: serde_json::Value) -> AuthContext {
AuthContext::new(
user_id.into(),
"sess-1".into(),
vec![],
metadata,
)
}
#[tokio::test]
async fn claim_resolver_pulls_tenant_id_by_default() {
let r = ClaimTenantResolver::new();
let auth = ctx_with_metadata("alice", json!({"tenant_id": "acme-corp"}));
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "acme-corp");
}
#[tokio::test]
async fn claim_resolver_honors_custom_claim_key() {
let r = ClaimTenantResolver::with_claim_key("org_id");
let auth = ctx_with_metadata("alice", json!({"org_id": "neon-9"}));
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "neon-9");
}
#[tokio::test]
async fn claim_resolver_ignores_irrelevant_claims() {
let r = ClaimTenantResolver {
claim_key: "tenant_id".into(),
single_user_fallback: false,
};
let auth = ctx_with_metadata("alice", json!({"realm": "prod"}));
let err = r.resolve(&auth).await.unwrap_err();
assert_eq!(err, TenantError::UnresolvedFromAuthContext);
}
#[tokio::test]
async fn claim_resolver_falls_back_to_user_id_when_enabled() {
let r = ClaimTenantResolver {
claim_key: "tenant_id".into(),
single_user_fallback: true,
};
let auth = ctx_with_metadata("alice-uuid", json!({}));
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "alice-uuid");
}
#[tokio::test]
async fn claim_resolver_does_not_fall_back_when_disabled() {
let r = ClaimTenantResolver {
claim_key: "tenant_id".into(),
single_user_fallback: false,
};
let auth = ctx_with_metadata("alice", json!({}));
let err = r.resolve(&auth).await.unwrap_err();
assert_eq!(err, TenantError::UnresolvedFromAuthContext);
}
#[tokio::test]
async fn claim_resolver_rejects_anonymous_even_with_fallback() {
let r = ClaimTenantResolver {
claim_key: "tenant_id".into(),
single_user_fallback: true,
};
let auth = AuthContext::anonymous();
let err = r.resolve(&auth).await.unwrap_err();
assert_eq!(err, TenantError::UnresolvedFromAuthContext);
}
#[tokio::test]
async fn claim_resolver_surfaces_invalid_shape_when_claim_malformed() {
let r = ClaimTenantResolver::new();
let auth = ctx_with_metadata("alice", json!({"tenant_id": "evil\0tenant"}));
let err = r.resolve(&auth).await.unwrap_err();
assert_eq!(err, TenantError::InvalidShape);
}
#[tokio::test]
async fn claim_resolver_prefers_claim_over_fallback() {
let r = ClaimTenantResolver {
claim_key: "tenant_id".into(),
single_user_fallback: true,
};
let auth = ctx_with_metadata("alice", json!({"tenant_id": "acme-corp"}));
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "acme-corp");
}
#[tokio::test]
async fn single_tenant_resolver_returns_default() {
let r = SingleTenantResolver::new();
let auth = ctx_with_metadata("alice", json!({"tenant_id": "ignored"}));
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "default");
}
#[tokio::test]
async fn single_tenant_resolver_ignores_metadata() {
let r = SingleTenantResolver::new();
let auth1 = ctx_with_metadata("alice", json!({"tenant_id": "acme"}));
let auth2 = ctx_with_metadata("bob", json!({"tenant_id": "neon"}));
let auth3 = AuthContext::anonymous();
assert_eq!(r.resolve(&auth1).await.unwrap().as_str(), "default");
assert_eq!(r.resolve(&auth2).await.unwrap().as_str(), "default");
assert_eq!(r.resolve(&auth3).await.unwrap().as_str(), "default");
}
#[tokio::test]
async fn single_tenant_resolver_with_custom_fixed_value() {
let r = SingleTenantResolver::with_fixed("dev-tenant").unwrap();
let auth = AuthContext::anonymous();
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "dev-tenant");
}
#[tokio::test]
async fn single_tenant_resolver_rejects_invalid_fixed_value() {
let err = SingleTenantResolver::with_fixed("").unwrap_err();
assert_eq!(err, TenantError::InvalidShape);
let err = SingleTenantResolver::with_fixed("evil\0").unwrap_err();
assert_eq!(err, TenantError::InvalidShape);
}
#[tokio::test]
async fn single_tenant_resolver_never_fails_after_construction() {
let r = SingleTenantResolver::new();
for auth in [
AuthContext::anonymous(),
ctx_with_metadata("alice", json!({})),
ctx_with_metadata("bob", json!({"tenant_id": "weird"})),
] {
assert!(r.resolve(&auth).await.is_ok());
}
}
#[tokio::test]
async fn resolver_is_object_safe() {
use std::sync::Arc;
let r: Arc<dyn TenantResolver> = Arc::new(SingleTenantResolver::new());
let auth = AuthContext::anonymous();
let t = r.resolve(&auth).await.unwrap();
assert_eq!(t.as_str(), "default");
}
}