datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{FieldInfo, Response, Tag};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use pgwire::types::format::FormatOptions;
21
22use crate::client;
23use crate::hooks::set_show::SetShowHook;
24use crate::hooks::transactions::TransactionStatementHook;
25use crate::hooks::QueryHook;
26use arrow_pg::datatypes::df;
27use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
28use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
29
30/// Simple startup handler that does no authentication
31pub struct SimpleStartupHandler;
32
33#[async_trait::async_trait]
34impl NoopStartupHandler for SimpleStartupHandler {}
35
36pub struct HandlerFactory {
37    pub session_service: Arc<DfSessionService>,
38}
39
40impl HandlerFactory {
41    pub fn new(session_context: Arc<SessionContext>) -> Self {
42        let session_service = Arc::new(DfSessionService::new(session_context));
43        HandlerFactory { session_service }
44    }
45
46    pub fn new_with_hooks(
47        session_context: Arc<SessionContext>,
48        query_hooks: Vec<Arc<dyn QueryHook>>,
49    ) -> Self {
50        let session_service = Arc::new(DfSessionService::new_with_hooks(
51            session_context,
52            query_hooks,
53        ));
54        HandlerFactory { session_service }
55    }
56}
57
58impl PgWireServerHandlers for HandlerFactory {
59    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
60        self.session_service.clone()
61    }
62
63    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
64        self.session_service.clone()
65    }
66
67    fn startup_handler(&self) -> Arc<impl StartupHandler> {
68        Arc::new(SimpleStartupHandler)
69    }
70
71    fn error_handler(&self) -> Arc<impl ErrorHandler> {
72        Arc::new(LoggingErrorHandler)
73    }
74}
75
76struct LoggingErrorHandler;
77
78impl ErrorHandler for LoggingErrorHandler {
79    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
80    where
81        C: ClientInfo,
82    {
83        info!("Sending error: {error}")
84    }
85}
86
87/// The pgwire handler backed by a datafusion `SessionContext`
88pub struct DfSessionService {
89    session_context: Arc<SessionContext>,
90    parser: Arc<Parser>,
91    query_hooks: Vec<Arc<dyn QueryHook>>,
92}
93
94impl DfSessionService {
95    pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
96        let hooks: Vec<Arc<dyn QueryHook>> =
97            vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
98        Self::new_with_hooks(session_context, hooks)
99    }
100
101    pub fn new_with_hooks(
102        session_context: Arc<SessionContext>,
103        query_hooks: Vec<Arc<dyn QueryHook>>,
104    ) -> DfSessionService {
105        let parser = Arc::new(Parser {
106            session_context: session_context.clone(),
107            sql_parser: PostgresCompatibilityParser::new(),
108            query_hooks: query_hooks.clone(),
109        });
110        DfSessionService {
111            session_context,
112            parser,
113            query_hooks,
114        }
115    }
116}
117
118#[async_trait]
119impl SimpleQueryHandler for DfSessionService {
120    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
121    where
122        C: ClientInfo + Unpin + Send + Sync,
123    {
124        log::debug!("Received query: {query}"); // Log the query for debugging
125
126        let statements = self
127            .parser
128            .sql_parser
129            .parse(query)
130            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
131
132        // empty query
133        if statements.is_empty() {
134            return Ok(vec![Response::EmptyQuery]);
135        }
136
137        let mut results = vec![];
138        'stmt: for statement in statements {
139            let query = statement.to_string();
140
141            // Call query hooks with the parsed statement
142            for hook in &self.query_hooks {
143                if let Some(result) = hook
144                    .handle_simple_query(&statement, &self.session_context, client)
145                    .await
146                {
147                    results.push(result?);
148                    continue 'stmt;
149                }
150            }
151
152            let df_result = {
153                let timeout = client::get_statement_timeout(client);
154                if let Some(timeout_duration) = timeout {
155                    tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
156                        .await
157                        .map_err(|_| {
158                            PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
159                                "ERROR".to_string(),
160                                "57014".to_string(), // query_canceled error code
161                                "canceling statement due to statement timeout".to_string(),
162                            )))
163                        })?
164                } else {
165                    self.session_context.sql(&query).await
166                }
167            };
168
169            // Handle query execution errors and transaction state
170            let df = match df_result {
171                Ok(df) => df,
172                Err(e) => {
173                    return Err(PgWireError::ApiError(Box::new(e)));
174                }
175            };
176
177            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
178                let resp = map_rows_affected_for_insert(&df).await?;
179                results.push(resp);
180            } else {
181                // For non-INSERT queries, return a regular Query response
182                let format_options =
183                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
184                let resp =
185                    df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
186                results.push(Response::Query(resp));
187            }
188        }
189        Ok(results)
190    }
191}
192
193#[async_trait]
194impl ExtendedQueryHandler for DfSessionService {
195    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
196    type QueryParser = Parser;
197
198    fn query_parser(&self) -> Arc<Self::QueryParser> {
199        self.parser.clone()
200    }
201
202    async fn do_query<C>(
203        &self,
204        client: &mut C,
205        portal: &Portal<Self::Statement>,
206        _max_rows: usize,
207    ) -> PgWireResult<Response>
208    where
209        C: ClientInfo + Unpin + Send + Sync,
210    {
211        let query = &portal.statement.statement.0;
212        log::debug!("Received execute extended query: {query}"); // Log for debugging
213
214        // Check query hooks first
215        if !self.query_hooks.is_empty() {
216            if let (_, Some((statement, plan))) = &portal.statement.statement {
217                // TODO: in the case where query hooks all return None, we do the param handling again later.
218                let param_types = plan
219                    .get_parameter_types()
220                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
221
222                let param_values: ParamValues =
223                    df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
224
225                for hook in &self.query_hooks {
226                    if let Some(result) = hook
227                        .handle_extended_query(
228                            statement,
229                            plan,
230                            &param_values,
231                            &self.session_context,
232                            client,
233                        )
234                        .await
235                    {
236                        return result;
237                    }
238                }
239            }
240        }
241
242        if let (_, Some((statement, plan))) = &portal.statement.statement {
243            let param_types = plan
244                .get_parameter_types()
245                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
246
247            let param_values =
248                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
249
250            let plan = plan
251                .clone()
252                .replace_params_with_values(&param_values)
253                .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
254                                                                   // &param_values
255            let optimised = self
256                .session_context
257                .state()
258                .optimize(&plan)
259                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
260
261            let dataframe = {
262                let timeout = client::get_statement_timeout(client);
263                if let Some(timeout_duration) = timeout {
264                    tokio::time::timeout(
265                        timeout_duration,
266                        self.session_context.execute_logical_plan(optimised),
267                    )
268                    .await
269                    .map_err(|_| {
270                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
271                            "ERROR".to_string(),
272                            "57014".to_string(), // query_canceled error code
273                            "canceling statement due to statement timeout".to_string(),
274                        )))
275                    })?
276                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
277                } else {
278                    self.session_context
279                        .execute_logical_plan(optimised)
280                        .await
281                        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
282                }
283            };
284
285            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
286                let resp = map_rows_affected_for_insert(&dataframe).await?;
287
288                Ok(resp)
289            } else {
290                // For non-INSERT queries, return a regular Query response
291                let format_options =
292                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
293                let resp = df::encode_dataframe(
294                    dataframe,
295                    &portal.result_column_format,
296                    Some(format_options),
297                )
298                .await?;
299                Ok(Response::Query(resp))
300            }
301        } else {
302            Ok(Response::EmptyQuery)
303        }
304    }
305}
306
307async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
308    // For INSERT queries, we need to execute the query to get the row count
309    // and return an Execution response with the proper tag
310    let result = df
311        .clone()
312        .collect()
313        .await
314        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
315
316    // Extract count field from the first batch
317    let rows_affected = result
318        .first()
319        .and_then(|batch| batch.column_by_name("count"))
320        .and_then(|col| {
321            col.as_any()
322                .downcast_ref::<datafusion::arrow::array::UInt64Array>()
323        })
324        .map_or(0, |array| array.value(0) as usize);
325
326    // Create INSERT tag with the affected row count
327    let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
328    Ok(Response::Execution(tag))
329}
330
331pub struct Parser {
332    session_context: Arc<SessionContext>,
333    sql_parser: PostgresCompatibilityParser,
334    query_hooks: Vec<Arc<dyn QueryHook>>,
335}
336
337#[async_trait]
338impl QueryParser for Parser {
339    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
340
341    async fn parse_sql<C>(
342        &self,
343        client: &C,
344        sql: &str,
345        _types: &[Option<Type>],
346    ) -> PgWireResult<Self::Statement>
347    where
348        C: ClientInfo + Unpin + Send + Sync,
349    {
350        log::debug!("Received parse extended query: {sql}"); // Log for debugging
351
352        let mut statements = self
353            .sql_parser
354            .parse(sql)
355            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
356        if statements.is_empty() {
357            return Ok((sql.to_string(), None));
358        }
359
360        let statement = statements.remove(0);
361        let query = statement.to_string();
362
363        let context = &self.session_context;
364        let state = context.state();
365
366        for hook in &self.query_hooks {
367            if let Some(logical_plan) = hook
368                .handle_extended_parse_query(&statement, context, client)
369                .await
370            {
371                return Ok((query, Some((statement, logical_plan?))));
372            }
373        }
374
375        let logical_plan = state
376            .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
377            .await
378            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
379        Ok((query, Some((statement, logical_plan))))
380    }
381
382    fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
383        if let (_, Some((_, plan))) = stmt {
384            let params = plan
385                .get_parameter_types()
386                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
387
388            let mut param_types = Vec::with_capacity(params.len());
389            for param_type in ordered_param_types(&params).iter() {
390                // Fixed: Use &params
391                if let Some(datatype) = param_type {
392                    let pgtype = into_pg_type(datatype)?;
393                    param_types.push(pgtype);
394                } else {
395                    param_types.push(Type::UNKNOWN);
396                }
397            }
398
399            Ok(param_types)
400        } else {
401            Ok(vec![])
402        }
403    }
404
405    fn get_result_schema(
406        &self,
407        stmt: &Self::Statement,
408        column_format: Option<&Format>,
409    ) -> PgWireResult<Vec<FieldInfo>> {
410        if let (_, Some((_, plan))) = stmt {
411            let schema = plan.schema();
412            let fields = arrow_schema_to_pg_fields(
413                schema.as_arrow(),
414                column_format.unwrap_or(&Format::UnifiedBinary),
415                None,
416            )?;
417
418            Ok(fields)
419        } else {
420            Ok(vec![])
421        }
422    }
423}
424
425fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
426    // Datafusion stores the parameters as a map.  In our case, the keys will be
427    // `$1`, `$2` etc.  The values will be the parameter types.
428    let mut types = types.iter().collect::<Vec<_>>();
429    types.sort_by(|a, b| a.0.cmp(b.0));
430    types.into_iter().map(|pt| pt.1.as_ref()).collect()
431}
432
433#[cfg(test)]
434mod tests {
435    use datafusion::prelude::SessionContext;
436
437    use super::*;
438    use crate::testing::MockClient;
439
440    struct TestHook;
441
442    #[async_trait]
443    impl QueryHook for TestHook {
444        async fn handle_simple_query(
445            &self,
446            statement: &sqlparser::ast::Statement,
447            _ctx: &SessionContext,
448            _client: &mut (dyn ClientInfo + Sync + Send),
449        ) -> Option<PgWireResult<Response>> {
450            if statement.to_string().contains("magic") {
451                Some(Ok(Response::EmptyQuery))
452            } else {
453                None
454            }
455        }
456
457        async fn handle_extended_parse_query(
458            &self,
459            _statement: &sqlparser::ast::Statement,
460            _session_context: &SessionContext,
461            _client: &(dyn ClientInfo + Send + Sync),
462        ) -> Option<PgWireResult<LogicalPlan>> {
463            None
464        }
465
466        async fn handle_extended_query(
467            &self,
468            _statement: &sqlparser::ast::Statement,
469            _logical_plan: &LogicalPlan,
470            _params: &ParamValues,
471            _session_context: &SessionContext,
472            _client: &mut (dyn ClientInfo + Send + Sync),
473        ) -> Option<PgWireResult<Response>> {
474            None
475        }
476    }
477
478    #[tokio::test]
479    async fn test_query_hooks() {
480        let hook = TestHook;
481        let ctx = SessionContext::new();
482        let mut client = MockClient::new();
483
484        // Parse a statement that contains "magic"
485        let parser = PostgresCompatibilityParser::new();
486        let statements = parser.parse("SELECT magic").unwrap();
487        let stmt = &statements[0];
488
489        // Hook should intercept
490        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
491        assert!(result.is_some());
492
493        // Parse a normal statement
494        let statements = parser.parse("SELECT 1").unwrap();
495        let stmt = &statements[0];
496
497        // Hook should not intercept
498        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
499        assert!(result.is_none());
500    }
501
502    #[tokio::test]
503    async fn test_multiple_statements_with_hook_continue() {
504        // Bug #227: when a hook returned a result, the code used `break 'stmt`
505        // which would exit the entire statement loop, preventing subsequent statements
506        // from being processed.
507        let session_context = Arc::new(SessionContext::new());
508
509        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
510        let service = DfSessionService::new_with_hooks(session_context, hooks);
511
512        let mut client = MockClient::new();
513
514        // Mix of queries with hooks and those without
515        let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
516
517        let results =
518            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
519                .await
520                .unwrap();
521
522        assert_eq!(results.len(), 4, "Expected 4 responses");
523
524        assert!(matches!(results[0], Response::EmptyQuery));
525        assert!(matches!(results[1], Response::Query(_)));
526        assert!(matches!(results[2], Response::EmptyQuery));
527        assert!(matches!(results[3], Response::Query(_)));
528    }
529}