use std::collections::HashSet;
use axum::http::HeaderMap;
use dashmap::DashMap;
use fraiseql_core::security::SecurityContext;
use fraiseql_error::{FraiseQLError, Result};
use tracing::warn;
pub(crate) const MAX_TENANT_KEY_LEN: usize = 128;
pub struct TenantKeyResolver;
impl TenantKeyResolver {
#[doc(hidden)] pub fn resolve(
security_context: Option<&SecurityContext>,
headers: &HeaderMap,
domain_registry: Option<&DomainRegistry>,
strict: bool,
) -> Result<Option<String>> {
let mut sources = Vec::new();
let mut resolved_value = None;
if let Some(ctx) = security_context {
if let Some(ref tid) = ctx.tenant_id {
resolved_value = Some(tid.0.clone());
sources.push(("JWT".to_string(), tid.0.clone()));
}
}
if let Some(val) = headers.get("X-Tenant-ID") {
if let Ok(s) = val.to_str() {
validate_tenant_key(s)?;
let header_value = s.to_string();
sources.push(("X-Tenant-ID".to_string(), header_value.clone()));
if resolved_value.is_none() {
resolved_value = Some(header_value);
}
}
}
if let Some(registry) = domain_registry {
if let Some(val) = headers.get("Host") {
if let Ok(host) = val.to_str() {
if let Some(key) = registry.lookup(host) {
sources.push(("Host".to_string(), key.clone()));
if resolved_value.is_none() {
resolved_value = Some(key);
}
}
}
}
}
if sources.len() > 1 {
let unique_values: HashSet<_> = sources.iter().map(|(_, v)| v).collect();
if unique_values.len() > 1 {
let conflicts: Vec<String> =
sources.iter().map(|(src, val)| format!("{}: {}", src, val)).collect();
warn!("Tenant source conflict detected: {}", conflicts.join(", "));
if strict {
return Err(FraiseQLError::Validation {
message: format!(
"Conflicting tenant values from sources: {}",
conflicts.join(", ")
),
path: None,
});
}
}
}
Ok(resolved_value)
}
}
fn validate_tenant_key(key: &str) -> Result<()> {
if key.len() > MAX_TENANT_KEY_LEN {
return Err(FraiseQLError::validation(format!(
"X-Tenant-ID exceeds maximum length of {MAX_TENANT_KEY_LEN} characters"
)));
}
if !key.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') {
return Err(FraiseQLError::validation(
"X-Tenant-ID contains invalid characters (allowed: a-zA-Z0-9_-)",
));
}
Ok(())
}
pub struct DomainRegistry {
domains: DashMap<String, String>,
}
impl DomainRegistry {
#[must_use]
pub fn new() -> Self {
Self {
domains: DashMap::new(),
}
}
pub fn register(&self, domain: impl Into<String>, tenant_key: impl Into<String>) {
self.domains.insert(domain.into(), tenant_key.into());
}
#[must_use]
pub fn remove(&self, domain: &str) -> bool {
self.domains.remove(domain).is_some()
}
#[must_use]
pub fn lookup(&self, host: &str) -> Option<String> {
let domain = host.split(':').next().unwrap_or(host);
self.domains.get(domain).map(|v| v.clone())
}
#[must_use]
pub fn domains(&self) -> Vec<(String, String)> {
self.domains.iter().map(|e| (e.key().clone(), e.value().clone())).collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.domains.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.domains.is_empty()
}
}
impl Default for DomainRegistry {
fn default() -> Self {
Self::new()
}
}