1use std::collections::HashSet;
31use std::fmt::Write;
32
33#[derive(Debug, Clone)]
35pub struct RlsConfig {
36 pub tenant_column: String,
38 pub session_variable: String,
40 pub application_role: Option<String>,
42 pub tables: HashSet<String>,
44 pub excluded_tables: HashSet<String>,
46 pub allow_bypass: bool,
48 pub policy_prefix: String,
50}
51
52impl Default for RlsConfig {
53 fn default() -> Self {
54 Self {
55 tenant_column: "tenant_id".to_string(),
56 session_variable: "app.current_tenant".to_string(),
57 application_role: None,
58 tables: HashSet::new(),
59 excluded_tables: HashSet::new(),
60 allow_bypass: true,
61 policy_prefix: "tenant_isolation".to_string(),
62 }
63 }
64}
65
66impl RlsConfig {
67 pub fn new(tenant_column: impl Into<String>) -> Self {
69 Self {
70 tenant_column: tenant_column.into(),
71 ..Default::default()
72 }
73 }
74
75 pub fn with_session_variable(mut self, var: impl Into<String>) -> Self {
77 self.session_variable = var.into();
78 self
79 }
80
81 pub fn with_role(mut self, role: impl Into<String>) -> Self {
83 self.application_role = Some(role.into());
84 self
85 }
86
87 pub fn add_table(mut self, table: impl Into<String>) -> Self {
89 self.tables.insert(table.into());
90 self
91 }
92
93 pub fn add_tables<I, S>(mut self, tables: I) -> Self
95 where
96 I: IntoIterator<Item = S>,
97 S: Into<String>,
98 {
99 self.tables.extend(tables.into_iter().map(Into::into));
100 self
101 }
102
103 pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105 self.excluded_tables.insert(table.into());
106 self
107 }
108
109 pub fn without_bypass(mut self) -> Self {
111 self.allow_bypass = false;
112 self
113 }
114
115 pub fn with_policy_prefix(mut self, prefix: impl Into<String>) -> Self {
117 self.policy_prefix = prefix.into();
118 self
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct RlsManager {
125 config: RlsConfig,
126}
127
128impl RlsManager {
129 pub fn new(config: RlsConfig) -> Self {
131 Self { config }
132 }
133
134 pub fn simple(tenant_column: impl Into<String>, session_var: impl Into<String>) -> Self {
136 Self::new(
137 RlsConfig::new(tenant_column)
138 .with_session_variable(session_var),
139 )
140 }
141
142 pub fn config(&self) -> &RlsConfig {
144 &self.config
145 }
146
147 pub fn enable_rls_sql(&self, table: &str) -> String {
149 format!("ALTER TABLE {} ENABLE ROW LEVEL SECURITY;", quote_ident(table))
150 }
151
152 pub fn force_rls_sql(&self, table: &str) -> String {
154 format!(
155 "ALTER TABLE {} FORCE ROW LEVEL SECURITY;",
156 quote_ident(table)
157 )
158 }
159
160 pub fn create_policy_sql(&self, table: &str) -> String {
162 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
163 let role = self
164 .config
165 .application_role
166 .as_deref()
167 .unwrap_or("PUBLIC");
168
169 format!(
171 r#"CREATE POLICY {} ON {}
172 AS PERMISSIVE
173 FOR ALL
174 TO {}
175 USING ({} = current_setting('{}')::text)
176 WITH CHECK ({} = current_setting('{}')::text);"#,
177 quote_ident(&policy_name),
178 quote_ident(table),
179 role,
180 quote_ident(&self.config.tenant_column),
181 self.config.session_variable,
182 quote_ident(&self.config.tenant_column),
183 self.config.session_variable,
184 )
185 }
186
187 pub fn create_uuid_policy_sql(&self, table: &str) -> String {
189 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
190 let role = self
191 .config
192 .application_role
193 .as_deref()
194 .unwrap_or("PUBLIC");
195
196 format!(
197 r#"CREATE POLICY {} ON {}
198 AS PERMISSIVE
199 FOR ALL
200 TO {}
201 USING ({} = current_setting('{}')::uuid)
202 WITH CHECK ({} = current_setting('{}')::uuid);"#,
203 quote_ident(&policy_name),
204 quote_ident(table),
205 role,
206 quote_ident(&self.config.tenant_column),
207 self.config.session_variable,
208 quote_ident(&self.config.tenant_column),
209 self.config.session_variable,
210 )
211 }
212
213 pub fn drop_policy_sql(&self, table: &str) -> String {
215 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
216 format!(
217 "DROP POLICY IF EXISTS {} ON {};",
218 quote_ident(&policy_name),
219 quote_ident(table)
220 )
221 }
222
223 pub fn set_tenant_sql(&self, tenant_id: &str) -> String {
225 format!(
226 "SET {} = '{}';",
227 self.config.session_variable,
228 tenant_id.replace('\'', "''")
229 )
230 }
231
232 pub fn set_tenant_local_sql(&self, tenant_id: &str) -> String {
234 format!(
235 "SET LOCAL {} = '{}';",
236 self.config.session_variable,
237 tenant_id.replace('\'', "''")
238 )
239 }
240
241 pub fn reset_tenant_sql(&self) -> String {
243 format!("RESET {};", self.config.session_variable)
244 }
245
246 pub fn current_tenant_sql(&self) -> String {
248 format!(
249 "SELECT current_setting('{}', true);",
250 self.config.session_variable
251 )
252 }
253
254 pub fn setup_sql(&self) -> String {
256 let mut sql = String::with_capacity(4096);
257
258 writeln!(sql, "-- Prax Multi-Tenant RLS Setup").unwrap();
260 writeln!(sql, "-- Generated for column: {}", self.config.tenant_column).unwrap();
261 writeln!(sql, "-- Session variable: {}", self.config.session_variable).unwrap();
262 writeln!(sql).unwrap();
263
264 if self.config.allow_bypass {
266 if let Some(ref role) = self.config.application_role {
267 writeln!(sql, "-- Admin role with BYPASSRLS").unwrap();
268 writeln!(sql, "DO $$").unwrap();
269 writeln!(sql, "BEGIN").unwrap();
270 writeln!(
271 sql,
272 " CREATE ROLE {}_admin WITH BYPASSRLS;",
273 role
274 )
275 .unwrap();
276 writeln!(sql, "EXCEPTION WHEN duplicate_object THEN NULL;").unwrap();
277 writeln!(sql, "END $$;").unwrap();
278 writeln!(sql).unwrap();
279 }
280 }
281
282 for table in &self.config.tables {
284 if self.config.excluded_tables.contains(table) {
285 continue;
286 }
287
288 writeln!(sql, "-- Table: {}", table).unwrap();
289 writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
290 writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
291 writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
292 writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
293 writeln!(sql).unwrap();
294 }
295
296 sql
297 }
298
299 pub fn migration_up_sql(&self, table: &str) -> String {
301 let mut sql = String::with_capacity(512);
302
303 writeln!(sql, "-- Enable RLS on {}", table).unwrap();
304 writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
305 writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
306 writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
307
308 sql
309 }
310
311 pub fn migration_down_sql(&self, table: &str) -> String {
313 let mut sql = String::with_capacity(256);
314
315 writeln!(sql, "-- Disable RLS on {}", table).unwrap();
316 writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
317 writeln!(
318 sql,
319 "ALTER TABLE {} DISABLE ROW LEVEL SECURITY;",
320 quote_ident(table)
321 )
322 .unwrap();
323
324 sql
325 }
326}
327
328#[derive(Default)]
330pub struct RlsManagerBuilder {
331 config: RlsConfig,
332}
333
334impl RlsManagerBuilder {
335 pub fn new() -> Self {
337 Self::default()
338 }
339
340 pub fn tenant_column(mut self, column: impl Into<String>) -> Self {
342 self.config.tenant_column = column.into();
343 self
344 }
345
346 pub fn session_variable(mut self, var: impl Into<String>) -> Self {
348 self.config.session_variable = var.into();
349 self
350 }
351
352 pub fn application_role(mut self, role: impl Into<String>) -> Self {
354 self.config.application_role = Some(role.into());
355 self
356 }
357
358 pub fn tables<I, S>(mut self, tables: I) -> Self
360 where
361 I: IntoIterator<Item = S>,
362 S: Into<String>,
363 {
364 self.config.tables.extend(tables.into_iter().map(Into::into));
365 self
366 }
367
368 pub fn exclude<I, S>(mut self, tables: I) -> Self
370 where
371 I: IntoIterator<Item = S>,
372 S: Into<String>,
373 {
374 self.config
375 .excluded_tables
376 .extend(tables.into_iter().map(Into::into));
377 self
378 }
379
380 pub fn policy_prefix(mut self, prefix: impl Into<String>) -> Self {
382 self.config.policy_prefix = prefix.into();
383 self
384 }
385
386 pub fn build(self) -> RlsManager {
388 RlsManager::new(self.config)
389 }
390}
391
392#[derive(Debug, Clone)]
394pub struct RlsPolicy {
395 pub name: String,
397 pub table: String,
399 pub command: PolicyCommand,
401 pub role: Option<String>,
403 pub using_expr: Option<String>,
405 pub with_check_expr: Option<String>,
407 pub permissive: bool,
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum PolicyCommand {
414 All,
415 Select,
416 Insert,
417 Update,
418 Delete,
419}
420
421impl PolicyCommand {
422 fn as_str(&self) -> &'static str {
423 match self {
424 Self::All => "ALL",
425 Self::Select => "SELECT",
426 Self::Insert => "INSERT",
427 Self::Update => "UPDATE",
428 Self::Delete => "DELETE",
429 }
430 }
431}
432
433impl RlsPolicy {
434 pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
436 Self {
437 name: name.into(),
438 table: table.into(),
439 command: PolicyCommand::All,
440 role: None,
441 using_expr: None,
442 with_check_expr: None,
443 permissive: true,
444 }
445 }
446
447 pub fn command(mut self, cmd: PolicyCommand) -> Self {
449 self.command = cmd;
450 self
451 }
452
453 pub fn role(mut self, role: impl Into<String>) -> Self {
455 self.role = Some(role.into());
456 self
457 }
458
459 pub fn using(mut self, expr: impl Into<String>) -> Self {
461 self.using_expr = Some(expr.into());
462 self
463 }
464
465 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
467 self.with_check_expr = Some(expr.into());
468 self
469 }
470
471 pub fn restrictive(mut self) -> Self {
473 self.permissive = false;
474 self
475 }
476
477 pub fn to_sql(&self) -> String {
479 let mut sql = String::with_capacity(256);
480
481 let policy_type = if self.permissive {
482 "PERMISSIVE"
483 } else {
484 "RESTRICTIVE"
485 };
486
487 write!(
488 sql,
489 "CREATE POLICY {} ON {}\n AS {}\n FOR {}\n TO {}",
490 quote_ident(&self.name),
491 quote_ident(&self.table),
492 policy_type,
493 self.command.as_str(),
494 self.role.as_deref().unwrap_or("PUBLIC"),
495 )
496 .unwrap();
497
498 if let Some(ref using) = self.using_expr {
499 write!(sql, "\n USING ({})", using).unwrap();
500 }
501
502 if let Some(ref check) = self.with_check_expr {
503 write!(sql, "\n WITH CHECK ({})", check).unwrap();
504 }
505
506 sql.push(';');
507 sql
508 }
509}
510
511fn quote_ident(name: &str) -> String {
513 if name
515 .chars()
516 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
517 && !name.is_empty()
518 && !name.chars().next().unwrap().is_ascii_digit()
519 {
520 name.to_string()
521 } else {
522 format!("\"{}\"", name.replace('"', "\"\""))
523 }
524}
525
526pub struct TenantGuard {
531 reset_sql: String,
532}
533
534impl TenantGuard {
535 pub fn new(session_var: &str, tenant_id: &str) -> (Self, String) {
539 let set_sql = format!(
540 "SET LOCAL {} = '{}';",
541 session_var,
542 tenant_id.replace('\'', "''")
543 );
544 let reset_sql = format!("RESET {};", session_var);
545
546 (Self { reset_sql }, set_sql)
547 }
548
549 pub fn reset_sql(&self) -> &str {
551 &self.reset_sql
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn test_rls_config() {
561 let config = RlsConfig::new("org_id")
562 .with_session_variable("app.org")
563 .with_role("app_user")
564 .add_tables(["users", "orders", "products"]);
565
566 assert_eq!(config.tenant_column, "org_id");
567 assert_eq!(config.session_variable, "app.org");
568 assert!(config.tables.contains("users"));
569 assert!(config.tables.contains("orders"));
570 }
571
572 #[test]
573 fn test_set_tenant_sql() {
574 let manager = RlsManager::simple("tenant_id", "app.tenant");
575
576 assert_eq!(
577 manager.set_tenant_sql("tenant-123"),
578 "SET app.tenant = 'tenant-123';"
579 );
580
581 assert_eq!(
583 manager.set_tenant_sql("'; DROP TABLE users; --"),
584 "SET app.tenant = '''; DROP TABLE users; --';"
585 );
586 }
587
588 #[test]
589 fn test_create_policy_sql() {
590 let manager = RlsManager::simple("tenant_id", "app.current_tenant");
591
592 let sql = manager.create_policy_sql("users");
593 assert!(sql.contains("CREATE POLICY"));
594 assert!(sql.contains("tenant_id = current_setting('app.current_tenant')"));
595 }
596
597 #[test]
598 fn test_setup_sql() {
599 let config = RlsConfig::new("tenant_id")
600 .with_session_variable("app.tenant")
601 .add_tables(["users", "orders"]);
602
603 let manager = RlsManager::new(config);
604 let sql = manager.setup_sql();
605
606 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
607 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
608 assert!(sql.contains("CREATE POLICY"));
609 }
610
611 #[test]
612 fn test_custom_policy() {
613 let policy = RlsPolicy::new("owner_access", "documents")
614 .command(PolicyCommand::All)
615 .role("app_user")
616 .using("owner_id = current_user_id()")
617 .with_check("owner_id = current_user_id()");
618
619 let sql = policy.to_sql();
620 assert!(sql.contains("CREATE POLICY owner_access"));
621 assert!(sql.contains("FOR ALL"));
622 assert!(sql.contains("USING (owner_id = current_user_id())"));
623 }
624
625 #[test]
626 fn test_migration_sql() {
627 let manager = RlsManager::simple("tenant_id", "app.tenant");
628
629 let up = manager.migration_up_sql("invoices");
630 assert!(up.contains("ENABLE ROW LEVEL SECURITY"));
631 assert!(up.contains("CREATE POLICY"));
632
633 let down = manager.migration_down_sql("invoices");
634 assert!(down.contains("DROP POLICY"));
635 assert!(down.contains("DISABLE ROW LEVEL SECURITY"));
636 }
637
638 #[test]
639 fn test_quote_ident() {
640 assert_eq!(quote_ident("users"), "users");
641 assert_eq!(quote_ident("user-data"), "\"user-data\"");
642 assert_eq!(quote_ident("User"), "\"User\"");
643 assert_eq!(quote_ident("table\"name"), "\"table\"\"name\"");
644 }
645}
646
647