Skip to main content

datafusion_postgres/hooks/
permissions.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use pgwire::api::results::Response;
9use pgwire::api::ClientInfo;
10use pgwire::error::{PgWireError, PgWireResult};
11
12use crate::auth::AuthManager;
13use crate::hooks::HookClient;
14use crate::QueryHook;
15
16use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
17
18#[derive(Debug)]
19pub struct PermissionsHook {
20    auth_manager: Arc<AuthManager>,
21}
22
23impl PermissionsHook {
24    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
25        PermissionsHook { auth_manager }
26    }
27
28    /// Check if the current user has permission to execute a statement
29    async fn check_statement_permission<C>(
30        &self,
31        client: &C,
32        statement: &Statement,
33    ) -> PgWireResult<()>
34    where
35        C: ClientInfo + ?Sized,
36    {
37        // Get the username from client metadata
38        let username = client
39            .metadata()
40            .get("user")
41            .map(|s| s.as_str())
42            .unwrap_or("anonymous");
43
44        // Determine required permissions based on Statement type
45        let (required_permission, resource) = match statement {
46            Statement::Query(_) => (Permission::Select, ResourceType::All),
47            Statement::Insert(_) => (Permission::Insert, ResourceType::All),
48            Statement::Update { .. } => (Permission::Update, ResourceType::All),
49            Statement::Delete(_) => (Permission::Delete, ResourceType::All),
50            Statement::CreateTable { .. } | Statement::CreateView { .. } => {
51                (Permission::Create, ResourceType::All)
52            }
53            Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
54            Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
55            // For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
56            _ => return Ok(()),
57        };
58
59        // Check permission
60        let has_permission = self
61            .auth_manager
62            .check_permission(username, required_permission, resource)
63            .await;
64
65        if !has_permission {
66            return Err(PgWireError::UserError(Box::new(
67                pgwire::error::ErrorInfo::new(
68                    "ERROR".to_string(),
69                    "42501".to_string(), // insufficient_privilege
70                    format!("permission denied for user \"{username}\""),
71                ),
72            )));
73        }
74
75        Ok(())
76    }
77
78    /// Check if a statement should skip permission checks
79    fn should_skip_permission_check(statement: &Statement) -> bool {
80        matches!(
81            statement,
82            Statement::Set { .. }
83                | Statement::ShowVariable { .. }
84                | Statement::ShowStatus { .. }
85                | Statement::StartTransaction { .. }
86                | Statement::Commit { .. }
87                | Statement::Rollback { .. }
88                | Statement::Savepoint { .. }
89                | Statement::ReleaseSavepoint { .. }
90        )
91    }
92}
93
94#[async_trait]
95impl QueryHook for PermissionsHook {
96    /// called in simple query handler to return response directly
97    async fn handle_simple_query(
98        &self,
99        statement: &Statement,
100        _session_context: &SessionContext,
101        client: &mut dyn HookClient,
102    ) -> Option<PgWireResult<Response>> {
103        if Self::should_skip_permission_check(statement) {
104            return None;
105        }
106
107        // Check permissions for other statements
108        if let Err(e) = self.check_statement_permission(&*client, statement).await {
109            return Some(Err(e));
110        }
111
112        None
113    }
114
115    async fn handle_extended_parse_query(
116        &self,
117        _stmt: &Statement,
118        _session_context: &SessionContext,
119        _client: &(dyn ClientInfo + Send + Sync),
120    ) -> Option<PgWireResult<LogicalPlan>> {
121        None
122    }
123
124    async fn handle_extended_query(
125        &self,
126        statement: &Statement,
127        _logical_plan: &LogicalPlan,
128        _params: &ParamValues,
129        _session_context: &SessionContext,
130        client: &mut dyn HookClient,
131    ) -> Option<PgWireResult<Response>> {
132        if Self::should_skip_permission_check(statement) {
133            return None;
134        }
135
136        // Check permissions for other statements
137        if let Err(e) = self.check_statement_permission(&*client, statement).await {
138            return Some(Err(e));
139        }
140
141        None
142    }
143}