use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use super::{Tenant, TenantManager, TenantError, TenantStatus};
thread_local! {
static CURRENT_TENANT: RefCell<Option<TenantContext>> = const { RefCell::new(None) };
}
#[derive(Debug, Clone)]
pub struct TenantContext {
pub tenant: Tenant,
pub user_id: Option<String>,
pub user_email: Option<String>,
pub role: TenantRole,
pub request_id: Option<String>,
pub metadata: HashMap<String, String>,
pub jwt_claims: Option<JwtClaims>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: Option<String>,
pub tenant_id: Option<String>,
pub iss: Option<String>,
pub aud: Option<OneOrMany<String>>,
pub exp: Option<u64>,
pub iat: Option<u64>,
pub nbf: Option<u64>,
pub jti: Option<String>,
pub email: Option<String>,
pub role: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum OneOrMany<T> {
One(T),
Many(Vec<T>),
}
impl<T> OneOrMany<T> {
pub fn contains(&self, value: &T) -> bool
where
T: PartialEq,
{
match self {
OneOrMany::One(v) => v == value,
OneOrMany::Many(vs) => vs.contains(value),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TenantRole {
Admin,
Member,
Viewer,
Service,
Anonymous,
}
impl Default for TenantRole {
fn default() -> Self {
Self::Anonymous
}
}
impl From<&str> for TenantRole {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"admin" | "administrator" | "owner" => TenantRole::Admin,
"member" | "user" | "write" => TenantRole::Member,
"viewer" | "reader" | "read" => TenantRole::Viewer,
"service" | "service_account" | "api" => TenantRole::Service,
_ => TenantRole::Anonymous,
}
}
}
impl TenantContext {
pub fn new(tenant: Tenant, user_id: Option<String>, role: TenantRole) -> Self {
Self {
tenant,
user_id,
user_email: None,
role,
request_id: None,
metadata: HashMap::new(),
jwt_claims: None,
}
}
pub fn with_jwt_claims(mut self, claims: JwtClaims) -> Self {
self.jwt_claims = Some(claims);
self
}
pub fn with_email(mut self, email: String) -> Self {
self.user_email = Some(email);
self
}
pub fn with_request_id(mut self, request_id: String) -> Self {
self.request_id = Some(request_id);
self
}
pub fn can_write(&self) -> bool {
matches!(self.role, TenantRole::Admin | TenantRole::Member | TenantRole::Service)
}
pub fn is_admin(&self) -> bool {
matches!(self.role, TenantRole::Admin)
}
pub fn schema(&self) -> &str {
&self.tenant.schema_name
}
pub fn has_feature(&self, feature: &str) -> bool {
self.tenant.quotas.features.get(feature).copied().unwrap_or(false)
}
pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.jwt_claims.as_ref().and_then(|c| c.custom.get(key))
}
}
pub fn set_tenant_context(ctx: TenantContext) {
CURRENT_TENANT.with(|c| {
*c.borrow_mut() = Some(ctx);
});
}
pub fn get_tenant_context() -> Option<TenantContext> {
CURRENT_TENANT.with(|c| c.borrow().clone())
}
pub fn clear_tenant_context() {
CURRENT_TENANT.with(|c| {
*c.borrow_mut() = None;
});
}
pub fn with_tenant<F, R>(ctx: TenantContext, f: F) -> R
where
F: FnOnce() -> R,
{
set_tenant_context(ctx);
let result = f();
clear_tenant_context();
result
}
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub secret: Option<String>,
pub public_key: Option<String>,
pub algorithm: Algorithm,
pub issuer: Option<String>,
pub audience: Option<String>,
pub tenant_id_claim: String,
pub role_claim: String,
pub allow_expired: bool,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: None,
public_key: None,
algorithm: Algorithm::HS256,
issuer: None,
audience: None,
tenant_id_claim: "tenant_id".to_string(),
role_claim: "role".to_string(),
allow_expired: false,
}
}
}
impl JwtConfig {
pub fn with_secret(secret: impl Into<String>) -> Self {
Self {
secret: Some(secret.into()),
..Default::default()
}
}
pub fn with_public_key(key: impl Into<String>, algorithm: Algorithm) -> Self {
Self {
public_key: Some(key.into()),
algorithm,
..Default::default()
}
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_tenant_claim(mut self, claim: impl Into<String>) -> Self {
self.tenant_id_claim = claim.into();
self
}
pub fn with_role_claim(mut self, claim: impl Into<String>) -> Self {
self.role_claim = claim.into();
self
}
}
pub struct TenantResolver {
manager: Arc<TenantManager>,
jwt_config: JwtConfig,
}
impl TenantResolver {
pub fn new(manager: Arc<TenantManager>) -> Self {
Self {
manager,
jwt_config: JwtConfig::default(),
}
}
pub fn with_jwt_config(mut self, config: JwtConfig) -> Self {
self.jwt_config = config;
self
}
pub fn resolve_from_api_key(&self, api_key: &str) -> Result<TenantContext, TenantError> {
let parts: Vec<&str> = api_key.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(TenantError::InvalidId("Invalid API key format".to_string()));
}
let tenant_id = parts.first().ok_or_else(|| TenantError::InvalidId("Invalid API key format".to_string()))?;
let tenant = self.manager.get_tenant(tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.to_string()))?;
if tenant.status != TenantStatus::Active && tenant.status != TenantStatus::Trial {
return Err(TenantError::NotActive(tenant_id.to_string()));
}
self.manager.touch(tenant_id);
Ok(TenantContext::new(tenant, None, TenantRole::Service))
}
pub fn resolve_from_jwt(&self, token: &str) -> Result<TenantContext, TenantError> {
let header = decode_header(token)
.map_err(|e| TenantError::InvalidId(format!("Invalid JWT header: {}", e)))?;
let mut validation = Validation::new(header.alg);
if let Some(ref issuer) = self.jwt_config.issuer {
validation.set_issuer(&[issuer]);
}
if let Some(ref audience) = self.jwt_config.audience {
validation.set_audience(&[audience]);
}
if self.jwt_config.allow_expired {
validation.validate_exp = false;
}
let decoding_key = self.get_decoding_key(&header.alg)?;
let token_data = decode::<JwtClaims>(token, &decoding_key, &validation)
.map_err(|e| TenantError::InvalidId(format!("JWT validation failed: {}", e)))?;
let claims = token_data.claims;
let tenant_id = claims.tenant_id.clone()
.or_else(|| {
claims.custom.get(&self.jwt_config.tenant_id_claim)
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.or_else(|| {
claims.aud.as_ref().and_then(|aud| match aud {
OneOrMany::One(s) => Some(s.clone()),
OneOrMany::Many(arr) => arr.first().cloned(),
})
})
.ok_or_else(|| TenantError::InvalidId("No tenant ID in JWT claims".to_string()))?;
let tenant = self.manager.get_tenant(&tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.clone()))?;
if tenant.status != TenantStatus::Active && tenant.status != TenantStatus::Trial {
return Err(TenantError::NotActive(tenant_id));
}
let user_id = claims.sub.clone();
let role = claims.role.as_ref()
.map(|r| TenantRole::from(r.as_str()))
.or_else(|| {
claims.custom.get(&self.jwt_config.role_claim)
.and_then(|v| v.as_str())
.map(|s| TenantRole::from(s))
})
.unwrap_or(TenantRole::Member);
self.manager.touch(&tenant_id);
let mut ctx = TenantContext::new(tenant, user_id, role)
.with_jwt_claims(claims.clone());
if let Some(ref email) = claims.email {
ctx = ctx.with_email(email.clone());
}
Ok(ctx)
}
fn get_decoding_key(&self, algorithm: &Algorithm) -> Result<DecodingKey, TenantError> {
match algorithm {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
let secret = self.jwt_config.secret.as_ref()
.ok_or_else(|| TenantError::InvalidId("JWT secret not configured".to_string()))?;
Ok(DecodingKey::from_secret(secret.as_bytes()))
}
Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
let key = self.jwt_config.public_key.as_ref()
.ok_or_else(|| TenantError::InvalidId("RSA public key not configured".to_string()))?;
DecodingKey::from_rsa_pem(key.as_bytes())
.or_else(|_| {
let der_bytes = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
key
).map_err(|_| TenantError::InvalidId("Invalid RSA key format".to_string()))?;
Ok(DecodingKey::from_rsa_der(&der_bytes))
})
}
Algorithm::ES256 | Algorithm::ES384 => {
let key = self.jwt_config.public_key.as_ref()
.ok_or_else(|| TenantError::InvalidId("EC public key not configured".to_string()))?;
DecodingKey::from_ec_pem(key.as_bytes())
.map_err(|e| TenantError::InvalidId(format!("Invalid EC key: {}", e)))
}
Algorithm::EdDSA => {
let key = self.jwt_config.public_key.as_ref()
.ok_or_else(|| TenantError::InvalidId("EdDSA public key not configured".to_string()))?;
DecodingKey::from_ed_pem(key.as_bytes())
.map_err(|e| TenantError::InvalidId(format!("Invalid EdDSA key: {}", e)))
}
_ => Err(TenantError::InvalidId(format!("Unsupported algorithm: {:?}", algorithm))),
}
}
pub fn resolve_from_subdomain(&self, host: &str) -> Result<TenantContext, TenantError> {
let parts: Vec<&str> = host.split('.').collect();
if parts.len() < 3 {
return Err(TenantError::InvalidId("Cannot determine tenant from host".to_string()));
}
let tenant_id = parts.first().ok_or_else(|| TenantError::InvalidId("Cannot determine tenant from host".to_string()))?;
let tenant = self.manager.get_tenant(tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.to_string()))?;
Ok(TenantContext::new(tenant, None, TenantRole::Anonymous))
}
pub fn resolve_from_path(&self, path: &str) -> Result<TenantContext, TenantError> {
let parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
if parts.len() < 2 || parts.first() != Some(&"tenants") {
return Err(TenantError::InvalidId("Invalid path format".to_string()));
}
let tenant_id = parts.get(1).ok_or_else(|| TenantError::InvalidId("Invalid path format".to_string()))?;
let tenant = self.manager.get_tenant(tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.to_string()))?;
Ok(TenantContext::new(tenant, None, TenantRole::Anonymous))
}
pub fn resolve_from_header(&self, tenant_id: &str) -> Result<TenantContext, TenantError> {
let tenant = self.manager.get_tenant(tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.to_string()))?;
if tenant.status != TenantStatus::Active && tenant.status != TenantStatus::Trial {
return Err(TenantError::NotActive(tenant_id.to_string()));
}
self.manager.touch(tenant_id);
Ok(TenantContext::new(tenant, None, TenantRole::Service))
}
pub fn create_jwt(&self, tenant_id: &str, user_id: &str, role: TenantRole, expires_in_secs: u64) -> Result<String, TenantError> {
use jsonwebtoken::{encode, EncodingKey, Header};
let secret = self.jwt_config.secret.as_ref()
.ok_or_else(|| TenantError::InvalidId("JWT secret not configured".to_string()))?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(std::time::Duration::ZERO)
.as_secs();
let claims = JwtClaims {
sub: Some(user_id.to_string()),
tenant_id: Some(tenant_id.to_string()),
iss: self.jwt_config.issuer.clone(),
aud: self.jwt_config.audience.as_ref().map(|a| OneOrMany::One(a.clone())),
exp: Some(now + expires_in_secs),
iat: Some(now),
nbf: Some(now),
jti: Some(uuid::Uuid::new_v4().to_string()),
email: None,
role: Some(format!("{:?}", role).to_lowercase()),
custom: HashMap::new(),
};
let header = Header::new(self.jwt_config.algorithm);
let key = EncodingKey::from_secret(secret.as_bytes());
encode(&header, &claims, &key)
.map_err(|e| TenantError::InvalidId(format!("Failed to create JWT: {}", e)))
}
}
pub struct TenantMiddleware {
resolver: Arc<TenantResolver>,
strategies: Vec<ResolutionStrategy>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ResolutionStrategy {
Header,
ApiKey,
Jwt,
Subdomain,
Path,
}
impl TenantMiddleware {
pub fn new(resolver: Arc<TenantResolver>) -> Self {
Self {
resolver,
strategies: vec![
ResolutionStrategy::Jwt,
ResolutionStrategy::Header,
ResolutionStrategy::ApiKey,
ResolutionStrategy::Path,
],
}
}
pub fn with_strategies(mut self, strategies: Vec<ResolutionStrategy>) -> Self {
self.strategies = strategies;
self
}
pub fn resolve(&self, request: &RequestInfo) -> Result<TenantContext, TenantError> {
let mut last_error = None;
for strategy in &self.strategies {
let result = match strategy {
ResolutionStrategy::Header => {
if let Some(ref tenant_id) = request.tenant_header {
self.resolver.resolve_from_header(tenant_id)
} else {
continue;
}
}
ResolutionStrategy::ApiKey => {
if let Some(ref api_key) = request.api_key {
self.resolver.resolve_from_api_key(api_key)
} else {
continue;
}
}
ResolutionStrategy::Jwt => {
if let Some(ref token) = request.jwt_token {
self.resolver.resolve_from_jwt(token)
} else {
continue;
}
}
ResolutionStrategy::Subdomain => {
if let Some(ref host) = request.host {
self.resolver.resolve_from_subdomain(host)
} else {
continue;
}
}
ResolutionStrategy::Path => {
self.resolver.resolve_from_path(&request.path)
}
};
match result {
Ok(ctx) => return Ok(ctx),
Err(e) => last_error = Some(e),
}
}
Err(last_error.unwrap_or_else(|| TenantError::NotFound("Could not resolve tenant".to_string())))
}
}
#[derive(Debug, Clone, Default)]
pub struct RequestInfo {
pub path: String,
pub host: Option<String>,
pub tenant_header: Option<String>,
pub api_key: Option<String>,
pub jwt_token: Option<String>,
}
impl RequestInfo {
pub fn from_headers(headers: &HashMap<String, String>, path: &str) -> Self {
Self {
path: path.to_string(),
host: headers.get("Host").cloned().or_else(|| headers.get("host").cloned()),
tenant_header: headers.get("X-Tenant-ID").cloned().or_else(|| headers.get("x-tenant-id").cloned()),
api_key: headers.get("X-API-Key").cloned()
.or_else(|| headers.get("x-api-key").cloned()),
jwt_token: Self::extract_jwt_token(headers),
}
}
fn extract_jwt_token(headers: &HashMap<String, String>) -> Option<String> {
headers.get("Authorization")
.or_else(|| headers.get("authorization"))
.and_then(|h| {
if let Some(token) = h.strip_prefix("Bearer ") {
if token.matches('.').count() == 2 {
return Some(token.to_string());
}
}
None
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi_tenant::{TenantManager, TenantPlan};
#[test]
fn test_tenant_context() {
let manager = TenantManager::new("test");
let tenant = manager.create_tenant("t1", "Test", TenantPlan::Pro).unwrap();
let ctx = TenantContext::new(tenant, Some("user1".to_string()), TenantRole::Admin);
assert!(ctx.can_write());
assert!(ctx.is_admin());
}
#[test]
fn test_thread_local_context() {
let manager = TenantManager::new("test");
let tenant = manager.create_tenant("t1", "Test", TenantPlan::Free).unwrap();
let ctx = TenantContext::new(tenant, None, TenantRole::Member);
let result = with_tenant(ctx.clone(), || {
let current = get_tenant_context().unwrap();
current.tenant.id.clone()
});
assert_eq!(result, "t1");
assert!(get_tenant_context().is_none());
}
#[test]
fn test_jwt_resolution() {
let manager = Arc::new(TenantManager::new("test"));
let _tenant = manager.create_tenant("test-tenant", "Test", TenantPlan::Pro).unwrap();
manager.update_status("test-tenant", TenantStatus::Active).unwrap();
let jwt_config = JwtConfig::with_secret("test-secret-key-that-is-long-enough")
.with_issuer("heliosdb")
.with_audience("heliosdb-api");
let resolver = Arc::new(TenantResolver::new(manager.clone()).with_jwt_config(jwt_config));
let token = resolver.create_jwt("test-tenant", "user123", TenantRole::Admin, 3600).unwrap();
let ctx = resolver.resolve_from_jwt(&token).unwrap();
assert_eq!(ctx.tenant.id, "test-tenant");
assert_eq!(ctx.user_id, Some("user123".to_string()));
assert_eq!(ctx.role, TenantRole::Admin);
}
#[test]
fn test_role_parsing() {
assert_eq!(TenantRole::from("admin"), TenantRole::Admin);
assert_eq!(TenantRole::from("ADMIN"), TenantRole::Admin);
assert_eq!(TenantRole::from("member"), TenantRole::Member);
assert_eq!(TenantRole::from("viewer"), TenantRole::Viewer);
assert_eq!(TenantRole::from("service"), TenantRole::Service);
assert_eq!(TenantRole::from("unknown"), TenantRole::Anonymous);
}
#[test]
fn test_middleware_resolution() {
let manager = Arc::new(TenantManager::new("test"));
let _tenant = manager.create_tenant("tenant-abc", "ABC Corp", TenantPlan::Starter).unwrap();
manager.update_status("tenant-abc", TenantStatus::Active).unwrap();
let resolver = Arc::new(TenantResolver::new(manager));
let middleware = TenantMiddleware::new(resolver)
.with_strategies(vec![ResolutionStrategy::Header]);
let request = RequestInfo {
path: "/api/data".to_string(),
host: None,
tenant_header: Some("tenant-abc".to_string()),
api_key: None,
jwt_token: None,
};
let ctx = middleware.resolve(&request).unwrap();
assert_eq!(ctx.tenant.id, "tenant-abc");
}
#[test]
fn test_api_key_resolution() {
let manager = Arc::new(TenantManager::new("test"));
let _tenant = manager.create_tenant("api-tenant", "API Tenant", TenantPlan::Pro).unwrap();
manager.update_status("api-tenant", TenantStatus::Active).unwrap();
let resolver = TenantResolver::new(manager);
let ctx = resolver.resolve_from_api_key("api-tenant:secret-key").unwrap();
assert_eq!(ctx.tenant.id, "api-tenant");
assert_eq!(ctx.role, TenantRole::Service);
}
}