use std::collections::HashMap;
use std::sync::Arc;
use super::config::{IdentificationMethod, TenantId};
#[derive(Debug, Clone, Default)]
pub struct RequestContext {
pub headers: HashMap<String, String>,
pub username: Option<String>,
pub database: Option<String>,
pub auth_token: Option<String>,
pub sql_context: HashMap<String, String>,
pub client_ip: Option<String>,
pub connection_id: Option<u64>,
}
impl RequestContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
pub fn with_username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
self.auth_token = Some(token.into());
self
}
pub fn with_sql_context(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.sql_context.insert(name.into(), value.into());
self
}
pub fn with_client_ip(mut self, ip: impl Into<String>) -> Self {
self.client_ip = Some(ip.into());
self
}
pub fn get_header(&self, name: &str) -> Option<&str> {
self.headers.get(name).map(|s| s.as_str())
}
pub fn get_sql_context(&self, name: &str) -> Option<&str> {
self.sql_context.get(name).map(|s| s.as_str())
}
}
pub trait TenantIdentifier: Send + Sync {
fn identify(&self, request: &RequestContext) -> Option<TenantId>;
fn strategy_name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct HeaderTenantIdentifier {
header_name: String,
lowercase: bool,
}
impl HeaderTenantIdentifier {
pub fn new(header_name: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
lowercase: true,
}
}
pub fn default_header() -> Self {
Self::new("X-Tenant-Id")
}
pub fn case_sensitive(mut self) -> Self {
self.lowercase = false;
self
}
}
impl TenantIdentifier for HeaderTenantIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
request
.get_header(&self.header_name)
.filter(|v| !v.is_empty())
.map(|v| {
if self.lowercase {
TenantId::new(v.to_lowercase())
} else {
TenantId::new(v)
}
})
}
fn strategy_name(&self) -> &'static str {
"header"
}
}
#[derive(Debug, Clone)]
pub struct UsernamePrefixIdentifier {
separator: char,
lowercase: bool,
}
impl UsernamePrefixIdentifier {
pub fn new(separator: char) -> Self {
Self {
separator,
lowercase: true,
}
}
pub fn with_dot() -> Self {
Self::new('.')
}
pub fn with_underscore() -> Self {
Self::new('_')
}
pub fn case_sensitive(mut self) -> Self {
self.lowercase = false;
self
}
}
impl TenantIdentifier for UsernamePrefixIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
request
.username
.as_ref()
.and_then(|username| username.split(self.separator).next())
.filter(|prefix| !prefix.is_empty())
.map(|prefix| {
if self.lowercase {
TenantId::new(prefix.to_lowercase())
} else {
TenantId::new(prefix)
}
})
}
fn strategy_name(&self) -> &'static str {
"username_prefix"
}
}
#[derive(Debug, Clone, Default)]
pub struct DatabaseNameIdentifier {
prefix: Option<String>,
suffix: Option<String>,
lowercase: bool,
}
impl DatabaseNameIdentifier {
pub fn new() -> Self {
Self::default()
}
pub fn strip_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn strip_suffix(mut self, suffix: impl Into<String>) -> Self {
self.suffix = Some(suffix.into());
self
}
pub fn case_sensitive(mut self) -> Self {
self.lowercase = false;
self
}
}
impl TenantIdentifier for DatabaseNameIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
request.database.as_ref().map(|db| {
let mut name = db.as_str();
if let Some(prefix) = &self.prefix {
name = name.strip_prefix(prefix.as_str()).unwrap_or(name);
}
if let Some(suffix) = &self.suffix {
name = name.strip_suffix(suffix.as_str()).unwrap_or(name);
}
if self.lowercase {
TenantId::new(name.to_lowercase())
} else {
TenantId::new(name)
}
})
}
fn strategy_name(&self) -> &'static str {
"database_name"
}
}
#[derive(Debug, Clone)]
pub struct SqlContextIdentifier {
variable_name: String,
}
impl SqlContextIdentifier {
pub fn new(variable_name: impl Into<String>) -> Self {
Self {
variable_name: variable_name.into(),
}
}
pub fn default_variable() -> Self {
Self::new("helios.tenant_id")
}
}
impl TenantIdentifier for SqlContextIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
request
.get_sql_context(&self.variable_name)
.filter(|v| !v.is_empty())
.map(|v| TenantId::new(v.to_lowercase()))
}
fn strategy_name(&self) -> &'static str {
"sql_context"
}
}
#[derive(Debug, Clone)]
pub struct JwtClaimIdentifier {
claim_name: String,
issuer: Option<String>,
_verification_key: Option<String>,
}
impl JwtClaimIdentifier {
pub fn new(claim_name: impl Into<String>) -> Self {
Self {
claim_name: claim_name.into(),
issuer: None,
_verification_key: None,
}
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
fn extract_claim(&self, token: &str) -> Option<String> {
use base64::Engine;
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return None;
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.ok()?;
let payload_str = String::from_utf8(payload).ok()?;
self.extract_json_string(&payload_str, &self.claim_name)
}
fn extract_json_string(&self, json: &str, key: &str) -> Option<String> {
let pattern = format!("\"{}\"", key);
let pos = json.find(&pattern)?;
let after_key = &json[pos + pattern.len()..];
let after_colon = after_key.trim_start().strip_prefix(':')?;
let after_colon = after_colon.trim_start();
if after_colon.starts_with('"') {
let value_start = 1;
let value_end = after_colon[1..].find('"')? + 1;
Some(after_colon[value_start..value_end].to_string())
} else {
None
}
}
}
impl TenantIdentifier for JwtClaimIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
request
.auth_token
.as_ref()
.and_then(|token| self.extract_claim(token))
.filter(|claim| !claim.is_empty())
.map(|claim| TenantId::new(claim.to_lowercase()))
}
fn strategy_name(&self) -> &'static str {
"jwt_claim"
}
}
#[derive(Clone)]
pub struct CompositeIdentifier {
identifiers: Vec<Arc<dyn TenantIdentifier>>,
}
impl CompositeIdentifier {
pub fn new() -> Self {
Self {
identifiers: Vec::new(),
}
}
pub fn add<I: TenantIdentifier + 'static>(mut self, identifier: I) -> Self {
self.identifiers.push(Arc::new(identifier));
self
}
pub fn add_arc(mut self, identifier: Arc<dyn TenantIdentifier>) -> Self {
self.identifiers.push(identifier);
self
}
}
impl Default for CompositeIdentifier {
fn default() -> Self {
Self::new()
}
}
impl TenantIdentifier for CompositeIdentifier {
fn identify(&self, request: &RequestContext) -> Option<TenantId> {
for identifier in &self.identifiers {
if let Some(tenant) = identifier.identify(request) {
return Some(tenant);
}
}
None
}
fn strategy_name(&self) -> &'static str {
"composite"
}
}
pub fn create_identifier(method: &IdentificationMethod) -> Box<dyn TenantIdentifier> {
match method {
IdentificationMethod::Header { header_name } => {
Box::new(HeaderTenantIdentifier::new(header_name))
}
IdentificationMethod::UsernamePrefix { separator } => {
Box::new(UsernamePrefixIdentifier::new(*separator))
}
IdentificationMethod::JwtClaim { claim_name, issuer } => {
let mut identifier = JwtClaimIdentifier::new(claim_name);
if let Some(iss) = issuer {
identifier = identifier.with_issuer(iss);
}
Box::new(identifier)
}
IdentificationMethod::DatabaseName => {
Box::new(DatabaseNameIdentifier::new())
}
IdentificationMethod::SqlContext { variable_name } => {
Box::new(SqlContextIdentifier::new(variable_name))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_identifier() {
let identifier = HeaderTenantIdentifier::new("X-Tenant-Id");
let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("tenanta".to_string())
);
let ctx_missing = RequestContext::new();
assert!(identifier.identify(&ctx_missing).is_none());
let ctx_empty = RequestContext::new().with_header("X-Tenant-Id", "");
assert!(identifier.identify(&ctx_empty).is_none());
}
#[test]
fn test_header_identifier_case_sensitive() {
let identifier = HeaderTenantIdentifier::new("X-Tenant-Id").case_sensitive();
let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("TenantA".to_string())
);
}
#[test]
fn test_username_prefix_identifier() {
let identifier = UsernamePrefixIdentifier::with_dot();
let ctx = RequestContext::new().with_username("tenant_a.admin");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("tenant_a".to_string())
);
let ctx_no_prefix = RequestContext::new().with_username("admin");
assert_eq!(
identifier.identify(&ctx_no_prefix).map(|t| t.0),
Some("admin".to_string())
);
let ctx_missing = RequestContext::new();
assert!(identifier.identify(&ctx_missing).is_none());
}
#[test]
fn test_database_name_identifier() {
let identifier = DatabaseNameIdentifier::new()
.strip_prefix("tenant_")
.strip_suffix("_db");
let ctx = RequestContext::new().with_database("tenant_acme_db");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("acme".to_string())
);
let ctx_no_fix = RequestContext::new().with_database("mydb");
assert_eq!(
identifier.identify(&ctx_no_fix).map(|t| t.0),
Some("mydb".to_string())
);
}
#[test]
fn test_sql_context_identifier() {
let identifier = SqlContextIdentifier::default_variable();
let ctx = RequestContext::new().with_sql_context("helios.tenant_id", "tenant_x");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("tenant_x".to_string())
);
let ctx_missing = RequestContext::new();
assert!(identifier.identify(&ctx_missing).is_none());
}
#[test]
fn test_jwt_claim_identifier() {
let identifier = JwtClaimIdentifier::new("tenant_id");
use base64::Engine;
let payload = r#"{"tenant_id":"acme","sub":"user1"}"#;
let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload);
let token = format!("header.{}.signature", encoded_payload);
let ctx = RequestContext::new().with_auth_token(&token);
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("acme".to_string())
);
}
#[test]
fn test_composite_identifier() {
let identifier = CompositeIdentifier::new()
.add(HeaderTenantIdentifier::new("X-Tenant-Id"))
.add(UsernamePrefixIdentifier::with_dot());
let ctx = RequestContext::new()
.with_header("X-Tenant-Id", "header_tenant")
.with_username("user_tenant.admin");
assert_eq!(
identifier.identify(&ctx).map(|t| t.0),
Some("header_tenant".to_string())
);
let ctx_no_header = RequestContext::new().with_username("user_tenant.admin");
assert_eq!(
identifier.identify(&ctx_no_header).map(|t| t.0),
Some("user_tenant".to_string())
);
let ctx_empty = RequestContext::new();
assert!(identifier.identify(&ctx_empty).is_none());
}
#[test]
fn test_create_identifier() {
let method = IdentificationMethod::header("X-Org-Id");
let identifier = create_identifier(&method);
assert_eq!(identifier.strategy_name(), "header");
let method = IdentificationMethod::username_prefix('_');
let identifier = create_identifier(&method);
assert_eq!(identifier.strategy_name(), "username_prefix");
}
}