heliosdb_proxy/multi_tenancy/
identifier.rs1use std::collections::HashMap;
14use std::sync::Arc;
15
16use super::config::{IdentificationMethod, TenantId};
17
18#[derive(Debug, Clone, Default)]
20pub struct RequestContext {
21 pub headers: HashMap<String, String>,
23
24 pub username: Option<String>,
26
27 pub database: Option<String>,
29
30 pub auth_token: Option<String>,
32
33 pub sql_context: HashMap<String, String>,
35
36 pub client_ip: Option<String>,
38
39 pub connection_id: Option<u64>,
41}
42
43impl RequestContext {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
51 self.headers.insert(name.into(), value.into());
52 self
53 }
54
55 pub fn with_username(mut self, username: impl Into<String>) -> Self {
57 self.username = Some(username.into());
58 self
59 }
60
61 pub fn with_database(mut self, database: impl Into<String>) -> Self {
63 self.database = Some(database.into());
64 self
65 }
66
67 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
69 self.auth_token = Some(token.into());
70 self
71 }
72
73 pub fn with_sql_context(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
75 self.sql_context.insert(name.into(), value.into());
76 self
77 }
78
79 pub fn with_client_ip(mut self, ip: impl Into<String>) -> Self {
81 self.client_ip = Some(ip.into());
82 self
83 }
84
85 pub fn get_header(&self, name: &str) -> Option<&str> {
87 self.headers.get(name).map(|s| s.as_str())
88 }
89
90 pub fn get_sql_context(&self, name: &str) -> Option<&str> {
92 self.sql_context.get(name).map(|s| s.as_str())
93 }
94}
95
96pub trait TenantIdentifier: Send + Sync {
98 fn identify(&self, request: &RequestContext) -> Option<TenantId>;
100
101 fn strategy_name(&self) -> &'static str;
103}
104
105#[derive(Debug, Clone)]
109pub struct HeaderTenantIdentifier {
110 header_name: String,
112
113 lowercase: bool,
115}
116
117impl HeaderTenantIdentifier {
118 pub fn new(header_name: impl Into<String>) -> Self {
120 Self {
121 header_name: header_name.into(),
122 lowercase: true,
123 }
124 }
125
126 pub fn default_header() -> Self {
128 Self::new("X-Tenant-Id")
129 }
130
131 pub fn case_sensitive(mut self) -> Self {
133 self.lowercase = false;
134 self
135 }
136}
137
138impl TenantIdentifier for HeaderTenantIdentifier {
139 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
140 request
141 .get_header(&self.header_name)
142 .filter(|v| !v.is_empty())
143 .map(|v| {
144 if self.lowercase {
145 TenantId::new(v.to_lowercase())
146 } else {
147 TenantId::new(v)
148 }
149 })
150 }
151
152 fn strategy_name(&self) -> &'static str {
153 "header"
154 }
155}
156
157#[derive(Debug, Clone)]
161pub struct UsernamePrefixIdentifier {
162 separator: char,
164
165 lowercase: bool,
167}
168
169impl UsernamePrefixIdentifier {
170 pub fn new(separator: char) -> Self {
172 Self {
173 separator,
174 lowercase: true,
175 }
176 }
177
178 pub fn with_dot() -> Self {
180 Self::new('.')
181 }
182
183 pub fn with_underscore() -> Self {
185 Self::new('_')
186 }
187
188 pub fn case_sensitive(mut self) -> Self {
190 self.lowercase = false;
191 self
192 }
193}
194
195impl TenantIdentifier for UsernamePrefixIdentifier {
196 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
197 request
198 .username
199 .as_ref()
200 .and_then(|username| username.split(self.separator).next())
201 .filter(|prefix| !prefix.is_empty())
202 .map(|prefix| {
203 if self.lowercase {
204 TenantId::new(prefix.to_lowercase())
205 } else {
206 TenantId::new(prefix)
207 }
208 })
209 }
210
211 fn strategy_name(&self) -> &'static str {
212 "username_prefix"
213 }
214}
215
216#[derive(Debug, Clone, Default)]
220pub struct DatabaseNameIdentifier {
221 prefix: Option<String>,
223
224 suffix: Option<String>,
226
227 lowercase: bool,
229}
230
231impl DatabaseNameIdentifier {
232 pub fn new() -> Self {
234 Self::default()
235 }
236
237 pub fn strip_prefix(mut self, prefix: impl Into<String>) -> Self {
239 self.prefix = Some(prefix.into());
240 self
241 }
242
243 pub fn strip_suffix(mut self, suffix: impl Into<String>) -> Self {
245 self.suffix = Some(suffix.into());
246 self
247 }
248
249 pub fn case_sensitive(mut self) -> Self {
251 self.lowercase = false;
252 self
253 }
254}
255
256impl TenantIdentifier for DatabaseNameIdentifier {
257 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
258 request.database.as_ref().map(|db| {
259 let mut name = db.as_str();
260
261 if let Some(prefix) = &self.prefix {
262 name = name.strip_prefix(prefix.as_str()).unwrap_or(name);
263 }
264
265 if let Some(suffix) = &self.suffix {
266 name = name.strip_suffix(suffix.as_str()).unwrap_or(name);
267 }
268
269 if self.lowercase {
270 TenantId::new(name.to_lowercase())
271 } else {
272 TenantId::new(name)
273 }
274 })
275 }
276
277 fn strategy_name(&self) -> &'static str {
278 "database_name"
279 }
280}
281
282#[derive(Debug, Clone)]
286pub struct SqlContextIdentifier {
287 variable_name: String,
289}
290
291impl SqlContextIdentifier {
292 pub fn new(variable_name: impl Into<String>) -> Self {
294 Self {
295 variable_name: variable_name.into(),
296 }
297 }
298
299 pub fn default_variable() -> Self {
301 Self::new("helios.tenant_id")
302 }
303}
304
305impl TenantIdentifier for SqlContextIdentifier {
306 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
307 request
308 .get_sql_context(&self.variable_name)
309 .filter(|v| !v.is_empty())
310 .map(|v| TenantId::new(v.to_lowercase()))
311 }
312
313 fn strategy_name(&self) -> &'static str {
314 "sql_context"
315 }
316}
317
318#[derive(Debug, Clone)]
322pub struct JwtClaimIdentifier {
323 claim_name: String,
325
326 issuer: Option<String>,
328
329 _verification_key: Option<String>,
332}
333
334impl JwtClaimIdentifier {
335 pub fn new(claim_name: impl Into<String>) -> Self {
337 Self {
338 claim_name: claim_name.into(),
339 issuer: None,
340 _verification_key: None,
341 }
342 }
343
344 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
346 self.issuer = Some(issuer.into());
347 self
348 }
349
350 fn extract_claim(&self, token: &str) -> Option<String> {
353 use base64::Engine;
354
355 let parts: Vec<&str> = token.split('.').collect();
356 if parts.len() != 3 {
357 return None;
358 }
359
360 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
362 .decode(parts[1])
363 .ok()?;
364
365 let payload_str = String::from_utf8(payload).ok()?;
366
367 self.extract_json_string(&payload_str, &self.claim_name)
370 }
371
372 fn extract_json_string(&self, json: &str, key: &str) -> Option<String> {
374 let pattern = format!("\"{}\"", key);
376 let pos = json.find(&pattern)?;
377 let after_key = &json[pos + pattern.len()..];
378
379 let after_colon = after_key.trim_start().strip_prefix(':')?;
381 let after_colon = after_colon.trim_start();
382
383 if let Some(inner) = after_colon.strip_prefix('"') {
385 let value_end = inner.find('"')?;
386 Some(inner[..value_end].to_string())
387 } else {
388 None
389 }
390 }
391}
392
393impl TenantIdentifier for JwtClaimIdentifier {
394 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
395 request
396 .auth_token
397 .as_ref()
398 .and_then(|token| self.extract_claim(token))
399 .filter(|claim| !claim.is_empty())
400 .map(|claim| TenantId::new(claim.to_lowercase()))
401 }
402
403 fn strategy_name(&self) -> &'static str {
404 "jwt_claim"
405 }
406}
407
408#[derive(Clone)]
410pub struct CompositeIdentifier {
411 identifiers: Vec<Arc<dyn TenantIdentifier>>,
413}
414
415impl CompositeIdentifier {
416 pub fn new() -> Self {
418 Self {
419 identifiers: Vec::new(),
420 }
421 }
422
423 #[allow(clippy::should_implement_trait)]
425 pub fn add<I: TenantIdentifier + 'static>(mut self, identifier: I) -> Self {
426 self.identifiers.push(Arc::new(identifier));
427 self
428 }
429
430 pub fn add_arc(mut self, identifier: Arc<dyn TenantIdentifier>) -> Self {
432 self.identifiers.push(identifier);
433 self
434 }
435}
436
437impl Default for CompositeIdentifier {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443impl TenantIdentifier for CompositeIdentifier {
444 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
445 for identifier in &self.identifiers {
446 if let Some(tenant) = identifier.identify(request) {
447 return Some(tenant);
448 }
449 }
450 None
451 }
452
453 fn strategy_name(&self) -> &'static str {
454 "composite"
455 }
456}
457
458pub fn create_identifier(method: &IdentificationMethod) -> Box<dyn TenantIdentifier> {
460 match method {
461 IdentificationMethod::Header { header_name } => {
462 Box::new(HeaderTenantIdentifier::new(header_name))
463 }
464 IdentificationMethod::UsernamePrefix { separator } => {
465 Box::new(UsernamePrefixIdentifier::new(*separator))
466 }
467 IdentificationMethod::JwtClaim { claim_name, issuer } => {
468 let mut identifier = JwtClaimIdentifier::new(claim_name);
469 if let Some(iss) = issuer {
470 identifier = identifier.with_issuer(iss);
471 }
472 Box::new(identifier)
473 }
474 IdentificationMethod::DatabaseName => Box::new(DatabaseNameIdentifier::new()),
475 IdentificationMethod::SqlContext { variable_name } => {
476 Box::new(SqlContextIdentifier::new(variable_name))
477 }
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_header_identifier() {
487 let identifier = HeaderTenantIdentifier::new("X-Tenant-Id");
488
489 let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
490 assert_eq!(
491 identifier.identify(&ctx).map(|t| t.0),
492 Some("tenanta".to_string())
493 );
494
495 let ctx_missing = RequestContext::new();
496 assert!(identifier.identify(&ctx_missing).is_none());
497
498 let ctx_empty = RequestContext::new().with_header("X-Tenant-Id", "");
499 assert!(identifier.identify(&ctx_empty).is_none());
500 }
501
502 #[test]
503 fn test_header_identifier_case_sensitive() {
504 let identifier = HeaderTenantIdentifier::new("X-Tenant-Id").case_sensitive();
505
506 let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
507 assert_eq!(
508 identifier.identify(&ctx).map(|t| t.0),
509 Some("TenantA".to_string())
510 );
511 }
512
513 #[test]
514 fn test_username_prefix_identifier() {
515 let identifier = UsernamePrefixIdentifier::with_dot();
516
517 let ctx = RequestContext::new().with_username("tenant_a.admin");
518 assert_eq!(
519 identifier.identify(&ctx).map(|t| t.0),
520 Some("tenant_a".to_string())
521 );
522
523 let ctx_no_prefix = RequestContext::new().with_username("admin");
524 assert_eq!(
525 identifier.identify(&ctx_no_prefix).map(|t| t.0),
526 Some("admin".to_string())
527 );
528
529 let ctx_missing = RequestContext::new();
530 assert!(identifier.identify(&ctx_missing).is_none());
531 }
532
533 #[test]
534 fn test_database_name_identifier() {
535 let identifier = DatabaseNameIdentifier::new()
536 .strip_prefix("tenant_")
537 .strip_suffix("_db");
538
539 let ctx = RequestContext::new().with_database("tenant_acme_db");
540 assert_eq!(
541 identifier.identify(&ctx).map(|t| t.0),
542 Some("acme".to_string())
543 );
544
545 let ctx_no_fix = RequestContext::new().with_database("mydb");
546 assert_eq!(
547 identifier.identify(&ctx_no_fix).map(|t| t.0),
548 Some("mydb".to_string())
549 );
550 }
551
552 #[test]
553 fn test_sql_context_identifier() {
554 let identifier = SqlContextIdentifier::default_variable();
555
556 let ctx = RequestContext::new().with_sql_context("helios.tenant_id", "tenant_x");
557 assert_eq!(
558 identifier.identify(&ctx).map(|t| t.0),
559 Some("tenant_x".to_string())
560 );
561
562 let ctx_missing = RequestContext::new();
563 assert!(identifier.identify(&ctx_missing).is_none());
564 }
565
566 #[test]
567 fn test_jwt_claim_identifier() {
568 let identifier = JwtClaimIdentifier::new("tenant_id");
569
570 use base64::Engine;
573 let payload = r#"{"tenant_id":"acme","sub":"user1"}"#;
574 let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload);
575 let token = format!("header.{}.signature", encoded_payload);
576
577 let ctx = RequestContext::new().with_auth_token(&token);
578 assert_eq!(
579 identifier.identify(&ctx).map(|t| t.0),
580 Some("acme".to_string())
581 );
582 }
583
584 #[test]
585 fn test_composite_identifier() {
586 let identifier = CompositeIdentifier::new()
587 .add(HeaderTenantIdentifier::new("X-Tenant-Id"))
588 .add(UsernamePrefixIdentifier::with_dot());
589
590 let ctx = RequestContext::new()
592 .with_header("X-Tenant-Id", "header_tenant")
593 .with_username("user_tenant.admin");
594 assert_eq!(
595 identifier.identify(&ctx).map(|t| t.0),
596 Some("header_tenant".to_string())
597 );
598
599 let ctx_no_header = RequestContext::new().with_username("user_tenant.admin");
601 assert_eq!(
602 identifier.identify(&ctx_no_header).map(|t| t.0),
603 Some("user_tenant".to_string())
604 );
605
606 let ctx_empty = RequestContext::new();
608 assert!(identifier.identify(&ctx_empty).is_none());
609 }
610
611 #[test]
612 fn test_create_identifier() {
613 let method = IdentificationMethod::header("X-Org-Id");
614 let identifier = create_identifier(&method);
615 assert_eq!(identifier.strategy_name(), "header");
616
617 let method = IdentificationMethod::username_prefix('_');
618 let identifier = create_identifier(&method);
619 assert_eq!(identifier.strategy_name(), "username_prefix");
620 }
621}