use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use pgwire::api::auth::{AuthSource, LoginInfo, Password};
use pgwire::error::{PgWireError, PgWireResult};
use tokio::sync::RwLock;
use datafusion_pg_catalog::pg_catalog::context::*;
#[derive(Debug, Clone)]
pub struct AuthManager {
users: Arc<RwLock<HashMap<String, User>>>,
roles: Arc<RwLock<HashMap<String, Role>>>,
}
impl Default for AuthManager {
fn default() -> Self {
Self::new()
}
}
impl AuthManager {
pub fn new() -> Self {
let mut users = HashMap::new();
let postgres_user = User {
username: "postgres".to_string(),
password_hash: "".to_string(), roles: vec!["postgres".to_string()],
is_superuser: true,
can_login: true,
connection_limit: None,
};
users.insert(postgres_user.username.clone(), postgres_user);
let mut roles = HashMap::new();
let postgres_role = Role {
name: "postgres".to_string(),
is_superuser: true,
can_login: true,
can_create_db: true,
can_create_role: true,
can_create_user: true,
can_replication: true,
grants: vec![Grant {
permission: Permission::All,
resource: ResourceType::All,
granted_by: "system".to_string(),
with_grant_option: true,
}],
inherited_roles: vec![],
};
roles.insert(postgres_role.name.clone(), postgres_role);
AuthManager {
users: Arc::new(RwLock::new(users)),
roles: Arc::new(RwLock::new(roles)),
}
}
pub async fn add_user(&self, user: User) -> PgWireResult<()> {
let mut users = self.users.write().await;
users.insert(user.username.clone(), user);
Ok(())
}
pub async fn add_role(&self, role: Role) -> PgWireResult<()> {
let mut roles = self.roles.write().await;
roles.insert(role.name.clone(), role);
Ok(())
}
pub async fn authenticate(&self, username: &str, password: &str) -> PgWireResult<bool> {
let users = self.users.read().await;
if let Some(user) = users.get(username) {
if !user.can_login {
return Ok(false);
}
if user.password_hash.is_empty() || password == user.password_hash {
return Ok(true);
}
}
Ok(false)
}
pub async fn get_user(&self, username: &str) -> Option<User> {
let users = self.users.read().await;
users.get(username).cloned()
}
pub async fn get_role(&self, role_name: &str) -> Option<Role> {
let roles = self.roles.read().await;
roles.get(role_name).cloned()
}
pub async fn user_has_role(&self, username: &str, role_name: &str) -> bool {
if let Some(user) = self.get_user(username).await {
return user.roles.contains(&role_name.to_string()) || user.is_superuser;
}
false
}
pub async fn list_users(&self) -> Vec<String> {
let users = self.users.read().await;
users.keys().cloned().collect()
}
pub async fn list_roles(&self) -> Vec<String> {
let roles = self.roles.read().await;
roles.keys().cloned().collect()
}
pub async fn grant_permission(
&self,
role_name: &str,
permission: Permission,
resource: ResourceType,
granted_by: &str,
with_grant_option: bool,
) -> PgWireResult<()> {
let mut roles = self.roles.write().await;
if let Some(role) = roles.get_mut(role_name) {
let grant = Grant {
permission,
resource,
granted_by: granted_by.to_string(),
with_grant_option,
};
role.grants.push(grant);
Ok(())
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42704".to_string(), format!("role \"{role_name}\" does not exist"),
),
)))
}
}
pub async fn revoke_permission(
&self,
role_name: &str,
permission: Permission,
resource: ResourceType,
) -> PgWireResult<()> {
let mut roles = self.roles.write().await;
if let Some(role) = roles.get_mut(role_name) {
role.grants
.retain(|grant| !(grant.permission == permission && grant.resource == resource));
Ok(())
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42704".to_string(), format!("role \"{role_name}\" does not exist"),
),
)))
}
}
pub async fn check_permission(
&self,
username: &str,
permission: Permission,
resource: ResourceType,
) -> bool {
if let Some(user) = self.get_user(username).await {
if user.is_superuser {
return true;
}
for role_name in &user.roles {
if let Some(role) = self.get_role(role_name).await {
if role.is_superuser {
return true;
}
for grant in &role.grants {
if self.permission_matches(&grant.permission, &permission)
&& self.resource_matches(&grant.resource, &resource)
{
return true;
}
}
for inherited_role in &role.inherited_roles {
if self
.check_role_permission(inherited_role, &permission, &resource)
.await
{
return true;
}
}
}
}
}
false
}
fn check_role_permission<'a>(
&'a self,
role_name: &'a str,
permission: &'a Permission,
resource: &'a ResourceType,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send + 'a>> {
Box::pin(async move {
if let Some(role) = self.get_role(role_name).await {
if role.is_superuser {
return true;
}
for grant in &role.grants {
if self.permission_matches(&grant.permission, permission)
&& self.resource_matches(&grant.resource, resource)
{
return true;
}
}
for inherited_role in &role.inherited_roles {
if self
.check_role_permission(inherited_role, permission, resource)
.await
{
return true;
}
}
}
false
})
}
fn permission_matches(&self, grant_permission: &Permission, requested: &Permission) -> bool {
grant_permission == requested || matches!(grant_permission, Permission::All)
}
fn resource_matches(&self, grant_resource: &ResourceType, requested: &ResourceType) -> bool {
match (grant_resource, requested) {
(a, b) if a == b => true,
(ResourceType::All, _) => true,
(ResourceType::Schema(schema), ResourceType::Table(table)) => {
table.starts_with(&format!("{schema}."))
}
_ => false,
}
}
pub async fn add_role_inheritance(
&self,
child_role: &str,
parent_role: &str,
) -> PgWireResult<()> {
let mut roles = self.roles.write().await;
if let Some(child) = roles.get_mut(child_role) {
if !child.inherited_roles.contains(&parent_role.to_string()) {
child.inherited_roles.push(parent_role.to_string());
}
Ok(())
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42704".to_string(), format!("role \"{child_role}\" does not exist"),
),
)))
}
}
pub async fn remove_role_inheritance(
&self,
child_role: &str,
parent_role: &str,
) -> PgWireResult<()> {
let mut roles = self.roles.write().await;
if let Some(child) = roles.get_mut(child_role) {
child.inherited_roles.retain(|role| role != parent_role);
Ok(())
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42704".to_string(), format!("role \"{child_role}\" does not exist"),
),
)))
}
}
pub async fn create_role(&self, config: RoleConfig) -> PgWireResult<()> {
let role = Role {
name: config.name.clone(),
is_superuser: config.is_superuser,
can_login: config.can_login,
can_create_db: config.can_create_db,
can_create_role: config.can_create_role,
can_create_user: config.can_create_user,
can_replication: config.can_replication,
grants: vec![],
inherited_roles: vec![],
};
self.add_role(role).await
}
pub async fn create_predefined_roles(&self) -> PgWireResult<()> {
self.create_role(RoleConfig {
name: "readonly".to_string(),
is_superuser: false,
can_login: false,
can_create_db: false,
can_create_role: false,
can_create_user: false,
can_replication: false,
})
.await?;
self.grant_permission(
"readonly",
Permission::Select,
ResourceType::All,
"system",
false,
)
.await?;
self.create_role(RoleConfig {
name: "readwrite".to_string(),
is_superuser: false,
can_login: false,
can_create_db: false,
can_create_role: false,
can_create_user: false,
can_replication: false,
})
.await?;
self.grant_permission(
"readwrite",
Permission::Select,
ResourceType::All,
"system",
false,
)
.await?;
self.grant_permission(
"readwrite",
Permission::Insert,
ResourceType::All,
"system",
false,
)
.await?;
self.grant_permission(
"readwrite",
Permission::Update,
ResourceType::All,
"system",
false,
)
.await?;
self.grant_permission(
"readwrite",
Permission::Delete,
ResourceType::All,
"system",
false,
)
.await?;
self.create_role(RoleConfig {
name: "dbadmin".to_string(),
is_superuser: false,
can_login: true,
can_create_db: true,
can_create_role: false,
can_create_user: false,
can_replication: false,
})
.await?;
self.grant_permission(
"dbadmin",
Permission::All,
ResourceType::All,
"system",
true,
)
.await?;
Ok(())
}
}
#[async_trait]
impl PgCatalogContextProvider for AuthManager {
async fn roles(&self) -> Vec<String> {
self.list_roles().await
}
async fn role(&self, name: &str) -> Option<Role> {
self.get_role(name).await
}
}
#[derive(Clone, Debug)]
pub struct DfAuthSource {
pub auth_manager: Arc<AuthManager>,
}
impl DfAuthSource {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
DfAuthSource { auth_manager }
}
}
#[async_trait]
impl AuthSource for DfAuthSource {
async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
if let Some(username) = login.user() {
if let Some(user) = self.auth_manager.get_user(username).await {
if user.can_login {
Ok(Password::new(None, user.password_hash.into_bytes()))
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"FATAL".to_string(),
"28000".to_string(), format!("User \"{username}\" is not allowed to login"),
),
)))
}
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"FATAL".to_string(),
"28P01".to_string(), format!("password authentication failed for user \"{username}\""),
),
)))
}
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"FATAL".to_string(),
"28P01".to_string(), "No username provided in login request".to_string(),
),
)))
}
}
}
#[derive(Debug)]
pub struct SimpleAuthSource {
auth_manager: Arc<AuthManager>,
}
impl SimpleAuthSource {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
SimpleAuthSource { auth_manager }
}
}
#[async_trait]
impl AuthSource for SimpleAuthSource {
async fn get_password(&self, login: &LoginInfo) -> PgWireResult<Password> {
let username = login.user().unwrap_or("anonymous");
if let Some(user) = self.auth_manager.get_user(username).await {
if user.can_login {
return Ok(Password::new(None, vec![]));
}
}
if username == "postgres" {
return Ok(Password::new(None, vec![]));
}
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"FATAL".to_string(),
"28P01".to_string(), format!("password authentication failed for user \"{username}\""),
),
)))
}
}
pub fn create_auth_source(auth_manager: Arc<AuthManager>) -> SimpleAuthSource {
SimpleAuthSource::new(auth_manager)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_auth_manager_creation() {
let auth_manager = AuthManager::new();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let users = auth_manager.list_users().await;
assert!(users.contains(&"postgres".to_string()));
}
#[tokio::test]
async fn test_user_authentication() {
let auth_manager = AuthManager::new();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(auth_manager.authenticate("postgres", "").await.unwrap());
assert!(!auth_manager
.authenticate("nonexistent", "password")
.await
.unwrap());
}
#[tokio::test]
async fn test_role_management() {
let auth_manager = AuthManager::new();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(auth_manager.user_has_role("postgres", "postgres").await);
assert!(auth_manager.user_has_role("postgres", "any_role").await); }
}