Skip to main content

datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{FieldInfo, Response, Tag};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use pgwire::messages::PgWireBackendMessage;
21use pgwire::types::format::FormatOptions;
22
23use crate::hooks::set_show::SetShowHook;
24use crate::hooks::transactions::TransactionStatementHook;
25use crate::hooks::QueryHook;
26use crate::{client, planner};
27use arrow_pg::datatypes::df;
28use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
29use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
30
31/// Simple startup handler that does no authentication
32pub struct SimpleStartupHandler;
33
34#[async_trait::async_trait]
35impl NoopStartupHandler for SimpleStartupHandler {}
36
37pub struct HandlerFactory {
38    pub session_service: Arc<DfSessionService>,
39}
40
41impl HandlerFactory {
42    pub fn new(session_context: Arc<SessionContext>) -> Self {
43        let session_service = Arc::new(DfSessionService::new(session_context));
44        HandlerFactory { session_service }
45    }
46
47    pub fn new_with_hooks(
48        session_context: Arc<SessionContext>,
49        query_hooks: Vec<Arc<dyn QueryHook>>,
50    ) -> Self {
51        let session_service = Arc::new(DfSessionService::new_with_hooks(
52            session_context,
53            query_hooks,
54        ));
55        HandlerFactory { session_service }
56    }
57}
58
59impl PgWireServerHandlers for HandlerFactory {
60    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
61        self.session_service.clone()
62    }
63
64    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
65        self.session_service.clone()
66    }
67
68    fn startup_handler(&self) -> Arc<impl StartupHandler> {
69        Arc::new(SimpleStartupHandler)
70    }
71
72    fn error_handler(&self) -> Arc<impl ErrorHandler> {
73        Arc::new(LoggingErrorHandler)
74    }
75}
76
77struct LoggingErrorHandler;
78
79impl ErrorHandler for LoggingErrorHandler {
80    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
81    where
82        C: ClientInfo,
83    {
84        info!("Sending error: {error}")
85    }
86}
87
88/// The pgwire handler backed by a datafusion `SessionContext`
89pub struct DfSessionService {
90    session_context: Arc<SessionContext>,
91    parser: Arc<Parser>,
92    query_hooks: Vec<Arc<dyn QueryHook>>,
93}
94
95impl DfSessionService {
96    pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
97        let hooks: Vec<Arc<dyn QueryHook>> =
98            vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
99        Self::new_with_hooks(session_context, hooks)
100    }
101
102    pub fn new_with_hooks(
103        session_context: Arc<SessionContext>,
104        query_hooks: Vec<Arc<dyn QueryHook>>,
105    ) -> DfSessionService {
106        let parser = Arc::new(Parser {
107            session_context: session_context.clone(),
108            sql_parser: PostgresCompatibilityParser::new(),
109            query_hooks: query_hooks.clone(),
110        });
111        DfSessionService {
112            session_context,
113            parser,
114            query_hooks,
115        }
116    }
117}
118
119#[async_trait]
120impl SimpleQueryHandler for DfSessionService {
121    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
122    where
123        C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
124        C::Error: std::fmt::Debug,
125        PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
126    {
127        log::debug!("Received query: {query}");
128        let statements = self
129            .parser
130            .sql_parser
131            .parse(query)
132            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
133
134        // empty query
135        if statements.is_empty() {
136            return Ok(vec![Response::EmptyQuery]);
137        }
138
139        let mut results = vec![];
140        'stmt: for statement in statements {
141            // Call query hooks with the parsed statement
142            for hook in &self.query_hooks {
143                if let Some(result) = hook
144                    .handle_simple_query(&statement, &self.session_context, client)
145                    .await
146                {
147                    results.push(result?);
148                    continue 'stmt;
149                }
150            }
151
152            let df_result = {
153                let query = statement.to_string();
154
155                let timeout = client::get_statement_timeout(client);
156                if let Some(timeout_duration) = timeout {
157                    tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
158                        .await
159                        .map_err(|_| {
160                            PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
161                                "ERROR".to_string(),
162                                "57014".to_string(), // query_canceled error code
163                                "canceling statement due to statement timeout".to_string(),
164                            )))
165                        })?
166                } else {
167                    self.session_context.sql(&query).await
168                }
169            };
170
171            // Handle query execution errors and transaction state
172            let df = match df_result {
173                Ok(df) => df,
174                Err(e) => {
175                    return Err(PgWireError::ApiError(Box::new(e)));
176                }
177            };
178
179            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
180                let resp = map_rows_affected_for_insert(&df).await?;
181                results.push(resp);
182            } else {
183                // For non-INSERT queries, return a regular Query response
184                let format_options =
185                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
186                let resp =
187                    df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
188                results.push(Response::Query(resp));
189            }
190        }
191        Ok(results)
192    }
193}
194
195#[async_trait]
196impl ExtendedQueryHandler for DfSessionService {
197    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
198    type QueryParser = Parser;
199
200    fn query_parser(&self) -> Arc<Self::QueryParser> {
201        self.parser.clone()
202    }
203
204    async fn do_query<C>(
205        &self,
206        client: &mut C,
207        portal: &Portal<Self::Statement>,
208        _max_rows: usize,
209    ) -> PgWireResult<Response>
210    where
211        C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
212        C::Error: std::fmt::Debug,
213        PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
214    {
215        let query = &portal.statement.statement.0;
216        log::debug!("Received execute extended query: {query}");
217        // Check query hooks first
218        if !self.query_hooks.is_empty() {
219            if let (_, Some((statement, plan))) = &portal.statement.statement {
220                // TODO: in the case where query hooks all return None, we do the param handling again later.
221                let param_types = planner::get_inferred_parameter_types(plan)
222                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
223
224                let param_values: ParamValues =
225                    df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
226
227                for hook in &self.query_hooks {
228                    if let Some(result) = hook
229                        .handle_extended_query(
230                            statement,
231                            plan,
232                            &param_values,
233                            &self.session_context,
234                            client,
235                        )
236                        .await
237                    {
238                        return result;
239                    }
240                }
241            }
242        }
243
244        if let (_, Some((statement, plan))) = &portal.statement.statement {
245            let param_types = planner::get_inferred_parameter_types(plan)
246                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
247
248            let param_values =
249                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
250
251            let plan = plan
252                .clone()
253                .replace_params_with_values(&param_values)
254                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
255            let optimised = self
256                .session_context
257                .state()
258                .optimize(&plan)
259                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
260
261            let dataframe = {
262                let timeout = client::get_statement_timeout(client);
263                if let Some(timeout_duration) = timeout {
264                    tokio::time::timeout(
265                        timeout_duration,
266                        self.session_context.execute_logical_plan(optimised),
267                    )
268                    .await
269                    .map_err(|_| {
270                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
271                            "ERROR".to_string(),
272                            "57014".to_string(), // query_canceled error code
273                            "canceling statement due to statement timeout".to_string(),
274                        )))
275                    })?
276                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
277                } else {
278                    self.session_context
279                        .execute_logical_plan(optimised)
280                        .await
281                        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
282                }
283            };
284
285            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
286                let resp = map_rows_affected_for_insert(&dataframe).await?;
287
288                Ok(resp)
289            } else {
290                // For non-INSERT queries, return a regular Query response
291                let format_options =
292                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
293                let resp = df::encode_dataframe(
294                    dataframe,
295                    &portal.result_column_format,
296                    Some(format_options),
297                )
298                .await?;
299                Ok(Response::Query(resp))
300            }
301        } else {
302            Ok(Response::EmptyQuery)
303        }
304    }
305}
306
307async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
308    // For INSERT queries, we need to execute the query to get the row count
309    // and return an Execution response with the proper tag
310    let result = df
311        .clone()
312        .collect()
313        .await
314        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
315
316    // Extract count field from the first batch
317    let rows_affected = result
318        .first()
319        .and_then(|batch| batch.column_by_name("count"))
320        .and_then(|col| {
321            col.as_any()
322                .downcast_ref::<datafusion::arrow::array::UInt64Array>()
323        })
324        .map_or(0, |array| array.value(0) as usize);
325
326    // Create INSERT tag with the affected row count
327    let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
328    Ok(Response::Execution(tag))
329}
330
331pub struct Parser {
332    session_context: Arc<SessionContext>,
333    sql_parser: PostgresCompatibilityParser,
334    query_hooks: Vec<Arc<dyn QueryHook>>,
335}
336
337#[async_trait]
338impl QueryParser for Parser {
339    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
340
341    async fn parse_sql<C>(
342        &self,
343        client: &C,
344        sql: &str,
345        _types: &[Option<Type>],
346    ) -> PgWireResult<Self::Statement>
347    where
348        C: ClientInfo + Unpin + Send + Sync,
349    {
350        log::debug!("Received parse extended query: {sql}");
351        let mut statements = self
352            .sql_parser
353            .parse(sql)
354            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
355        if statements.is_empty() {
356            return Ok((sql.to_string(), None));
357        }
358
359        let statement = statements.remove(0);
360        let query = statement.to_string();
361
362        let context = &self.session_context;
363        let state = context.state();
364
365        for hook in &self.query_hooks {
366            if let Some(logical_plan) = hook
367                .handle_extended_parse_query(&statement, context, client)
368                .await
369            {
370                return Ok((query, Some((statement, logical_plan?))));
371            }
372        }
373
374        let logical_plan = state
375            .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
376            .await
377            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
378        Ok((query, Some((statement, logical_plan))))
379    }
380
381    fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
382        if let (_, Some((_, plan))) = stmt {
383            let params = planner::get_inferred_parameter_types(plan)
384                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
385
386            let mut param_types = Vec::with_capacity(params.len());
387            for param_type in ordered_param_types(&params).iter() {
388                if let Some(datatype) = param_type {
389                    let pgtype = into_pg_type(datatype)?;
390                    param_types.push(pgtype);
391                } else {
392                    param_types.push(Type::UNKNOWN);
393                }
394            }
395
396            Ok(param_types)
397        } else {
398            Ok(vec![])
399        }
400    }
401
402    fn get_result_schema(
403        &self,
404        stmt: &Self::Statement,
405        column_format: Option<&Format>,
406    ) -> PgWireResult<Vec<FieldInfo>> {
407        if let (_, Some((_, plan))) = stmt {
408            let schema = plan.schema();
409            let fields = arrow_schema_to_pg_fields(
410                schema.as_arrow(),
411                column_format.unwrap_or(&Format::UnifiedBinary),
412                None,
413            )?;
414
415            Ok(fields)
416        } else {
417            Ok(vec![])
418        }
419    }
420}
421
422fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
423    // Datafusion stores the parameters as a map.  In our case, the keys will be
424    // `$1`, `$2` etc.  The values will be the parameter types.
425    let mut types = types.iter().collect::<Vec<_>>();
426    types.sort_by(|a, b| a.0.cmp(b.0));
427    types.into_iter().map(|pt| pt.1.as_ref()).collect()
428}
429
430#[cfg(test)]
431mod tests {
432    use datafusion::prelude::SessionContext;
433
434    use super::*;
435    use crate::testing::MockClient;
436
437    use crate::hooks::HookClient;
438
439    struct TestHook;
440
441    #[async_trait]
442    impl QueryHook for TestHook {
443        async fn handle_simple_query(
444            &self,
445            statement: &sqlparser::ast::Statement,
446            _ctx: &SessionContext,
447            _client: &mut dyn HookClient,
448        ) -> Option<PgWireResult<Response>> {
449            if statement.to_string().contains("magic") {
450                Some(Ok(Response::EmptyQuery))
451            } else {
452                None
453            }
454        }
455
456        async fn handle_extended_parse_query(
457            &self,
458            _statement: &sqlparser::ast::Statement,
459            _session_context: &SessionContext,
460            _client: &(dyn ClientInfo + Send + Sync),
461        ) -> Option<PgWireResult<LogicalPlan>> {
462            None
463        }
464
465        async fn handle_extended_query(
466            &self,
467            _statement: &sqlparser::ast::Statement,
468            _logical_plan: &LogicalPlan,
469            _params: &ParamValues,
470            _session_context: &SessionContext,
471            _client: &mut dyn HookClient,
472        ) -> Option<PgWireResult<Response>> {
473            None
474        }
475    }
476
477    #[tokio::test]
478    async fn test_query_hooks() {
479        let hook = TestHook;
480        let ctx = SessionContext::new();
481        let mut client = MockClient::new();
482
483        // Parse a statement that contains "magic"
484        let parser = PostgresCompatibilityParser::new();
485        let statements = parser.parse("SELECT magic").unwrap();
486        let stmt = &statements[0];
487
488        // Hook should intercept
489        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
490        assert!(result.is_some());
491
492        // Parse a normal statement
493        let statements = parser.parse("SELECT 1").unwrap();
494        let stmt = &statements[0];
495
496        // Hook should not intercept
497        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
498        assert!(result.is_none());
499    }
500
501    #[tokio::test]
502    async fn test_multiple_statements_with_hook_continue() {
503        // Bug #227: when a hook returned a result, the code used `break 'stmt`
504        // which would exit the entire statement loop, preventing subsequent statements
505        // from being processed.
506        let session_context = Arc::new(SessionContext::new());
507
508        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
509        let service = DfSessionService::new_with_hooks(session_context, hooks);
510
511        let mut client = MockClient::new();
512
513        // Mix of queries with hooks and those without
514        let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
515
516        let results =
517            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
518                .await
519                .unwrap();
520
521        assert_eq!(results.len(), 4, "Expected 4 responses");
522
523        assert!(matches!(results[0], Response::EmptyQuery));
524        assert!(matches!(results[1], Response::Query(_)));
525        assert!(matches!(results[2], Response::EmptyQuery));
526        assert!(matches!(results[3], Response::Query(_)));
527    }
528
529    #[tokio::test]
530    async fn test_set_sends_parameter_status_via_sink() {
531        use pgwire::messages::PgWireBackendMessage;
532
533        let service = crate::testing::setup_handlers();
534        let mut client = MockClient::new();
535
536        let test_cases = vec![
537            ("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
538            (
539                "SET intervalstyle = 'postgres'",
540                "IntervalStyle",
541                "postgres",
542            ),
543            ("SET bytea_output = 'hex'", "bytea_output", "hex"),
544            (
545                "SET application_name = 'myapp'",
546                "application_name",
547                "myapp",
548            ),
549            ("SET search_path = 'public'", "search_path", "public"),
550            ("SET extra_float_digits = '2'", "extra_float_digits", "2"),
551            (
552                "SET TIME ZONE 'America/New_York'",
553                "TimeZone",
554                "America/New_York",
555            ),
556        ];
557
558        for (sql, expected_key, expected_value) in test_cases {
559            client.sent_messages.clear();
560
561            let responses =
562                <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
563                    .await
564                    .unwrap();
565
566            assert!(
567                matches!(responses[0], Response::Execution(_)),
568                "Expected SET tag for {sql}"
569            );
570
571            let ps_msgs: Vec<_> = client
572                .sent_messages()
573                .iter()
574                .filter_map(|m| match m {
575                    PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
576                    _ => None,
577                })
578                .collect();
579
580            assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
581            assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
582            assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
583        }
584    }
585
586    #[tokio::test]
587    async fn test_set_statement_timeout_no_parameter_status() {
588        use pgwire::messages::PgWireBackendMessage;
589
590        let service = crate::testing::setup_handlers();
591        let mut client = MockClient::new();
592
593        <DfSessionService as SimpleQueryHandler>::do_query(
594            &service,
595            &mut client,
596            "SET statement_timeout TO '5000ms'",
597        )
598        .await
599        .unwrap();
600
601        let has_ps = client
602            .sent_messages()
603            .iter()
604            .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
605
606        assert!(!has_ps, "statement_timeout should not send ParameterStatus");
607    }
608}