datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::auth::{AuthManager, Permission, ResourceType};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::DataType;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use pgwire::api::auth::noop::NoopStartupHandler;
10use pgwire::api::auth::StartupHandler;
11use pgwire::api::portal::{Format, Portal};
12use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
13use pgwire::api::results::{
14    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
15    Response, Tag,
16};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::stmt::StoredStatement;
19use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20use pgwire::error::{PgWireError, PgWireResult};
21use tokio::sync::Mutex;
22
23use arrow_pg::datatypes::df;
24use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum TransactionState {
28    None,
29    Active,
30    Failed,
31}
32
33/// Simple startup handler that does no authentication
34/// For production, use DfAuthSource with proper pgwire authentication handlers
35pub struct SimpleStartupHandler;
36
37#[async_trait::async_trait]
38impl NoopStartupHandler for SimpleStartupHandler {}
39
40pub struct HandlerFactory {
41    pub session_service: Arc<DfSessionService>,
42}
43
44impl HandlerFactory {
45    pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
46        let session_service =
47            Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
48        HandlerFactory { session_service }
49    }
50}
51
52impl PgWireServerHandlers for HandlerFactory {
53    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
54        self.session_service.clone()
55    }
56
57    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
58        self.session_service.clone()
59    }
60
61    fn startup_handler(&self) -> Arc<impl StartupHandler> {
62        Arc::new(SimpleStartupHandler)
63    }
64}
65
66/// The pgwire handler backed by a datafusion `SessionContext`
67pub struct DfSessionService {
68    session_context: Arc<SessionContext>,
69    parser: Arc<Parser>,
70    timezone: Arc<Mutex<String>>,
71    transaction_state: Arc<Mutex<TransactionState>>,
72    auth_manager: Arc<AuthManager>,
73}
74
75impl DfSessionService {
76    pub fn new(
77        session_context: Arc<SessionContext>,
78        auth_manager: Arc<AuthManager>,
79    ) -> DfSessionService {
80        let parser = Arc::new(Parser {
81            session_context: session_context.clone(),
82        });
83        DfSessionService {
84            session_context,
85            parser,
86            timezone: Arc::new(Mutex::new("UTC".to_string())),
87            transaction_state: Arc::new(Mutex::new(TransactionState::None)),
88            auth_manager,
89        }
90    }
91
92    /// Check if the current user has permission to execute a query
93    async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
94    where
95        C: ClientInfo,
96    {
97        // Get the username from client metadata
98        let username = client
99            .metadata()
100            .get("user")
101            .map(|s| s.as_str())
102            .unwrap_or("anonymous");
103
104        // Parse query to determine required permissions
105        let query_lower = query.to_lowercase();
106        let query_trimmed = query_lower.trim();
107
108        let (required_permission, resource) = if query_trimmed.starts_with("select") {
109            (Permission::Select, self.extract_table_from_query(query))
110        } else if query_trimmed.starts_with("insert") {
111            (Permission::Insert, self.extract_table_from_query(query))
112        } else if query_trimmed.starts_with("update") {
113            (Permission::Update, self.extract_table_from_query(query))
114        } else if query_trimmed.starts_with("delete") {
115            (Permission::Delete, self.extract_table_from_query(query))
116        } else if query_trimmed.starts_with("create table")
117            || query_trimmed.starts_with("create view")
118        {
119            (Permission::Create, ResourceType::All)
120        } else if query_trimmed.starts_with("drop") {
121            (Permission::Drop, self.extract_table_from_query(query))
122        } else if query_trimmed.starts_with("alter") {
123            (Permission::Alter, self.extract_table_from_query(query))
124        } else {
125            // For other queries (SHOW, EXPLAIN, etc.), allow all users
126            return Ok(());
127        };
128
129        // Check permission
130        let has_permission = self
131            .auth_manager
132            .check_permission(username, required_permission, resource)
133            .await;
134
135        if !has_permission {
136            return Err(PgWireError::UserError(Box::new(
137                pgwire::error::ErrorInfo::new(
138                    "ERROR".to_string(),
139                    "42501".to_string(), // insufficient_privilege
140                    format!("permission denied for user \"{username}\""),
141                ),
142            )));
143        }
144
145        Ok(())
146    }
147
148    /// Extract table name from query (simplified parsing)
149    fn extract_table_from_query(&self, query: &str) -> ResourceType {
150        let words: Vec<&str> = query.split_whitespace().collect();
151
152        // Simple heuristic to find table names
153        for (i, word) in words.iter().enumerate() {
154            let word_lower = word.to_lowercase();
155            if (word_lower == "from" || word_lower == "into" || word_lower == "table")
156                && i + 1 < words.len()
157            {
158                let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
159                return ResourceType::Table(table_name.to_string());
160            }
161        }
162
163        // If we can't determine the table, default to All
164        ResourceType::All
165    }
166
167    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
168        let fields = vec![FieldInfo::new(
169            name.to_string(),
170            None,
171            None,
172            Type::VARCHAR,
173            FieldFormat::Text,
174        )];
175
176        let row = {
177            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
178            encoder.encode_field(&Some(value))?;
179            encoder.finish()
180        };
181
182        let row_stream = futures::stream::once(async move { row });
183        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
184    }
185
186    async fn try_respond_set_statements<'a>(
187        &self,
188        query_lower: &str,
189    ) -> PgWireResult<Option<Response<'a>>> {
190        if query_lower.starts_with("set") {
191            if query_lower.starts_with("set time zone") {
192                let parts: Vec<&str> = query_lower.split_whitespace().collect();
193                if parts.len() >= 4 {
194                    let tz = parts[3].trim_matches('"');
195                    let mut timezone = self.timezone.lock().await;
196                    *timezone = tz.to_string();
197                    Ok(Some(Response::Execution(Tag::new("SET"))))
198                } else {
199                    Err(PgWireError::UserError(Box::new(
200                        pgwire::error::ErrorInfo::new(
201                            "ERROR".to_string(),
202                            "42601".to_string(),
203                            "Invalid SET TIME ZONE syntax".to_string(),
204                        ),
205                    )))
206                }
207            } else {
208                // noop: skip any unsupported set statements
209                Ok(Some(Response::Execution(Tag::new("SET"))))
210            }
211        } else {
212            Ok(None)
213        }
214    }
215
216    async fn try_respond_transaction_statements<'a>(
217        &self,
218        query_lower: &str,
219    ) -> PgWireResult<Option<Response<'a>>> {
220        // Transaction handling based on pgwire example:
221        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
222        match query_lower.trim() {
223            "begin" | "begin transaction" | "begin work" | "start transaction" => {
224                let mut state = self.transaction_state.lock().await;
225                match *state {
226                    TransactionState::None => {
227                        *state = TransactionState::Active;
228                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
229                    }
230                    TransactionState::Active => {
231                        // Already in transaction, PostgreSQL allows this but issues a warning
232                        // For simplicity, we'll just return BEGIN again
233                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
234                    }
235                    TransactionState::Failed => {
236                        // Can't start new transaction from failed state
237                        Err(PgWireError::UserError(Box::new(
238                            pgwire::error::ErrorInfo::new(
239                                "ERROR".to_string(),
240                                "25P01".to_string(),
241                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
242                            ),
243                        )))
244                    }
245                }
246            }
247            "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
248                let mut state = self.transaction_state.lock().await;
249                match *state {
250                    TransactionState::Active => {
251                        *state = TransactionState::None;
252                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
253                    }
254                    TransactionState::None => {
255                        // PostgreSQL allows COMMIT outside transaction with warning
256                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
257                    }
258                    TransactionState::Failed => {
259                        // COMMIT in failed transaction is treated as ROLLBACK
260                        *state = TransactionState::None;
261                        Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
262                    }
263                }
264            }
265            "rollback" | "rollback transaction" | "rollback work" | "abort" => {
266                let mut state = self.transaction_state.lock().await;
267                *state = TransactionState::None;
268                Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
269            }
270            _ => Ok(None),
271        }
272    }
273
274    async fn try_respond_show_statements<'a>(
275        &self,
276        query_lower: &str,
277    ) -> PgWireResult<Option<Response<'a>>> {
278        if query_lower.starts_with("show ") {
279            match query_lower.strip_suffix(";").unwrap_or(query_lower) {
280                "show time zone" => {
281                    let timezone = self.timezone.lock().await.clone();
282                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
283                    Ok(Some(Response::Query(resp)))
284                }
285                "show server_version" => {
286                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
287                    Ok(Some(Response::Query(resp)))
288                }
289                "show transaction_isolation" => {
290                    let resp =
291                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
292                    Ok(Some(Response::Query(resp)))
293                }
294                "show catalogs" => {
295                    let catalogs = self.session_context.catalog_names();
296                    let value = catalogs.join(", ");
297                    let resp = Self::mock_show_response("Catalogs", &value)?;
298                    Ok(Some(Response::Query(resp)))
299                }
300                "show search_path" => {
301                    let default_catalog = "datafusion";
302                    let resp = Self::mock_show_response("search_path", default_catalog)?;
303                    Ok(Some(Response::Query(resp)))
304                }
305                _ => Err(PgWireError::UserError(Box::new(
306                    pgwire::error::ErrorInfo::new(
307                        "ERROR".to_string(),
308                        "42704".to_string(),
309                        format!("Unrecognized SHOW command: {query_lower}"),
310                    ),
311                ))),
312            }
313        } else {
314            Ok(None)
315        }
316    }
317}
318
319#[async_trait]
320impl SimpleQueryHandler for DfSessionService {
321    async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
322    where
323        C: ClientInfo + Unpin + Send + Sync,
324    {
325        let query_lower = query.to_lowercase().trim().to_string();
326        log::debug!("Received query: {}", query); // Log the query for debugging
327
328        // Check permissions for the query (skip for SET, transaction, and SHOW statements)
329        if !query_lower.starts_with("set")
330            && !query_lower.starts_with("begin")
331            && !query_lower.starts_with("commit")
332            && !query_lower.starts_with("rollback")
333            && !query_lower.starts_with("start")
334            && !query_lower.starts_with("end")
335            && !query_lower.starts_with("abort")
336            && !query_lower.starts_with("show")
337        {
338            self.check_query_permission(client, query).await?;
339        }
340
341        if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
342            return Ok(vec![resp]);
343        }
344
345        if let Some(resp) = self
346            .try_respond_transaction_statements(&query_lower)
347            .await?
348        {
349            return Ok(vec![resp]);
350        }
351
352        if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
353            return Ok(vec![resp]);
354        }
355
356        // Check if we're in a failed transaction and block non-transaction commands
357        {
358            let state = self.transaction_state.lock().await;
359            if *state == TransactionState::Failed {
360                return Err(PgWireError::UserError(Box::new(
361                    pgwire::error::ErrorInfo::new(
362                        "ERROR".to_string(),
363                        "25P01".to_string(),
364                        "current transaction is aborted, commands ignored until end of transaction block".to_string(),
365                    ),
366                )));
367            }
368        }
369
370        let df_result = self.session_context.sql(query).await;
371
372        // Handle query execution errors and transaction state
373        let df = match df_result {
374            Ok(df) => df,
375            Err(e) => {
376                // If we're in a transaction and a query fails, mark transaction as failed
377                {
378                    let mut state = self.transaction_state.lock().await;
379                    if *state == TransactionState::Active {
380                        *state = TransactionState::Failed;
381                    }
382                }
383                return Err(PgWireError::ApiError(Box::new(e)));
384            }
385        };
386
387        if query_lower.starts_with("insert into") {
388            // For INSERT queries, we need to execute the query to get the row count
389            // and return an Execution response with the proper tag
390            let result = df
391                .clone()
392                .collect()
393                .await
394                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
395
396            // Extract count field from the first batch
397            let rows_affected = result
398                .first()
399                .and_then(|batch| batch.column_by_name("count"))
400                .and_then(|col| {
401                    col.as_any()
402                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
403                })
404                .map_or(0, |array| array.value(0) as usize);
405
406            // Create INSERT tag with the affected row count
407            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
408            Ok(vec![Response::Execution(tag)])
409        } else {
410            // For non-INSERT queries, return a regular Query response
411            let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
412            Ok(vec![Response::Query(resp)])
413        }
414    }
415}
416
417#[async_trait]
418impl ExtendedQueryHandler for DfSessionService {
419    type Statement = (String, LogicalPlan);
420    type QueryParser = Parser;
421
422    fn query_parser(&self) -> Arc<Self::QueryParser> {
423        self.parser.clone()
424    }
425
426    async fn do_describe_statement<C>(
427        &self,
428        _client: &mut C,
429        target: &StoredStatement<Self::Statement>,
430    ) -> PgWireResult<DescribeStatementResponse>
431    where
432        C: ClientInfo + Unpin + Send + Sync,
433    {
434        let (_, plan) = &target.statement;
435        let schema = plan.schema();
436        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
437        let params = plan
438            .get_parameter_types()
439            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
440
441        let mut param_types = Vec::with_capacity(params.len());
442        for param_type in ordered_param_types(&params).iter() {
443            // Fixed: Use &params
444            if let Some(datatype) = param_type {
445                let pgtype = into_pg_type(datatype)?;
446                param_types.push(pgtype);
447            } else {
448                param_types.push(Type::UNKNOWN);
449            }
450        }
451
452        Ok(DescribeStatementResponse::new(param_types, fields))
453    }
454
455    async fn do_describe_portal<C>(
456        &self,
457        _client: &mut C,
458        target: &Portal<Self::Statement>,
459    ) -> PgWireResult<DescribePortalResponse>
460    where
461        C: ClientInfo + Unpin + Send + Sync,
462    {
463        let (_, plan) = &target.statement.statement;
464        let format = &target.result_column_format;
465        let schema = plan.schema();
466        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
467
468        Ok(DescribePortalResponse::new(fields))
469    }
470
471    async fn do_query<'a, C>(
472        &self,
473        client: &mut C,
474        portal: &Portal<Self::Statement>,
475        _max_rows: usize,
476    ) -> PgWireResult<Response<'a>>
477    where
478        C: ClientInfo + Unpin + Send + Sync,
479    {
480        let query = portal
481            .statement
482            .statement
483            .0
484            .to_lowercase()
485            .trim()
486            .to_string();
487        log::debug!("Received execute extended query: {}", query); // Log for debugging
488
489        // Check permissions for the query (skip for SET and SHOW statements)
490        if !query.starts_with("set") && !query.starts_with("show") {
491            self.check_query_permission(client, &portal.statement.statement.0)
492                .await?;
493        }
494
495        if let Some(resp) = self.try_respond_set_statements(&query).await? {
496            return Ok(resp);
497        }
498
499        if let Some(resp) = self.try_respond_show_statements(&query).await? {
500            return Ok(resp);
501        }
502
503        let (_, plan) = &portal.statement.statement;
504
505        let param_types = plan
506            .get_parameter_types()
507            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
508        let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
509        let plan = plan
510            .clone()
511            .replace_params_with_values(&param_values)
512            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
513        let dataframe = self
514            .session_context
515            .execute_logical_plan(plan)
516            .await
517            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
518        let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
519        Ok(Response::Query(resp))
520    }
521}
522
523pub struct Parser {
524    session_context: Arc<SessionContext>,
525}
526
527#[async_trait]
528impl QueryParser for Parser {
529    type Statement = (String, LogicalPlan);
530
531    async fn parse_sql<C>(
532        &self,
533        _client: &C,
534        sql: &str,
535        _types: &[Type],
536    ) -> PgWireResult<Self::Statement> {
537        log::debug!("Received parse extended query: {}", sql); // Log for debugging
538        let context = &self.session_context;
539        let state = context.state();
540        let logical_plan = state
541            .create_logical_plan(sql)
542            .await
543            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
544        let optimised = state
545            .optimize(&logical_plan)
546            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
547        Ok((sql.to_string(), optimised))
548    }
549}
550
551fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
552    // Datafusion stores the parameters as a map.  In our case, the keys will be
553    // `$1`, `$2` etc.  The values will be the parameter types.
554    let mut types = types.iter().collect::<Vec<_>>();
555    types.sort_by(|a, b| a.0.cmp(b.0));
556    types.into_iter().map(|pt| pt.1.as_ref()).collect()
557}