datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::ToDFSchema;
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use log::{info, warn};
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{
17    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
18    Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25use tokio::sync::Mutex;
26
27use crate::auth::AuthManager;
28use arrow_pg::datatypes::df;
29use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
30use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
31use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
32
33// Metadata keys for session-level settings
34const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";
35
36/// Simple startup handler that does no authentication
37/// For production, use DfAuthSource with proper pgwire authentication handlers
38pub struct SimpleStartupHandler;
39
40#[async_trait::async_trait]
41impl NoopStartupHandler for SimpleStartupHandler {}
42
43pub struct HandlerFactory {
44    pub session_service: Arc<DfSessionService>,
45}
46
47impl HandlerFactory {
48    pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
49        let session_service =
50            Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
51        HandlerFactory { session_service }
52    }
53}
54
55impl PgWireServerHandlers for HandlerFactory {
56    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
57        self.session_service.clone()
58    }
59
60    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
61        self.session_service.clone()
62    }
63
64    fn startup_handler(&self) -> Arc<impl StartupHandler> {
65        Arc::new(SimpleStartupHandler)
66    }
67
68    fn error_handler(&self) -> Arc<impl ErrorHandler> {
69        Arc::new(LoggingErrorHandler)
70    }
71}
72
73struct LoggingErrorHandler;
74
75impl ErrorHandler for LoggingErrorHandler {
76    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
77    where
78        C: ClientInfo,
79    {
80        info!("Sending error: {error}")
81    }
82}
83
84/// The pgwire handler backed by a datafusion `SessionContext`
85pub struct DfSessionService {
86    session_context: Arc<SessionContext>,
87    parser: Arc<Parser>,
88    timezone: Arc<Mutex<String>>,
89    auth_manager: Arc<AuthManager>,
90}
91
92impl DfSessionService {
93    pub fn new(
94        session_context: Arc<SessionContext>,
95        auth_manager: Arc<AuthManager>,
96    ) -> DfSessionService {
97        let parser = Arc::new(Parser {
98            session_context: session_context.clone(),
99            sql_parser: PostgresCompatibilityParser::new(),
100        });
101        DfSessionService {
102            session_context,
103            parser,
104            timezone: Arc::new(Mutex::new("UTC".to_string())),
105            auth_manager,
106        }
107    }
108
109    /// Get statement timeout from client metadata
110    fn get_statement_timeout<C>(client: &C) -> Option<std::time::Duration>
111    where
112        C: ClientInfo,
113    {
114        client
115            .metadata()
116            .get(METADATA_STATEMENT_TIMEOUT)
117            .and_then(|s| s.parse::<u64>().ok())
118            .map(std::time::Duration::from_millis)
119    }
120
121    /// Set statement timeout in client metadata
122    fn set_statement_timeout<C>(client: &mut C, timeout: Option<std::time::Duration>)
123    where
124        C: ClientInfo,
125    {
126        let metadata = client.metadata_mut();
127        if let Some(duration) = timeout {
128            metadata.insert(
129                METADATA_STATEMENT_TIMEOUT.to_string(),
130                duration.as_millis().to_string(),
131            );
132        } else {
133            metadata.remove(METADATA_STATEMENT_TIMEOUT);
134        }
135    }
136
137    /// Check if the current user has permission to execute a query
138    async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
139    where
140        C: ClientInfo,
141    {
142        // Get the username from client metadata
143        let username = client
144            .metadata()
145            .get("user")
146            .map(|s| s.as_str())
147            .unwrap_or("anonymous");
148
149        // Parse query to determine required permissions
150        let query_lower = query.to_lowercase();
151        let query_trimmed = query_lower.trim();
152
153        let (required_permission, resource) = if query_trimmed.starts_with("select") {
154            (Permission::Select, self.extract_table_from_query(query))
155        } else if query_trimmed.starts_with("insert") {
156            (Permission::Insert, self.extract_table_from_query(query))
157        } else if query_trimmed.starts_with("update") {
158            (Permission::Update, self.extract_table_from_query(query))
159        } else if query_trimmed.starts_with("delete") {
160            (Permission::Delete, self.extract_table_from_query(query))
161        } else if query_trimmed.starts_with("create table")
162            || query_trimmed.starts_with("create view")
163        {
164            (Permission::Create, ResourceType::All)
165        } else if query_trimmed.starts_with("drop") {
166            (Permission::Drop, self.extract_table_from_query(query))
167        } else if query_trimmed.starts_with("alter") {
168            (Permission::Alter, self.extract_table_from_query(query))
169        } else {
170            // For other queries (SHOW, EXPLAIN, etc.), allow all users
171            return Ok(());
172        };
173
174        // Check permission
175        let has_permission = self
176            .auth_manager
177            .check_permission(username, required_permission, resource)
178            .await;
179
180        if !has_permission {
181            return Err(PgWireError::UserError(Box::new(
182                pgwire::error::ErrorInfo::new(
183                    "ERROR".to_string(),
184                    "42501".to_string(), // insufficient_privilege
185                    format!("permission denied for user \"{username}\""),
186                ),
187            )));
188        }
189
190        Ok(())
191    }
192
193    /// Extract table name from query (simplified parsing)
194    fn extract_table_from_query(&self, query: &str) -> ResourceType {
195        let words: Vec<&str> = query.split_whitespace().collect();
196
197        // Simple heuristic to find table names
198        for (i, word) in words.iter().enumerate() {
199            let word_lower = word.to_lowercase();
200            if (word_lower == "from" || word_lower == "into" || word_lower == "table")
201                && i + 1 < words.len()
202            {
203                let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
204                return ResourceType::Table(table_name.to_string());
205            }
206        }
207
208        // If we can't determine the table, default to All
209        ResourceType::All
210    }
211
212    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
213        let fields = vec![FieldInfo::new(
214            name.to_string(),
215            None,
216            None,
217            Type::VARCHAR,
218            FieldFormat::Text,
219        )];
220
221        let row = {
222            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
223            encoder.encode_field(&Some(value))?;
224            encoder.finish()
225        };
226
227        let row_stream = futures::stream::once(async move { row });
228        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
229    }
230
231    async fn try_respond_set_statements<'a, C>(
232        &self,
233        client: &mut C,
234        query_lower: &str,
235    ) -> PgWireResult<Option<Response<'a>>>
236    where
237        C: ClientInfo,
238    {
239        if query_lower.starts_with("set") {
240            if query_lower.starts_with("set time zone") {
241                let parts: Vec<&str> = query_lower.split_whitespace().collect();
242                if parts.len() >= 4 {
243                    let tz = parts[3].trim_matches('"');
244                    let mut timezone = self.timezone.lock().await;
245                    *timezone = tz.to_string();
246                    Ok(Some(Response::Execution(Tag::new("SET"))))
247                } else {
248                    Err(PgWireError::UserError(Box::new(
249                        pgwire::error::ErrorInfo::new(
250                            "ERROR".to_string(),
251                            "42601".to_string(),
252                            "Invalid SET TIME ZONE syntax".to_string(),
253                        ),
254                    )))
255                }
256            } else if query_lower.starts_with("set statement_timeout") {
257                let parts: Vec<&str> = query_lower.split_whitespace().collect();
258                if parts.len() >= 3 {
259                    let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
260
261                    let timeout = if timeout_str == "0" || timeout_str.is_empty() {
262                        None
263                    } else {
264                        // Parse timeout value (supports ms, s, min formats)
265                        let timeout_ms = if timeout_str.ends_with("ms") {
266                            timeout_str.trim_end_matches("ms").parse::<u64>()
267                        } else if timeout_str.ends_with("s") {
268                            timeout_str
269                                .trim_end_matches("s")
270                                .parse::<u64>()
271                                .map(|s| s * 1000)
272                        } else if timeout_str.ends_with("min") {
273                            timeout_str
274                                .trim_end_matches("min")
275                                .parse::<u64>()
276                                .map(|m| m * 60 * 1000)
277                        } else {
278                            // Default to milliseconds
279                            timeout_str.parse::<u64>()
280                        };
281
282                        match timeout_ms {
283                            Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
284                            _ => None,
285                        }
286                    };
287
288                    Self::set_statement_timeout(client, timeout);
289                    Ok(Some(Response::Execution(Tag::new("SET"))))
290                } else {
291                    Err(PgWireError::UserError(Box::new(
292                        pgwire::error::ErrorInfo::new(
293                            "ERROR".to_string(),
294                            "42601".to_string(),
295                            "Invalid SET statement_timeout syntax".to_string(),
296                        ),
297                    )))
298                }
299            } else {
300                // pass SET query to datafusion
301                if let Err(e) = self.session_context.sql(query_lower).await {
302                    warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
303                }
304
305                // Always return SET success
306                Ok(Some(Response::Execution(Tag::new("SET"))))
307            }
308        } else {
309            Ok(None)
310        }
311    }
312
313    async fn try_respond_transaction_statements<'a, C>(
314        &self,
315        client: &C,
316        query_lower: &str,
317    ) -> PgWireResult<Option<Response<'a>>>
318    where
319        C: ClientInfo,
320    {
321        // Transaction handling based on pgwire example:
322        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
323        match query_lower.trim() {
324            "begin" | "begin transaction" | "begin work" | "start transaction" => {
325                match client.transaction_status() {
326                    TransactionStatus::Idle => {
327                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
328                    }
329                    TransactionStatus::Transaction => {
330                        // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
331                        // This matches PostgreSQL's handling of nested transaction blocks
332                        log::warn!("BEGIN command ignored: already in transaction block");
333                        Ok(Some(Response::Execution(Tag::new("BEGIN"))))
334                    }
335                    TransactionStatus::Error => {
336                        // Can't start new transaction from failed state
337                        Err(PgWireError::UserError(Box::new(
338                            pgwire::error::ErrorInfo::new(
339                                "ERROR".to_string(),
340                                "25P01".to_string(),
341                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
342                            ),
343                        )))
344                    }
345                }
346            }
347            "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
348                match client.transaction_status() {
349                    TransactionStatus::Idle | TransactionStatus::Transaction => {
350                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
351                    }
352                    TransactionStatus::Error => {
353                        Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
354                    }
355                }
356            }
357            "rollback" | "rollback transaction" | "rollback work" | "abort" => {
358                Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
359            }
360            _ => Ok(None),
361        }
362    }
363
364    async fn try_respond_show_statements<'a, C>(
365        &self,
366        client: &C,
367        query_lower: &str,
368    ) -> PgWireResult<Option<Response<'a>>>
369    where
370        C: ClientInfo,
371    {
372        if query_lower.starts_with("show ") {
373            match query_lower.strip_suffix(";").unwrap_or(query_lower) {
374                "show time zone" => {
375                    let timezone = self.timezone.lock().await.clone();
376                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
377                    Ok(Some(Response::Query(resp)))
378                }
379                "show server_version" => {
380                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
381                    Ok(Some(Response::Query(resp)))
382                }
383                "show transaction_isolation" => {
384                    let resp =
385                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
386                    Ok(Some(Response::Query(resp)))
387                }
388                "show catalogs" => {
389                    let catalogs = self.session_context.catalog_names();
390                    let value = catalogs.join(", ");
391                    let resp = Self::mock_show_response("Catalogs", &value)?;
392                    Ok(Some(Response::Query(resp)))
393                }
394                "show search_path" => {
395                    let default_schema = "public";
396                    let resp = Self::mock_show_response("search_path", default_schema)?;
397                    Ok(Some(Response::Query(resp)))
398                }
399                "show statement_timeout" => {
400                    let timeout = Self::get_statement_timeout(client);
401                    let timeout_str = match timeout {
402                        Some(duration) => format!("{}ms", duration.as_millis()),
403                        None => "0".to_string(),
404                    };
405                    let resp = Self::mock_show_response("statement_timeout", &timeout_str)?;
406                    Ok(Some(Response::Query(resp)))
407                }
408                "show transaction isolation level" => {
409                    let resp = Self::mock_show_response("transaction_isolation", "read_committed")?;
410                    Ok(Some(Response::Query(resp)))
411                }
412                _ => {
413                    info!("Unsupported show statement: {query_lower}");
414                    let resp = Self::mock_show_response("unsupported_show_statement", "")?;
415                    Ok(Some(Response::Query(resp)))
416                }
417            }
418        } else {
419            Ok(None)
420        }
421    }
422}
423
424#[async_trait]
425impl SimpleQueryHandler for DfSessionService {
426    async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
427    where
428        C: ClientInfo + Unpin + Send + Sync,
429    {
430        log::debug!("Received query: {query}"); // Log the query for debugging
431
432        // Check for transaction commands early to avoid SQL parsing issues with ABORT
433        let query_lower = query.to_lowercase().trim().to_string();
434        if let Some(resp) = self
435            .try_respond_transaction_statements(client, &query_lower)
436            .await?
437        {
438            return Ok(vec![resp]);
439        }
440
441        let mut statements = self
442            .parser
443            .sql_parser
444            .parse(query)
445            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
446
447        // TODO: deal with multiple statements
448        let statement = statements.remove(0);
449
450        // TODO: improve statement check by using statement directly
451        let query = statement.to_string();
452        let query_lower = query.to_lowercase().trim().to_string();
453
454        // Check permissions for the query (skip for SET, transaction, and SHOW statements)
455        if !query_lower.starts_with("set")
456            && !query_lower.starts_with("begin")
457            && !query_lower.starts_with("commit")
458            && !query_lower.starts_with("rollback")
459            && !query_lower.starts_with("start")
460            && !query_lower.starts_with("end")
461            && !query_lower.starts_with("abort")
462            && !query_lower.starts_with("show")
463        {
464            self.check_query_permission(client, &query).await?;
465        }
466
467        if let Some(resp) = self
468            .try_respond_set_statements(client, &query_lower)
469            .await?
470        {
471            return Ok(vec![resp]);
472        }
473
474        if let Some(resp) = self
475            .try_respond_show_statements(client, &query_lower)
476            .await?
477        {
478            return Ok(vec![resp]);
479        }
480
481        // Check if we're in a failed transaction and block non-transaction
482        // commands
483        if client.transaction_status() == TransactionStatus::Error {
484            return Err(PgWireError::UserError(Box::new(
485                pgwire::error::ErrorInfo::new(
486                    "ERROR".to_string(),
487                    "25P01".to_string(),
488                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
489                ),
490            )));
491        }
492
493        let df_result = {
494            let timeout = Self::get_statement_timeout(client);
495            if let Some(timeout_duration) = timeout {
496                tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
497                    .await
498                    .map_err(|_| {
499                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
500                            "ERROR".to_string(),
501                            "57014".to_string(), // query_canceled error code
502                            "canceling statement due to statement timeout".to_string(),
503                        )))
504                    })?
505            } else {
506                self.session_context.sql(&query).await
507            }
508        };
509
510        // Handle query execution errors and transaction state
511        let df = match df_result {
512            Ok(df) => df,
513            Err(e) => {
514                return Err(PgWireError::ApiError(Box::new(e)));
515            }
516        };
517
518        if query_lower.starts_with("insert into") {
519            // For INSERT queries, we need to execute the query to get the row count
520            // and return an Execution response with the proper tag
521            let result = df
522                .clone()
523                .collect()
524                .await
525                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
526
527            // Extract count field from the first batch
528            let rows_affected = result
529                .first()
530                .and_then(|batch| batch.column_by_name("count"))
531                .and_then(|col| {
532                    col.as_any()
533                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
534                })
535                .map_or(0, |array| array.value(0) as usize);
536
537            // Create INSERT tag with the affected row count
538            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
539            Ok(vec![Response::Execution(tag)])
540        } else {
541            // For non-INSERT queries, return a regular Query response
542            let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
543            Ok(vec![Response::Query(resp)])
544        }
545    }
546}
547
548#[async_trait]
549impl ExtendedQueryHandler for DfSessionService {
550    type Statement = (String, LogicalPlan);
551    type QueryParser = Parser;
552
553    fn query_parser(&self) -> Arc<Self::QueryParser> {
554        self.parser.clone()
555    }
556
557    async fn do_describe_statement<C>(
558        &self,
559        _client: &mut C,
560        target: &StoredStatement<Self::Statement>,
561    ) -> PgWireResult<DescribeStatementResponse>
562    where
563        C: ClientInfo + Unpin + Send + Sync,
564    {
565        let (_, plan) = &target.statement;
566        let schema = plan.schema();
567        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
568        let params = plan
569            .get_parameter_types()
570            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
571
572        let mut param_types = Vec::with_capacity(params.len());
573        for param_type in ordered_param_types(&params).iter() {
574            // Fixed: Use &params
575            if let Some(datatype) = param_type {
576                let pgtype = into_pg_type(datatype)?;
577                param_types.push(pgtype);
578            } else {
579                param_types.push(Type::UNKNOWN);
580            }
581        }
582
583        Ok(DescribeStatementResponse::new(param_types, fields))
584    }
585
586    async fn do_describe_portal<C>(
587        &self,
588        _client: &mut C,
589        target: &Portal<Self::Statement>,
590    ) -> PgWireResult<DescribePortalResponse>
591    where
592        C: ClientInfo + Unpin + Send + Sync,
593    {
594        let (_, plan) = &target.statement.statement;
595        let format = &target.result_column_format;
596        let schema = plan.schema();
597        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
598
599        Ok(DescribePortalResponse::new(fields))
600    }
601
602    async fn do_query<'a, C>(
603        &self,
604        client: &mut C,
605        portal: &Portal<Self::Statement>,
606        _max_rows: usize,
607    ) -> PgWireResult<Response<'a>>
608    where
609        C: ClientInfo + Unpin + Send + Sync,
610    {
611        let query = portal
612            .statement
613            .statement
614            .0
615            .to_lowercase()
616            .trim()
617            .to_string();
618        log::debug!("Received execute extended query: {query}"); // Log for debugging
619
620        // Check permissions for the query (skip for SET and SHOW statements)
621        if !query.starts_with("set") && !query.starts_with("show") {
622            self.check_query_permission(client, &portal.statement.statement.0)
623                .await?;
624        }
625
626        if let Some(resp) = self.try_respond_set_statements(client, &query).await? {
627            return Ok(resp);
628        }
629
630        if let Some(resp) = self
631            .try_respond_transaction_statements(client, &query)
632            .await?
633        {
634            return Ok(resp);
635        }
636
637        if let Some(resp) = self.try_respond_show_statements(client, &query).await? {
638            return Ok(resp);
639        }
640
641        // Check if we're in a failed transaction and block non-transaction
642        // commands
643        if client.transaction_status() == TransactionStatus::Error {
644            return Err(PgWireError::UserError(Box::new(
645                pgwire::error::ErrorInfo::new(
646                    "ERROR".to_string(),
647                    "25P01".to_string(),
648                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
649                ),
650            )));
651        }
652
653        let (_, plan) = &portal.statement.statement;
654
655        let param_types = plan
656            .get_parameter_types()
657            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
658
659        let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
660
661        let plan = plan
662            .clone()
663            .replace_params_with_values(&param_values)
664            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
665                                                               // &param_values
666        let optimised = self
667            .session_context
668            .state()
669            .optimize(&plan)
670            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
671
672        let dataframe = {
673            let timeout = Self::get_statement_timeout(client);
674            if let Some(timeout_duration) = timeout {
675                tokio::time::timeout(
676                    timeout_duration,
677                    self.session_context.execute_logical_plan(optimised),
678                )
679                .await
680                .map_err(|_| {
681                    PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
682                        "ERROR".to_string(),
683                        "57014".to_string(), // query_canceled error code
684                        "canceling statement due to statement timeout".to_string(),
685                    )))
686                })?
687                .map_err(|e| PgWireError::ApiError(Box::new(e)))?
688            } else {
689                self.session_context
690                    .execute_logical_plan(optimised)
691                    .await
692                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
693            }
694        };
695        let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
696        Ok(Response::Query(resp))
697    }
698}
699
700pub struct Parser {
701    session_context: Arc<SessionContext>,
702    sql_parser: PostgresCompatibilityParser,
703}
704
705impl Parser {
706    fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
707        // Check for transaction commands that shouldn't be parsed by DataFusion
708        let sql_lower = sql.to_lowercase();
709        let sql_trimmed = sql_lower.trim();
710
711        if matches!(
712            sql_trimmed,
713            "" | "begin"
714                | "begin transaction"
715                | "begin work"
716                | "start transaction"
717                | "commit"
718                | "commit transaction"
719                | "commit work"
720                | "end"
721                | "end transaction"
722                | "rollback"
723                | "rollback transaction"
724                | "rollback work"
725                | "abort"
726        ) {
727            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
728            let dummy_schema = datafusion::common::DFSchema::empty();
729            return Ok(Some(LogicalPlan::EmptyRelation(
730                datafusion::logical_expr::EmptyRelation {
731                    produce_one_row: false,
732                    schema: Arc::new(dummy_schema),
733                },
734            )));
735        }
736
737        // show statement may not be supported by datafusion
738        if sql_trimmed.starts_with("show") {
739            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
740            let show_schema =
741                Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
742            let df_schema = show_schema.to_dfschema()?;
743            return Ok(Some(LogicalPlan::EmptyRelation(
744                datafusion::logical_expr::EmptyRelation {
745                    produce_one_row: true,
746                    schema: Arc::new(df_schema),
747                },
748            )));
749        }
750
751        Ok(None)
752    }
753}
754
755#[async_trait]
756impl QueryParser for Parser {
757    type Statement = (String, LogicalPlan);
758
759    async fn parse_sql<C>(
760        &self,
761        _client: &C,
762        sql: &str,
763        _types: &[Type],
764    ) -> PgWireResult<Self::Statement> {
765        log::debug!("Received parse extended query: {sql}"); // Log for debugging
766
767        // Check for transaction commands that shouldn't be parsed by DataFusion
768        if let Some(plan) = self
769            .try_shortcut_parse_plan(sql)
770            .map_err(|e| PgWireError::ApiError(Box::new(e)))?
771        {
772            return Ok((sql.to_string(), plan));
773        }
774
775        let mut statements = self
776            .sql_parser
777            .parse(sql)
778            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
779        let statement = statements.remove(0);
780
781        let query = statement.to_string();
782
783        let context = &self.session_context;
784        let state = context.state();
785        let logical_plan = state
786            .statement_to_plan(Statement::Statement(Box::new(statement)))
787            .await
788            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
789        Ok((query, logical_plan))
790    }
791}
792
793fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
794    // Datafusion stores the parameters as a map.  In our case, the keys will be
795    // `$1`, `$2` etc.  The values will be the parameter types.
796    let mut types = types.iter().collect::<Vec<_>>();
797    types.sort_by(|a, b| a.0.cmp(b.0));
798    types.into_iter().map(|pt| pt.1.as_ref()).collect()
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use crate::auth::AuthManager;
805    use datafusion::prelude::SessionContext;
806    use std::collections::HashMap;
807    use std::time::Duration;
808
809    struct MockClient {
810        metadata: HashMap<String, String>,
811    }
812
813    impl MockClient {
814        fn new() -> Self {
815            Self {
816                metadata: HashMap::new(),
817            }
818        }
819    }
820
821    impl ClientInfo for MockClient {
822        fn socket_addr(&self) -> std::net::SocketAddr {
823            "127.0.0.1:5432".parse().unwrap()
824        }
825
826        fn is_secure(&self) -> bool {
827            false
828        }
829
830        fn protocol_version(&self) -> pgwire::messages::ProtocolVersion {
831            pgwire::messages::ProtocolVersion::PROTOCOL3_0
832        }
833
834        fn set_protocol_version(&mut self, _version: pgwire::messages::ProtocolVersion) {}
835
836        fn pid_and_secret_key(&self) -> (i32, pgwire::messages::startup::SecretKey) {
837            (0, pgwire::messages::startup::SecretKey::I32(0))
838        }
839
840        fn set_pid_and_secret_key(
841            &mut self,
842            _pid: i32,
843            _secret_key: pgwire::messages::startup::SecretKey,
844        ) {
845        }
846
847        fn state(&self) -> pgwire::api::PgWireConnectionState {
848            pgwire::api::PgWireConnectionState::ReadyForQuery
849        }
850
851        fn set_state(&mut self, _new_state: pgwire::api::PgWireConnectionState) {}
852
853        fn transaction_status(&self) -> pgwire::messages::response::TransactionStatus {
854            pgwire::messages::response::TransactionStatus::Idle
855        }
856
857        fn set_transaction_status(
858            &mut self,
859            _new_status: pgwire::messages::response::TransactionStatus,
860        ) {
861        }
862
863        fn metadata(&self) -> &HashMap<String, String> {
864            &self.metadata
865        }
866
867        fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
868            &mut self.metadata
869        }
870
871        fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
872            None
873        }
874    }
875
876    #[tokio::test]
877    async fn test_statement_timeout_set_and_show() {
878        let session_context = Arc::new(SessionContext::new());
879        let auth_manager = Arc::new(AuthManager::new());
880        let service = DfSessionService::new(session_context, auth_manager);
881        let mut client = MockClient::new();
882
883        // Test setting timeout to 5000ms
884        let set_response = service
885            .try_respond_set_statements(&mut client, "set statement_timeout '5000ms'")
886            .await
887            .unwrap();
888        assert!(set_response.is_some());
889
890        // Verify the timeout was set in client metadata
891        let timeout = DfSessionService::get_statement_timeout(&client);
892        assert_eq!(timeout, Some(Duration::from_millis(5000)));
893
894        // Test SHOW statement_timeout
895        let show_response = service
896            .try_respond_show_statements(&client, "show statement_timeout")
897            .await
898            .unwrap();
899        assert!(show_response.is_some());
900    }
901
902    #[tokio::test]
903    async fn test_statement_timeout_disable() {
904        let session_context = Arc::new(SessionContext::new());
905        let auth_manager = Arc::new(AuthManager::new());
906        let service = DfSessionService::new(session_context, auth_manager);
907        let mut client = MockClient::new();
908
909        // Set timeout first
910        service
911            .try_respond_set_statements(&mut client, "set statement_timeout '1000ms'")
912            .await
913            .unwrap();
914
915        // Disable timeout with 0
916        service
917            .try_respond_set_statements(&mut client, "set statement_timeout '0'")
918            .await
919            .unwrap();
920
921        let timeout = DfSessionService::get_statement_timeout(&client);
922        assert_eq!(timeout, None);
923    }
924}