datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::{ParamValues, ToDFSchema};
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use datafusion::sql::sqlparser;
12use log::info;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::auth::StartupHandler;
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{
18    DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25use pgwire::types::format::FormatOptions;
26
27use crate::auth::AuthManager;
28use crate::client;
29use crate::hooks::set_show::SetShowHook;
30use crate::hooks::QueryHook;
31use arrow_pg::datatypes::df;
32use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
33use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
34use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
35
36/// Simple startup handler that does no authentication
37/// For production, use DfAuthSource with proper pgwire authentication handlers
38pub struct SimpleStartupHandler;
39
40#[async_trait::async_trait]
41impl NoopStartupHandler for SimpleStartupHandler {}
42
43pub struct HandlerFactory {
44    pub session_service: Arc<DfSessionService>,
45}
46
47impl HandlerFactory {
48    pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
49        let session_service =
50            Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
51        HandlerFactory { session_service }
52    }
53
54    pub fn new_with_hooks(
55        session_context: Arc<SessionContext>,
56        auth_manager: Arc<AuthManager>,
57        query_hooks: Vec<Arc<dyn QueryHook>>,
58    ) -> Self {
59        let session_service = Arc::new(DfSessionService::new_with_hooks(
60            session_context,
61            auth_manager.clone(),
62            query_hooks,
63        ));
64        HandlerFactory { session_service }
65    }
66}
67
68impl PgWireServerHandlers for HandlerFactory {
69    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
70        self.session_service.clone()
71    }
72
73    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
74        self.session_service.clone()
75    }
76
77    fn startup_handler(&self) -> Arc<impl StartupHandler> {
78        Arc::new(SimpleStartupHandler)
79    }
80
81    fn error_handler(&self) -> Arc<impl ErrorHandler> {
82        Arc::new(LoggingErrorHandler)
83    }
84}
85
86struct LoggingErrorHandler;
87
88impl ErrorHandler for LoggingErrorHandler {
89    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
90    where
91        C: ClientInfo,
92    {
93        info!("Sending error: {error}")
94    }
95}
96
97/// The pgwire handler backed by a datafusion `SessionContext`
98pub struct DfSessionService {
99    session_context: Arc<SessionContext>,
100    parser: Arc<Parser>,
101    auth_manager: Arc<AuthManager>,
102    query_hooks: Vec<Arc<dyn QueryHook>>,
103}
104
105impl DfSessionService {
106    pub fn new(
107        session_context: Arc<SessionContext>,
108        auth_manager: Arc<AuthManager>,
109    ) -> DfSessionService {
110        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
111        Self::new_with_hooks(session_context, auth_manager, hooks)
112    }
113
114    pub fn new_with_hooks(
115        session_context: Arc<SessionContext>,
116        auth_manager: Arc<AuthManager>,
117        query_hooks: Vec<Arc<dyn QueryHook>>,
118    ) -> DfSessionService {
119        let parser = Arc::new(Parser {
120            session_context: session_context.clone(),
121            sql_parser: PostgresCompatibilityParser::new(),
122            query_hooks: query_hooks.clone(),
123        });
124        DfSessionService {
125            session_context,
126            parser,
127            auth_manager,
128            query_hooks,
129        }
130    }
131
132    /// Check if the current user has permission to execute a query
133    async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
134    where
135        C: ClientInfo,
136    {
137        // Get the username from client metadata
138        let username = client
139            .metadata()
140            .get("user")
141            .map(|s| s.as_str())
142            .unwrap_or("anonymous");
143
144        // Parse query to determine required permissions
145        let query_lower = query.to_lowercase();
146        let query_trimmed = query_lower.trim();
147
148        let (required_permission, resource) = if query_trimmed.starts_with("select") {
149            (Permission::Select, self.extract_table_from_query(query))
150        } else if query_trimmed.starts_with("insert") {
151            (Permission::Insert, self.extract_table_from_query(query))
152        } else if query_trimmed.starts_with("update") {
153            (Permission::Update, self.extract_table_from_query(query))
154        } else if query_trimmed.starts_with("delete") {
155            (Permission::Delete, self.extract_table_from_query(query))
156        } else if query_trimmed.starts_with("create table")
157            || query_trimmed.starts_with("create view")
158        {
159            (Permission::Create, ResourceType::All)
160        } else if query_trimmed.starts_with("drop") {
161            (Permission::Drop, self.extract_table_from_query(query))
162        } else if query_trimmed.starts_with("alter") {
163            (Permission::Alter, self.extract_table_from_query(query))
164        } else {
165            // For other queries (SHOW, EXPLAIN, etc.), allow all users
166            return Ok(());
167        };
168
169        // Check permission
170        let has_permission = self
171            .auth_manager
172            .check_permission(username, required_permission, resource)
173            .await;
174
175        if !has_permission {
176            return Err(PgWireError::UserError(Box::new(
177                pgwire::error::ErrorInfo::new(
178                    "ERROR".to_string(),
179                    "42501".to_string(), // insufficient_privilege
180                    format!("permission denied for user \"{username}\""),
181                ),
182            )));
183        }
184
185        Ok(())
186    }
187
188    /// Extract table name from query (simplified parsing)
189    fn extract_table_from_query(&self, query: &str) -> ResourceType {
190        let words: Vec<&str> = query.split_whitespace().collect();
191
192        // Simple heuristic to find table names
193        for (i, word) in words.iter().enumerate() {
194            let word_lower = word.to_lowercase();
195            if (word_lower == "from" || word_lower == "into" || word_lower == "table")
196                && i + 1 < words.len()
197            {
198                let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
199                return ResourceType::Table(table_name.to_string());
200            }
201        }
202
203        // If we can't determine the table, default to All
204        ResourceType::All
205    }
206
207    async fn try_respond_transaction_statements<C>(
208        &self,
209        client: &C,
210        query_lower: &str,
211    ) -> PgWireResult<Option<Response>>
212    where
213        C: ClientInfo,
214    {
215        // Transaction handling based on pgwire example:
216        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
217        match query_lower.trim() {
218            "begin" | "begin transaction" | "begin work" | "start transaction" => {
219                match client.transaction_status() {
220                    TransactionStatus::Idle => {
221                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
222                    }
223                    TransactionStatus::Transaction => {
224                        // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
225                        // This matches PostgreSQL's handling of nested transaction blocks
226                        log::warn!("BEGIN command ignored: already in transaction block");
227                        Ok(Some(Response::Execution(Tag::new("BEGIN"))))
228                    }
229                    TransactionStatus::Error => {
230                        // Can't start new transaction from failed state
231                        Err(PgWireError::UserError(Box::new(
232                            pgwire::error::ErrorInfo::new(
233                                "ERROR".to_string(),
234                                "25P01".to_string(),
235                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
236                            ),
237                        )))
238                    }
239                }
240            }
241            "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
242                match client.transaction_status() {
243                    TransactionStatus::Idle | TransactionStatus::Transaction => {
244                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
245                    }
246                    TransactionStatus::Error => {
247                        Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
248                    }
249                }
250            }
251            "rollback" | "rollback transaction" | "rollback work" | "abort" => {
252                Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
253            }
254            _ => Ok(None),
255        }
256    }
257}
258
259#[async_trait]
260impl SimpleQueryHandler for DfSessionService {
261    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
262    where
263        C: ClientInfo + Unpin + Send + Sync,
264    {
265        log::debug!("Received query: {query}"); // Log the query for debugging
266
267        // Check for transaction commands early to avoid SQL parsing issues with ABORT
268        let query_lower = query.to_lowercase().trim().to_string();
269        if let Some(resp) = self
270            .try_respond_transaction_statements(client, &query_lower)
271            .await?
272        {
273            return Ok(vec![resp]);
274        }
275
276        let statements = self
277            .parser
278            .sql_parser
279            .parse(query)
280            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
281
282        // empty query
283        if statements.is_empty() {
284            return Ok(vec![Response::EmptyQuery]);
285        }
286
287        let mut results = vec![];
288        'stmt: for statement in statements {
289            // TODO: improve statement check by using statement directly
290            let query = statement.to_string();
291            let query_lower = query.to_lowercase().trim().to_string();
292
293            // Check permissions for the query (skip for SET, transaction, and SHOW statements)
294            if !query_lower.starts_with("set")
295                && !query_lower.starts_with("begin")
296                && !query_lower.starts_with("commit")
297                && !query_lower.starts_with("rollback")
298                && !query_lower.starts_with("start")
299                && !query_lower.starts_with("end")
300                && !query_lower.starts_with("abort")
301                && !query_lower.starts_with("show")
302            {
303                self.check_query_permission(client, &query).await?;
304            }
305
306            // Call query hooks with the parsed statement
307            for hook in &self.query_hooks {
308                if let Some(result) = hook
309                    .handle_simple_query(&statement, &self.session_context, client)
310                    .await
311                {
312                    results.push(result?);
313                    continue 'stmt;
314                }
315            }
316
317            // Check if we're in a failed transaction and block non-transaction
318            // commands
319            if client.transaction_status() == TransactionStatus::Error {
320                return Err(PgWireError::UserError(Box::new(
321                pgwire::error::ErrorInfo::new(
322                    "ERROR".to_string(),
323                    "25P01".to_string(),
324                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
325                ),
326            )));
327            }
328
329            let df_result = {
330                let timeout = client::get_statement_timeout(client);
331                if let Some(timeout_duration) = timeout {
332                    tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
333                        .await
334                        .map_err(|_| {
335                            PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
336                                "ERROR".to_string(),
337                                "57014".to_string(), // query_canceled error code
338                                "canceling statement due to statement timeout".to_string(),
339                            )))
340                        })?
341                } else {
342                    self.session_context.sql(&query).await
343                }
344            };
345
346            // Handle query execution errors and transaction state
347            let df = match df_result {
348                Ok(df) => df,
349                Err(e) => {
350                    return Err(PgWireError::ApiError(Box::new(e)));
351                }
352            };
353
354            if query_lower.starts_with("insert into") {
355                let resp = map_rows_affected_for_insert(&df).await?;
356                results.push(resp);
357            } else {
358                // For non-INSERT queries, return a regular Query response
359                let format_options =
360                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
361                let resp =
362                    df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
363                results.push(Response::Query(resp));
364            }
365        }
366        Ok(results)
367    }
368}
369
370#[async_trait]
371impl ExtendedQueryHandler for DfSessionService {
372    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
373    type QueryParser = Parser;
374
375    fn query_parser(&self) -> Arc<Self::QueryParser> {
376        self.parser.clone()
377    }
378
379    async fn do_describe_statement<C>(
380        &self,
381        _client: &mut C,
382        target: &StoredStatement<Self::Statement>,
383    ) -> PgWireResult<DescribeStatementResponse>
384    where
385        C: ClientInfo + Unpin + Send + Sync,
386    {
387        if let (_, Some((_, plan))) = &target.statement {
388            let schema = plan.schema();
389            let fields =
390                arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary, None)?;
391            let params = plan
392                .get_parameter_types()
393                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
394
395            let mut param_types = Vec::with_capacity(params.len());
396            for param_type in ordered_param_types(&params).iter() {
397                // Fixed: Use &params
398                if let Some(datatype) = param_type {
399                    let pgtype = into_pg_type(datatype)?;
400                    param_types.push(pgtype);
401                } else {
402                    param_types.push(Type::UNKNOWN);
403                }
404            }
405
406            Ok(DescribeStatementResponse::new(param_types, fields))
407        } else {
408            Ok(DescribeStatementResponse::no_data())
409        }
410    }
411
412    async fn do_describe_portal<C>(
413        &self,
414        _client: &mut C,
415        target: &Portal<Self::Statement>,
416    ) -> PgWireResult<DescribePortalResponse>
417    where
418        C: ClientInfo + Unpin + Send + Sync,
419    {
420        if let (_, Some((_, plan))) = &target.statement.statement {
421            let format = &target.result_column_format;
422            let schema = plan.schema();
423            let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format, None)?;
424
425            Ok(DescribePortalResponse::new(fields))
426        } else {
427            Ok(DescribePortalResponse::no_data())
428        }
429    }
430
431    async fn do_query<C>(
432        &self,
433        client: &mut C,
434        portal: &Portal<Self::Statement>,
435        _max_rows: usize,
436    ) -> PgWireResult<Response>
437    where
438        C: ClientInfo + Unpin + Send + Sync,
439    {
440        let query = portal
441            .statement
442            .statement
443            .0
444            .to_lowercase()
445            .trim()
446            .to_string();
447        log::debug!("Received execute extended query: {query}"); // Log for debugging
448
449        // Check query hooks first
450        if !self.query_hooks.is_empty() {
451            if let (_, Some((statement, plan))) = &portal.statement.statement {
452                // TODO: in the case where query hooks all return None, we do the param handling again later.
453                let param_types = plan
454                    .get_parameter_types()
455                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
456
457                let param_values: ParamValues =
458                    df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
459
460                for hook in &self.query_hooks {
461                    if let Some(result) = hook
462                        .handle_extended_query(
463                            statement,
464                            plan,
465                            &param_values,
466                            &self.session_context,
467                            client,
468                        )
469                        .await
470                    {
471                        return result;
472                    }
473                }
474            }
475        }
476
477        // Check permissions for the query (skip for SET and SHOW statements)
478        if !query.starts_with("set") && !query.starts_with("show") {
479            self.check_query_permission(client, &portal.statement.statement.0)
480                .await?;
481        }
482
483        if let Some(resp) = self
484            .try_respond_transaction_statements(client, &query)
485            .await?
486        {
487            return Ok(resp);
488        }
489
490        // Check if we're in a failed transaction and block non-transaction
491        // commands
492        if client.transaction_status() == TransactionStatus::Error {
493            return Err(PgWireError::UserError(Box::new(
494                pgwire::error::ErrorInfo::new(
495                    "ERROR".to_string(),
496                    "25P01".to_string(),
497                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
498                ),
499            )));
500        }
501
502        if let (_, Some((_, plan))) = &portal.statement.statement {
503            let param_types = plan
504                .get_parameter_types()
505                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
506
507            let param_values =
508                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
509
510            let plan = plan
511                .clone()
512                .replace_params_with_values(&param_values)
513                .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
514                                                                   // &param_values
515            let optimised = self
516                .session_context
517                .state()
518                .optimize(&plan)
519                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
520
521            let dataframe = {
522                let timeout = client::get_statement_timeout(client);
523                if let Some(timeout_duration) = timeout {
524                    tokio::time::timeout(
525                        timeout_duration,
526                        self.session_context.execute_logical_plan(optimised),
527                    )
528                    .await
529                    .map_err(|_| {
530                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
531                            "ERROR".to_string(),
532                            "57014".to_string(), // query_canceled error code
533                            "canceling statement due to statement timeout".to_string(),
534                        )))
535                    })?
536                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
537                } else {
538                    self.session_context
539                        .execute_logical_plan(optimised)
540                        .await
541                        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
542                }
543            };
544
545            if query.starts_with("insert into") {
546                let resp = map_rows_affected_for_insert(&dataframe).await?;
547
548                Ok(resp)
549            } else {
550                // For non-INSERT queries, return a regular Query response
551                let format_options =
552                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
553                let resp = df::encode_dataframe(
554                    dataframe,
555                    &portal.result_column_format,
556                    Some(format_options),
557                )
558                .await?;
559                Ok(Response::Query(resp))
560            }
561        } else {
562            Ok(Response::EmptyQuery)
563        }
564    }
565}
566
567async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
568    // For INSERT queries, we need to execute the query to get the row count
569    // and return an Execution response with the proper tag
570    let result = df
571        .clone()
572        .collect()
573        .await
574        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
575
576    // Extract count field from the first batch
577    let rows_affected = result
578        .first()
579        .and_then(|batch| batch.column_by_name("count"))
580        .and_then(|col| {
581            col.as_any()
582                .downcast_ref::<datafusion::arrow::array::UInt64Array>()
583        })
584        .map_or(0, |array| array.value(0) as usize);
585
586    // Create INSERT tag with the affected row count
587    let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
588    Ok(Response::Execution(tag))
589}
590
591pub struct Parser {
592    session_context: Arc<SessionContext>,
593    sql_parser: PostgresCompatibilityParser,
594    query_hooks: Vec<Arc<dyn QueryHook>>,
595}
596
597impl Parser {
598    fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
599        // Check for transaction commands that shouldn't be parsed by DataFusion
600        let sql_lower = sql.to_lowercase();
601        let sql_trimmed = sql_lower.trim();
602
603        if matches!(
604            sql_trimmed,
605            "" | "begin"
606                | "begin transaction"
607                | "begin work"
608                | "start transaction"
609                | "commit"
610                | "commit transaction"
611                | "commit work"
612                | "end"
613                | "end transaction"
614                | "rollback"
615                | "rollback transaction"
616                | "rollback work"
617                | "abort"
618        ) {
619            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
620            let dummy_schema = datafusion::common::DFSchema::empty();
621            return Ok(Some(LogicalPlan::EmptyRelation(
622                datafusion::logical_expr::EmptyRelation {
623                    produce_one_row: false,
624                    schema: Arc::new(dummy_schema),
625                },
626            )));
627        }
628
629        // show statement may not be supported by datafusion
630        if sql_trimmed.starts_with("show") {
631            let show_schema =
632                Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
633            let df_schema = show_schema.to_dfschema()?;
634            return Ok(Some(LogicalPlan::EmptyRelation(
635                datafusion::logical_expr::EmptyRelation {
636                    produce_one_row: true,
637                    schema: Arc::new(df_schema),
638                },
639            )));
640        }
641
642        Ok(None)
643    }
644}
645
646#[async_trait]
647impl QueryParser for Parser {
648    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
649
650    async fn parse_sql<C>(
651        &self,
652        client: &C,
653        sql: &str,
654        _types: &[Option<Type>],
655    ) -> PgWireResult<Self::Statement>
656    where
657        C: ClientInfo + Unpin + Send + Sync,
658    {
659        log::debug!("Received parse extended query: {sql}"); // Log for debugging
660
661        let mut statements = self
662            .sql_parser
663            .parse(sql)
664            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
665        if statements.is_empty() {
666            return Ok((sql.to_string(), None));
667        }
668
669        let statement = statements.remove(0);
670
671        // Check for transaction commands that shouldn't be parsed by DataFusion
672        if let Some(plan) = self
673            .try_shortcut_parse_plan(sql)
674            .map_err(|e| PgWireError::ApiError(Box::new(e)))?
675        {
676            return Ok((sql.to_string(), Some((statement, plan))));
677        }
678
679        let query = statement.to_string();
680
681        let context = &self.session_context;
682        let state = context.state();
683
684        for hook in &self.query_hooks {
685            if let Some(logical_plan) = hook
686                .handle_extended_parse_query(&statement, context, client)
687                .await
688            {
689                return Ok((query, Some((statement, logical_plan?))));
690            }
691        }
692
693        let logical_plan = state
694            .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
695            .await
696            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
697        Ok((query, Some((statement, logical_plan))))
698    }
699}
700
701fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
702    // Datafusion stores the parameters as a map.  In our case, the keys will be
703    // `$1`, `$2` etc.  The values will be the parameter types.
704    let mut types = types.iter().collect::<Vec<_>>();
705    types.sort_by(|a, b| a.0.cmp(b.0));
706    types.into_iter().map(|pt| pt.1.as_ref()).collect()
707}
708
709#[cfg(test)]
710mod tests {
711    use datafusion::prelude::SessionContext;
712
713    use super::*;
714    use crate::testing::MockClient;
715
716    struct TestHook;
717
718    #[async_trait]
719    impl QueryHook for TestHook {
720        async fn handle_simple_query(
721            &self,
722            statement: &sqlparser::ast::Statement,
723            _ctx: &SessionContext,
724            _client: &mut (dyn ClientInfo + Sync + Send),
725        ) -> Option<PgWireResult<Response>> {
726            if statement.to_string().contains("magic") {
727                Some(Ok(Response::EmptyQuery))
728            } else {
729                None
730            }
731        }
732
733        async fn handle_extended_parse_query(
734            &self,
735            _statement: &sqlparser::ast::Statement,
736            _session_context: &SessionContext,
737            _client: &(dyn ClientInfo + Send + Sync),
738        ) -> Option<PgWireResult<LogicalPlan>> {
739            None
740        }
741
742        async fn handle_extended_query(
743            &self,
744            _statement: &sqlparser::ast::Statement,
745            _logical_plan: &LogicalPlan,
746            _params: &ParamValues,
747            _session_context: &SessionContext,
748            _client: &mut (dyn ClientInfo + Send + Sync),
749        ) -> Option<PgWireResult<Response>> {
750            None
751        }
752    }
753
754    #[tokio::test]
755    async fn test_query_hooks() {
756        let hook = TestHook;
757        let ctx = SessionContext::new();
758        let mut client = MockClient::new();
759
760        // Parse a statement that contains "magic"
761        let parser = PostgresCompatibilityParser::new();
762        let statements = parser.parse("SELECT magic").unwrap();
763        let stmt = &statements[0];
764
765        // Hook should intercept
766        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
767        assert!(result.is_some());
768
769        // Parse a normal statement
770        let statements = parser.parse("SELECT 1").unwrap();
771        let stmt = &statements[0];
772
773        // Hook should not intercept
774        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
775        assert!(result.is_none());
776    }
777
778    #[tokio::test]
779    async fn test_multiple_statements_with_hook_continue() {
780        // Bug #227: when a hook returned a result, the code used `break 'stmt`
781        // which would exit the entire statement loop, preventing subsequent statements
782        // from being processed.
783        let session_context = Arc::new(SessionContext::new());
784        let auth_manager = Arc::new(AuthManager::new());
785
786        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
787        let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
788
789        let mut client = MockClient::new();
790
791        // Mix of queries with hooks and those without
792        let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
793
794        let results =
795            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
796                .await
797                .unwrap();
798
799        assert_eq!(results.len(), 4, "Expected 4 responses");
800
801        assert!(matches!(results[0], Response::EmptyQuery));
802        assert!(matches!(results[1], Response::Query(_)));
803        assert!(matches!(results[2], Response::EmptyQuery));
804        assert!(matches!(results[3], Response::Query(_)));
805    }
806}