use crate::acl::{AclConfig, AclManager};
use crate::error::{Error, Result};
use async_trait::async_trait;
use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
pub struct PostgresAclManager {
db: DatabaseConnection,
_admin_user: String,
}
impl PostgresAclManager {
pub fn new(db: DatabaseConnection, admin_user: Option<String>) -> Self {
Self {
db,
_admin_user: admin_user.unwrap_or_else(|| "postgres".to_string()),
}
}
fn validate_sql_identifier(identifier: &str) -> Result<String> {
if identifier.is_empty() || identifier.len() > 63 {
return Err(Error::config("Identifier must be 1-63 characters"));
}
let first_char = identifier.chars().next().unwrap();
if !first_char.is_ascii_alphabetic() && first_char != '_' {
return Err(Error::config(
"Identifier must start with letter or underscore",
));
}
for c in identifier.chars() {
if !c.is_ascii_alphanumeric() && c != '_' {
return Err(Error::config(
"Identifier can only contain letters, numbers, and underscores",
));
}
}
Ok(format!("\"{}\"", identifier))
}
fn escape_sql_string(value: &str) -> String {
value.replace('\'', "''")
}
async fn enable_rls_for_table(&self, table_name: &str) -> Result<()> {
let table_name = Self::validate_sql_identifier(table_name)?;
let sql = format!("ALTER TABLE {} ENABLE ROW LEVEL SECURITY;", table_name);
self
.db
.execute(Statement::from_string(self.db.get_database_backend(), sql))
.await?;
Ok(())
}
async fn create_tenant_policy(&self, table_name: &str, tenant_user: &str) -> Result<()> {
let table_name_escaped = Self::validate_sql_identifier(table_name)?;
let tenant_user_escaped = Self::validate_sql_identifier(tenant_user)?;
let tenant_user_str = Self::escape_sql_string(tenant_user);
let policy_name = format!("tenant_isolation_policy_{}", tenant_user);
let policy_name_escaped = Self::validate_sql_identifier(&policy_name)?;
let drop_sql = format!(
"DROP POLICY IF EXISTS {} ON {};",
policy_name_escaped, table_name_escaped
);
let _ = self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
drop_sql,
))
.await;
let create_sql = format!(
"CREATE POLICY {} ON {} FOR ALL TO {} USING (tenant_id = '{}');",
policy_name_escaped, table_name_escaped, tenant_user_escaped, tenant_user_str
);
self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
create_sql,
))
.await?;
Ok(())
}
async fn enable_rls_for_all_tables(&self) -> Result<()> {
let tables = vec![
"tasks",
"queues",
"servers",
"workers",
"stats",
"schedulers",
"scheduler_entries",
"scheduler_events",
];
for table in tables {
self.enable_rls_for_table(table).await?;
}
Ok(())
}
async fn create_all_tenant_policies(&self, tenant_user: &str) -> Result<()> {
let tables = vec![
"tasks",
"queues",
"servers",
"workers",
"stats",
"schedulers",
"scheduler_entries",
"scheduler_events",
];
for table in tables {
self.create_tenant_policy(table, tenant_user).await?;
}
Ok(())
}
async fn drop_all_tenant_policies(&self, tenant_user: &str) -> Result<()> {
let tables = vec![
"tasks",
"queues",
"servers",
"workers",
"stats",
"schedulers",
"scheduler_entries",
"scheduler_events",
];
for table in tables {
let table_escaped = Self::validate_sql_identifier(table)?;
let policy_name = format!("tenant_isolation_policy_{}", tenant_user);
let policy_name_escaped = Self::validate_sql_identifier(&policy_name)?;
let drop_sql = format!(
"DROP POLICY IF EXISTS {} ON {};",
policy_name_escaped, table_escaped
);
let _ = self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
drop_sql,
))
.await;
}
Ok(())
}
}
#[async_trait]
impl AclManager for PostgresAclManager {
async fn create_tenant_user(&self, config: &AclConfig) -> Result<()> {
if !config.enabled {
return Err(Error::config("ACL feature is not enabled"));
}
let username = &config.node_config.username;
let password = &config.node_config.password;
let username_escaped = Self::validate_sql_identifier(username)?;
let password_escaped = Self::escape_sql_string(password);
let create_user_sql = format!(
"CREATE USER IF NOT EXISTS {} WITH PASSWORD '{}';",
username_escaped, password_escaped
);
let create_user_sql_alt = format!(
"CREATE USER {} WITH PASSWORD '{}';",
username_escaped, password_escaped
);
match self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
create_user_sql,
))
.await
{
Ok(_) => {}
Err(_) => {
match self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
create_user_sql_alt,
))
.await
{
Ok(_) => {}
Err(e) => {
let error_str = e.to_string().to_lowercase();
let is_duplicate_error = error_str.contains("already exists")
|| error_str.contains("duplicate")
|| error_str.contains("42710");
if !is_duplicate_error {
return Err(e.into());
}
}
}
}
}
let grant_sql = format!(
"GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO {};",
username_escaped
);
self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
grant_sql,
))
.await?;
let grant_seq_sql = format!(
"GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO {};",
username_escaped
);
self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
grant_seq_sql,
))
.await?;
self.enable_rls_for_all_tables().await?;
self.create_all_tenant_policies(username).await?;
Ok(())
}
async fn delete_tenant_user(&self, username: &str) -> Result<()> {
let username_escaped = Self::validate_sql_identifier(username)?;
self.drop_all_tenant_policies(username).await?;
let revoke_sql = format!(
"REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {};",
username_escaped
);
let _ = self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
revoke_sql,
))
.await;
let drop_sql = format!("DROP USER IF EXISTS {};", username_escaped);
self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
drop_sql,
))
.await?;
Ok(())
}
async fn list_tenant_users(&self) -> Result<Vec<String>> {
let query_sql =
"SELECT usename FROM pg_user WHERE usename NOT LIKE 'pg_%' AND usename != 'postgres';";
let result = self
.db
.query_all(Statement::from_string(
self.db.get_database_backend(),
query_sql,
))
.await?;
let mut users = Vec::new();
for row in result {
if let Ok(username) = row.try_get::<String>("", "usename") {
users.push(username);
}
}
Ok(users)
}
async fn tenant_user_exists(&self, username: &str) -> Result<bool> {
let username_escaped = Self::escape_sql_string(username);
let query_sql = format!(
"SELECT 1 FROM pg_user WHERE usename = '{}';",
username_escaped
);
let result = self
.db
.query_one(Statement::from_string(
self.db.get_database_backend(),
query_sql,
))
.await;
Ok(result.is_ok())
}
async fn update_tenant_user(&self, config: &AclConfig) -> Result<()> {
if !config.enabled {
return Err(Error::config("ACL feature is not enabled"));
}
let username = &config.node_config.username;
let password = &config.node_config.password;
if !self.tenant_user_exists(username).await? {
return Err(Error::config("Tenant user does not exist"));
}
let username_escaped = Self::validate_sql_identifier(username)?;
let password_escaped = Self::escape_sql_string(password);
let alter_sql = format!(
"ALTER USER {} WITH PASSWORD '{}';",
username_escaped, password_escaped
);
self
.db
.execute(Statement::from_string(
self.db.get_database_backend(),
alter_sql,
))
.await?;
self.drop_all_tenant_policies(username).await?;
self.create_all_tenant_policies(username).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_sql_identifier() {
assert!(PostgresAclManager::validate_sql_identifier("tenant_user1").is_ok());
assert!(PostgresAclManager::validate_sql_identifier("_user").is_ok());
assert!(PostgresAclManager::validate_sql_identifier("User123").is_ok());
assert!(PostgresAclManager::validate_sql_identifier("").is_err());
assert!(PostgresAclManager::validate_sql_identifier("123user").is_err());
assert!(PostgresAclManager::validate_sql_identifier("user-name").is_err());
assert!(PostgresAclManager::validate_sql_identifier("user name").is_err());
assert!(PostgresAclManager::validate_sql_identifier("user;DROP TABLE").is_err());
}
#[test]
fn test_escape_sql_string() {
assert_eq!(
PostgresAclManager::escape_sql_string("password"),
"password"
);
assert_eq!(
PostgresAclManager::escape_sql_string("pass'word"),
"pass''word"
);
assert_eq!(PostgresAclManager::escape_sql_string("it's"), "it''s");
}
}