1use std::collections::HashMap;
13use std::sync::Arc;
14
15use super::config::{IsolationStrategy, TenantConfig, TenantId};
16
17#[derive(Debug, Clone)]
19pub struct RoutingDecision {
20 pub database: Option<String>,
22
23 pub search_path: Option<String>,
25
26 pub branch: Option<String>,
28
29 pub pre_query_commands: Vec<String>,
31
32 pub requires_transform: bool,
34}
35
36impl RoutingDecision {
37 pub fn default() -> Self {
39 Self {
40 database: None,
41 search_path: None,
42 branch: None,
43 pre_query_commands: Vec::new(),
44 requires_transform: false,
45 }
46 }
47
48 pub fn database(name: impl Into<String>) -> Self {
50 Self {
51 database: Some(name.into()),
52 ..Self::default()
53 }
54 }
55
56 pub fn schema(database: impl Into<String>, schema: impl Into<String>) -> Self {
58 let schema_name = schema.into();
59 Self {
60 database: Some(database.into()),
61 search_path: Some(schema_name.clone()),
62 pre_query_commands: vec![format!("SET search_path TO {}", schema_name)],
63 ..Self::default()
64 }
65 }
66
67 pub fn branch(name: impl Into<String>) -> Self {
69 Self {
70 branch: Some(name.into()),
71 ..Self::default()
72 }
73 }
74
75 pub fn row_level(database: impl Into<String>) -> Self {
77 Self {
78 database: Some(database.into()),
79 requires_transform: true,
80 ..Self::default()
81 }
82 }
83}
84
85pub trait IsolationHandler: Send + Sync {
87 fn get_routing(&self, tenant: &TenantId, config: &TenantConfig) -> RoutingDecision;
89
90 fn can_access_table(&self, tenant: &TenantId, table: &str, config: &TenantConfig) -> bool;
92
93 fn strategy_name(&self) -> &'static str;
95}
96
97#[derive(Debug, Clone, Default)]
101pub struct DatabaseIsolationHandler;
102
103impl DatabaseIsolationHandler {
104 pub fn new() -> Self {
106 Self
107 }
108}
109
110impl IsolationHandler for DatabaseIsolationHandler {
111 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
112 if let IsolationStrategy::Database { database_name } = &config.isolation {
113 RoutingDecision::database(database_name)
114 } else {
115 RoutingDecision::default()
116 }
117 }
118
119 fn can_access_table(&self, _tenant: &TenantId, _table: &str, config: &TenantConfig) -> bool {
120 config.permissions.is_table_allowed(_table)
122 }
123
124 fn strategy_name(&self) -> &'static str {
125 "database"
126 }
127}
128
129#[derive(Debug, Clone, Default)]
133pub struct SchemaIsolationHandler;
134
135impl SchemaIsolationHandler {
136 pub fn new() -> Self {
138 Self
139 }
140}
141
142impl IsolationHandler for SchemaIsolationHandler {
143 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
144 if let IsolationStrategy::Schema {
145 database_name,
146 schema_name,
147 } = &config.isolation
148 {
149 RoutingDecision::schema(database_name, schema_name)
150 } else {
151 RoutingDecision::default()
152 }
153 }
154
155 fn can_access_table(&self, _tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
156 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
158 if let Some((schema, _)) = table.split_once('.') {
160 return schema.eq_ignore_ascii_case(schema_name)
161 && config.permissions.is_table_allowed(table);
162 }
163 }
164 config.permissions.is_table_allowed(table)
166 }
167
168 fn strategy_name(&self) -> &'static str {
169 "schema"
170 }
171}
172
173#[derive(Debug, Clone, Default)]
177pub struct RowIsolationHandler {
178 tenant_tables: HashMap<String, String>,
180}
181
182impl RowIsolationHandler {
183 pub fn new() -> Self {
185 Self::default()
186 }
187
188 pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
190 self.tenant_tables.insert(table.into(), column.into());
191 self
192 }
193}
194
195impl IsolationHandler for RowIsolationHandler {
196 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
197 if let IsolationStrategy::Row { database_name, .. } = &config.isolation {
198 RoutingDecision::row_level(database_name)
199 } else {
200 RoutingDecision::default()
201 }
202 }
203
204 fn can_access_table(&self, _tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
205 config.permissions.is_table_allowed(table)
206 }
207
208 fn strategy_name(&self) -> &'static str {
209 "row"
210 }
211}
212
213#[derive(Debug, Clone, Default)]
217pub struct BranchIsolationHandler;
218
219impl BranchIsolationHandler {
220 pub fn new() -> Self {
222 Self
223 }
224}
225
226impl IsolationHandler for BranchIsolationHandler {
227 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
228 if let IsolationStrategy::Branch { branch_name } = &config.isolation {
229 RoutingDecision::branch(branch_name)
230 } else {
231 RoutingDecision::default()
232 }
233 }
234
235 fn can_access_table(&self, _tenant: &TenantId, _table: &str, config: &TenantConfig) -> bool {
236 config.permissions.is_table_allowed(_table)
238 }
239
240 fn strategy_name(&self) -> &'static str {
241 "branch"
242 }
243}
244
245pub fn create_handler(strategy: &IsolationStrategy) -> Arc<dyn IsolationHandler> {
247 match strategy {
248 IsolationStrategy::Database { .. } => Arc::new(DatabaseIsolationHandler::new()),
249 IsolationStrategy::Schema { .. } => Arc::new(SchemaIsolationHandler::new()),
250 IsolationStrategy::Row { .. } => Arc::new(RowIsolationHandler::new()),
251 IsolationStrategy::Branch { .. } => Arc::new(BranchIsolationHandler::new()),
252 }
253}
254
255pub struct IsolationRouter {
257 default_handler: Arc<dyn IsolationHandler>,
259
260 handlers: parking_lot::RwLock<HashMap<TenantId, Arc<dyn IsolationHandler>>>,
262}
263
264impl IsolationRouter {
265 pub fn new() -> Self {
267 Self {
268 default_handler: Arc::new(SchemaIsolationHandler::new()),
269 handlers: parking_lot::RwLock::new(HashMap::new()),
270 }
271 }
272
273 pub fn with_default_handler(mut self, handler: Arc<dyn IsolationHandler>) -> Self {
275 self.default_handler = handler;
276 self
277 }
278
279 pub fn register_tenant(&self, tenant: TenantId, handler: Arc<dyn IsolationHandler>) {
281 self.handlers.write().insert(tenant, handler);
282 }
283
284 pub fn register_from_config(&self, config: &TenantConfig) {
286 let handler = create_handler(&config.isolation);
287 self.handlers.write().insert(config.id.clone(), handler);
288 }
289
290 pub fn get_routing(&self, tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
292 let handlers = self.handlers.read();
293 handlers
294 .get(tenant)
295 .unwrap_or(&self.default_handler)
296 .get_routing(tenant, config)
297 }
298
299 pub fn can_access_table(&self, tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
301 let handlers = self.handlers.read();
302 handlers
303 .get(tenant)
304 .unwrap_or(&self.default_handler)
305 .can_access_table(tenant, table, config)
306 }
307}
308
309impl Default for IsolationRouter {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315pub struct TenantProvisioner {
317 database_template: String,
319
320 schema_template: String,
322
323 branch_template: String,
325}
326
327impl Default for TenantProvisioner {
328 fn default() -> Self {
329 Self {
330 database_template: "tenant_{id}_db".to_string(),
331 schema_template: "tenant_{id}".to_string(),
332 branch_template: "tenant_{id}".to_string(),
333 }
334 }
335}
336
337impl TenantProvisioner {
338 pub fn new() -> Self {
340 Self::default()
341 }
342
343 pub fn database_template(mut self, template: impl Into<String>) -> Self {
345 self.database_template = template.into();
346 self
347 }
348
349 pub fn schema_template(mut self, template: impl Into<String>) -> Self {
351 self.schema_template = template.into();
352 self
353 }
354
355 pub fn branch_template(mut self, template: impl Into<String>) -> Self {
357 self.branch_template = template.into();
358 self
359 }
360
361 pub fn generate_database_name(&self, tenant: &TenantId) -> String {
363 self.database_template.replace("{id}", &tenant.0)
364 }
365
366 pub fn generate_schema_name(&self, tenant: &TenantId) -> String {
368 self.schema_template.replace("{id}", &tenant.0)
369 }
370
371 pub fn generate_branch_name(&self, tenant: &TenantId) -> String {
373 self.branch_template.replace("{id}", &tenant.0)
374 }
375
376 pub fn generate_isolation(
378 &self,
379 tenant: &TenantId,
380 strategy_type: &str,
381 shared_database: Option<&str>,
382 ) -> IsolationStrategy {
383 match strategy_type {
384 "database" => IsolationStrategy::database(self.generate_database_name(tenant)),
385 "schema" => IsolationStrategy::schema(
386 shared_database.unwrap_or("shared"),
387 self.generate_schema_name(tenant),
388 ),
389 "row" => IsolationStrategy::row(
390 shared_database.unwrap_or("shared"),
391 "tenant_id",
392 ),
393 "branch" => IsolationStrategy::branch(self.generate_branch_name(tenant)),
394 _ => IsolationStrategy::schema("public", self.generate_schema_name(tenant)),
395 }
396 }
397
398 pub fn sql_create_database(&self, tenant: &TenantId) -> Vec<String> {
400 let db_name = self.generate_database_name(tenant);
401 vec![
402 format!("CREATE DATABASE {} WITH OWNER = postgres", db_name),
403 format!(
404 "GRANT ALL PRIVILEGES ON DATABASE {} TO postgres",
405 db_name
406 ),
407 ]
408 }
409
410 pub fn sql_create_schema(&self, tenant: &TenantId, database: &str) -> Vec<String> {
412 let schema_name = self.generate_schema_name(tenant);
413 vec![
414 format!("-- Connect to database: {}", database),
415 format!("CREATE SCHEMA IF NOT EXISTS {}", schema_name),
416 format!("GRANT ALL ON SCHEMA {} TO postgres", schema_name),
417 ]
418 }
419
420 pub fn sql_create_rls_policy(
422 &self,
423 tenant: &TenantId,
424 table: &str,
425 tenant_column: &str,
426 ) -> Vec<String> {
427 let policy_name = format!("tenant_{}_policy", tenant.0);
428 vec![
429 format!("ALTER TABLE {} ENABLE ROW LEVEL SECURITY", table),
430 format!(
431 "CREATE POLICY {} ON {} FOR ALL USING ({} = '{}')",
432 policy_name, table, tenant_column, tenant.0
433 ),
434 ]
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::multi_tenancy::config::{TenantConfig, TenantPermissions};
442
443 fn create_test_config(
444 id: &str,
445 isolation: IsolationStrategy,
446 ) -> TenantConfig {
447 TenantConfig::builder()
448 .id(id)
449 .name(format!("Test {}", id))
450 .isolation(isolation)
451 .build()
452 }
453
454 #[test]
455 fn test_routing_decision() {
456 let db = RoutingDecision::database("mydb");
457 assert_eq!(db.database, Some("mydb".to_string()));
458 assert!(!db.requires_transform);
459
460 let schema = RoutingDecision::schema("mydb", "myschema");
461 assert_eq!(schema.database, Some("mydb".to_string()));
462 assert_eq!(schema.search_path, Some("myschema".to_string()));
463 assert!(!schema.pre_query_commands.is_empty());
464
465 let branch = RoutingDecision::branch("mybranch");
466 assert_eq!(branch.branch, Some("mybranch".to_string()));
467
468 let row = RoutingDecision::row_level("mydb");
469 assert!(row.requires_transform);
470 }
471
472 #[test]
473 fn test_database_isolation_handler() {
474 let handler = DatabaseIsolationHandler::new();
475 let config = create_test_config(
476 "tenant_a",
477 IsolationStrategy::database("tenant_a_db"),
478 );
479
480 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
481 assert_eq!(routing.database, Some("tenant_a_db".to_string()));
482 assert_eq!(handler.strategy_name(), "database");
483 }
484
485 #[test]
486 fn test_schema_isolation_handler() {
487 let handler = SchemaIsolationHandler::new();
488 let config = create_test_config(
489 "tenant_a",
490 IsolationStrategy::schema("shared_db", "tenant_a"),
491 );
492
493 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
494 assert_eq!(routing.database, Some("shared_db".to_string()));
495 assert_eq!(routing.search_path, Some("tenant_a".to_string()));
496 assert_eq!(handler.strategy_name(), "schema");
497
498 let tenant = TenantId::new("tenant_a");
500 assert!(handler.can_access_table(&tenant, "users", &config));
501 assert!(handler.can_access_table(&tenant, "tenant_a.users", &config));
502 assert!(!handler.can_access_table(&tenant, "tenant_b.users", &config));
503 }
504
505 #[test]
506 fn test_row_isolation_handler() {
507 let handler = RowIsolationHandler::new()
508 .register_table("users", "tenant_id")
509 .register_table("orders", "tenant_id");
510
511 let config = create_test_config(
512 "tenant_a",
513 IsolationStrategy::row("shared_db", "tenant_id"),
514 );
515
516 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
517 assert_eq!(routing.database, Some("shared_db".to_string()));
518 assert!(routing.requires_transform);
519 assert_eq!(handler.strategy_name(), "row");
520 }
521
522 #[test]
523 fn test_branch_isolation_handler() {
524 let handler = BranchIsolationHandler::new();
525 let config = create_test_config(
526 "tenant_a",
527 IsolationStrategy::branch("tenant_a_branch"),
528 );
529
530 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
531 assert_eq!(routing.branch, Some("tenant_a_branch".to_string()));
532 assert_eq!(handler.strategy_name(), "branch");
533 }
534
535 #[test]
536 fn test_isolation_router() {
537 let router = IsolationRouter::new();
538
539 let config_a = create_test_config(
540 "tenant_a",
541 IsolationStrategy::database("tenant_a_db"),
542 );
543 let config_b = create_test_config(
544 "tenant_b",
545 IsolationStrategy::schema("shared", "tenant_b"),
546 );
547
548 router.register_from_config(&config_a);
549 router.register_from_config(&config_b);
550
551 let routing_a = router.get_routing(&TenantId::new("tenant_a"), &config_a);
552 assert_eq!(routing_a.database, Some("tenant_a_db".to_string()));
553
554 let routing_b = router.get_routing(&TenantId::new("tenant_b"), &config_b);
555 assert_eq!(routing_b.database, Some("shared".to_string()));
556 assert_eq!(routing_b.search_path, Some("tenant_b".to_string()));
557 }
558
559 #[test]
560 fn test_tenant_provisioner() {
561 let provisioner = TenantProvisioner::new();
562 let tenant = TenantId::new("acme");
563
564 assert_eq!(provisioner.generate_database_name(&tenant), "tenant_acme_db");
565 assert_eq!(provisioner.generate_schema_name(&tenant), "tenant_acme");
566 assert_eq!(provisioner.generate_branch_name(&tenant), "tenant_acme");
567
568 let isolation = provisioner.generate_isolation(&tenant, "schema", Some("shared_db"));
569 assert!(matches!(
570 isolation,
571 IsolationStrategy::Schema { database_name, schema_name }
572 if database_name == "shared_db" && schema_name == "tenant_acme"
573 ));
574 }
575
576 #[test]
577 fn test_provisioner_sql_generation() {
578 let provisioner = TenantProvisioner::new();
579 let tenant = TenantId::new("acme");
580
581 let db_sql = provisioner.sql_create_database(&tenant);
582 assert!(!db_sql.is_empty());
583 assert!(db_sql[0].contains("CREATE DATABASE"));
584
585 let schema_sql = provisioner.sql_create_schema(&tenant, "shared");
586 assert!(schema_sql.iter().any(|s| s.contains("CREATE SCHEMA")));
587
588 let rls_sql = provisioner.sql_create_rls_policy(&tenant, "users", "tenant_id");
589 assert!(rls_sql.iter().any(|s| s.contains("ROW LEVEL SECURITY")));
590 }
591}