heliosdb_proxy/multi_tenancy/
identifier.rs1use std::collections::HashMap;
14use std::sync::Arc;
15
16use super::config::{IdentificationMethod, TenantId};
17
18#[derive(Debug, Clone, Default)]
20pub struct RequestContext {
21 pub headers: HashMap<String, String>,
23
24 pub username: Option<String>,
26
27 pub database: Option<String>,
29
30 pub auth_token: Option<String>,
32
33 pub sql_context: HashMap<String, String>,
35
36 pub client_ip: Option<String>,
38
39 pub connection_id: Option<u64>,
41}
42
43impl RequestContext {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
51 self.headers.insert(name.into(), value.into());
52 self
53 }
54
55 pub fn with_username(mut self, username: impl Into<String>) -> Self {
57 self.username = Some(username.into());
58 self
59 }
60
61 pub fn with_database(mut self, database: impl Into<String>) -> Self {
63 self.database = Some(database.into());
64 self
65 }
66
67 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
69 self.auth_token = Some(token.into());
70 self
71 }
72
73 pub fn with_sql_context(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
75 self.sql_context.insert(name.into(), value.into());
76 self
77 }
78
79 pub fn with_client_ip(mut self, ip: impl Into<String>) -> Self {
81 self.client_ip = Some(ip.into());
82 self
83 }
84
85 pub fn get_header(&self, name: &str) -> Option<&str> {
87 self.headers.get(name).map(|s| s.as_str())
88 }
89
90 pub fn get_sql_context(&self, name: &str) -> Option<&str> {
92 self.sql_context.get(name).map(|s| s.as_str())
93 }
94}
95
96pub trait TenantIdentifier: Send + Sync {
98 fn identify(&self, request: &RequestContext) -> Option<TenantId>;
100
101 fn strategy_name(&self) -> &'static str;
103}
104
105#[derive(Debug, Clone)]
109pub struct HeaderTenantIdentifier {
110 header_name: String,
112
113 lowercase: bool,
115}
116
117impl HeaderTenantIdentifier {
118 pub fn new(header_name: impl Into<String>) -> Self {
120 Self {
121 header_name: header_name.into(),
122 lowercase: true,
123 }
124 }
125
126 pub fn default_header() -> Self {
128 Self::new("X-Tenant-Id")
129 }
130
131 pub fn case_sensitive(mut self) -> Self {
133 self.lowercase = false;
134 self
135 }
136}
137
138impl TenantIdentifier for HeaderTenantIdentifier {
139 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
140 request
141 .get_header(&self.header_name)
142 .filter(|v| !v.is_empty())
143 .map(|v| {
144 if self.lowercase {
145 TenantId::new(v.to_lowercase())
146 } else {
147 TenantId::new(v)
148 }
149 })
150 }
151
152 fn strategy_name(&self) -> &'static str {
153 "header"
154 }
155}
156
157#[derive(Debug, Clone)]
161pub struct UsernamePrefixIdentifier {
162 separator: char,
164
165 lowercase: bool,
167}
168
169impl UsernamePrefixIdentifier {
170 pub fn new(separator: char) -> Self {
172 Self {
173 separator,
174 lowercase: true,
175 }
176 }
177
178 pub fn with_dot() -> Self {
180 Self::new('.')
181 }
182
183 pub fn with_underscore() -> Self {
185 Self::new('_')
186 }
187
188 pub fn case_sensitive(mut self) -> Self {
190 self.lowercase = false;
191 self
192 }
193}
194
195impl TenantIdentifier for UsernamePrefixIdentifier {
196 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
197 request
198 .username
199 .as_ref()
200 .and_then(|username| username.split(self.separator).next())
201 .filter(|prefix| !prefix.is_empty())
202 .map(|prefix| {
203 if self.lowercase {
204 TenantId::new(prefix.to_lowercase())
205 } else {
206 TenantId::new(prefix)
207 }
208 })
209 }
210
211 fn strategy_name(&self) -> &'static str {
212 "username_prefix"
213 }
214}
215
216#[derive(Debug, Clone, Default)]
220pub struct DatabaseNameIdentifier {
221 prefix: Option<String>,
223
224 suffix: Option<String>,
226
227 lowercase: bool,
229}
230
231impl DatabaseNameIdentifier {
232 pub fn new() -> Self {
234 Self::default()
235 }
236
237 pub fn strip_prefix(mut self, prefix: impl Into<String>) -> Self {
239 self.prefix = Some(prefix.into());
240 self
241 }
242
243 pub fn strip_suffix(mut self, suffix: impl Into<String>) -> Self {
245 self.suffix = Some(suffix.into());
246 self
247 }
248
249 pub fn case_sensitive(mut self) -> Self {
251 self.lowercase = false;
252 self
253 }
254}
255
256impl TenantIdentifier for DatabaseNameIdentifier {
257 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
258 request.database.as_ref().map(|db| {
259 let mut name = db.as_str();
260
261 if let Some(prefix) = &self.prefix {
262 name = name.strip_prefix(prefix.as_str()).unwrap_or(name);
263 }
264
265 if let Some(suffix) = &self.suffix {
266 name = name.strip_suffix(suffix.as_str()).unwrap_or(name);
267 }
268
269 if self.lowercase {
270 TenantId::new(name.to_lowercase())
271 } else {
272 TenantId::new(name)
273 }
274 })
275 }
276
277 fn strategy_name(&self) -> &'static str {
278 "database_name"
279 }
280}
281
282#[derive(Debug, Clone)]
286pub struct SqlContextIdentifier {
287 variable_name: String,
289}
290
291impl SqlContextIdentifier {
292 pub fn new(variable_name: impl Into<String>) -> Self {
294 Self {
295 variable_name: variable_name.into(),
296 }
297 }
298
299 pub fn default_variable() -> Self {
301 Self::new("helios.tenant_id")
302 }
303}
304
305impl TenantIdentifier for SqlContextIdentifier {
306 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
307 request
308 .get_sql_context(&self.variable_name)
309 .filter(|v| !v.is_empty())
310 .map(|v| TenantId::new(v.to_lowercase()))
311 }
312
313 fn strategy_name(&self) -> &'static str {
314 "sql_context"
315 }
316}
317
318#[derive(Debug, Clone)]
322pub struct JwtClaimIdentifier {
323 claim_name: String,
325
326 issuer: Option<String>,
328
329 _verification_key: Option<String>,
332}
333
334impl JwtClaimIdentifier {
335 pub fn new(claim_name: impl Into<String>) -> Self {
337 Self {
338 claim_name: claim_name.into(),
339 issuer: None,
340 _verification_key: None,
341 }
342 }
343
344 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
346 self.issuer = Some(issuer.into());
347 self
348 }
349
350 fn extract_claim(&self, token: &str) -> Option<String> {
353 use base64::Engine;
354
355 let parts: Vec<&str> = token.split('.').collect();
356 if parts.len() != 3 {
357 return None;
358 }
359
360 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
362 .decode(parts[1])
363 .ok()?;
364
365 let payload_str = String::from_utf8(payload).ok()?;
366
367 self.extract_json_string(&payload_str, &self.claim_name)
370 }
371
372 fn extract_json_string(&self, json: &str, key: &str) -> Option<String> {
374 let pattern = format!("\"{}\"", key);
376 let pos = json.find(&pattern)?;
377 let after_key = &json[pos + pattern.len()..];
378
379 let after_colon = after_key.trim_start().strip_prefix(':')?;
381 let after_colon = after_colon.trim_start();
382
383 if after_colon.starts_with('"') {
385 let value_start = 1;
386 let value_end = after_colon[1..].find('"')? + 1;
387 Some(after_colon[value_start..value_end].to_string())
388 } else {
389 None
390 }
391 }
392}
393
394impl TenantIdentifier for JwtClaimIdentifier {
395 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
396 request
397 .auth_token
398 .as_ref()
399 .and_then(|token| self.extract_claim(token))
400 .filter(|claim| !claim.is_empty())
401 .map(|claim| TenantId::new(claim.to_lowercase()))
402 }
403
404 fn strategy_name(&self) -> &'static str {
405 "jwt_claim"
406 }
407}
408
409#[derive(Clone)]
411pub struct CompositeIdentifier {
412 identifiers: Vec<Arc<dyn TenantIdentifier>>,
414}
415
416impl CompositeIdentifier {
417 pub fn new() -> Self {
419 Self {
420 identifiers: Vec::new(),
421 }
422 }
423
424 pub fn add<I: TenantIdentifier + 'static>(mut self, identifier: I) -> Self {
426 self.identifiers.push(Arc::new(identifier));
427 self
428 }
429
430 pub fn add_arc(mut self, identifier: Arc<dyn TenantIdentifier>) -> Self {
432 self.identifiers.push(identifier);
433 self
434 }
435}
436
437impl Default for CompositeIdentifier {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443impl TenantIdentifier for CompositeIdentifier {
444 fn identify(&self, request: &RequestContext) -> Option<TenantId> {
445 for identifier in &self.identifiers {
446 if let Some(tenant) = identifier.identify(request) {
447 return Some(tenant);
448 }
449 }
450 None
451 }
452
453 fn strategy_name(&self) -> &'static str {
454 "composite"
455 }
456}
457
458pub fn create_identifier(method: &IdentificationMethod) -> Box<dyn TenantIdentifier> {
460 match method {
461 IdentificationMethod::Header { header_name } => {
462 Box::new(HeaderTenantIdentifier::new(header_name))
463 }
464 IdentificationMethod::UsernamePrefix { separator } => {
465 Box::new(UsernamePrefixIdentifier::new(*separator))
466 }
467 IdentificationMethod::JwtClaim { claim_name, issuer } => {
468 let mut identifier = JwtClaimIdentifier::new(claim_name);
469 if let Some(iss) = issuer {
470 identifier = identifier.with_issuer(iss);
471 }
472 Box::new(identifier)
473 }
474 IdentificationMethod::DatabaseName => {
475 Box::new(DatabaseNameIdentifier::new())
476 }
477 IdentificationMethod::SqlContext { variable_name } => {
478 Box::new(SqlContextIdentifier::new(variable_name))
479 }
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_header_identifier() {
489 let identifier = HeaderTenantIdentifier::new("X-Tenant-Id");
490
491 let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
492 assert_eq!(
493 identifier.identify(&ctx).map(|t| t.0),
494 Some("tenanta".to_string())
495 );
496
497 let ctx_missing = RequestContext::new();
498 assert!(identifier.identify(&ctx_missing).is_none());
499
500 let ctx_empty = RequestContext::new().with_header("X-Tenant-Id", "");
501 assert!(identifier.identify(&ctx_empty).is_none());
502 }
503
504 #[test]
505 fn test_header_identifier_case_sensitive() {
506 let identifier = HeaderTenantIdentifier::new("X-Tenant-Id").case_sensitive();
507
508 let ctx = RequestContext::new().with_header("X-Tenant-Id", "TenantA");
509 assert_eq!(
510 identifier.identify(&ctx).map(|t| t.0),
511 Some("TenantA".to_string())
512 );
513 }
514
515 #[test]
516 fn test_username_prefix_identifier() {
517 let identifier = UsernamePrefixIdentifier::with_dot();
518
519 let ctx = RequestContext::new().with_username("tenant_a.admin");
520 assert_eq!(
521 identifier.identify(&ctx).map(|t| t.0),
522 Some("tenant_a".to_string())
523 );
524
525 let ctx_no_prefix = RequestContext::new().with_username("admin");
526 assert_eq!(
527 identifier.identify(&ctx_no_prefix).map(|t| t.0),
528 Some("admin".to_string())
529 );
530
531 let ctx_missing = RequestContext::new();
532 assert!(identifier.identify(&ctx_missing).is_none());
533 }
534
535 #[test]
536 fn test_database_name_identifier() {
537 let identifier = DatabaseNameIdentifier::new()
538 .strip_prefix("tenant_")
539 .strip_suffix("_db");
540
541 let ctx = RequestContext::new().with_database("tenant_acme_db");
542 assert_eq!(
543 identifier.identify(&ctx).map(|t| t.0),
544 Some("acme".to_string())
545 );
546
547 let ctx_no_fix = RequestContext::new().with_database("mydb");
548 assert_eq!(
549 identifier.identify(&ctx_no_fix).map(|t| t.0),
550 Some("mydb".to_string())
551 );
552 }
553
554 #[test]
555 fn test_sql_context_identifier() {
556 let identifier = SqlContextIdentifier::default_variable();
557
558 let ctx = RequestContext::new().with_sql_context("helios.tenant_id", "tenant_x");
559 assert_eq!(
560 identifier.identify(&ctx).map(|t| t.0),
561 Some("tenant_x".to_string())
562 );
563
564 let ctx_missing = RequestContext::new();
565 assert!(identifier.identify(&ctx_missing).is_none());
566 }
567
568 #[test]
569 fn test_jwt_claim_identifier() {
570 let identifier = JwtClaimIdentifier::new("tenant_id");
571
572 use base64::Engine;
575 let payload = r#"{"tenant_id":"acme","sub":"user1"}"#;
576 let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload);
577 let token = format!("header.{}.signature", encoded_payload);
578
579 let ctx = RequestContext::new().with_auth_token(&token);
580 assert_eq!(
581 identifier.identify(&ctx).map(|t| t.0),
582 Some("acme".to_string())
583 );
584 }
585
586 #[test]
587 fn test_composite_identifier() {
588 let identifier = CompositeIdentifier::new()
589 .add(HeaderTenantIdentifier::new("X-Tenant-Id"))
590 .add(UsernamePrefixIdentifier::with_dot());
591
592 let ctx = RequestContext::new()
594 .with_header("X-Tenant-Id", "header_tenant")
595 .with_username("user_tenant.admin");
596 assert_eq!(
597 identifier.identify(&ctx).map(|t| t.0),
598 Some("header_tenant".to_string())
599 );
600
601 let ctx_no_header = RequestContext::new().with_username("user_tenant.admin");
603 assert_eq!(
604 identifier.identify(&ctx_no_header).map(|t| t.0),
605 Some("user_tenant".to_string())
606 );
607
608 let ctx_empty = RequestContext::new();
610 assert!(identifier.identify(&ctx_empty).is_none());
611 }
612
613 #[test]
614 fn test_create_identifier() {
615 let method = IdentificationMethod::header("X-Org-Id");
616 let identifier = create_identifier(&method);
617 assert_eq!(identifier.strategy_name(), "header");
618
619 let method = IdentificationMethod::username_prefix('_');
620 let identifier = create_identifier(&method);
621 assert_eq!(identifier.strategy_name(), "username_prefix");
622 }
623}