Skip to main content

datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::StartupHandler;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::cancel::{CancelHandler, DefaultCancelHandler};
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{FieldInfo, Response, Tag};
18use pgwire::api::stmt::QueryParser;
19use pgwire::api::store::PortalStore;
20use pgwire::api::{
21    ClientInfo, ClientPortalStore, ConnectionManager, ErrorHandler, PgWireServerHandlers, Type,
22};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::PgWireBackendMessage;
25use pgwire::types::format::FormatOptions;
26
27use crate::hooks::QueryHook;
28use crate::hooks::cursor::CursorStatementHook;
29use crate::hooks::set_show::SetShowHook;
30use crate::hooks::transactions::TransactionStatementHook;
31use crate::{client, planner};
32use arrow_pg::datatypes::df;
33use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
34use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
35
36/// Simple startup handler that does no authentication
37pub struct SimpleStartupHandler {
38    connection_manager: Arc<ConnectionManager>,
39}
40
41#[async_trait::async_trait]
42impl NoopStartupHandler for SimpleStartupHandler {
43    fn connection_manager(&self) -> Option<Arc<ConnectionManager>> {
44        Some(self.connection_manager.clone())
45    }
46}
47
48pub struct HandlerFactory {
49    pub session_service: Arc<DfSessionService>,
50    cancel_handler: Arc<DefaultCancelHandler>,
51    startup_handler: Arc<SimpleStartupHandler>,
52}
53
54impl HandlerFactory {
55    pub fn new(session_context: Arc<SessionContext>) -> Self {
56        let session_service = Arc::new(DfSessionService::new(session_context));
57        let connection_manager = Arc::new(ConnectionManager::new());
58        HandlerFactory {
59            session_service,
60            cancel_handler: Arc::new(DefaultCancelHandler::new(connection_manager.clone())),
61            startup_handler: Arc::new(SimpleStartupHandler {
62                connection_manager: connection_manager.clone(),
63            }),
64        }
65    }
66
67    pub fn new_with_hooks(
68        session_context: Arc<SessionContext>,
69        query_hooks: Vec<Arc<dyn QueryHook>>,
70    ) -> Self {
71        let session_service = Arc::new(DfSessionService::new_with_hooks(
72            session_context,
73            query_hooks,
74        ));
75        let connection_manager = Arc::new(ConnectionManager::new());
76        HandlerFactory {
77            session_service,
78            cancel_handler: Arc::new(DefaultCancelHandler::new(connection_manager.clone())),
79            startup_handler: Arc::new(SimpleStartupHandler {
80                connection_manager: connection_manager.clone(),
81            }),
82        }
83    }
84}
85
86impl PgWireServerHandlers for HandlerFactory {
87    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
88        self.session_service.clone()
89    }
90
91    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
92        self.session_service.clone()
93    }
94
95    fn startup_handler(&self) -> Arc<impl StartupHandler> {
96        self.startup_handler.clone()
97    }
98
99    fn error_handler(&self) -> Arc<impl ErrorHandler> {
100        Arc::new(LoggingErrorHandler)
101    }
102
103    fn cancel_handler(&self) -> Arc<impl CancelHandler> {
104        self.cancel_handler.clone()
105    }
106}
107
108struct LoggingErrorHandler;
109
110impl ErrorHandler for LoggingErrorHandler {
111    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
112    where
113        C: ClientInfo,
114    {
115        info!("Sending error: {error}")
116    }
117}
118
119/// The pgwire handler backed by a datafusion `SessionContext`
120pub struct DfSessionService {
121    session_context: Arc<SessionContext>,
122    parser: Arc<Parser>,
123    query_hooks: Vec<Arc<dyn QueryHook>>,
124}
125
126impl DfSessionService {
127    pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
128        let hooks: Vec<Arc<dyn QueryHook>> = vec![
129            Arc::new(CursorStatementHook),
130            Arc::new(SetShowHook),
131            Arc::new(TransactionStatementHook),
132        ];
133        Self::new_with_hooks(session_context, hooks)
134    }
135
136    pub fn new_with_hooks(
137        session_context: Arc<SessionContext>,
138        query_hooks: Vec<Arc<dyn QueryHook>>,
139    ) -> DfSessionService {
140        let parser = Arc::new(Parser {
141            session_context: session_context.clone(),
142            sql_parser: PostgresCompatibilityParser::new(),
143            query_hooks: query_hooks.clone(),
144        });
145        DfSessionService {
146            session_context,
147            parser,
148            query_hooks,
149        }
150    }
151}
152
153#[async_trait]
154impl SimpleQueryHandler for DfSessionService {
155    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
156    where
157        C: ClientInfo
158            + ClientPortalStore
159            + futures::Sink<PgWireBackendMessage>
160            + Unpin
161            + Send
162            + Sync,
163        C::PortalStore: PortalStore,
164        C::Error: std::fmt::Debug,
165        PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
166    {
167        log::debug!("Received query: {query}");
168
169        let statements = self
170            .parser
171            .sql_parser
172            .parse(query)
173            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
174
175        // empty query
176        if statements.is_empty() {
177            return Ok(vec![Response::EmptyQuery]);
178        }
179
180        let mut results = vec![];
181        'stmt: for statement in statements {
182            // Call query hooks with the parsed statement
183            for hook in &self.query_hooks {
184                if let Some(result) = hook
185                    .handle_simple_query(&statement, &self.session_context, client)
186                    .await
187                {
188                    results.push(result?);
189                    continue 'stmt;
190                }
191            }
192
193            let df_result = {
194                let query = statement.to_string();
195
196                let timeout = client::get_statement_timeout(client);
197                if let Some(timeout_duration) = timeout {
198                    tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
199                        .await
200                        .map_err(|_| {
201                            PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
202                                "ERROR".to_string(),
203                                "57014".to_string(), // query_canceled error code
204                                "canceling statement due to statement timeout".to_string(),
205                            )))
206                        })?
207                } else {
208                    self.session_context.sql(&query).await
209                }
210            };
211
212            // Handle query execution errors and transaction state
213            let df = match df_result {
214                Ok(df) => df,
215                Err(e) => {
216                    return Err(PgWireError::ApiError(Box::new(e)));
217                }
218            };
219
220            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
221                let resp = map_rows_affected_for_insert(&df).await?;
222                results.push(resp);
223            } else {
224                // For non-INSERT queries, return a regular Query response
225                let format_options =
226                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
227                let resp =
228                    df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
229                results.push(Response::Query(resp));
230            }
231        }
232        Ok(results)
233    }
234}
235
236#[async_trait]
237impl ExtendedQueryHandler for DfSessionService {
238    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
239    type QueryParser = Parser;
240
241    fn query_parser(&self) -> Arc<Self::QueryParser> {
242        self.parser.clone()
243    }
244
245    async fn do_query<C>(
246        &self,
247        client: &mut C,
248        portal: &Portal<Self::Statement>,
249        _max_rows: usize,
250    ) -> PgWireResult<Response>
251    where
252        C: ClientInfo
253            + ClientPortalStore
254            + futures::Sink<PgWireBackendMessage>
255            + Unpin
256            + Send
257            + Sync,
258        C::PortalStore: PortalStore,
259        C::Error: std::fmt::Debug,
260        PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
261    {
262        let query = &portal.statement.statement.0;
263        log::debug!("Received execute extended query: {query}");
264        // Check query hooks first
265        if !self.query_hooks.is_empty()
266            && let (_, Some((statement, plan))) = &portal.statement.statement
267        {
268            // TODO: in the case where query hooks all return None, we do the param handling again later.
269            let param_types = planner::get_inferred_parameter_types(plan)
270                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
271
272            let param_values: ParamValues =
273                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
274
275            for hook in &self.query_hooks {
276                if let Some(result) = hook
277                    .handle_extended_query(
278                        statement,
279                        plan,
280                        &param_values,
281                        &self.session_context,
282                        client,
283                    )
284                    .await
285                {
286                    return result;
287                }
288            }
289        }
290
291        if let (_, Some((statement, plan))) = &portal.statement.statement {
292            let param_types = planner::get_inferred_parameter_types(plan)
293                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
294
295            let param_values =
296                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
297
298            let plan = plan
299                .clone()
300                .replace_params_with_values(&param_values)
301                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
302            let optimised = self
303                .session_context
304                .state()
305                .optimize(&plan)
306                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
307
308            let dataframe = {
309                let timeout = client::get_statement_timeout(client);
310                if let Some(timeout_duration) = timeout {
311                    tokio::time::timeout(
312                        timeout_duration,
313                        self.session_context.execute_logical_plan(optimised),
314                    )
315                    .await
316                    .map_err(|_| {
317                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
318                            "ERROR".to_string(),
319                            "57014".to_string(), // query_canceled error code
320                            "canceling statement due to statement timeout".to_string(),
321                        )))
322                    })?
323                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
324                } else {
325                    self.session_context
326                        .execute_logical_plan(optimised)
327                        .await
328                        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
329                }
330            };
331
332            if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
333                let resp = map_rows_affected_for_insert(&dataframe).await?;
334
335                Ok(resp)
336            } else {
337                // For non-INSERT queries, return a regular Query response
338                let format_options =
339                    Arc::new(FormatOptions::from_client_metadata(client.metadata()));
340                let resp = df::encode_dataframe(
341                    dataframe,
342                    &portal.result_column_format,
343                    Some(format_options),
344                )
345                .await?;
346                Ok(Response::Query(resp))
347            }
348        } else {
349            Ok(Response::EmptyQuery)
350        }
351    }
352}
353
354async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
355    // For INSERT queries, we need to execute the query to get the row count
356    // and return an Execution response with the proper tag
357    let result = df
358        .clone()
359        .collect()
360        .await
361        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
362
363    // Extract count field from the first batch
364    let rows_affected = result
365        .first()
366        .and_then(|batch| batch.column_by_name("count"))
367        .and_then(|col| {
368            col.as_any()
369                .downcast_ref::<datafusion::arrow::array::UInt64Array>()
370        })
371        .map_or(0, |array| array.value(0) as usize);
372
373    // Create INSERT tag with the affected row count
374    let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
375    Ok(Response::Execution(tag))
376}
377
378pub struct Parser {
379    session_context: Arc<SessionContext>,
380    sql_parser: PostgresCompatibilityParser,
381    query_hooks: Vec<Arc<dyn QueryHook>>,
382}
383
384#[async_trait]
385impl QueryParser for Parser {
386    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
387
388    async fn parse_sql<C>(
389        &self,
390        client: &C,
391        sql: &str,
392        _types: &[Option<Type>],
393    ) -> PgWireResult<Self::Statement>
394    where
395        C: ClientInfo + Unpin + Send + Sync,
396    {
397        log::debug!("Received parse extended query: {sql}");
398        let mut statements = self
399            .sql_parser
400            .parse(sql)
401            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402        if statements.is_empty() {
403            return Ok((sql.to_string(), None));
404        }
405
406        let statement = statements.remove(0);
407        let query = statement.to_string();
408
409        let context = &self.session_context;
410        let state = context.state();
411
412        for hook in &self.query_hooks {
413            if let Some(logical_plan) = hook
414                .handle_extended_parse_query(&statement, context, client)
415                .await
416            {
417                return Ok((query, Some((statement, logical_plan?))));
418            }
419        }
420
421        let logical_plan = state
422            .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
423            .await
424            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
425        Ok((query, Some((statement, logical_plan))))
426    }
427
428    fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
429        if let (_, Some((_, plan))) = stmt {
430            let params = planner::get_inferred_parameter_types(plan)
431                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
432
433            let mut param_types = Vec::with_capacity(params.len());
434            for param_type in ordered_param_types(&params).iter() {
435                if let Some(datatype) = param_type {
436                    let pgtype = into_pg_type(datatype)?;
437                    param_types.push(pgtype);
438                } else {
439                    param_types.push(Type::UNKNOWN);
440                }
441            }
442
443            Ok(param_types)
444        } else {
445            Ok(vec![])
446        }
447    }
448
449    fn get_result_schema(
450        &self,
451        stmt: &Self::Statement,
452        column_format: Option<&Format>,
453    ) -> PgWireResult<Vec<FieldInfo>> {
454        if let (_, Some((_, plan))) = stmt {
455            if !matches!(plan, LogicalPlan::Ddl(_) | LogicalPlan::Dml(_)) {
456                let schema = plan.schema();
457                let fields = arrow_schema_to_pg_fields(
458                    schema.as_arrow(),
459                    column_format.unwrap_or(&Format::UnifiedText),
460                    None,
461                )?;
462
463                Ok(fields)
464            } else {
465                Ok(vec![])
466            }
467        } else {
468            Ok(vec![])
469        }
470    }
471}
472
473fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
474    // Datafusion stores the parameters as a map.  In our case, the keys will be
475    // `$1`, `$2` etc.  The values will be the parameter types.
476    let mut types = types.iter().collect::<Vec<_>>();
477    types.sort_by(|a, b| a.0.cmp(b.0));
478    types.into_iter().map(|pt| pt.1.as_ref()).collect()
479}
480
481#[cfg(test)]
482mod tests {
483    use datafusion::prelude::SessionContext;
484
485    use super::*;
486    use crate::testing::MockClient;
487
488    use crate::hooks::HookClient;
489
490    struct TestHook;
491
492    #[async_trait]
493    impl QueryHook for TestHook {
494        async fn handle_simple_query(
495            &self,
496            statement: &sqlparser::ast::Statement,
497            _ctx: &SessionContext,
498            _client: &mut dyn HookClient,
499        ) -> Option<PgWireResult<Response>> {
500            if statement.to_string().contains("magic") {
501                Some(Ok(Response::EmptyQuery))
502            } else {
503                None
504            }
505        }
506
507        async fn handle_extended_parse_query(
508            &self,
509            _statement: &sqlparser::ast::Statement,
510            _session_context: &SessionContext,
511            _client: &(dyn ClientInfo + Send + Sync),
512        ) -> Option<PgWireResult<LogicalPlan>> {
513            None
514        }
515
516        async fn handle_extended_query(
517            &self,
518            _statement: &sqlparser::ast::Statement,
519            _logical_plan: &LogicalPlan,
520            _params: &ParamValues,
521            _session_context: &SessionContext,
522            _client: &mut dyn HookClient,
523        ) -> Option<PgWireResult<Response>> {
524            None
525        }
526    }
527
528    #[tokio::test]
529    async fn test_query_hooks() {
530        let hook = TestHook;
531        let ctx = SessionContext::new();
532        let mut client = MockClient::new();
533
534        // Parse a statement that contains "magic"
535        let parser = PostgresCompatibilityParser::new();
536        let statements = parser.parse("SELECT magic").unwrap();
537        let stmt = &statements[0];
538
539        // Hook should intercept
540        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
541        assert!(result.is_some());
542
543        // Parse a normal statement
544        let statements = parser.parse("SELECT 1").unwrap();
545        let stmt = &statements[0];
546
547        // Hook should not intercept
548        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
549        assert!(result.is_none());
550    }
551
552    #[tokio::test]
553    async fn test_multiple_statements_with_hook_continue() {
554        // Bug #227: when a hook returned a result, the code used `break 'stmt`
555        // which would exit the entire statement loop, preventing subsequent statements
556        // from being processed.
557        let session_context = Arc::new(SessionContext::new());
558
559        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
560        let service = DfSessionService::new_with_hooks(session_context, hooks);
561
562        let mut client = MockClient::new();
563
564        // Mix of queries with hooks and those without
565        let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
566
567        let results =
568            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
569                .await
570                .unwrap();
571
572        assert_eq!(results.len(), 4, "Expected 4 responses");
573
574        assert!(matches!(results[0], Response::EmptyQuery));
575        assert!(matches!(results[1], Response::Query(_)));
576        assert!(matches!(results[2], Response::EmptyQuery));
577        assert!(matches!(results[3], Response::Query(_)));
578    }
579
580    #[tokio::test]
581    async fn test_set_sends_parameter_status_via_sink() {
582        use pgwire::messages::PgWireBackendMessage;
583
584        let service = crate::testing::setup_handlers();
585        let mut client = MockClient::new();
586
587        let test_cases = vec![
588            ("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
589            (
590                "SET intervalstyle = 'postgres'",
591                "IntervalStyle",
592                "postgres",
593            ),
594            ("SET bytea_output = 'hex'", "bytea_output", "hex"),
595            (
596                "SET application_name = 'myapp'",
597                "application_name",
598                "myapp",
599            ),
600            ("SET search_path = 'public'", "search_path", "public"),
601            ("SET extra_float_digits = '2'", "extra_float_digits", "2"),
602            (
603                "SET TIME ZONE 'America/New_York'",
604                "TimeZone",
605                "America/New_York",
606            ),
607        ];
608
609        for (sql, expected_key, expected_value) in test_cases {
610            client.sent_messages.clear();
611
612            let responses =
613                <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
614                    .await
615                    .unwrap();
616
617            assert!(
618                matches!(responses[0], Response::Execution(_)),
619                "Expected SET tag for {sql}"
620            );
621
622            let ps_msgs: Vec<_> = client
623                .sent_messages()
624                .iter()
625                .filter_map(|m| match m {
626                    PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
627                    _ => None,
628                })
629                .collect();
630
631            assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
632            assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
633            assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
634        }
635    }
636
637    #[tokio::test]
638    async fn test_set_statement_timeout_no_parameter_status() {
639        use pgwire::messages::PgWireBackendMessage;
640
641        let service = crate::testing::setup_handlers();
642        let mut client = MockClient::new();
643
644        <DfSessionService as SimpleQueryHandler>::do_query(
645            &service,
646            &mut client,
647            "SET statement_timeout TO '5000ms'",
648        )
649        .await
650        .unwrap();
651
652        let has_ps = client
653            .sent_messages()
654            .iter()
655            .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
656
657        assert!(!has_ps, "statement_timeout should not send ParameterStatus");
658    }
659
660    fn assert_execution_tag(response: &Response, expected: &str) {
661        match response {
662            Response::Execution(tag) => {
663                let cc = pgwire::messages::response::CommandComplete::from(tag.clone());
664                assert_eq!(cc.tag, expected, "Unexpected execution tag");
665            }
666            other => panic!("Expected Execution response, got: {other:?}"),
667        }
668    }
669
670    async fn assert_query_response_empty(response: &mut Response) {
671        use futures::StreamExt;
672
673        let Response::Query(qr) = response else {
674            panic!("Expected Query response, got: {response:?}");
675        };
676
677        let mut count = 0;
678        while qr.data_rows().next().await.is_some() {
679            count += 1;
680        }
681        assert_eq!(count, 0, "Expected no rows from exhausted cursor");
682    }
683
684    #[tokio::test]
685    async fn test_declare_fetch_close_cursor() {
686        let service = crate::testing::setup_handlers();
687        let mut client = MockClient::new();
688
689        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
690            &service,
691            &mut client,
692            "DECLARE test_cursor CURSOR FOR SELECT 1 AS col",
693        )
694        .await
695        .unwrap();
696
697        assert_eq!(responses.len(), 1);
698        assert_execution_tag(&responses[0], "DECLARE CURSOR");
699
700        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
701            &service,
702            &mut client,
703            "FETCH NEXT FROM test_cursor",
704        )
705        .await
706        .unwrap();
707
708        assert_eq!(responses.len(), 1);
709        assert!(
710            matches!(&responses[0], Response::Query(_)),
711            "Expected Query response for FETCH"
712        );
713
714        let mut responses = <DfSessionService as SimpleQueryHandler>::do_query(
715            &service,
716            &mut client,
717            "FETCH NEXT FROM test_cursor",
718        )
719        .await
720        .unwrap();
721
722        assert_eq!(responses.len(), 1);
723        assert_query_response_empty(&mut responses[0]).await;
724
725        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
726            &service,
727            &mut client,
728            "CLOSE test_cursor",
729        )
730        .await
731        .unwrap();
732
733        assert_eq!(responses.len(), 1);
734        assert_execution_tag(&responses[0], "CLOSE CURSOR");
735    }
736
737    #[tokio::test]
738    async fn test_fetch_nonexistent_cursor() {
739        let service = crate::testing::setup_handlers();
740        let mut client = MockClient::new();
741
742        let result = <DfSessionService as SimpleQueryHandler>::do_query(
743            &service,
744            &mut client,
745            "FETCH NEXT FROM nonexistent",
746        )
747        .await;
748
749        assert!(result.is_err());
750    }
751
752    #[tokio::test]
753    async fn test_close_all_portals() {
754        let service = crate::testing::setup_handlers();
755        let mut client = MockClient::new();
756
757        <DfSessionService as SimpleQueryHandler>::do_query(
758            &service,
759            &mut client,
760            "DECLARE c1 CURSOR FOR SELECT 1",
761        )
762        .await
763        .unwrap();
764
765        <DfSessionService as SimpleQueryHandler>::do_query(
766            &service,
767            &mut client,
768            "DECLARE c2 CURSOR FOR SELECT 2",
769        )
770        .await
771        .unwrap();
772
773        let responses =
774            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, "CLOSE ALL")
775                .await
776                .unwrap();
777
778        assert!(matches!(&responses[0], Response::Execution(_)),);
779
780        let result = <DfSessionService as SimpleQueryHandler>::do_query(
781            &service,
782            &mut client,
783            "FETCH NEXT FROM c1",
784        )
785        .await;
786        assert!(result.is_err(), "c1 should be closed");
787    }
788
789    #[tokio::test]
790    async fn test_fetch_forward_n() {
791        let service = crate::testing::setup_handlers();
792        let mut client = MockClient::new();
793
794        <DfSessionService as SimpleQueryHandler>::do_query(
795            &service,
796            &mut client,
797            "CREATE TABLE nums AS SELECT 1 AS n UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5",
798        )
799        .await
800        .unwrap();
801
802        <DfSessionService as SimpleQueryHandler>::do_query(
803            &service,
804            &mut client,
805            "DECLARE mycur CURSOR FOR SELECT n FROM nums ORDER BY n",
806        )
807        .await
808        .unwrap();
809
810        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
811            &service,
812            &mut client,
813            "FETCH FORWARD 3 FROM mycur",
814        )
815        .await
816        .unwrap();
817
818        assert!(
819            matches!(&responses[0], Response::Query(_)),
820            "Expected Query response for FORWARD 3"
821        );
822
823        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
824            &service,
825            &mut client,
826            "FETCH FORWARD ALL FROM mycur",
827        )
828        .await
829        .unwrap();
830
831        let resp_desc = match &responses[0] {
832            Response::Query(_) => "Query".to_string(),
833            Response::Execution(tag) => {
834                let cc = pgwire::messages::response::CommandComplete::from(tag.clone());
835                format!("Execution({})", cc.tag)
836            }
837            other => format!("{:?}", other),
838        };
839        assert!(
840            matches!(&responses[0], Response::Query(_)),
841            "Expected Query response for remaining rows, got: {resp_desc}"
842        );
843
844        let mut responses = <DfSessionService as SimpleQueryHandler>::do_query(
845            &service,
846            &mut client,
847            "FETCH NEXT FROM mycur",
848        )
849        .await
850        .unwrap();
851
852        assert_query_response_empty(&mut responses[0]).await;
853    }
854
855    #[tokio::test]
856    async fn test_scroll_cursor_error() {
857        let service = crate::testing::setup_handlers();
858        let mut client = MockClient::new();
859
860        <DfSessionService as SimpleQueryHandler>::do_query(
861            &service,
862            &mut client,
863            "DECLARE mycur CURSOR FOR SELECT 1",
864        )
865        .await
866        .unwrap();
867
868        let result = <DfSessionService as SimpleQueryHandler>::do_query(
869            &service,
870            &mut client,
871            "FETCH PRIOR FROM mycur",
872        )
873        .await;
874
875        assert!(result.is_err(), "PRIOR should fail on forward-only cursor");
876    }
877
878    #[tokio::test]
879    async fn test_move_cursor() {
880        let service = crate::testing::setup_handlers();
881        let mut client = MockClient::new();
882
883        <DfSessionService as SimpleQueryHandler>::do_query(
884            &service,
885            &mut client,
886            "DECLARE mycur CURSOR FOR SELECT generate_series(1, 5) AS n",
887        )
888        .await
889        .unwrap();
890
891        let responses = <DfSessionService as SimpleQueryHandler>::do_query(
892            &service,
893            &mut client,
894            "FETCH FORWARD 3 FROM mycur",
895        )
896        .await
897        .unwrap();
898
899        assert!(matches!(&responses[0], Response::Query(_)));
900    }
901}