Skip to main content

heliosdb_proxy/multi_tenancy/
identifier.rs

1//! Tenant Identification Strategies
2//!
3//! This module provides different strategies for identifying tenants from incoming requests.
4//!
5//! # Strategies
6//!
7//! - **Header**: Extract tenant ID from HTTP header (e.g., X-Tenant-Id)
8//! - **UsernamePrefix**: Extract from username prefix (e.g., tenant.user -> tenant)
9//! - **JWT**: Extract from JWT claim
10//! - **DatabaseName**: Use database name as tenant ID
11//! - **SqlContext**: Extract from SQL context variable
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use super::config::{IdentificationMethod, TenantId};
17
18/// Request context for tenant identification
19#[derive(Debug, Clone, Default)]
20pub struct RequestContext {
21    /// HTTP headers (or similar protocol headers)
22    pub headers: HashMap<String, String>,
23
24    /// Username from authentication
25    pub username: Option<String>,
26
27    /// Database name from connection
28    pub database: Option<String>,
29
30    /// Authentication token (e.g., JWT)
31    pub auth_token: Option<String>,
32
33    /// SQL context variables
34    pub sql_context: HashMap<String, String>,
35
36    /// Client IP address
37    pub client_ip: Option<String>,
38
39    /// Connection ID
40    pub connection_id: Option<u64>,
41}
42
43impl RequestContext {
44    /// Create a new empty request context
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Set a header
50    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
51        self.headers.insert(name.into(), value.into());
52        self
53    }
54
55    /// Set username
56    pub fn with_username(mut self, username: impl Into<String>) -> Self {
57        self.username = Some(username.into());
58        self
59    }
60
61    /// Set database
62    pub fn with_database(mut self, database: impl Into<String>) -> Self {
63        self.database = Some(database.into());
64        self
65    }
66
67    /// Set auth token
68    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
69        self.auth_token = Some(token.into());
70        self
71    }
72
73    /// Set SQL context variable
74    pub fn with_sql_context(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
75        self.sql_context.insert(name.into(), value.into());
76        self
77    }
78
79    /// Set client IP
80    pub fn with_client_ip(mut self, ip: impl Into<String>) -> Self {
81        self.client_ip = Some(ip.into());
82        self
83    }
84
85    /// Get header value
86    pub fn get_header(&self, name: &str) -> Option<&str> {
87        self.headers.get(name).map(|s| s.as_str())
88    }
89
90    /// Get SQL context variable
91    pub fn get_sql_context(&self, name: &str) -> Option<&str> {
92        self.sql_context.get(name).map(|s| s.as_str())
93    }
94}
95
96/// Trait for tenant identification strategies
97pub trait TenantIdentifier: Send + Sync {
98    /// Identify tenant from request context
99    fn identify(&self, request: &RequestContext) -> Option<TenantId>;
100
101    /// Get the name of this identification strategy
102    fn strategy_name(&self) -> &'static str;
103}
104
105/// Header-based tenant identification
106///
107/// Extracts tenant ID from a specific HTTP header.
108#[derive(Debug, Clone)]
109pub struct HeaderTenantIdentifier {
110    /// Header name to extract tenant ID from
111    header_name: String,
112
113    /// Whether to lowercase the tenant ID
114    lowercase: bool,
115}
116
117impl HeaderTenantIdentifier {
118    /// Create a new header identifier
119    pub fn new(header_name: impl Into<String>) -> Self {
120        Self {
121            header_name: header_name.into(),
122            lowercase: true,
123        }
124    }
125
126    /// Create with X-Tenant-Id header
127    pub fn default_header() -> Self {
128        Self::new("X-Tenant-Id")
129    }
130
131    /// Don't lowercase the tenant ID
132    pub fn case_sensitive(mut self) -> Self {
133        self.lowercase = false;
134        self
135    }
136}
137
138impl TenantIdentifier for HeaderTenantIdentifier {
139    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
140        request
141            .get_header(&self.header_name)
142            .filter(|v| !v.is_empty())
143            .map(|v| {
144                if self.lowercase {
145                    TenantId::new(v.to_lowercase())
146                } else {
147                    TenantId::new(v)
148                }
149            })
150    }
151
152    fn strategy_name(&self) -> &'static str {
153        "header"
154    }
155}
156
157/// Username prefix-based tenant identification
158///
159/// Extracts tenant ID from username prefix (e.g., "tenant_a.user" -> "tenant_a")
160#[derive(Debug, Clone)]
161pub struct UsernamePrefixIdentifier {
162    /// Separator character between tenant and username
163    separator: char,
164
165    /// Whether to lowercase the tenant ID
166    lowercase: bool,
167}
168
169impl UsernamePrefixIdentifier {
170    /// Create a new username prefix identifier
171    pub fn new(separator: char) -> Self {
172        Self {
173            separator,
174            lowercase: true,
175        }
176    }
177
178    /// Create with dot separator
179    pub fn with_dot() -> Self {
180        Self::new('.')
181    }
182
183    /// Create with underscore separator
184    pub fn with_underscore() -> Self {
185        Self::new('_')
186    }
187
188    /// Don't lowercase the tenant ID
189    pub fn case_sensitive(mut self) -> Self {
190        self.lowercase = false;
191        self
192    }
193}
194
195impl TenantIdentifier for UsernamePrefixIdentifier {
196    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
197        request
198            .username
199            .as_ref()
200            .and_then(|username| username.split(self.separator).next())
201            .filter(|prefix| !prefix.is_empty())
202            .map(|prefix| {
203                if self.lowercase {
204                    TenantId::new(prefix.to_lowercase())
205                } else {
206                    TenantId::new(prefix)
207                }
208            })
209    }
210
211    fn strategy_name(&self) -> &'static str {
212        "username_prefix"
213    }
214}
215
216/// Database name-based tenant identification
217///
218/// Uses the database name as the tenant ID.
219#[derive(Debug, Clone, Default)]
220pub struct DatabaseNameIdentifier {
221    /// Prefix to strip from database name (e.g., "tenant_")
222    prefix: Option<String>,
223
224    /// Suffix to strip from database name (e.g., "_db")
225    suffix: Option<String>,
226
227    /// Whether to lowercase the tenant ID
228    lowercase: bool,
229}
230
231impl DatabaseNameIdentifier {
232    /// Create a new database name identifier
233    pub fn new() -> Self {
234        Self::default()
235    }
236
237    /// Strip prefix from database name
238    pub fn strip_prefix(mut self, prefix: impl Into<String>) -> Self {
239        self.prefix = Some(prefix.into());
240        self
241    }
242
243    /// Strip suffix from database name
244    pub fn strip_suffix(mut self, suffix: impl Into<String>) -> Self {
245        self.suffix = Some(suffix.into());
246        self
247    }
248
249    /// Don't lowercase the tenant ID
250    pub fn case_sensitive(mut self) -> Self {
251        self.lowercase = false;
252        self
253    }
254}
255
256impl TenantIdentifier for DatabaseNameIdentifier {
257    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
258        request.database.as_ref().map(|db| {
259            let mut name = db.as_str();
260
261            if let Some(prefix) = &self.prefix {
262                name = name.strip_prefix(prefix.as_str()).unwrap_or(name);
263            }
264
265            if let Some(suffix) = &self.suffix {
266                name = name.strip_suffix(suffix.as_str()).unwrap_or(name);
267            }
268
269            if self.lowercase {
270                TenantId::new(name.to_lowercase())
271            } else {
272                TenantId::new(name)
273            }
274        })
275    }
276
277    fn strategy_name(&self) -> &'static str {
278        "database_name"
279    }
280}
281
282/// SQL context variable-based tenant identification
283///
284/// Extracts tenant ID from a SQL session variable (e.g., SET helios.tenant_id = 'tenant_a')
285#[derive(Debug, Clone)]
286pub struct SqlContextIdentifier {
287    /// Variable name to look for
288    variable_name: String,
289}
290
291impl SqlContextIdentifier {
292    /// Create a new SQL context identifier
293    pub fn new(variable_name: impl Into<String>) -> Self {
294        Self {
295            variable_name: variable_name.into(),
296        }
297    }
298
299    /// Create with default variable name
300    pub fn default_variable() -> Self {
301        Self::new("helios.tenant_id")
302    }
303}
304
305impl TenantIdentifier for SqlContextIdentifier {
306    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
307        request
308            .get_sql_context(&self.variable_name)
309            .filter(|v| !v.is_empty())
310            .map(|v| TenantId::new(v.to_lowercase()))
311    }
312
313    fn strategy_name(&self) -> &'static str {
314        "sql_context"
315    }
316}
317
318/// JWT claim-based tenant identification
319///
320/// Extracts tenant ID from a JWT token claim.
321#[derive(Debug, Clone)]
322pub struct JwtClaimIdentifier {
323    /// JWT claim name
324    claim_name: String,
325
326    /// Expected issuer (optional validation)
327    issuer: Option<String>,
328
329    /// JWT verification key (simplified - in real impl would be more complex)
330    /// In production, this would integrate with a proper JWT library
331    _verification_key: Option<String>,
332}
333
334impl JwtClaimIdentifier {
335    /// Create a new JWT claim identifier
336    pub fn new(claim_name: impl Into<String>) -> Self {
337        Self {
338            claim_name: claim_name.into(),
339            issuer: None,
340            _verification_key: None,
341        }
342    }
343
344    /// Set expected issuer
345    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
346        self.issuer = Some(issuer.into());
347        self
348    }
349
350    /// Simple JWT payload extraction (base64 decode middle part)
351    /// In production, proper signature verification would be required
352    fn extract_claim(&self, token: &str) -> Option<String> {
353        use base64::Engine;
354
355        let parts: Vec<&str> = token.split('.').collect();
356        if parts.len() != 3 {
357            return None;
358        }
359
360        // Decode payload (middle part)
361        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
362            .decode(parts[1])
363            .ok()?;
364
365        let payload_str = String::from_utf8(payload).ok()?;
366
367        // Simple JSON parsing for claim extraction
368        // In production, use serde_json
369        self.extract_json_string(&payload_str, &self.claim_name)
370    }
371
372    /// Simple JSON string extraction (for demo purposes)
373    fn extract_json_string(&self, json: &str, key: &str) -> Option<String> {
374        // Look for "key":"value" or "key": "value"
375        let pattern = format!("\"{}\"", key);
376        let pos = json.find(&pattern)?;
377        let after_key = &json[pos + pattern.len()..];
378
379        // Skip whitespace and colon
380        let after_colon = after_key.trim_start().strip_prefix(':')?;
381        let after_colon = after_colon.trim_start();
382
383        // Extract quoted value
384        if after_colon.starts_with('"') {
385            let value_start = 1;
386            let value_end = after_colon[1..].find('"')? + 1;
387            Some(after_colon[value_start..value_end].to_string())
388        } else {
389            None
390        }
391    }
392}
393
394impl TenantIdentifier for JwtClaimIdentifier {
395    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
396        request
397            .auth_token
398            .as_ref()
399            .and_then(|token| self.extract_claim(token))
400            .filter(|claim| !claim.is_empty())
401            .map(|claim| TenantId::new(claim.to_lowercase()))
402    }
403
404    fn strategy_name(&self) -> &'static str {
405        "jwt_claim"
406    }
407}
408
409/// Composite identifier that tries multiple strategies in order
410#[derive(Clone)]
411pub struct CompositeIdentifier {
412    /// Identifiers to try in order
413    identifiers: Vec<Arc<dyn TenantIdentifier>>,
414}
415
416impl CompositeIdentifier {
417    /// Create a new composite identifier
418    pub fn new() -> Self {
419        Self {
420            identifiers: Vec::new(),
421        }
422    }
423
424    /// Add an identifier to try
425    pub fn add<I: TenantIdentifier + 'static>(mut self, identifier: I) -> Self {
426        self.identifiers.push(Arc::new(identifier));
427        self
428    }
429
430    /// Add an identifier wrapped in Arc
431    pub fn add_arc(mut self, identifier: Arc<dyn TenantIdentifier>) -> Self {
432        self.identifiers.push(identifier);
433        self
434    }
435}
436
437impl Default for CompositeIdentifier {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443impl TenantIdentifier for CompositeIdentifier {
444    fn identify(&self, request: &RequestContext) -> Option<TenantId> {
445        for identifier in &self.identifiers {
446            if let Some(tenant) = identifier.identify(request) {
447                return Some(tenant);
448            }
449        }
450        None
451    }
452
453    fn strategy_name(&self) -> &'static str {
454        "composite"
455    }
456}
457
458/// Create a tenant identifier from identification method
459pub fn create_identifier(method: &IdentificationMethod) -> Box<dyn TenantIdentifier> {
460    match method {
461        IdentificationMethod::Header { header_name } => {
462            Box::new(HeaderTenantIdentifier::new(header_name))
463        }
464        IdentificationMethod::UsernamePrefix { separator } => {
465            Box::new(UsernamePrefixIdentifier::new(*separator))
466        }
467        IdentificationMethod::JwtClaim { claim_name, issuer } => {
468            let mut identifier = JwtClaimIdentifier::new(claim_name);
469            if let Some(iss) = issuer {
470                identifier = identifier.with_issuer(iss);
471            }
472            Box::new(identifier)
473        }
474        IdentificationMethod::DatabaseName => {
475            Box::new(DatabaseNameIdentifier::new())
476        }
477        IdentificationMethod::SqlContext { variable_name } => {
478            Box::new(SqlContextIdentifier::new(variable_name))
479        }
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_header_identifier() {
489        let identifier = HeaderTenantIdentifier::new("X-Tenant-Id");
490
491        let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
492        assert_eq!(
493            identifier.identify(&ctx).map(|t| t.0),
494            Some("tenanta".to_string())
495        );
496
497        let ctx_missing = RequestContext::new();
498        assert!(identifier.identify(&ctx_missing).is_none());
499
500        let ctx_empty = RequestContext::new().with_header("X-Tenant-Id", "");
501        assert!(identifier.identify(&ctx_empty).is_none());
502    }
503
504    #[test]
505    fn test_header_identifier_case_sensitive() {
506        let identifier = HeaderTenantIdentifier::new("X-Tenant-Id").case_sensitive();
507
508        let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
509        assert_eq!(
510            identifier.identify(&ctx).map(|t| t.0),
511            Some("TenantA".to_string())
512        );
513    }
514
515    #[test]
516    fn test_username_prefix_identifier() {
517        let identifier = UsernamePrefixIdentifier::with_dot();
518
519        let ctx = RequestContext::new().with_username("tenant_a.admin");
520        assert_eq!(
521            identifier.identify(&ctx).map(|t| t.0),
522            Some("tenant_a".to_string())
523        );
524
525        let ctx_no_prefix = RequestContext::new().with_username("admin");
526        assert_eq!(
527            identifier.identify(&ctx_no_prefix).map(|t| t.0),
528            Some("admin".to_string())
529        );
530
531        let ctx_missing = RequestContext::new();
532        assert!(identifier.identify(&ctx_missing).is_none());
533    }
534
535    #[test]
536    fn test_database_name_identifier() {
537        let identifier = DatabaseNameIdentifier::new()
538            .strip_prefix("tenant_")
539            .strip_suffix("_db");
540
541        let ctx = RequestContext::new().with_database("tenant_acme_db");
542        assert_eq!(
543            identifier.identify(&ctx).map(|t| t.0),
544            Some("acme".to_string())
545        );
546
547        let ctx_no_fix = RequestContext::new().with_database("mydb");
548        assert_eq!(
549            identifier.identify(&ctx_no_fix).map(|t| t.0),
550            Some("mydb".to_string())
551        );
552    }
553
554    #[test]
555    fn test_sql_context_identifier() {
556        let identifier = SqlContextIdentifier::default_variable();
557
558        let ctx = RequestContext::new().with_sql_context("helios.tenant_id", "tenant_x");
559        assert_eq!(
560            identifier.identify(&ctx).map(|t| t.0),
561            Some("tenant_x".to_string())
562        );
563
564        let ctx_missing = RequestContext::new();
565        assert!(identifier.identify(&ctx_missing).is_none());
566    }
567
568    #[test]
569    fn test_jwt_claim_identifier() {
570        let identifier = JwtClaimIdentifier::new("tenant_id");
571
572        // Create a simple JWT-like token (header.payload.signature)
573        // Payload: {"tenant_id":"acme","sub":"user1"}
574        use base64::Engine;
575        let payload = r#"{"tenant_id":"acme","sub":"user1"}"#;
576        let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload);
577        let token = format!("header.{}.signature", encoded_payload);
578
579        let ctx = RequestContext::new().with_auth_token(&token);
580        assert_eq!(
581            identifier.identify(&ctx).map(|t| t.0),
582            Some("acme".to_string())
583        );
584    }
585
586    #[test]
587    fn test_composite_identifier() {
588        let identifier = CompositeIdentifier::new()
589            .add(HeaderTenantIdentifier::new("X-Tenant-Id"))
590            .add(UsernamePrefixIdentifier::with_dot());
591
592        // Header takes precedence
593        let ctx = RequestContext::new()
594            .with_header("X-Tenant-Id", "header_tenant")
595            .with_username("user_tenant.admin");
596        assert_eq!(
597            identifier.identify(&ctx).map(|t| t.0),
598            Some("header_tenant".to_string())
599        );
600
601        // Falls back to username prefix
602        let ctx_no_header = RequestContext::new().with_username("user_tenant.admin");
603        assert_eq!(
604            identifier.identify(&ctx_no_header).map(|t| t.0),
605            Some("user_tenant".to_string())
606        );
607
608        // No match
609        let ctx_empty = RequestContext::new();
610        assert!(identifier.identify(&ctx_empty).is_none());
611    }
612
613    #[test]
614    fn test_create_identifier() {
615        let method = IdentificationMethod::header("X-Org-Id");
616        let identifier = create_identifier(&method);
617        assert_eq!(identifier.strategy_name(), "header");
618
619        let method = IdentificationMethod::username_prefix('_');
620        let identifier = create_identifier(&method);
621        assert_eq!(identifier.strategy_name(), "username_prefix");
622    }
623}