1use std::collections::HashSet;
31use std::fmt::Write;
32
33#[derive(Debug, Clone)]
35pub struct RlsConfig {
36 pub tenant_column: String,
38 pub session_variable: String,
40 pub application_role: Option<String>,
42 pub tables: HashSet<String>,
44 pub excluded_tables: HashSet<String>,
46 pub allow_bypass: bool,
48 pub policy_prefix: String,
50}
51
52impl Default for RlsConfig {
53 fn default() -> Self {
54 Self {
55 tenant_column: "tenant_id".to_string(),
56 session_variable: "app.current_tenant".to_string(),
57 application_role: None,
58 tables: HashSet::new(),
59 excluded_tables: HashSet::new(),
60 allow_bypass: true,
61 policy_prefix: "tenant_isolation".to_string(),
62 }
63 }
64}
65
66impl RlsConfig {
67 pub fn new(tenant_column: impl Into<String>) -> Self {
69 Self {
70 tenant_column: tenant_column.into(),
71 ..Default::default()
72 }
73 }
74
75 pub fn with_session_variable(mut self, var: impl Into<String>) -> Self {
77 self.session_variable = var.into();
78 self
79 }
80
81 pub fn with_role(mut self, role: impl Into<String>) -> Self {
83 self.application_role = Some(role.into());
84 self
85 }
86
87 pub fn add_table(mut self, table: impl Into<String>) -> Self {
89 self.tables.insert(table.into());
90 self
91 }
92
93 pub fn add_tables<I, S>(mut self, tables: I) -> Self
95 where
96 I: IntoIterator<Item = S>,
97 S: Into<String>,
98 {
99 self.tables.extend(tables.into_iter().map(Into::into));
100 self
101 }
102
103 pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105 self.excluded_tables.insert(table.into());
106 self
107 }
108
109 pub fn without_bypass(mut self) -> Self {
111 self.allow_bypass = false;
112 self
113 }
114
115 pub fn with_policy_prefix(mut self, prefix: impl Into<String>) -> Self {
117 self.policy_prefix = prefix.into();
118 self
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct RlsManager {
125 config: RlsConfig,
126}
127
128impl RlsManager {
129 pub fn new(config: RlsConfig) -> Self {
131 Self { config }
132 }
133
134 pub fn simple(tenant_column: impl Into<String>, session_var: impl Into<String>) -> Self {
136 Self::new(RlsConfig::new(tenant_column).with_session_variable(session_var))
137 }
138
139 pub fn config(&self) -> &RlsConfig {
141 &self.config
142 }
143
144 pub fn enable_rls_sql(&self, table: &str) -> String {
146 format!(
147 "ALTER TABLE {} ENABLE ROW LEVEL SECURITY;",
148 quote_ident(table)
149 )
150 }
151
152 pub fn force_rls_sql(&self, table: &str) -> String {
154 format!(
155 "ALTER TABLE {} FORCE ROW LEVEL SECURITY;",
156 quote_ident(table)
157 )
158 }
159
160 pub fn create_policy_sql(&self, table: &str) -> String {
162 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
163 let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
164
165 format!(
167 r#"CREATE POLICY {} ON {}
168 AS PERMISSIVE
169 FOR ALL
170 TO {}
171 USING ({} = current_setting('{}')::text)
172 WITH CHECK ({} = current_setting('{}')::text);"#,
173 quote_ident(&policy_name),
174 quote_ident(table),
175 role,
176 quote_ident(&self.config.tenant_column),
177 self.config.session_variable,
178 quote_ident(&self.config.tenant_column),
179 self.config.session_variable,
180 )
181 }
182
183 pub fn create_uuid_policy_sql(&self, table: &str) -> String {
185 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
186 let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
187
188 format!(
189 r#"CREATE POLICY {} ON {}
190 AS PERMISSIVE
191 FOR ALL
192 TO {}
193 USING ({} = current_setting('{}')::uuid)
194 WITH CHECK ({} = current_setting('{}')::uuid);"#,
195 quote_ident(&policy_name),
196 quote_ident(table),
197 role,
198 quote_ident(&self.config.tenant_column),
199 self.config.session_variable,
200 quote_ident(&self.config.tenant_column),
201 self.config.session_variable,
202 )
203 }
204
205 pub fn drop_policy_sql(&self, table: &str) -> String {
207 let policy_name = format!("{}_{}", self.config.policy_prefix, table);
208 format!(
209 "DROP POLICY IF EXISTS {} ON {};",
210 quote_ident(&policy_name),
211 quote_ident(table)
212 )
213 }
214
215 pub fn set_tenant_sql(&self, tenant_id: &str) -> String {
217 format!(
218 "SET {} = '{}';",
219 self.config.session_variable,
220 tenant_id.replace('\'', "''")
221 )
222 }
223
224 pub fn set_tenant_local_sql(&self, tenant_id: &str) -> String {
226 format!(
227 "SET LOCAL {} = '{}';",
228 self.config.session_variable,
229 tenant_id.replace('\'', "''")
230 )
231 }
232
233 pub fn reset_tenant_sql(&self) -> String {
235 format!("RESET {};", self.config.session_variable)
236 }
237
238 pub fn current_tenant_sql(&self) -> String {
240 format!(
241 "SELECT current_setting('{}', true);",
242 self.config.session_variable
243 )
244 }
245
246 pub fn setup_sql(&self) -> String {
248 let mut sql = String::with_capacity(4096);
249
250 writeln!(sql, "-- Prax Multi-Tenant RLS Setup").unwrap();
252 writeln!(
253 sql,
254 "-- Generated for column: {}",
255 self.config.tenant_column
256 )
257 .unwrap();
258 writeln!(sql, "-- Session variable: {}", self.config.session_variable).unwrap();
259 writeln!(sql).unwrap();
260
261 if self.config.allow_bypass {
263 if let Some(ref role) = self.config.application_role {
264 writeln!(sql, "-- Admin role with BYPASSRLS").unwrap();
265 writeln!(sql, "DO $$").unwrap();
266 writeln!(sql, "BEGIN").unwrap();
267 writeln!(sql, " CREATE ROLE {}_admin WITH BYPASSRLS;", role).unwrap();
268 writeln!(sql, "EXCEPTION WHEN duplicate_object THEN NULL;").unwrap();
269 writeln!(sql, "END $$;").unwrap();
270 writeln!(sql).unwrap();
271 }
272 }
273
274 for table in &self.config.tables {
276 if self.config.excluded_tables.contains(table) {
277 continue;
278 }
279
280 writeln!(sql, "-- Table: {}", table).unwrap();
281 writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
282 writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
283 writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
284 writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
285 writeln!(sql).unwrap();
286 }
287
288 sql
289 }
290
291 pub fn migration_up_sql(&self, table: &str) -> String {
293 let mut sql = String::with_capacity(512);
294
295 writeln!(sql, "-- Enable RLS on {}", table).unwrap();
296 writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
297 writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
298 writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
299
300 sql
301 }
302
303 pub fn migration_down_sql(&self, table: &str) -> String {
305 let mut sql = String::with_capacity(256);
306
307 writeln!(sql, "-- Disable RLS on {}", table).unwrap();
308 writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
309 writeln!(
310 sql,
311 "ALTER TABLE {} DISABLE ROW LEVEL SECURITY;",
312 quote_ident(table)
313 )
314 .unwrap();
315
316 sql
317 }
318}
319
320#[derive(Default)]
322pub struct RlsManagerBuilder {
323 config: RlsConfig,
324}
325
326impl RlsManagerBuilder {
327 pub fn new() -> Self {
329 Self::default()
330 }
331
332 pub fn tenant_column(mut self, column: impl Into<String>) -> Self {
334 self.config.tenant_column = column.into();
335 self
336 }
337
338 pub fn session_variable(mut self, var: impl Into<String>) -> Self {
340 self.config.session_variable = var.into();
341 self
342 }
343
344 pub fn application_role(mut self, role: impl Into<String>) -> Self {
346 self.config.application_role = Some(role.into());
347 self
348 }
349
350 pub fn tables<I, S>(mut self, tables: I) -> Self
352 where
353 I: IntoIterator<Item = S>,
354 S: Into<String>,
355 {
356 self.config
357 .tables
358 .extend(tables.into_iter().map(Into::into));
359 self
360 }
361
362 pub fn exclude<I, S>(mut self, tables: I) -> Self
364 where
365 I: IntoIterator<Item = S>,
366 S: Into<String>,
367 {
368 self.config
369 .excluded_tables
370 .extend(tables.into_iter().map(Into::into));
371 self
372 }
373
374 pub fn policy_prefix(mut self, prefix: impl Into<String>) -> Self {
376 self.config.policy_prefix = prefix.into();
377 self
378 }
379
380 pub fn build(self) -> RlsManager {
382 RlsManager::new(self.config)
383 }
384}
385
386#[derive(Debug, Clone)]
388pub struct RlsPolicy {
389 pub name: String,
391 pub table: String,
393 pub command: PolicyCommand,
395 pub role: Option<String>,
397 pub using_expr: Option<String>,
399 pub with_check_expr: Option<String>,
401 pub permissive: bool,
403}
404
405#[derive(Debug, Clone, Copy, PartialEq, Eq)]
407pub enum PolicyCommand {
408 All,
409 Select,
410 Insert,
411 Update,
412 Delete,
413}
414
415impl PolicyCommand {
416 fn as_str(&self) -> &'static str {
417 match self {
418 Self::All => "ALL",
419 Self::Select => "SELECT",
420 Self::Insert => "INSERT",
421 Self::Update => "UPDATE",
422 Self::Delete => "DELETE",
423 }
424 }
425}
426
427impl RlsPolicy {
428 pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
430 Self {
431 name: name.into(),
432 table: table.into(),
433 command: PolicyCommand::All,
434 role: None,
435 using_expr: None,
436 with_check_expr: None,
437 permissive: true,
438 }
439 }
440
441 pub fn command(mut self, cmd: PolicyCommand) -> Self {
443 self.command = cmd;
444 self
445 }
446
447 pub fn role(mut self, role: impl Into<String>) -> Self {
449 self.role = Some(role.into());
450 self
451 }
452
453 pub fn using(mut self, expr: impl Into<String>) -> Self {
455 self.using_expr = Some(expr.into());
456 self
457 }
458
459 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
461 self.with_check_expr = Some(expr.into());
462 self
463 }
464
465 pub fn restrictive(mut self) -> Self {
467 self.permissive = false;
468 self
469 }
470
471 pub fn to_sql(&self) -> String {
473 let mut sql = String::with_capacity(256);
474
475 let policy_type = if self.permissive {
476 "PERMISSIVE"
477 } else {
478 "RESTRICTIVE"
479 };
480
481 write!(
482 sql,
483 "CREATE POLICY {} ON {}\n AS {}\n FOR {}\n TO {}",
484 quote_ident(&self.name),
485 quote_ident(&self.table),
486 policy_type,
487 self.command.as_str(),
488 self.role.as_deref().unwrap_or("PUBLIC"),
489 )
490 .unwrap();
491
492 if let Some(ref using) = self.using_expr {
493 write!(sql, "\n USING ({})", using).unwrap();
494 }
495
496 if let Some(ref check) = self.with_check_expr {
497 write!(sql, "\n WITH CHECK ({})", check).unwrap();
498 }
499
500 sql.push(';');
501 sql
502 }
503}
504
505fn quote_ident(name: &str) -> String {
507 if name
509 .chars()
510 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
511 && !name.is_empty()
512 && !name.chars().next().unwrap().is_ascii_digit()
513 {
514 name.to_string()
515 } else {
516 format!("\"{}\"", name.replace('"', "\"\""))
517 }
518}
519
520pub struct TenantGuard {
525 reset_sql: String,
526}
527
528impl TenantGuard {
529 pub fn new(session_var: &str, tenant_id: &str) -> (Self, String) {
533 let set_sql = format!(
534 "SET LOCAL {} = '{}';",
535 session_var,
536 tenant_id.replace('\'', "''")
537 );
538 let reset_sql = format!("RESET {};", session_var);
539
540 (Self { reset_sql }, set_sql)
541 }
542
543 pub fn reset_sql(&self) -> &str {
545 &self.reset_sql
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_rls_config() {
555 let config = RlsConfig::new("org_id")
556 .with_session_variable("app.org")
557 .with_role("app_user")
558 .add_tables(["users", "orders", "products"]);
559
560 assert_eq!(config.tenant_column, "org_id");
561 assert_eq!(config.session_variable, "app.org");
562 assert!(config.tables.contains("users"));
563 assert!(config.tables.contains("orders"));
564 }
565
566 #[test]
567 fn test_set_tenant_sql() {
568 let manager = RlsManager::simple("tenant_id", "app.tenant");
569
570 assert_eq!(
571 manager.set_tenant_sql("tenant-123"),
572 "SET app.tenant = 'tenant-123';"
573 );
574
575 assert_eq!(
577 manager.set_tenant_sql("'; DROP TABLE users; --"),
578 "SET app.tenant = '''; DROP TABLE users; --';"
579 );
580 }
581
582 #[test]
583 fn test_create_policy_sql() {
584 let manager = RlsManager::simple("tenant_id", "app.current_tenant");
585
586 let sql = manager.create_policy_sql("users");
587 assert!(sql.contains("CREATE POLICY"));
588 assert!(sql.contains("tenant_id = current_setting('app.current_tenant')"));
589 }
590
591 #[test]
592 fn test_setup_sql() {
593 let config = RlsConfig::new("tenant_id")
594 .with_session_variable("app.tenant")
595 .add_tables(["users", "orders"]);
596
597 let manager = RlsManager::new(config);
598 let sql = manager.setup_sql();
599
600 assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
601 assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
602 assert!(sql.contains("CREATE POLICY"));
603 }
604
605 #[test]
606 fn test_custom_policy() {
607 let policy = RlsPolicy::new("owner_access", "documents")
608 .command(PolicyCommand::All)
609 .role("app_user")
610 .using("owner_id = current_user_id()")
611 .with_check("owner_id = current_user_id()");
612
613 let sql = policy.to_sql();
614 assert!(sql.contains("CREATE POLICY owner_access"));
615 assert!(sql.contains("FOR ALL"));
616 assert!(sql.contains("USING (owner_id = current_user_id())"));
617 }
618
619 #[test]
620 fn test_migration_sql() {
621 let manager = RlsManager::simple("tenant_id", "app.tenant");
622
623 let up = manager.migration_up_sql("invoices");
624 assert!(up.contains("ENABLE ROW LEVEL SECURITY"));
625 assert!(up.contains("CREATE POLICY"));
626
627 let down = manager.migration_down_sql("invoices");
628 assert!(down.contains("DROP POLICY"));
629 assert!(down.contains("DISABLE ROW LEVEL SECURITY"));
630 }
631
632 #[test]
633 fn test_quote_ident() {
634 assert_eq!(quote_ident("users"), "users");
635 assert_eq!(quote_ident("user-data"), "\"user-data\"");
636 assert_eq!(quote_ident("User"), "\"User\"");
637 assert_eq!(quote_ident("table\"name"), "\"table\"\"name\"");
638 }
639}