use std::cell::Cell;
use std::future::Future;
use super::context::{TenantContext, TenantId};
tokio::task_local! {
static TENANT_CONTEXT: TenantContext;
}
thread_local! {
static SYNC_TENANT_ID: Cell<Option<TenantId>> = const { Cell::new(None) };
}
pub async fn with_tenant<F, T>(tenant_id: impl Into<TenantId>, f: F) -> T
where
F: Future<Output = T>,
{
let ctx = TenantContext::new(tenant_id);
TENANT_CONTEXT.scope(ctx, f).await
}
pub async fn with_context<F, T>(ctx: TenantContext, f: F) -> T
where
F: Future<Output = T>,
{
TENANT_CONTEXT.scope(ctx, f).await
}
#[inline]
pub fn current_tenant() -> Option<TenantContext> {
TENANT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
}
#[inline]
pub fn current_tenant_id() -> Option<TenantId> {
TENANT_CONTEXT.try_with(|ctx| ctx.id.clone()).ok()
}
#[inline]
pub fn current_tenant_id_str() -> &'static str {
""
}
#[inline]
pub fn has_tenant() -> bool {
TENANT_CONTEXT.try_with(|_| ()).is_ok()
}
#[inline]
pub fn with_current_tenant<F, T>(f: F) -> Option<T>
where
F: FnOnce(&TenantContext) -> T,
{
TENANT_CONTEXT.try_with(f).ok()
}
#[inline]
pub fn require_tenant() -> Result<TenantContext, TenantNotSetError> {
current_tenant().ok_or(TenantNotSetError)
}
#[derive(Debug, Clone, Copy)]
pub struct TenantNotSetError;
impl std::fmt::Display for TenantNotSetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tenant context not set")
}
}
impl std::error::Error for TenantNotSetError {}
pub fn set_sync_tenant(tenant_id: impl Into<TenantId>) -> SyncTenantGuard {
let id = tenant_id.into();
let previous = SYNC_TENANT_ID.with(|cell| cell.replace(Some(id)));
SyncTenantGuard { previous }
}
#[inline]
pub fn sync_tenant_id() -> Option<TenantId> {
SYNC_TENANT_ID.with(|cell| {
unsafe { &*cell.as_ptr() }.clone()
})
}
pub struct SyncTenantGuard {
previous: Option<TenantId>,
}
impl Drop for SyncTenantGuard {
fn drop(&mut self) {
SYNC_TENANT_ID.with(|cell| cell.set(self.previous.take()));
}
}
#[derive(Debug, Clone)]
pub struct TenantScope {
context: TenantContext,
}
impl TenantScope {
pub fn new(tenant_id: impl Into<TenantId>) -> Self {
Self {
context: TenantContext::new(tenant_id),
}
}
pub fn from_context(context: TenantContext) -> Self {
Self { context }
}
pub fn tenant_id(&self) -> &TenantId {
&self.context.id
}
pub fn context(&self) -> &TenantContext {
&self.context
}
pub async fn run<F, T>(&self, f: F) -> T
where
F: Future<Output = T>,
{
TENANT_CONTEXT.scope(self.context.clone(), f).await
}
pub fn run_sync<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T,
{
let _guard = set_sync_tenant(self.context.id.clone());
f()
}
}
pub trait TenantExtractor: Send + Sync {
fn extract(&self, headers: &[(String, String)]) -> Option<TenantId>;
}
#[derive(Debug, Clone)]
pub struct HeaderExtractor {
header_name: String,
}
impl HeaderExtractor {
pub fn new(header_name: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
}
}
pub fn default_header() -> Self {
Self::new("X-Tenant-ID")
}
}
impl TenantExtractor for HeaderExtractor {
fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(&self.header_name))
.map(|(_, v)| TenantId::new(v.clone()))
}
}
#[derive(Debug, Clone)]
pub struct JwtClaimExtractor {
claim_name: String,
}
impl JwtClaimExtractor {
pub fn new(claim_name: impl Into<String>) -> Self {
Self {
claim_name: claim_name.into(),
}
}
pub fn default_claim() -> Self {
Self::new("tenant_id")
}
pub fn claim_name(&self) -> &str {
&self.claim_name
}
}
impl TenantExtractor for JwtClaimExtractor {
fn extract(&self, _headers: &[(String, String)]) -> Option<TenantId> {
None
}
}
pub struct CompositeExtractor {
extractors: Vec<Box<dyn TenantExtractor>>,
}
impl CompositeExtractor {
pub fn new() -> Self {
Self {
extractors: Vec::new(),
}
}
pub fn add<E: TenantExtractor + 'static>(mut self, extractor: E) -> Self {
self.extractors.push(Box::new(extractor));
self
}
}
impl Default for CompositeExtractor {
fn default() -> Self {
Self::new()
}
}
impl TenantExtractor for CompositeExtractor {
fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
for extractor in &self.extractors {
if let Some(id) = extractor.extract(headers) {
return Some(id);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_with_tenant() {
let result = with_tenant("test-tenant", async { current_tenant_id() }).await;
assert_eq!(result.unwrap().as_str(), "test-tenant");
}
#[tokio::test]
async fn test_no_tenant() {
assert!(current_tenant().is_none());
assert!(!has_tenant());
}
#[tokio::test]
async fn test_nested_tenant() {
with_tenant("outer", async {
assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
with_tenant("inner", async {
assert_eq!(current_tenant_id().unwrap().as_str(), "inner");
})
.await;
assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
})
.await;
}
#[tokio::test]
async fn test_tenant_scope() {
let scope = TenantScope::new("scoped-tenant");
let result = scope
.run(async { current_tenant_id().map(|id| id.as_str().to_string()) })
.await;
assert_eq!(result, Some("scoped-tenant".to_string()));
}
#[test]
fn test_sync_tenant() {
{
let _guard = set_sync_tenant("sync-tenant");
assert_eq!(sync_tenant_id().unwrap().as_str(), "sync-tenant");
}
assert!(sync_tenant_id().is_none());
}
#[test]
fn test_header_extractor() {
let extractor = HeaderExtractor::new("X-Tenant-ID");
let headers = vec![
("Content-Type".to_string(), "application/json".to_string()),
("X-Tenant-ID".to_string(), "tenant-from-header".to_string()),
];
let id = extractor.extract(&headers);
assert_eq!(id.unwrap().as_str(), "tenant-from-header");
}
#[test]
fn test_composite_extractor() {
let extractor = CompositeExtractor::new()
.add(HeaderExtractor::new("X-Organization-ID"))
.add(HeaderExtractor::new("X-Tenant-ID"));
let headers = vec![("X-Tenant-ID".to_string(), "fallback-tenant".to_string())];
let id = extractor.extract(&headers);
assert_eq!(id.unwrap().as_str(), "fallback-tenant");
}
}