use std::collections::HashSet;
use std::fmt::Write;
#[derive(Debug, Clone)]
pub struct RlsConfig {
pub tenant_column: String,
pub session_variable: String,
pub application_role: Option<String>,
pub tables: HashSet<String>,
pub excluded_tables: HashSet<String>,
pub allow_bypass: bool,
pub policy_prefix: String,
}
impl Default for RlsConfig {
fn default() -> Self {
Self {
tenant_column: "tenant_id".to_string(),
session_variable: "app.current_tenant".to_string(),
application_role: None,
tables: HashSet::new(),
excluded_tables: HashSet::new(),
allow_bypass: true,
policy_prefix: "tenant_isolation".to_string(),
}
}
}
impl RlsConfig {
pub fn new(tenant_column: impl Into<String>) -> Self {
Self {
tenant_column: tenant_column.into(),
..Default::default()
}
}
pub fn with_session_variable(mut self, var: impl Into<String>) -> Self {
self.session_variable = var.into();
self
}
pub fn with_role(mut self, role: impl Into<String>) -> Self {
self.application_role = Some(role.into());
self
}
pub fn add_table(mut self, table: impl Into<String>) -> Self {
self.tables.insert(table.into());
self
}
pub fn add_tables<I, S>(mut self, tables: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.tables.extend(tables.into_iter().map(Into::into));
self
}
pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
self.excluded_tables.insert(table.into());
self
}
pub fn without_bypass(mut self) -> Self {
self.allow_bypass = false;
self
}
pub fn with_policy_prefix(mut self, prefix: impl Into<String>) -> Self {
self.policy_prefix = prefix.into();
self
}
}
#[derive(Debug, Clone)]
pub struct RlsManager {
config: RlsConfig,
}
impl RlsManager {
pub fn new(config: RlsConfig) -> Self {
Self { config }
}
pub fn simple(tenant_column: impl Into<String>, session_var: impl Into<String>) -> Self {
Self::new(RlsConfig::new(tenant_column).with_session_variable(session_var))
}
pub fn config(&self) -> &RlsConfig {
&self.config
}
pub fn enable_rls_sql(&self, table: &str) -> String {
format!(
"ALTER TABLE {} ENABLE ROW LEVEL SECURITY;",
quote_ident(table)
)
}
pub fn force_rls_sql(&self, table: &str) -> String {
format!(
"ALTER TABLE {} FORCE ROW LEVEL SECURITY;",
quote_ident(table)
)
}
pub fn create_policy_sql(&self, table: &str) -> String {
let policy_name = format!("{}_{}", self.config.policy_prefix, table);
let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
format!(
r#"CREATE POLICY {} ON {}
AS PERMISSIVE
FOR ALL
TO {}
USING ({} = current_setting('{}')::text)
WITH CHECK ({} = current_setting('{}')::text);"#,
quote_ident(&policy_name),
quote_ident(table),
role,
quote_ident(&self.config.tenant_column),
self.config.session_variable,
quote_ident(&self.config.tenant_column),
self.config.session_variable,
)
}
pub fn create_uuid_policy_sql(&self, table: &str) -> String {
let policy_name = format!("{}_{}", self.config.policy_prefix, table);
let role = self.config.application_role.as_deref().unwrap_or("PUBLIC");
format!(
r#"CREATE POLICY {} ON {}
AS PERMISSIVE
FOR ALL
TO {}
USING ({} = current_setting('{}')::uuid)
WITH CHECK ({} = current_setting('{}')::uuid);"#,
quote_ident(&policy_name),
quote_ident(table),
role,
quote_ident(&self.config.tenant_column),
self.config.session_variable,
quote_ident(&self.config.tenant_column),
self.config.session_variable,
)
}
pub fn drop_policy_sql(&self, table: &str) -> String {
let policy_name = format!("{}_{}", self.config.policy_prefix, table);
format!(
"DROP POLICY IF EXISTS {} ON {};",
quote_ident(&policy_name),
quote_ident(table)
)
}
pub fn set_tenant_sql(&self, tenant_id: &str) -> String {
format!(
"SET {} = '{}';",
self.config.session_variable,
tenant_id.replace('\'', "''")
)
}
pub fn set_tenant_local_sql(&self, tenant_id: &str) -> String {
format!(
"SET LOCAL {} = '{}';",
self.config.session_variable,
tenant_id.replace('\'', "''")
)
}
pub fn reset_tenant_sql(&self) -> String {
format!("RESET {};", self.config.session_variable)
}
pub fn current_tenant_sql(&self) -> String {
format!(
"SELECT current_setting('{}', true);",
self.config.session_variable
)
}
pub fn setup_sql(&self) -> String {
let mut sql = String::with_capacity(4096);
writeln!(sql, "-- Prax Multi-Tenant RLS Setup").unwrap();
writeln!(
sql,
"-- Generated for column: {}",
self.config.tenant_column
)
.unwrap();
writeln!(sql, "-- Session variable: {}", self.config.session_variable).unwrap();
writeln!(sql).unwrap();
if self.config.allow_bypass
&& let Some(ref role) = self.config.application_role
{
writeln!(sql, "-- Admin role with BYPASSRLS").unwrap();
writeln!(sql, "DO $$").unwrap();
writeln!(sql, "BEGIN").unwrap();
writeln!(sql, " CREATE ROLE {}_admin WITH BYPASSRLS;", role).unwrap();
writeln!(sql, "EXCEPTION WHEN duplicate_object THEN NULL;").unwrap();
writeln!(sql, "END $$;").unwrap();
writeln!(sql).unwrap();
}
for table in &self.config.tables {
if self.config.excluded_tables.contains(table) {
continue;
}
writeln!(sql, "-- Table: {}", table).unwrap();
writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
writeln!(sql).unwrap();
}
sql
}
pub fn migration_up_sql(&self, table: &str) -> String {
let mut sql = String::with_capacity(512);
writeln!(sql, "-- Enable RLS on {}", table).unwrap();
writeln!(sql, "{}", self.enable_rls_sql(table)).unwrap();
writeln!(sql, "{}", self.force_rls_sql(table)).unwrap();
writeln!(sql, "{}", self.create_policy_sql(table)).unwrap();
sql
}
pub fn migration_down_sql(&self, table: &str) -> String {
let mut sql = String::with_capacity(256);
writeln!(sql, "-- Disable RLS on {}", table).unwrap();
writeln!(sql, "{}", self.drop_policy_sql(table)).unwrap();
writeln!(
sql,
"ALTER TABLE {} DISABLE ROW LEVEL SECURITY;",
quote_ident(table)
)
.unwrap();
sql
}
}
#[derive(Default)]
pub struct RlsManagerBuilder {
config: RlsConfig,
}
impl RlsManagerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn tenant_column(mut self, column: impl Into<String>) -> Self {
self.config.tenant_column = column.into();
self
}
pub fn session_variable(mut self, var: impl Into<String>) -> Self {
self.config.session_variable = var.into();
self
}
pub fn application_role(mut self, role: impl Into<String>) -> Self {
self.config.application_role = Some(role.into());
self
}
pub fn tables<I, S>(mut self, tables: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.config
.tables
.extend(tables.into_iter().map(Into::into));
self
}
pub fn exclude<I, S>(mut self, tables: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.config
.excluded_tables
.extend(tables.into_iter().map(Into::into));
self
}
pub fn policy_prefix(mut self, prefix: impl Into<String>) -> Self {
self.config.policy_prefix = prefix.into();
self
}
pub fn build(self) -> RlsManager {
RlsManager::new(self.config)
}
}
#[derive(Debug, Clone)]
pub struct RlsPolicy {
pub name: String,
pub table: String,
pub command: PolicyCommand,
pub role: Option<String>,
pub using_expr: Option<String>,
pub with_check_expr: Option<String>,
pub permissive: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolicyCommand {
All,
Select,
Insert,
Update,
Delete,
}
impl PolicyCommand {
fn as_str(&self) -> &'static str {
match self {
Self::All => "ALL",
Self::Select => "SELECT",
Self::Insert => "INSERT",
Self::Update => "UPDATE",
Self::Delete => "DELETE",
}
}
}
impl RlsPolicy {
pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
Self {
name: name.into(),
table: table.into(),
command: PolicyCommand::All,
role: None,
using_expr: None,
with_check_expr: None,
permissive: true,
}
}
pub fn command(mut self, cmd: PolicyCommand) -> Self {
self.command = cmd;
self
}
pub fn role(mut self, role: impl Into<String>) -> Self {
self.role = Some(role.into());
self
}
pub fn using(mut self, expr: impl Into<String>) -> Self {
self.using_expr = Some(expr.into());
self
}
pub fn with_check(mut self, expr: impl Into<String>) -> Self {
self.with_check_expr = Some(expr.into());
self
}
pub fn restrictive(mut self) -> Self {
self.permissive = false;
self
}
pub fn to_sql(&self) -> String {
let mut sql = String::with_capacity(256);
let policy_type = if self.permissive {
"PERMISSIVE"
} else {
"RESTRICTIVE"
};
write!(
sql,
"CREATE POLICY {} ON {}\n AS {}\n FOR {}\n TO {}",
quote_ident(&self.name),
quote_ident(&self.table),
policy_type,
self.command.as_str(),
self.role.as_deref().unwrap_or("PUBLIC"),
)
.unwrap();
if let Some(ref using) = self.using_expr {
write!(sql, "\n USING ({})", using).unwrap();
}
if let Some(ref check) = self.with_check_expr {
write!(sql, "\n WITH CHECK ({})", check).unwrap();
}
sql.push(';');
sql
}
}
fn quote_ident(name: &str) -> String {
if name
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
&& !name.is_empty()
&& !name.chars().next().unwrap().is_ascii_digit()
{
name.to_string()
} else {
format!("\"{}\"", name.replace('"', "\"\""))
}
}
pub struct TenantGuard {
reset_sql: String,
}
impl TenantGuard {
pub fn new(session_var: &str, tenant_id: &str) -> (Self, String) {
let set_sql = format!(
"SET LOCAL {} = '{}';",
session_var,
tenant_id.replace('\'', "''")
);
let reset_sql = format!("RESET {};", session_var);
(Self { reset_sql }, set_sql)
}
pub fn reset_sql(&self) -> &str {
&self.reset_sql
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rls_config() {
let config = RlsConfig::new("org_id")
.with_session_variable("app.org")
.with_role("app_user")
.add_tables(["users", "orders", "products"]);
assert_eq!(config.tenant_column, "org_id");
assert_eq!(config.session_variable, "app.org");
assert!(config.tables.contains("users"));
assert!(config.tables.contains("orders"));
}
#[test]
fn test_set_tenant_sql() {
let manager = RlsManager::simple("tenant_id", "app.tenant");
assert_eq!(
manager.set_tenant_sql("tenant-123"),
"SET app.tenant = 'tenant-123';"
);
assert_eq!(
manager.set_tenant_sql("'; DROP TABLE users; --"),
"SET app.tenant = '''; DROP TABLE users; --';"
);
}
#[test]
fn test_create_policy_sql() {
let manager = RlsManager::simple("tenant_id", "app.current_tenant");
let sql = manager.create_policy_sql("users");
assert!(sql.contains("CREATE POLICY"));
assert!(sql.contains("tenant_id = current_setting('app.current_tenant')"));
}
#[test]
fn test_setup_sql() {
let config = RlsConfig::new("tenant_id")
.with_session_variable("app.tenant")
.add_tables(["users", "orders"]);
let manager = RlsManager::new(config);
let sql = manager.setup_sql();
assert!(sql.contains("ENABLE ROW LEVEL SECURITY"));
assert!(sql.contains("FORCE ROW LEVEL SECURITY"));
assert!(sql.contains("CREATE POLICY"));
}
#[test]
fn test_custom_policy() {
let policy = RlsPolicy::new("owner_access", "documents")
.command(PolicyCommand::All)
.role("app_user")
.using("owner_id = current_user_id()")
.with_check("owner_id = current_user_id()");
let sql = policy.to_sql();
assert!(sql.contains("CREATE POLICY owner_access"));
assert!(sql.contains("FOR ALL"));
assert!(sql.contains("USING (owner_id = current_user_id())"));
}
#[test]
fn test_migration_sql() {
let manager = RlsManager::simple("tenant_id", "app.tenant");
let up = manager.migration_up_sql("invoices");
assert!(up.contains("ENABLE ROW LEVEL SECURITY"));
assert!(up.contains("CREATE POLICY"));
let down = manager.migration_down_sql("invoices");
assert!(down.contains("DROP POLICY"));
assert!(down.contains("DISABLE ROW LEVEL SECURITY"));
}
#[test]
fn test_quote_ident() {
assert_eq!(quote_ident("users"), "users");
assert_eq!(quote_ident("user-data"), "\"user-data\"");
assert_eq!(quote_ident("User"), "\"User\"");
assert_eq!(quote_ident("table\"name"), "\"table\"\"name\"");
}
}