1use std::collections::HashMap;
13use std::sync::Arc;
14
15use super::config::{IsolationStrategy, TenantConfig, TenantId};
16
17#[derive(Debug, Clone)]
19pub struct RoutingDecision {
20 pub database: Option<String>,
22
23 pub search_path: Option<String>,
25
26 pub branch: Option<String>,
28
29 pub pre_query_commands: Vec<String>,
31
32 pub requires_transform: bool,
34}
35
36impl RoutingDecision {
37 #[allow(clippy::should_implement_trait)]
39 pub fn default() -> Self {
40 Self {
41 database: None,
42 search_path: None,
43 branch: None,
44 pre_query_commands: Vec::new(),
45 requires_transform: false,
46 }
47 }
48
49 pub fn database(name: impl Into<String>) -> Self {
51 Self {
52 database: Some(name.into()),
53 ..Self::default()
54 }
55 }
56
57 pub fn schema(database: impl Into<String>, schema: impl Into<String>) -> Self {
59 let schema_name = schema.into();
60 Self {
61 database: Some(database.into()),
62 search_path: Some(schema_name.clone()),
63 pre_query_commands: vec![format!("SET search_path TO {}", schema_name)],
64 ..Self::default()
65 }
66 }
67
68 pub fn branch(name: impl Into<String>) -> Self {
70 Self {
71 branch: Some(name.into()),
72 ..Self::default()
73 }
74 }
75
76 pub fn row_level(database: impl Into<String>) -> Self {
78 Self {
79 database: Some(database.into()),
80 requires_transform: true,
81 ..Self::default()
82 }
83 }
84}
85
86pub trait IsolationHandler: Send + Sync {
88 fn get_routing(&self, tenant: &TenantId, config: &TenantConfig) -> RoutingDecision;
90
91 fn can_access_table(&self, tenant: &TenantId, table: &str, config: &TenantConfig) -> bool;
93
94 fn strategy_name(&self) -> &'static str;
96}
97
98#[derive(Debug, Clone, Default)]
102pub struct DatabaseIsolationHandler;
103
104impl DatabaseIsolationHandler {
105 pub fn new() -> Self {
107 Self
108 }
109}
110
111impl IsolationHandler for DatabaseIsolationHandler {
112 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
113 if let IsolationStrategy::Database { database_name } = &config.isolation {
114 RoutingDecision::database(database_name)
115 } else {
116 RoutingDecision::default()
117 }
118 }
119
120 fn can_access_table(&self, _tenant: &TenantId, _table: &str, config: &TenantConfig) -> bool {
121 config.permissions.is_table_allowed(_table)
123 }
124
125 fn strategy_name(&self) -> &'static str {
126 "database"
127 }
128}
129
130#[derive(Debug, Clone, Default)]
134pub struct SchemaIsolationHandler;
135
136impl SchemaIsolationHandler {
137 pub fn new() -> Self {
139 Self
140 }
141}
142
143impl IsolationHandler for SchemaIsolationHandler {
144 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
145 if let IsolationStrategy::Schema {
146 database_name,
147 schema_name,
148 } = &config.isolation
149 {
150 RoutingDecision::schema(database_name, schema_name)
151 } else {
152 RoutingDecision::default()
153 }
154 }
155
156 fn can_access_table(&self, _tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
157 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
159 if let Some((schema, _)) = table.split_once('.') {
161 return schema.eq_ignore_ascii_case(schema_name)
162 && config.permissions.is_table_allowed(table);
163 }
164 }
165 config.permissions.is_table_allowed(table)
167 }
168
169 fn strategy_name(&self) -> &'static str {
170 "schema"
171 }
172}
173
174#[derive(Debug, Clone, Default)]
178pub struct RowIsolationHandler {
179 tenant_tables: HashMap<String, String>,
181}
182
183impl RowIsolationHandler {
184 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
191 self.tenant_tables.insert(table.into(), column.into());
192 self
193 }
194}
195
196impl IsolationHandler for RowIsolationHandler {
197 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
198 if let IsolationStrategy::Row { database_name, .. } = &config.isolation {
199 RoutingDecision::row_level(database_name)
200 } else {
201 RoutingDecision::default()
202 }
203 }
204
205 fn can_access_table(&self, _tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
206 config.permissions.is_table_allowed(table)
207 }
208
209 fn strategy_name(&self) -> &'static str {
210 "row"
211 }
212}
213
214#[derive(Debug, Clone, Default)]
218pub struct BranchIsolationHandler;
219
220impl BranchIsolationHandler {
221 pub fn new() -> Self {
223 Self
224 }
225}
226
227impl IsolationHandler for BranchIsolationHandler {
228 fn get_routing(&self, _tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
229 if let IsolationStrategy::Branch { branch_name } = &config.isolation {
230 RoutingDecision::branch(branch_name)
231 } else {
232 RoutingDecision::default()
233 }
234 }
235
236 fn can_access_table(&self, _tenant: &TenantId, _table: &str, config: &TenantConfig) -> bool {
237 config.permissions.is_table_allowed(_table)
239 }
240
241 fn strategy_name(&self) -> &'static str {
242 "branch"
243 }
244}
245
246pub fn create_handler(strategy: &IsolationStrategy) -> Arc<dyn IsolationHandler> {
248 match strategy {
249 IsolationStrategy::Database { .. } => Arc::new(DatabaseIsolationHandler::new()),
250 IsolationStrategy::Schema { .. } => Arc::new(SchemaIsolationHandler::new()),
251 IsolationStrategy::Row { .. } => Arc::new(RowIsolationHandler::new()),
252 IsolationStrategy::Branch { .. } => Arc::new(BranchIsolationHandler::new()),
253 }
254}
255
256pub struct IsolationRouter {
258 default_handler: Arc<dyn IsolationHandler>,
260
261 handlers: parking_lot::RwLock<HashMap<TenantId, Arc<dyn IsolationHandler>>>,
263}
264
265impl IsolationRouter {
266 pub fn new() -> Self {
268 Self {
269 default_handler: Arc::new(SchemaIsolationHandler::new()),
270 handlers: parking_lot::RwLock::new(HashMap::new()),
271 }
272 }
273
274 pub fn with_default_handler(mut self, handler: Arc<dyn IsolationHandler>) -> Self {
276 self.default_handler = handler;
277 self
278 }
279
280 pub fn register_tenant(&self, tenant: TenantId, handler: Arc<dyn IsolationHandler>) {
282 self.handlers.write().insert(tenant, handler);
283 }
284
285 pub fn register_from_config(&self, config: &TenantConfig) {
287 let handler = create_handler(&config.isolation);
288 self.handlers.write().insert(config.id.clone(), handler);
289 }
290
291 pub fn get_routing(&self, tenant: &TenantId, config: &TenantConfig) -> RoutingDecision {
293 let handlers = self.handlers.read();
294 handlers
295 .get(tenant)
296 .unwrap_or(&self.default_handler)
297 .get_routing(tenant, config)
298 }
299
300 pub fn can_access_table(&self, tenant: &TenantId, table: &str, config: &TenantConfig) -> bool {
302 let handlers = self.handlers.read();
303 handlers
304 .get(tenant)
305 .unwrap_or(&self.default_handler)
306 .can_access_table(tenant, table, config)
307 }
308}
309
310impl Default for IsolationRouter {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316pub struct TenantProvisioner {
318 database_template: String,
320
321 schema_template: String,
323
324 branch_template: String,
326}
327
328impl Default for TenantProvisioner {
329 fn default() -> Self {
330 Self {
331 database_template: "tenant_{id}_db".to_string(),
332 schema_template: "tenant_{id}".to_string(),
333 branch_template: "tenant_{id}".to_string(),
334 }
335 }
336}
337
338impl TenantProvisioner {
339 pub fn new() -> Self {
341 Self::default()
342 }
343
344 pub fn database_template(mut self, template: impl Into<String>) -> Self {
346 self.database_template = template.into();
347 self
348 }
349
350 pub fn schema_template(mut self, template: impl Into<String>) -> Self {
352 self.schema_template = template.into();
353 self
354 }
355
356 pub fn branch_template(mut self, template: impl Into<String>) -> Self {
358 self.branch_template = template.into();
359 self
360 }
361
362 pub fn generate_database_name(&self, tenant: &TenantId) -> String {
364 self.database_template.replace("{id}", &tenant.0)
365 }
366
367 pub fn generate_schema_name(&self, tenant: &TenantId) -> String {
369 self.schema_template.replace("{id}", &tenant.0)
370 }
371
372 pub fn generate_branch_name(&self, tenant: &TenantId) -> String {
374 self.branch_template.replace("{id}", &tenant.0)
375 }
376
377 pub fn generate_isolation(
379 &self,
380 tenant: &TenantId,
381 strategy_type: &str,
382 shared_database: Option<&str>,
383 ) -> IsolationStrategy {
384 match strategy_type {
385 "database" => IsolationStrategy::database(self.generate_database_name(tenant)),
386 "schema" => IsolationStrategy::schema(
387 shared_database.unwrap_or("shared"),
388 self.generate_schema_name(tenant),
389 ),
390 "row" => IsolationStrategy::row(shared_database.unwrap_or("shared"), "tenant_id"),
391 "branch" => IsolationStrategy::branch(self.generate_branch_name(tenant)),
392 _ => IsolationStrategy::schema("public", self.generate_schema_name(tenant)),
393 }
394 }
395
396 pub fn sql_create_database(&self, tenant: &TenantId) -> Vec<String> {
398 let db_name = self.generate_database_name(tenant);
399 vec![
400 format!("CREATE DATABASE {} WITH OWNER = postgres", db_name),
401 format!("GRANT ALL PRIVILEGES ON DATABASE {} TO postgres", db_name),
402 ]
403 }
404
405 pub fn sql_create_schema(&self, tenant: &TenantId, database: &str) -> Vec<String> {
407 let schema_name = self.generate_schema_name(tenant);
408 vec![
409 format!("-- Connect to database: {}", database),
410 format!("CREATE SCHEMA IF NOT EXISTS {}", schema_name),
411 format!("GRANT ALL ON SCHEMA {} TO postgres", schema_name),
412 ]
413 }
414
415 pub fn sql_create_rls_policy(
417 &self,
418 tenant: &TenantId,
419 table: &str,
420 tenant_column: &str,
421 ) -> Vec<String> {
422 let policy_name = format!("tenant_{}_policy", tenant.0);
423 vec![
424 format!("ALTER TABLE {} ENABLE ROW LEVEL SECURITY", table),
425 format!(
426 "CREATE POLICY {} ON {} FOR ALL USING ({} = '{}')",
427 policy_name, table, tenant_column, tenant.0
428 ),
429 ]
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::multi_tenancy::config::{TenantConfig, TenantPermissions};
437
438 fn create_test_config(id: &str, isolation: IsolationStrategy) -> TenantConfig {
439 TenantConfig::builder()
440 .id(id)
441 .name(format!("Test {}", id))
442 .isolation(isolation)
443 .build()
444 }
445
446 #[test]
447 fn test_routing_decision() {
448 let db = RoutingDecision::database("mydb");
449 assert_eq!(db.database, Some("mydb".to_string()));
450 assert!(!db.requires_transform);
451
452 let schema = RoutingDecision::schema("mydb", "myschema");
453 assert_eq!(schema.database, Some("mydb".to_string()));
454 assert_eq!(schema.search_path, Some("myschema".to_string()));
455 assert!(!schema.pre_query_commands.is_empty());
456
457 let branch = RoutingDecision::branch("mybranch");
458 assert_eq!(branch.branch, Some("mybranch".to_string()));
459
460 let row = RoutingDecision::row_level("mydb");
461 assert!(row.requires_transform);
462 }
463
464 #[test]
465 fn test_database_isolation_handler() {
466 let handler = DatabaseIsolationHandler::new();
467 let config = create_test_config("tenant_a", IsolationStrategy::database("tenant_a_db"));
468
469 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
470 assert_eq!(routing.database, Some("tenant_a_db".to_string()));
471 assert_eq!(handler.strategy_name(), "database");
472 }
473
474 #[test]
475 fn test_schema_isolation_handler() {
476 let handler = SchemaIsolationHandler::new();
477 let config = create_test_config(
478 "tenant_a",
479 IsolationStrategy::schema("shared_db", "tenant_a"),
480 );
481
482 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
483 assert_eq!(routing.database, Some("shared_db".to_string()));
484 assert_eq!(routing.search_path, Some("tenant_a".to_string()));
485 assert_eq!(handler.strategy_name(), "schema");
486
487 let tenant = TenantId::new("tenant_a");
489 assert!(handler.can_access_table(&tenant, "users", &config));
490 assert!(handler.can_access_table(&tenant, "tenant_a.users", &config));
491 assert!(!handler.can_access_table(&tenant, "tenant_b.users", &config));
492 }
493
494 #[test]
495 fn test_row_isolation_handler() {
496 let handler = RowIsolationHandler::new()
497 .register_table("users", "tenant_id")
498 .register_table("orders", "tenant_id");
499
500 let config =
501 create_test_config("tenant_a", IsolationStrategy::row("shared_db", "tenant_id"));
502
503 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
504 assert_eq!(routing.database, Some("shared_db".to_string()));
505 assert!(routing.requires_transform);
506 assert_eq!(handler.strategy_name(), "row");
507 }
508
509 #[test]
510 fn test_branch_isolation_handler() {
511 let handler = BranchIsolationHandler::new();
512 let config = create_test_config("tenant_a", IsolationStrategy::branch("tenant_a_branch"));
513
514 let routing = handler.get_routing(&TenantId::new("tenant_a"), &config);
515 assert_eq!(routing.branch, Some("tenant_a_branch".to_string()));
516 assert_eq!(handler.strategy_name(), "branch");
517 }
518
519 #[test]
520 fn test_isolation_router() {
521 let router = IsolationRouter::new();
522
523 let config_a = create_test_config("tenant_a", IsolationStrategy::database("tenant_a_db"));
524 let config_b =
525 create_test_config("tenant_b", IsolationStrategy::schema("shared", "tenant_b"));
526
527 router.register_from_config(&config_a);
528 router.register_from_config(&config_b);
529
530 let routing_a = router.get_routing(&TenantId::new("tenant_a"), &config_a);
531 assert_eq!(routing_a.database, Some("tenant_a_db".to_string()));
532
533 let routing_b = router.get_routing(&TenantId::new("tenant_b"), &config_b);
534 assert_eq!(routing_b.database, Some("shared".to_string()));
535 assert_eq!(routing_b.search_path, Some("tenant_b".to_string()));
536 }
537
538 #[test]
539 fn test_tenant_provisioner() {
540 let provisioner = TenantProvisioner::new();
541 let tenant = TenantId::new("acme");
542
543 assert_eq!(
544 provisioner.generate_database_name(&tenant),
545 "tenant_acme_db"
546 );
547 assert_eq!(provisioner.generate_schema_name(&tenant), "tenant_acme");
548 assert_eq!(provisioner.generate_branch_name(&tenant), "tenant_acme");
549
550 let isolation = provisioner.generate_isolation(&tenant, "schema", Some("shared_db"));
551 assert!(matches!(
552 isolation,
553 IsolationStrategy::Schema { database_name, schema_name }
554 if database_name == "shared_db" && schema_name == "tenant_acme"
555 ));
556 }
557
558 #[test]
559 fn test_provisioner_sql_generation() {
560 let provisioner = TenantProvisioner::new();
561 let tenant = TenantId::new("acme");
562
563 let db_sql = provisioner.sql_create_database(&tenant);
564 assert!(!db_sql.is_empty());
565 assert!(db_sql[0].contains("CREATE DATABASE"));
566
567 let schema_sql = provisioner.sql_create_schema(&tenant, "shared");
568 assert!(schema_sql.iter().any(|s| s.contains("CREATE SCHEMA")));
569
570 let rls_sql = provisioner.sql_create_rls_policy(&tenant, "users", "tenant_id");
571 assert!(rls_sql.iter().any(|s| s.contains("ROW LEVEL SECURITY")));
572 }
573}