datafusion_postgres/hooks/
permissions.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
9use pgwire::api::results::Response;
10use pgwire::api::ClientInfo;
11use pgwire::error::{PgWireError, PgWireResult};
12
13use crate::auth::AuthManager;
14use crate::QueryHook;
15
16#[derive(Debug)]
17pub struct PermissionsHook {
18    auth_manager: Arc<AuthManager>,
19}
20
21impl PermissionsHook {
22    pub fn new(auth_manager: Arc<AuthManager>) -> Self {
23        PermissionsHook { auth_manager }
24    }
25
26    /// Check if the current user has permission to execute a statement
27    async fn check_statement_permission<C>(
28        &self,
29        client: &C,
30        statement: &Statement,
31    ) -> PgWireResult<()>
32    where
33        C: ClientInfo + ?Sized,
34    {
35        // Get the username from client metadata
36        let username = client
37            .metadata()
38            .get("user")
39            .map(|s| s.as_str())
40            .unwrap_or("anonymous");
41
42        // Determine required permissions based on Statement type
43        let (required_permission, resource) = match statement {
44            Statement::Query(_) => (Permission::Select, ResourceType::All),
45            Statement::Insert(_) => (Permission::Insert, ResourceType::All),
46            Statement::Update { .. } => (Permission::Update, ResourceType::All),
47            Statement::Delete(_) => (Permission::Delete, ResourceType::All),
48            Statement::CreateTable { .. } | Statement::CreateView { .. } => {
49                (Permission::Create, ResourceType::All)
50            }
51            Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
52            Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
53            // For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
54            _ => return Ok(()),
55        };
56
57        // Check permission
58        let has_permission = self
59            .auth_manager
60            .check_permission(username, required_permission, resource)
61            .await;
62
63        if !has_permission {
64            return Err(PgWireError::UserError(Box::new(
65                pgwire::error::ErrorInfo::new(
66                    "ERROR".to_string(),
67                    "42501".to_string(), // insufficient_privilege
68                    format!("permission denied for user \"{username}\""),
69                ),
70            )));
71        }
72
73        Ok(())
74    }
75
76    /// Check if a statement should skip permission checks
77    fn should_skip_permission_check(statement: &Statement) -> bool {
78        matches!(
79            statement,
80            Statement::Set { .. }
81                | Statement::ShowVariable { .. }
82                | Statement::ShowStatus { .. }
83                | Statement::StartTransaction { .. }
84                | Statement::Commit { .. }
85                | Statement::Rollback { .. }
86                | Statement::Savepoint { .. }
87                | Statement::ReleaseSavepoint { .. }
88        )
89    }
90}
91
92#[async_trait]
93impl QueryHook for PermissionsHook {
94    /// called in simple query handler to return response directly
95    async fn handle_simple_query(
96        &self,
97        statement: &Statement,
98        _session_context: &SessionContext,
99        client: &mut (dyn ClientInfo + Send + Sync),
100    ) -> Option<PgWireResult<Response>> {
101        if Self::should_skip_permission_check(statement) {
102            return None;
103        }
104
105        // Check permissions for other statements
106        if let Err(e) = self.check_statement_permission(&*client, statement).await {
107            return Some(Err(e));
108        }
109
110        None
111    }
112
113    async fn handle_extended_parse_query(
114        &self,
115        _stmt: &Statement,
116        _session_context: &SessionContext,
117        _client: &(dyn ClientInfo + Send + Sync),
118    ) -> Option<PgWireResult<LogicalPlan>> {
119        None
120    }
121
122    async fn handle_extended_query(
123        &self,
124        statement: &Statement,
125        _logical_plan: &LogicalPlan,
126        _params: &ParamValues,
127        _session_context: &SessionContext,
128        client: &mut (dyn ClientInfo + Send + Sync),
129    ) -> Option<PgWireResult<Response>> {
130        if Self::should_skip_permission_check(statement) {
131            return None;
132        }
133
134        // Check permissions for other statements
135        if let Err(e) = self.check_statement_permission(&*client, statement).await {
136            return Some(Err(e));
137        }
138
139        None
140    }
141}