use std::sync::Arc;
use async_trait::async_trait;
use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use pgwire::api::results::Response;
use pgwire::api::ClientInfo;
use pgwire::error::{PgWireError, PgWireResult};
use crate::auth::AuthManager;
use crate::hooks::HookClient;
use crate::QueryHook;
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
#[derive(Debug)]
pub struct PermissionsHook {
auth_manager: Arc<AuthManager>,
}
impl PermissionsHook {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
PermissionsHook { auth_manager }
}
async fn check_statement_permission<C>(
&self,
client: &C,
statement: &Statement,
) -> PgWireResult<()>
where
C: ClientInfo + ?Sized,
{
let username = client
.metadata()
.get("user")
.map(|s| s.as_str())
.unwrap_or("anonymous");
let (required_permission, resource) = match statement {
Statement::Query(_) => (Permission::Select, ResourceType::All),
Statement::Insert(_) => (Permission::Insert, ResourceType::All),
Statement::Update { .. } => (Permission::Update, ResourceType::All),
Statement::Delete(_) => (Permission::Delete, ResourceType::All),
Statement::CreateTable { .. } | Statement::CreateView { .. } => {
(Permission::Create, ResourceType::All)
}
Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
_ => return Ok(()),
};
let has_permission = self
.auth_manager
.check_permission(username, required_permission, resource)
.await;
if !has_permission {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42501".to_string(), format!("permission denied for user \"{username}\""),
),
)));
}
Ok(())
}
fn should_skip_permission_check(statement: &Statement) -> bool {
matches!(
statement,
Statement::Set { .. }
| Statement::ShowVariable { .. }
| Statement::ShowStatus { .. }
| Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
| Statement::Savepoint { .. }
| Statement::ReleaseSavepoint { .. }
)
}
}
#[async_trait]
impl QueryHook for PermissionsHook {
async fn handle_simple_query(
&self,
statement: &Statement,
_session_context: &SessionContext,
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
if Self::should_skip_permission_check(statement) {
return None;
}
if let Err(e) = self.check_statement_permission(&*client, statement).await {
return Some(Err(e));
}
None
}
async fn handle_extended_parse_query(
&self,
_stmt: &Statement,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
None
}
async fn handle_extended_query(
&self,
statement: &Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
if Self::should_skip_permission_check(statement) {
return None;
}
if let Err(e) = self.check_statement_permission(&*client, statement).await {
return Some(Err(e));
}
None
}
}