use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use arc_swap::ArcSwap;
use axum::http::HeaderMap;
use smol_str::SmolStr;
use crate::auth::guards::Guard;
use crate::web::{Error, RequestContext};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TenantId(pub SmolStr);
impl TenantId {
pub fn new(id: impl AsRef<str>) -> Self {
Self(SmolStr::new(id.as_ref()))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
pub enum TenantStrategy {
Header(&'static str),
Subdomain { base_domain: &'static str },
HeaderThenSubdomain {
header: &'static str,
base_domain: &'static str,
},
}
pub struct TenantConfig {
pub id: TenantId,
pub display_name: String,
pub datasource: String,
}
pub struct TenantRegistry {
strategy: TenantStrategy,
snapshot: ArcSwap<TenantSnapshot>,
fallback: Option<TenantId>,
}
struct TenantSnapshot {
known: HashMap<TenantId, Arc<TenantConfig>>,
suspended: HashSet<TenantId>,
}
impl TenantRegistry {
pub fn new(
strategy: TenantStrategy,
tenants: Vec<TenantConfig>,
fallback: Option<TenantId>,
) -> Self {
let known = tenants
.into_iter()
.map(|t| (t.id.clone(), Arc::new(t)))
.collect();
Self {
strategy,
snapshot: ArcSwap::from_pointee(TenantSnapshot {
known,
suspended: HashSet::new(),
}),
fallback,
}
}
pub fn upsert(&self, cfg: TenantConfig) {
self.snapshot.rcu(|cur| {
let mut known = cur.known.clone();
known.insert(
cfg.id.clone(),
Arc::new(TenantConfig {
id: cfg.id.clone(),
display_name: cfg.display_name.clone(),
datasource: cfg.datasource.clone(),
}),
);
TenantSnapshot {
known,
suspended: cur.suspended.clone(),
}
});
tracing::info!("tenant upserted (live, no restart)");
}
pub fn suspend(&self, id: &TenantId) {
self.snapshot.rcu(|cur| {
let mut suspended = cur.suspended.clone();
suspended.insert(id.clone());
TenantSnapshot {
known: cur.known.clone(),
suspended,
}
});
}
pub fn resume(&self, id: &TenantId) {
self.snapshot.rcu(|cur| {
let mut suspended = cur.suspended.clone();
suspended.remove(id);
TenantSnapshot {
known: cur.known.clone(),
suspended,
}
});
}
fn raw_id(&self, headers: &HeaderMap) -> Option<SmolStr> {
fn from_header(headers: &HeaderMap, name: &str) -> Option<SmolStr> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
.map(SmolStr::new)
}
fn from_subdomain(headers: &HeaderMap, base: &str) -> Option<SmolStr> {
let host = headers.get("host")?.to_str().ok()?;
let host = host.rsplit_once(':').map_or(host, |(h, _)| h);
host.strip_suffix(base)?
.strip_suffix('.')
.filter(|s| !s.is_empty())
.map(SmolStr::new)
}
match &self.strategy {
TenantStrategy::Header(h) => from_header(headers, h),
TenantStrategy::Subdomain { base_domain } => from_subdomain(headers, base_domain),
TenantStrategy::HeaderThenSubdomain {
header,
base_domain,
} => from_header(headers, header).or_else(|| from_subdomain(headers, base_domain)),
}
}
pub fn resolve(&self, headers: &HeaderMap) -> Option<Arc<TenantConfig>> {
let snap = self.snapshot.load();
if let Some(id) = self.raw_id(headers).map(TenantId) {
if snap.suspended.contains(&id) {
return None; }
if let Some(cfg) = snap.known.get(&id) {
return Some(cfg.clone()); }
}
self.fallback
.as_ref()
.filter(|id| !snap.suspended.contains(id))
.and_then(|id| snap.known.get(id).cloned())
}
pub fn get(&self, id: &TenantId) -> Option<Arc<TenantConfig>> {
self.snapshot.load().known.get(id).cloned()
}
pub fn resolve_by_id(&self, id: &str) -> Option<Arc<TenantConfig>> {
let snap = self.snapshot.load();
let id = TenantId::new(id);
if snap.suspended.contains(&id) {
return None;
}
snap.known.get(&id).cloned()
}
}
pub struct TenantGuard;
pub static TENANT: TenantGuard = TenantGuard;
impl Guard for TenantGuard {
fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
let tenant = ctx.tenant().ok_or(Error::Unauthorized)?;
if let Some(claim) = ctx
.claims()
.and_then(|c| c.get("tenant"))
.and_then(|v| v.as_str())
{
if claim != tenant.id.as_str() {
return Err(Error::Forbidden);
}
}
Ok(())
}
}