prax_query/tenant/
rls.rs

1//! PostgreSQL Row-Level Security (RLS) integration.
2//!
3//! This module provides high-performance RLS support for multi-tenant applications
4//! using PostgreSQL's native row-level security features.
5//!
6//! # Performance Benefits
7//!
8//! Using database-level RLS provides:
9//! - **Zero application overhead** - Filtering happens in the database engine
10//! - **Guaranteed isolation** - Even raw SQL queries are filtered
11//! - **Index utilization** - RLS policies can use indexes efficiently
12//! - **Prepared statement caching** - Same statements work for all tenants
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use prax_query::tenant::rls::{RlsManager, RlsPolicy};
18//!
19//! // Create RLS manager
20//! let rls = RlsManager::new("tenant_id", "app.current_tenant");
21//!
22//! // Generate setup SQL
23//! let setup = rls.setup_sql(&["users", "orders", "products"]);
24//! conn.execute_batch(&setup).await?;
25//!
26//! // Set tenant context for session
27//! rls.set_tenant_sql("tenant-123");
28//! ```
29
30use std::collections::HashSet;
31use std::fmt::Write;
32
33/// Configuration for PostgreSQL RLS.
34#[derive(Debug, Clone)]
35pub struct RlsConfig {
36    /// The tenant ID column name.
37    pub tenant_column: String,
38    /// PostgreSQL setting name for current tenant (e.g., "app.current_tenant").
39    pub session_variable: String,
40    /// Role to apply policies to.
41    pub application_role: Option<String>,
42    /// Tables to enable RLS on.
43    pub tables: HashSet<String>,
44    /// Tables excluded from RLS (e.g., shared lookup tables).
45    pub excluded_tables: HashSet<String>,
46    /// Whether to use BYPASSRLS for admin operations.
47    pub allow_bypass: bool,
48    /// Policy name prefix.
49    pub policy_prefix: String,
50}
51
52impl Default for RlsConfig {
53    fn default() -> Self {
54        Self {
55            tenant_column: "tenant_id".to_string(),
56            session_variable: "app.current_tenant".to_string(),
57            application_role: None,
58            tables: HashSet::new(),
59            excluded_tables: HashSet::new(),
60            allow_bypass: true,
61            policy_prefix: "tenant_isolation".to_string(),
62        }
63    }
64}
65
66impl RlsConfig {
67    /// Create a new RLS config with the given tenant column.
68    pub fn new(tenant_column: impl Into<String>) -> Self {
69        Self {
70            tenant_column: tenant_column.into(),
71            ..Default::default()
72        }
73    }
74
75    /// Set the session variable name.
76    pub fn with_session_variable(mut self, var: impl Into<String>) -> Self {
77        self.session_variable = var.into();
78        self
79    }
80
81    /// Set the application role.
82    pub fn with_role(mut self, role: impl Into<String>) -> Self {
83        self.application_role = Some(role.into());
84        self
85    }
86
87    /// Add a table for RLS.
88    pub fn add_table(mut self, table: impl Into<String>) -> Self {
89        self.tables.insert(table.into());
90        self
91    }
92
93    /// Add multiple tables for RLS.
94    pub fn add_tables<I, S>(mut self, tables: I) -> Self
95    where
96        I: IntoIterator<Item = S>,
97        S: Into<String>,
98    {
99        self.tables.extend(tables.into_iter().map(Into::into));
100        self
101    }
102
103    /// Exclude a table from RLS.
104    pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105        self.excluded_tables.insert(table.into());
106        self
107    }
108
109    /// Disable bypass for admin.
110    pub fn without_bypass(mut self) -> Self {
111        self.allow_bypass = false;
112        self
113    }
114
115    /// Set the policy prefix.
116    pub fn with_policy_prefix(mut self, prefix: impl Into<String>) -> Self {
117        self.policy_prefix = prefix.into();
118        self
119    }
120}
121
122/// Manager for PostgreSQL RLS operations.
123#[derive(Debug, Clone)]
124pub struct RlsManager {
125    config: RlsConfig,
126}
127
128impl RlsManager {
129    /// Create a new RLS manager with the given config.
130    pub fn new(config: RlsConfig) -> Self {
131        Self { config }
132    }
133
134    /// Create with simple defaults.
135    pub fn simple(tenant_column: impl Into<String>, session_var: impl Into<String>) -> Self {
136        Self::new(
137            RlsConfig::new(tenant_column)
138                .with_session_variable(session_var),
139        )
140    }
141
142    /// Get the config.
143    pub fn config(&self) -> &RlsConfig {
144        &self.config
145    }
146
147    /// Generate SQL to enable RLS on a table.
148    pub fn enable_rls_sql(&self, table: &str) -> String {
149        format!("ALTER TABLE {} ENABLE ROW LEVEL SECURITY;", quote_ident(table))
150    }
151
152    /// Generate SQL to force RLS even for table owners.
153    pub fn force_rls_sql(&self, table: &str) -> String {
154        format!(
155            "ALTER TABLE {} FORCE ROW LEVEL SECURITY;",
156            quote_ident(table)
157        )
158    }
159
160    /// Generate SQL for the tenant isolation policy.
161    pub fn create_policy_sql(&self, table: &str) -> String {
162        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
163        let role = self
164            .config
165            .application_role
166            .as_deref()
167            .unwrap_or("PUBLIC");
168
169        // Create policy that filters by tenant_id = current_setting('app.current_tenant')
170        format!(
171            r#"CREATE POLICY {} ON {}
172    AS PERMISSIVE
173    FOR ALL
174    TO {}
175    USING ({} = current_setting('{}')::text)
176    WITH CHECK ({} = current_setting('{}')::text);"#,
177            quote_ident(&policy_name),
178            quote_ident(table),
179            role,
180            quote_ident(&self.config.tenant_column),
181            self.config.session_variable,
182            quote_ident(&self.config.tenant_column),
183            self.config.session_variable,
184        )
185    }
186
187    /// Generate SQL for UUID tenant columns.
188    pub fn create_uuid_policy_sql(&self, table: &str) -> String {
189        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
190        let role = self
191            .config
192            .application_role
193            .as_deref()
194            .unwrap_or("PUBLIC");
195
196        format!(
197            r#"CREATE POLICY {} ON {}
198    AS PERMISSIVE
199    FOR ALL
200    TO {}
201    USING ({} = current_setting('{}')::uuid)
202    WITH CHECK ({} = current_setting('{}')::uuid);"#,
203            quote_ident(&policy_name),
204            quote_ident(table),
205            role,
206            quote_ident(&self.config.tenant_column),
207            self.config.session_variable,
208            quote_ident(&self.config.tenant_column),
209            self.config.session_variable,
210        )
211    }
212
213    /// Generate SQL to drop a policy.
214    pub fn drop_policy_sql(&self, table: &str) -> String {
215        let policy_name = format!("{}_{}", self.config.policy_prefix, table);
216        format!(
217            "DROP POLICY IF EXISTS {} ON {};",
218            quote_ident(&policy_name),
219            quote_ident(table)
220        )
221    }
222
223    /// Generate SQL to set the current tenant for a session.
224    pub fn set_tenant_sql(&self, tenant_id: &str) -> String {
225        format!(
226            "SET {} = '{}';",
227            self.config.session_variable,
228            tenant_id.replace('\'', "''")
229        )
230    }
231
232    /// Generate SQL to set the current tenant locally (transaction only).
233    pub fn set_tenant_local_sql(&self, tenant_id: &str) -> String {
234        format!(
235            "SET LOCAL {} = '{}';",
236            self.config.session_variable,
237            tenant_id.replace('\'', "''")
238        )
239    }
240
241    /// Generate SQL to reset the tenant context.
242    pub fn reset_tenant_sql(&self) -> String {
243        format!("RESET {};", self.config.session_variable)
244    }
245
246    /// Generate SQL to check the current tenant.
247    pub fn current_tenant_sql(&self) -> String {
248        format!(
249            "SELECT current_setting('{}', true);",
250            self.config.session_variable
251        )
252    }
253
254    /// Generate complete setup SQL for all configured tables.
255    pub fn setup_sql(&self) -> String {
256        let mut sql = String::with_capacity(4096);
257
258        // Header
259        writeln!(sql, "-- Prax Multi-Tenant RLS Setup").unwrap();
260        writeln!(sql, "-- Generated for column: {}", self.config.tenant_column).unwrap();
261        writeln!(sql, "-- Session variable: {}", self.config.session_variable).unwrap();
262        writeln!(sql).unwrap();
263
264        // Create admin role if bypass is enabled
265        if self.config.allow_bypass {
266            if let Some(ref role) = self.config.application_role {
267                writeln!(sql, "-- Admin role with BYPASSRLS").unwrap();
268                writeln!(sql, "DO $$").unwrap();
269                writeln!(sql, "BEGIN").unwrap();
270                writeln!(
271                    sql,
272                    "    CREATE ROLE {}_admin WITH BYPASSRLS;",
273                    role
274                )
275                .unwrap();
276                writeln!(sql, "EXCEPTION WHEN duplicate_object THEN NULL;").unwrap();
277                writeln!(sql, "END $$;").unwrap();
278                writeln!(sql).unwrap();
279            }
280        }
281
282        // Enable RLS and create policies for each table
283        for table in &self.config.tables {
284            if self.config.excluded_tables.contains(table) {
285                continue;
286            }
287
288            writeln!(sql, "-- Table: {}", table).unwrap();
289            writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
290            writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
291            writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
292            writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
293            writeln!(sql).unwrap();
294        }
295
296        sql
297    }
298
299    /// Generate migration SQL to add RLS to a new table.
300    pub fn migration_up_sql(&self, table: &str) -> String {
301        let mut sql = String::with_capacity(512);
302
303        writeln!(sql, "-- Enable RLS on {}", table).unwrap();
304        writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
305        writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
306        writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
307
308        sql
309    }
310
311    /// Generate migration SQL to remove RLS from a table.
312    pub fn migration_down_sql(&self, table: &str) -> String {
313        let mut sql = String::with_capacity(256);
314
315        writeln!(sql, "-- Disable RLS on {}", table).unwrap();
316        writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
317        writeln!(
318            sql,
319            "ALTER TABLE {} DISABLE ROW LEVEL SECURITY;",
320            quote_ident(table)
321        )
322        .unwrap();
323
324        sql
325    }
326}
327
328/// Builder for RLS manager.
329#[derive(Default)]
330pub struct RlsManagerBuilder {
331    config: RlsConfig,
332}
333
334impl RlsManagerBuilder {
335    /// Create a new builder.
336    pub fn new() -> Self {
337        Self::default()
338    }
339
340    /// Set the tenant column.
341    pub fn tenant_column(mut self, column: impl Into<String>) -> Self {
342        self.config.tenant_column = column.into();
343        self
344    }
345
346    /// Set the session variable.
347    pub fn session_variable(mut self, var: impl Into<String>) -> Self {
348        self.config.session_variable = var.into();
349        self
350    }
351
352    /// Set the application role.
353    pub fn application_role(mut self, role: impl Into<String>) -> Self {
354        self.config.application_role = Some(role.into());
355        self
356    }
357
358    /// Add tables.
359    pub fn tables<I, S>(mut self, tables: I) -> Self
360    where
361        I: IntoIterator<Item = S>,
362        S: Into<String>,
363    {
364        self.config.tables.extend(tables.into_iter().map(Into::into));
365        self
366    }
367
368    /// Exclude tables.
369    pub fn exclude<I, S>(mut self, tables: I) -> Self
370    where
371        I: IntoIterator<Item = S>,
372        S: Into<String>,
373    {
374        self.config
375            .excluded_tables
376            .extend(tables.into_iter().map(Into::into));
377        self
378    }
379
380    /// Set policy prefix.
381    pub fn policy_prefix(mut self, prefix: impl Into<String>) -> Self {
382        self.config.policy_prefix = prefix.into();
383        self
384    }
385
386    /// Build the manager.
387    pub fn build(self) -> RlsManager {
388        RlsManager::new(self.config)
389    }
390}
391
392/// Represents a custom RLS policy.
393#[derive(Debug, Clone)]
394pub struct RlsPolicy {
395    /// Policy name.
396    pub name: String,
397    /// Table the policy applies to.
398    pub table: String,
399    /// Command the policy applies to (ALL, SELECT, INSERT, UPDATE, DELETE).
400    pub command: PolicyCommand,
401    /// Role the policy applies to.
402    pub role: Option<String>,
403    /// USING expression (for SELECT, UPDATE, DELETE).
404    pub using_expr: Option<String>,
405    /// WITH CHECK expression (for INSERT, UPDATE).
406    pub with_check_expr: Option<String>,
407    /// Whether this is a permissive or restrictive policy.
408    pub permissive: bool,
409}
410
411/// SQL command that a policy applies to.
412#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum PolicyCommand {
414    All,
415    Select,
416    Insert,
417    Update,
418    Delete,
419}
420
421impl PolicyCommand {
422    fn as_str(&self) -> &'static str {
423        match self {
424            Self::All => "ALL",
425            Self::Select => "SELECT",
426            Self::Insert => "INSERT",
427            Self::Update => "UPDATE",
428            Self::Delete => "DELETE",
429        }
430    }
431}
432
433impl RlsPolicy {
434    /// Create a new policy.
435    pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
436        Self {
437            name: name.into(),
438            table: table.into(),
439            command: PolicyCommand::All,
440            role: None,
441            using_expr: None,
442            with_check_expr: None,
443            permissive: true,
444        }
445    }
446
447    /// Set the command.
448    pub fn command(mut self, cmd: PolicyCommand) -> Self {
449        self.command = cmd;
450        self
451    }
452
453    /// Set the role.
454    pub fn role(mut self, role: impl Into<String>) -> Self {
455        self.role = Some(role.into());
456        self
457    }
458
459    /// Set the USING expression.
460    pub fn using(mut self, expr: impl Into<String>) -> Self {
461        self.using_expr = Some(expr.into());
462        self
463    }
464
465    /// Set the WITH CHECK expression.
466    pub fn with_check(mut self, expr: impl Into<String>) -> Self {
467        self.with_check_expr = Some(expr.into());
468        self
469    }
470
471    /// Make this a restrictive policy.
472    pub fn restrictive(mut self) -> Self {
473        self.permissive = false;
474        self
475    }
476
477    /// Generate the CREATE POLICY SQL.
478    pub fn to_sql(&self) -> String {
479        let mut sql = String::with_capacity(256);
480
481        let policy_type = if self.permissive {
482            "PERMISSIVE"
483        } else {
484            "RESTRICTIVE"
485        };
486
487        write!(
488            sql,
489            "CREATE POLICY {} ON {}\n    AS {}\n    FOR {}\n    TO {}",
490            quote_ident(&self.name),
491            quote_ident(&self.table),
492            policy_type,
493            self.command.as_str(),
494            self.role.as_deref().unwrap_or("PUBLIC"),
495        )
496        .unwrap();
497
498        if let Some(ref using) = self.using_expr {
499            write!(sql, "\n    USING ({})", using).unwrap();
500        }
501
502        if let Some(ref check) = self.with_check_expr {
503            write!(sql, "\n    WITH CHECK ({})", check).unwrap();
504        }
505
506        sql.push(';');
507        sql
508    }
509}
510
511/// Quote a PostgreSQL identifier.
512fn quote_ident(name: &str) -> String {
513    // Simple quoting - in production, use proper escaping
514    if name
515        .chars()
516        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
517        && !name.is_empty()
518        && !name.chars().next().unwrap().is_ascii_digit()
519    {
520        name.to_string()
521    } else {
522        format!("\"{}\"", name.replace('"', "\"\""))
523    }
524}
525
526/// Context guard that sets tenant for the duration of its lifetime.
527///
528/// Uses PostgreSQL's SET LOCAL to ensure the setting only applies to
529/// the current transaction.
530pub struct TenantGuard {
531    reset_sql: String,
532}
533
534impl TenantGuard {
535    /// Create a new tenant guard.
536    ///
537    /// The caller should execute `set_sql()` before using the connection.
538    pub fn new(session_var: &str, tenant_id: &str) -> (Self, String) {
539        let set_sql = format!(
540            "SET LOCAL {} = '{}';",
541            session_var,
542            tenant_id.replace('\'', "''")
543        );
544        let reset_sql = format!("RESET {};", session_var);
545
546        (Self { reset_sql }, set_sql)
547    }
548
549    /// Get the SQL to reset the tenant context.
550    pub fn reset_sql(&self) -> &str {
551        &self.reset_sql
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn test_rls_config() {
561        let config = RlsConfig::new("org_id")
562            .with_session_variable("app.org")
563            .with_role("app_user")
564            .add_tables(["users", "orders", "products"]);
565
566        assert_eq!(config.tenant_column, "org_id");
567        assert_eq!(config.session_variable, "app.org");
568        assert!(config.tables.contains("users"));
569        assert!(config.tables.contains("orders"));
570    }
571
572    #[test]
573    fn test_set_tenant_sql() {
574        let manager = RlsManager::simple("tenant_id", "app.tenant");
575
576        assert_eq!(
577            manager.set_tenant_sql("tenant-123"),
578            "SET app.tenant = 'tenant-123';"
579        );
580
581        // Test SQL injection prevention
582        assert_eq!(
583            manager.set_tenant_sql("'; DROP TABLE users; --"),
584            "SET app.tenant = '''; DROP TABLE users; --';"
585        );
586    }
587
588    #[test]
589    fn test_create_policy_sql() {
590        let manager = RlsManager::simple("tenant_id", "app.current_tenant");
591
592        let sql = manager.create_policy_sql("users");
593        assert!(sql.contains("CREATE POLICY"));
594        assert!(sql.contains("tenant_id = current_setting('app.current_tenant')"));
595    }
596
597    #[test]
598    fn test_setup_sql() {
599        let config = RlsConfig::new("tenant_id")
600            .with_session_variable("app.tenant")
601            .add_tables(["users", "orders"]);
602
603        let manager = RlsManager::new(config);
604        let sql = manager.setup_sql();
605
606        assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
607        assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
608        assert!(sql.contains("CREATE POLICY"));
609    }
610
611    #[test]
612    fn test_custom_policy() {
613        let policy = RlsPolicy::new("owner_access", "documents")
614            .command(PolicyCommand::All)
615            .role("app_user")
616            .using("owner_id = current_user_id()")
617            .with_check("owner_id = current_user_id()");
618
619        let sql = policy.to_sql();
620        assert!(sql.contains("CREATE POLICY owner_access"));
621        assert!(sql.contains("FOR ALL"));
622        assert!(sql.contains("USING (owner_id = current_user_id())"));
623    }
624
625    #[test]
626    fn test_migration_sql() {
627        let manager = RlsManager::simple("tenant_id", "app.tenant");
628
629        let up = manager.migration_up_sql("invoices");
630        assert!(up.contains("ENABLE ROW LEVEL SECURITY"));
631        assert!(up.contains("CREATE POLICY"));
632
633        let down = manager.migration_down_sql("invoices");
634        assert!(down.contains("DROP POLICY"));
635        assert!(down.contains("DISABLE ROW LEVEL SECURITY"));
636    }
637
638    #[test]
639    fn test_quote_ident() {
640        assert_eq!(quote_ident("users"), "users");
641        assert_eq!(quote_ident("user-data"), "\"user-data\"");
642        assert_eq!(quote_ident("User"), "\"User\"");
643        assert_eq!(quote_ident("table\"name"), "\"table\"\"name\"");
644    }
645}
646
647