Skip to main content

prax_query/tenant/
rls.rs

1//! PostgreSQL Row-Level Security (RLS) integration.
2//!
3//! This module provides high-performance RLS support for multi-tenant applications
4//! using PostgreSQL's native row-level security features.
5//!
6//! # Performance Benefits
7//!
8//! Using database-level RLS provides:
9//! - **Zero application overhead** - Filtering happens in the database engine
10//! - **Guaranteed isolation** - Even raw SQL queries are filtered
11//! - **Index utilization** - RLS policies can use indexes efficiently
12//! - **Prepared statement caching** - Same statements work for all tenants
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use prax_query::tenant::rls::{RlsManager, RlsPolicy};
18//!
19//! // Create RLS manager
20//! let rls = RlsManager::new("tenant_id", "app.current_tenant");
21//!
22//! // Generate setup SQL
23//! let setup = rls.setup_sql(&["users", "orders", "products"]);
24//! conn.execute_batch(&setup).await?;
25//!
26//! // Set tenant context for session
27//! rls.set_tenant_sql("tenant-123");
28//! ```
29
30use std::collections::HashSet;
31use std::fmt::Write;
32
33/// Configuration for PostgreSQL RLS.
34#[derive(Debug, Clone)]
35pub struct RlsConfig {
36    /// The tenant ID column name.
37    pub tenant_column: String,
38    /// PostgreSQL setting name for current tenant (e.g., "app.current_tenant").
39    pub session_variable: String,
40    /// Role to apply policies to.
41    pub application_role: Option<String>,
42    /// Tables to enable RLS on.
43    pub tables: HashSet<String>,
44    /// Tables excluded from RLS (e.g., shared lookup tables).
45    pub excluded_tables: HashSet<String>,
46    /// Whether to use BYPASSRLS for admin operations.
47    pub allow_bypass: bool,
48    /// Policy name prefix.
49    pub policy_prefix: String,
50}
51
52impl Default for RlsConfig {
53    fn default() -> Self {
54        Self {
55            tenant_column: "tenant_id".to_string(),
56            session_variable: "app.current_tenant".to_string(),
57            application_role: None,
58            tables: HashSet::new(),
59            excluded_tables: HashSet::new(),
60            allow_bypass: true,
61            policy_prefix: "tenant_isolation".to_string(),
62        }
63    }
64}
65
66impl RlsConfig {
67    /// Create a new RLS config with the given tenant column.
68    pub fn new(tenant_column: impl Into<String>) -> Self {
69        Self {
70            tenant_column: tenant_column.into(),
71            ..Default::default()
72        }
73    }
74
75    /// Set the session variable name.
76    pub fn with_session_variable(mut self, var: impl Into<String>) -> Self {
77        self.session_variable = var.into();
78        self
79    }
80
81    /// Set the application role.
82    pub fn with_role(mut self, role: impl Into<String>) -> Self {
83        self.application_role = Some(role.into());
84        self
85    }
86
87    /// Add a table for RLS.
88    pub fn add_table(mut self, table: impl Into<String>) -> Self {
89        self.tables.insert(table.into());
90        self
91    }
92
93    /// Add multiple tables for RLS.
94    pub fn add_tables<I, S>(mut self, tables: I) -> Self
95    where
96        I: IntoIterator<Item = S>,
97        S: Into<String>,
98    {
99        self.tables.extend(tables.into_iter().map(Into::into));
100        self
101    }
102
103    /// Exclude a table from RLS.
104    pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105        self.excluded_tables.insert(table.into());
106        self
107    }
108
109    /// Disable bypass for admin.
110    pub fn without_bypass(mut self) -> Self {
111        self.allow_bypass = false;
112        self
113    }
114
115    /// Set the policy prefix.
116    pub fn with_policy_prefix(mut self, prefix: impl Into<String>) -> Self {
117        self.policy_prefix = prefix.into();
118        self
119    }
120}
121
122/// Manager for PostgreSQL RLS operations.
123#[derive(Debug, Clone)]
124pub struct RlsManager {
125    config: RlsConfig,
126}
127
128impl RlsManager {
129    /// Create a new RLS manager with the given config.
130    pub fn new(config: RlsConfig) -> Self {
131        Self { config }
132    }
133
134    /// Create with simple defaults.
135    pub fn simple(tenant_column: impl Into<String>, session_var: impl Into<String>) -> Self {
136        Self::new(RlsConfig::new(tenant_column).with_session_variable(session_var))
137    }
138
139    /// Get the config.
140    pub fn config(&self) -> &RlsConfig {
141        &self.config
142    }
143
144    /// Generate SQL to enable RLS on a table.
145    pub fn enable_rls_sql(&self, table: &str) -> String {
146        format!(
147            "ALTER TABLE {} ENABLE ROW LEVEL SECURITY;",
148            quote_ident(table)
149        )
150    }
151
152    /// Generate SQL to force RLS even for table owners.
153    pub fn force_rls_sql(&self, table: &str) -> String {
154        format!(
155            "ALTER TABLE {} FORCE ROW LEVEL SECURITY;",
156            quote_ident(table)
157        )
158    }
159
160    /// Generate SQL for the tenant isolation policy.
161    pub fn create_policy_sql(&self, table: &str) -> String {
162        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
163        let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
164
165        // Create policy that filters by tenant_id = current_setting('app.current_tenant')
166        format!(
167            r#"CREATE POLICY {} ON {}
168    AS PERMISSIVE
169    FOR ALL
170    TO {}
171    USING ({} = current_setting('{}')::text)
172    WITH CHECK ({} = current_setting('{}')::text);"#,
173            quote_ident(&policy_name),
174            quote_ident(table),
175            role,
176            quote_ident(&self.config.tenant_column),
177            self.config.session_variable,
178            quote_ident(&self.config.tenant_column),
179            self.config.session_variable,
180        )
181    }
182
183    /// Generate SQL for UUID tenant columns.
184    pub fn create_uuid_policy_sql(&self, table: &str) -> String {
185        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
186        let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
187
188        format!(
189            r#"CREATE POLICY {} ON {}
190    AS PERMISSIVE
191    FOR ALL
192    TO {}
193    USING ({} = current_setting('{}')::uuid)
194    WITH CHECK ({} = current_setting('{}')::uuid);"#,
195            quote_ident(&policy_name),
196            quote_ident(table),
197            role,
198            quote_ident(&self.config.tenant_column),
199            self.config.session_variable,
200            quote_ident(&self.config.tenant_column),
201            self.config.session_variable,
202        )
203    }
204
205    /// Generate SQL to drop a policy.
206    pub fn drop_policy_sql(&self, table: &str) -> String {
207        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
208        format!(
209            "DROP POLICY IF EXISTS {} ON {};",
210            quote_ident(&policy_name),
211            quote_ident(table)
212        )
213    }
214
215    /// Generate SQL to set the current tenant for a session.
216    pub fn set_tenant_sql(&self, tenant_id: &str) -> String {
217        format!(
218            "SET {} = '{}';",
219            self.config.session_variable,
220            tenant_id.replace('\'', "''")
221        )
222    }
223
224    /// Generate SQL to set the current tenant locally (transaction only).
225    pub fn set_tenant_local_sql(&self, tenant_id: &str) -> String {
226        format!(
227            "SET LOCAL {} = '{}';",
228            self.config.session_variable,
229            tenant_id.replace('\'', "''")
230        )
231    }
232
233    /// Generate SQL to reset the tenant context.
234    pub fn reset_tenant_sql(&self) -> String {
235        format!("RESET {};", self.config.session_variable)
236    }
237
238    /// Generate SQL to check the current tenant.
239    pub fn current_tenant_sql(&self) -> String {
240        format!(
241            "SELECT current_setting('{}', true);",
242            self.config.session_variable
243        )
244    }
245
246    /// Generate complete setup SQL for all configured tables.
247    pub fn setup_sql(&self) -> String {
248        let mut sql = String::with_capacity(4096);
249
250        // Header
251        writeln!(sql, "-- Prax Multi-Tenant RLS Setup").unwrap();
252        writeln!(
253            sql,
254            "-- Generated for column: {}",
255            self.config.tenant_column
256        )
257        .unwrap();
258        writeln!(sql, "-- Session variable: {}", self.config.session_variable).unwrap();
259        writeln!(sql).unwrap();
260
261        // Create admin role if bypass is enabled
262        if self.config.allow_bypass {
263            if let Some(ref role) = self.config.application_role {
264                writeln!(sql, "-- Admin role with BYPASSRLS").unwrap();
265                writeln!(sql, "DO $$").unwrap();
266                writeln!(sql, "BEGIN").unwrap();
267                writeln!(sql, "    CREATE ROLE {}_admin WITH BYPASSRLS;", role).unwrap();
268                writeln!(sql, "EXCEPTION WHEN duplicate_object THEN NULL;").unwrap();
269                writeln!(sql, "END $$;").unwrap();
270                writeln!(sql).unwrap();
271            }
272        }
273
274        // Enable RLS and create policies for each table
275        for table in &self.config.tables {
276            if self.config.excluded_tables.contains(table) {
277                continue;
278            }
279
280            writeln!(sql, "-- Table: {}", table).unwrap();
281            writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
282            writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
283            writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
284            writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
285            writeln!(sql).unwrap();
286        }
287
288        sql
289    }
290
291    /// Generate migration SQL to add RLS to a new table.
292    pub fn migration_up_sql(&self, table: &str) -> String {
293        let mut sql = String::with_capacity(512);
294
295        writeln!(sql, "-- Enable RLS on {}", table).unwrap();
296        writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
297        writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
298        writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
299
300        sql
301    }
302
303    /// Generate migration SQL to remove RLS from a table.
304    pub fn migration_down_sql(&self, table: &str) -> String {
305        let mut sql = String::with_capacity(256);
306
307        writeln!(sql, "-- Disable RLS on {}", table).unwrap();
308        writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
309        writeln!(
310            sql,
311            "ALTER TABLE {} DISABLE ROW LEVEL SECURITY;",
312            quote_ident(table)
313        )
314        .unwrap();
315
316        sql
317    }
318}
319
320/// Builder for RLS manager.
321#[derive(Default)]
322pub struct RlsManagerBuilder {
323    config: RlsConfig,
324}
325
326impl RlsManagerBuilder {
327    /// Create a new builder.
328    pub fn new() -> Self {
329        Self::default()
330    }
331
332    /// Set the tenant column.
333    pub fn tenant_column(mut self, column: impl Into<String>) -> Self {
334        self.config.tenant_column = column.into();
335        self
336    }
337
338    /// Set the session variable.
339    pub fn session_variable(mut self, var: impl Into<String>) -> Self {
340        self.config.session_variable = var.into();
341        self
342    }
343
344    /// Set the application role.
345    pub fn application_role(mut self, role: impl Into<String>) -> Self {
346        self.config.application_role = Some(role.into());
347        self
348    }
349
350    /// Add tables.
351    pub fn tables<I, S>(mut self, tables: I) -> Self
352    where
353        I: IntoIterator<Item = S>,
354        S: Into<String>,
355    {
356        self.config
357            .tables
358            .extend(tables.into_iter().map(Into::into));
359        self
360    }
361
362    /// Exclude tables.
363    pub fn exclude<I, S>(mut self, tables: I) -> Self
364    where
365        I: IntoIterator<Item = S>,
366        S: Into<String>,
367    {
368        self.config
369            .excluded_tables
370            .extend(tables.into_iter().map(Into::into));
371        self
372    }
373
374    /// Set policy prefix.
375    pub fn policy_prefix(mut self, prefix: impl Into<String>) -> Self {
376        self.config.policy_prefix = prefix.into();
377        self
378    }
379
380    /// Build the manager.
381    pub fn build(self) -> RlsManager {
382        RlsManager::new(self.config)
383    }
384}
385
386/// Represents a custom RLS policy.
387#[derive(Debug, Clone)]
388pub struct RlsPolicy {
389    /// Policy name.
390    pub name: String,
391    /// Table the policy applies to.
392    pub table: String,
393    /// Command the policy applies to (ALL, SELECT, INSERT, UPDATE, DELETE).
394    pub command: PolicyCommand,
395    /// Role the policy applies to.
396    pub role: Option<String>,
397    /// USING expression (for SELECT, UPDATE, DELETE).
398    pub using_expr: Option<String>,
399    /// WITH CHECK expression (for INSERT, UPDATE).
400    pub with_check_expr: Option<String>,
401    /// Whether this is a permissive or restrictive policy.
402    pub permissive: bool,
403}
404
405/// SQL command that a policy applies to.
406#[derive(Debug, Clone, Copy, PartialEq, Eq)]
407pub enum PolicyCommand {
408    All,
409    Select,
410    Insert,
411    Update,
412    Delete,
413}
414
415impl PolicyCommand {
416    fn as_str(&self) -> &'static str {
417        match self {
418            Self::All => "ALL",
419            Self::Select => "SELECT",
420            Self::Insert => "INSERT",
421            Self::Update => "UPDATE",
422            Self::Delete => "DELETE",
423        }
424    }
425}
426
427impl RlsPolicy {
428    /// Create a new policy.
429    pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
430        Self {
431            name: name.into(),
432            table: table.into(),
433            command: PolicyCommand::All,
434            role: None,
435            using_expr: None,
436            with_check_expr: None,
437            permissive: true,
438        }
439    }
440
441    /// Set the command.
442    pub fn command(mut self, cmd: PolicyCommand) -> Self {
443        self.command = cmd;
444        self
445    }
446
447    /// Set the role.
448    pub fn role(mut self, role: impl Into<String>) -> Self {
449        self.role = Some(role.into());
450        self
451    }
452
453    /// Set the USING expression.
454    pub fn using(mut self, expr: impl Into<String>) -> Self {
455        self.using_expr = Some(expr.into());
456        self
457    }
458
459    /// Set the WITH CHECK expression.
460    pub fn with_check(mut self, expr: impl Into<String>) -> Self {
461        self.with_check_expr = Some(expr.into());
462        self
463    }
464
465    /// Make this a restrictive policy.
466    pub fn restrictive(mut self) -> Self {
467        self.permissive = false;
468        self
469    }
470
471    /// Generate the CREATE POLICY SQL.
472    pub fn to_sql(&self) -> String {
473        let mut sql = String::with_capacity(256);
474
475        let policy_type = if self.permissive {
476            "PERMISSIVE"
477        } else {
478            "RESTRICTIVE"
479        };
480
481        write!(
482            sql,
483            "CREATE POLICY {} ON {}\n    AS {}\n    FOR {}\n    TO {}",
484            quote_ident(&self.name),
485            quote_ident(&self.table),
486            policy_type,
487            self.command.as_str(),
488            self.role.as_deref().unwrap_or("PUBLIC"),
489        )
490        .unwrap();
491
492        if let Some(ref using) = self.using_expr {
493            write!(sql, "\n    USING ({})", using).unwrap();
494        }
495
496        if let Some(ref check) = self.with_check_expr {
497            write!(sql, "\n    WITH CHECK ({})", check).unwrap();
498        }
499
500        sql.push(';');
501        sql
502    }
503}
504
505/// Quote a PostgreSQL identifier.
506fn quote_ident(name: &str) -> String {
507    // Simple quoting - in production, use proper escaping
508    if name
509        .chars()
510        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
511        && !name.is_empty()
512        && !name.chars().next().unwrap().is_ascii_digit()
513    {
514        name.to_string()
515    } else {
516        format!("\"{}\"", name.replace('"', "\"\""))
517    }
518}
519
520/// Context guard that sets tenant for the duration of its lifetime.
521///
522/// Uses PostgreSQL's SET LOCAL to ensure the setting only applies to
523/// the current transaction.
524pub struct TenantGuard {
525    reset_sql: String,
526}
527
528impl TenantGuard {
529    /// Create a new tenant guard.
530    ///
531    /// The caller should execute `set_sql()` before using the connection.
532    pub fn new(session_var: &str, tenant_id: &str) -> (Self, String) {
533        let set_sql = format!(
534            "SET LOCAL {} = '{}';",
535            session_var,
536            tenant_id.replace('\'', "''")
537        );
538        let reset_sql = format!("RESET {};", session_var);
539
540        (Self { reset_sql }, set_sql)
541    }
542
543    /// Get the SQL to reset the tenant context.
544    pub fn reset_sql(&self) -> &str {
545        &self.reset_sql
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_rls_config() {
555        let config = RlsConfig::new("org_id")
556            .with_session_variable("app.org")
557            .with_role("app_user")
558            .add_tables(["users", "orders", "products"]);
559
560        assert_eq!(config.tenant_column, "org_id");
561        assert_eq!(config.session_variable, "app.org");
562        assert!(config.tables.contains("users"));
563        assert!(config.tables.contains("orders"));
564    }
565
566    #[test]
567    fn test_set_tenant_sql() {
568        let manager = RlsManager::simple("tenant_id", "app.tenant");
569
570        assert_eq!(
571            manager.set_tenant_sql("tenant-123"),
572            "SET app.tenant = 'tenant-123';"
573        );
574
575        // Test SQL injection prevention
576        assert_eq!(
577            manager.set_tenant_sql("'; DROP TABLE users; --"),
578            "SET app.tenant = '''; DROP TABLE users; --';"
579        );
580    }
581
582    #[test]
583    fn test_create_policy_sql() {
584        let manager = RlsManager::simple("tenant_id", "app.current_tenant");
585
586        let sql = manager.create_policy_sql("users");
587        assert!(sql.contains("CREATE POLICY"));
588        assert!(sql.contains("tenant_id = current_setting('app.current_tenant')"));
589    }
590
591    #[test]
592    fn test_setup_sql() {
593        let config = RlsConfig::new("tenant_id")
594            .with_session_variable("app.tenant")
595            .add_tables(["users", "orders"]);
596
597        let manager = RlsManager::new(config);
598        let sql = manager.setup_sql();
599
600        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
601        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
602        assert!(sql.contains("CREATE POLICY"));
603    }
604
605    #[test]
606    fn test_custom_policy() {
607        let policy = RlsPolicy::new("owner_access", "documents")
608            .command(PolicyCommand::All)
609            .role("app_user")
610            .using("owner_id = current_user_id()")
611            .with_check("owner_id = current_user_id()");
612
613        let sql = policy.to_sql();
614        assert!(sql.contains("CREATE POLICY owner_access"));
615        assert!(sql.contains("FOR ALL"));
616        assert!(sql.contains("USING (owner_id = current_user_id())"));
617    }
618
619    #[test]
620    fn test_migration_sql() {
621        let manager = RlsManager::simple("tenant_id", "app.tenant");
622
623        let up = manager.migration_up_sql("invoices");
624        assert!(up.contains("ENABLE ROW LEVEL SECURITY"));
625        assert!(up.contains("CREATE POLICY"));
626
627        let down = manager.migration_down_sql("invoices");
628        assert!(down.contains("DROP POLICY"));
629        assert!(down.contains("DISABLE ROW LEVEL SECURITY"));
630    }
631
632    #[test]
633    fn test_quote_ident() {
634        assert_eq!(quote_ident("users"), "users");
635        assert_eq!(quote_ident("user-data"), "\"user-data\"");
636        assert_eq!(quote_ident("User"), "\"User\"");
637        assert_eq!(quote_ident("table\"name"), "\"table\"\"name\"");
638    }
639}