use std::sync::Arc;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Extension,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::OnceCell;
use crate::jwt_verifier::{JwtVerifier, JwtVerifierConfig};
tokio::task_local! {
static CURRENT_TENANT_ID: String;
}
pub fn current_tenant_id() -> String {
CURRENT_TENANT_ID
.try_with(|t| t.clone())
.unwrap_or_else(|_| "default".to_string())
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TenantPlan {
Starter,
Pro,
Enterprise,
}
impl TenantPlan {
pub fn from_str(s: &str) -> Self {
match s {
"pro" => Self::Pro,
"enterprise" => Self::Enterprise,
_ => Self::Starter,
}
}
}
impl std::fmt::Display for TenantPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Starter => write!(f, "starter"),
Self::Pro => write!(f, "pro"),
Self::Enterprise => write!(f, "enterprise"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantContext {
pub tenant_id: String,
pub plan: TenantPlan,
}
impl TenantContext {
pub fn new(tenant_id: impl Into<String>, plan: TenantPlan) -> Self {
Self { tenant_id: tenant_id.into(), plan }
}
pub fn default_tenant() -> Self {
Self { tenant_id: "default".to_string(), plan: TenantPlan::Enterprise }
}
pub fn is_default(&self) -> bool {
self.tenant_id == "default"
}
}
static JWT_VERIFIER: OnceCell<Option<Arc<JwtVerifier>>> = OnceCell::const_new();
async fn jwt_verifier() -> Option<Arc<JwtVerifier>> {
JWT_VERIFIER
.get_or_init(|| async {
JwtVerifierConfig::from_env().map(|cfg| Arc::new(JwtVerifier::new(cfg)))
})
.await
.clone()
}
fn tenant_id_from_jwt(token: &str) -> Option<String> {
let parts: Vec<&str> = token.splitn(3, '.').collect();
if parts.len() < 2 {
return None;
}
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
let claims: Value = serde_json::from_slice(&payload_bytes).ok()?;
claims.get("tenant_id")?.as_str().map(|s| s.to_string())
}
fn tenant_id_from_bearer_unverified(headers: &HeaderMap) -> Option<String> {
let auth = headers.get("authorization")?.to_str().ok()?;
let token = auth.strip_prefix("Bearer ")?;
tenant_id_from_jwt(token)
}
fn bearer_token<'a>(headers: &'a HeaderMap) -> Option<&'a str> {
headers.get("authorization")?.to_str().ok()?.strip_prefix("Bearer ")
}
fn plan_from_claim(raw: Option<&str>) -> TenantPlan {
raw.map(TenantPlan::from_str).unwrap_or(TenantPlan::Enterprise)
}
pub async fn tenant_extractor_middleware(
mut req: Request<Body>,
next: Next,
) -> Response {
let headers = req.headers().clone();
let verifier = jwt_verifier().await;
if let Some(v) = verifier.clone() {
if let Some(token) = bearer_token(&headers) {
match v.verify(token).await {
Ok(claims) => {
let ctx = TenantContext::new(
claims.tenant_id.clone(),
plan_from_claim(claims.plan.as_deref()),
);
tracing::debug!(
tenant_id = %ctx.tenant_id,
plan = %ctx.plan,
sub = claims.sub.as_deref().unwrap_or(""),
"tenant resolved via verified JWT"
);
let tenant_id = ctx.tenant_id.clone();
req.extensions_mut().insert(ctx);
return CURRENT_TENANT_ID.scope(tenant_id, next.run(req)).await;
}
Err(err) => {
if v.config().enforce {
tracing::warn!(
error = %err,
"rejecting request: JWT verification failed"
);
return (
StatusCode::UNAUTHORIZED,
"invalid bearer token",
)
.into_response();
}
tracing::warn!(
error = %err,
"JWT verification failed — falling back to legacy path"
);
}
}
} else if v.config().enforce {
let has_xtenant = headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| !s.is_empty())
.unwrap_or(false);
if !has_xtenant {
return (
StatusCode::UNAUTHORIZED,
"authorization required",
)
.into_response();
}
}
}
let ctx = if let Some(tid) = headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.filter(|s| !s.is_empty())
{
TenantContext::new(tid, TenantPlan::Enterprise)
} else if let Some(tid) = tenant_id_from_bearer_unverified(&headers) {
TenantContext::new(tid, TenantPlan::Enterprise)
} else {
TenantContext::default_tenant()
};
tracing::debug!(
tenant_id = %ctx.tenant_id,
plan = %ctx.plan,
"tenant resolved (legacy path)"
);
let tenant_id = ctx.tenant_id.clone();
req.extensions_mut().insert(ctx);
CURRENT_TENANT_ID.scope(tenant_id, next.run(req)).await
}
pub fn require_tenant(
ext: Option<Extension<TenantContext>>,
) -> Result<TenantContext, Response> {
ext.map(|Extension(ctx)| ctx)
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"TenantContext missing — tenant_extractor_middleware not wired",
)
.into_response()
})
}
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
fn make_jwt(payload_json: &str) -> String {
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"HS256","typ":"JWT"}"#);
let payload = URL_SAFE_NO_PAD.encode(payload_json);
format!("{}.{}.fakesig", header, payload)
}
#[test]
fn test_tenant_plan_from_str() {
assert_eq!(TenantPlan::from_str("starter"), TenantPlan::Starter);
assert_eq!(TenantPlan::from_str("pro"), TenantPlan::Pro);
assert_eq!(TenantPlan::from_str("enterprise"), TenantPlan::Enterprise);
assert_eq!(TenantPlan::from_str("unknown"), TenantPlan::Starter);
}
#[test]
fn test_tenant_plan_display() {
assert_eq!(TenantPlan::Starter.to_string(), "starter");
assert_eq!(TenantPlan::Pro.to_string(), "pro");
assert_eq!(TenantPlan::Enterprise.to_string(), "enterprise");
}
#[test]
fn test_default_tenant() {
let ctx = TenantContext::default_tenant();
assert_eq!(ctx.tenant_id, "default");
assert!(ctx.is_default());
assert_eq!(ctx.plan, TenantPlan::Enterprise);
}
#[test]
fn test_tenant_id_from_jwt_valid() {
let jwt = make_jwt(r#"{"sub":"user123","tenant_id":"acme-corp"}"#);
assert_eq!(tenant_id_from_jwt(&jwt), Some("acme-corp".to_string()));
}
#[test]
fn test_tenant_id_from_jwt_missing_claim() {
let jwt = make_jwt(r#"{"sub":"user123"}"#);
assert_eq!(tenant_id_from_jwt(&jwt), None);
}
#[test]
fn test_tenant_id_from_jwt_malformed() {
assert_eq!(tenant_id_from_jwt("not.a.jwt.at.all"), None);
assert_eq!(tenant_id_from_jwt("onlyone"), None);
}
#[test]
fn test_tenant_id_from_bearer_valid() {
let jwt = make_jwt(r#"{"tenant_id":"example-tenant"}"#);
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
format!("Bearer {}", jwt).parse().unwrap(),
);
assert_eq!(tenant_id_from_bearer_unverified(&headers), Some("example-tenant".to_string()));
}
#[test]
fn test_tenant_id_from_bearer_missing() {
let headers = HeaderMap::new();
assert_eq!(tenant_id_from_bearer_unverified(&headers), None);
}
#[test]
fn test_tenant_context_new() {
let ctx = TenantContext::new("acme", TenantPlan::Pro);
assert_eq!(ctx.tenant_id, "acme");
assert_eq!(ctx.plan, TenantPlan::Pro);
assert!(!ctx.is_default());
}
#[test]
fn test_current_tenant_id_default_outside_scope() {
assert_eq!(current_tenant_id(), "default");
}
#[tokio::test]
async fn test_current_tenant_id_inside_scope() {
let result = CURRENT_TENANT_ID
.scope("example-tenant".to_string(), async { current_tenant_id() })
.await;
assert_eq!(result, "example-tenant");
}
#[tokio::test]
async fn test_current_tenant_id_nested_scope() {
let outer = CURRENT_TENANT_ID
.scope("tenant-a".to_string(), async {
let inner = CURRENT_TENANT_ID
.scope("tenant-b".to_string(), async { current_tenant_id() })
.await;
(current_tenant_id(), inner)
})
.await;
assert_eq!(outer.0, "tenant-a");
assert_eq!(outer.1, "tenant-b");
}
}