1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::StartupHandler;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::cancel::{CancelHandler, DefaultCancelHandler};
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{FieldInfo, Response, Tag};
18use pgwire::api::stmt::QueryParser;
19use pgwire::api::store::PortalStore;
20use pgwire::api::{
21 ClientInfo, ClientPortalStore, ConnectionManager, ErrorHandler, PgWireServerHandlers, Type,
22};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::PgWireBackendMessage;
25use pgwire::types::format::FormatOptions;
26
27use crate::hooks::QueryHook;
28use crate::hooks::cursor::CursorStatementHook;
29use crate::hooks::set_show::SetShowHook;
30use crate::hooks::transactions::TransactionStatementHook;
31use crate::{client, planner};
32use arrow_pg::datatypes::df;
33use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
34use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
35
36pub struct SimpleStartupHandler {
38 connection_manager: Arc<ConnectionManager>,
39}
40
41#[async_trait::async_trait]
42impl NoopStartupHandler for SimpleStartupHandler {
43 fn connection_manager(&self) -> Option<Arc<ConnectionManager>> {
44 Some(self.connection_manager.clone())
45 }
46}
47
48pub struct HandlerFactory {
49 pub session_service: Arc<DfSessionService>,
50 cancel_handler: Arc<DefaultCancelHandler>,
51 startup_handler: Arc<SimpleStartupHandler>,
52}
53
54impl HandlerFactory {
55 pub fn new(session_context: Arc<SessionContext>) -> Self {
56 let session_service = Arc::new(DfSessionService::new(session_context));
57 let connection_manager = Arc::new(ConnectionManager::new());
58 HandlerFactory {
59 session_service,
60 cancel_handler: Arc::new(DefaultCancelHandler::new(connection_manager.clone())),
61 startup_handler: Arc::new(SimpleStartupHandler {
62 connection_manager: connection_manager.clone(),
63 }),
64 }
65 }
66
67 pub fn new_with_hooks(
68 session_context: Arc<SessionContext>,
69 query_hooks: Vec<Arc<dyn QueryHook>>,
70 ) -> Self {
71 let session_service = Arc::new(DfSessionService::new_with_hooks(
72 session_context,
73 query_hooks,
74 ));
75 let connection_manager = Arc::new(ConnectionManager::new());
76 HandlerFactory {
77 session_service,
78 cancel_handler: Arc::new(DefaultCancelHandler::new(connection_manager.clone())),
79 startup_handler: Arc::new(SimpleStartupHandler {
80 connection_manager: connection_manager.clone(),
81 }),
82 }
83 }
84}
85
86impl PgWireServerHandlers for HandlerFactory {
87 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
88 self.session_service.clone()
89 }
90
91 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
92 self.session_service.clone()
93 }
94
95 fn startup_handler(&self) -> Arc<impl StartupHandler> {
96 self.startup_handler.clone()
97 }
98
99 fn error_handler(&self) -> Arc<impl ErrorHandler> {
100 Arc::new(LoggingErrorHandler)
101 }
102
103 fn cancel_handler(&self) -> Arc<impl CancelHandler> {
104 self.cancel_handler.clone()
105 }
106}
107
108struct LoggingErrorHandler;
109
110impl ErrorHandler for LoggingErrorHandler {
111 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
112 where
113 C: ClientInfo,
114 {
115 info!("Sending error: {error}")
116 }
117}
118
119pub struct DfSessionService {
121 session_context: Arc<SessionContext>,
122 parser: Arc<Parser>,
123 query_hooks: Vec<Arc<dyn QueryHook>>,
124}
125
126impl DfSessionService {
127 pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
128 let hooks: Vec<Arc<dyn QueryHook>> = vec![
129 Arc::new(CursorStatementHook),
130 Arc::new(SetShowHook),
131 Arc::new(TransactionStatementHook),
132 ];
133 Self::new_with_hooks(session_context, hooks)
134 }
135
136 pub fn new_with_hooks(
137 session_context: Arc<SessionContext>,
138 query_hooks: Vec<Arc<dyn QueryHook>>,
139 ) -> DfSessionService {
140 let parser = Arc::new(Parser {
141 session_context: session_context.clone(),
142 sql_parser: PostgresCompatibilityParser::new(),
143 query_hooks: query_hooks.clone(),
144 });
145 DfSessionService {
146 session_context,
147 parser,
148 query_hooks,
149 }
150 }
151}
152
153#[async_trait]
154impl SimpleQueryHandler for DfSessionService {
155 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
156 where
157 C: ClientInfo
158 + ClientPortalStore
159 + futures::Sink<PgWireBackendMessage>
160 + Unpin
161 + Send
162 + Sync,
163 C::PortalStore: PortalStore,
164 C::Error: std::fmt::Debug,
165 PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
166 {
167 log::debug!("Received query: {query}");
168
169 let statements = self
170 .parser
171 .sql_parser
172 .parse(query)
173 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
174
175 if statements.is_empty() {
177 return Ok(vec![Response::EmptyQuery]);
178 }
179
180 let mut results = vec![];
181 'stmt: for statement in statements {
182 for hook in &self.query_hooks {
184 if let Some(result) = hook
185 .handle_simple_query(&statement, &self.session_context, client)
186 .await
187 {
188 results.push(result?);
189 continue 'stmt;
190 }
191 }
192
193 let df_result = {
194 let query = statement.to_string();
195
196 let timeout = client::get_statement_timeout(client);
197 if let Some(timeout_duration) = timeout {
198 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
199 .await
200 .map_err(|_| {
201 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
202 "ERROR".to_string(),
203 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
205 )))
206 })?
207 } else {
208 self.session_context.sql(&query).await
209 }
210 };
211
212 let df = match df_result {
214 Ok(df) => df,
215 Err(e) => {
216 return Err(PgWireError::ApiError(Box::new(e)));
217 }
218 };
219
220 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
221 let resp = map_rows_affected_for_insert(&df).await?;
222 results.push(resp);
223 } else {
224 let format_options =
226 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
227 let resp =
228 df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
229 results.push(Response::Query(resp));
230 }
231 }
232 Ok(results)
233 }
234}
235
236#[async_trait]
237impl ExtendedQueryHandler for DfSessionService {
238 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
239 type QueryParser = Parser;
240
241 fn query_parser(&self) -> Arc<Self::QueryParser> {
242 self.parser.clone()
243 }
244
245 async fn do_query<C>(
246 &self,
247 client: &mut C,
248 portal: &Portal<Self::Statement>,
249 _max_rows: usize,
250 ) -> PgWireResult<Response>
251 where
252 C: ClientInfo
253 + ClientPortalStore
254 + futures::Sink<PgWireBackendMessage>
255 + Unpin
256 + Send
257 + Sync,
258 C::PortalStore: PortalStore,
259 C::Error: std::fmt::Debug,
260 PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
261 {
262 let query = &portal.statement.statement.0;
263 log::debug!("Received execute extended query: {query}");
264 if !self.query_hooks.is_empty()
266 && let (_, Some((statement, plan))) = &portal.statement.statement
267 {
268 let param_types = planner::get_inferred_parameter_types(plan)
270 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
271
272 let param_values: ParamValues =
273 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
274
275 for hook in &self.query_hooks {
276 if let Some(result) = hook
277 .handle_extended_query(
278 statement,
279 plan,
280 ¶m_values,
281 &self.session_context,
282 client,
283 )
284 .await
285 {
286 return result;
287 }
288 }
289 }
290
291 if let (_, Some((statement, plan))) = &portal.statement.statement {
292 let param_types = planner::get_inferred_parameter_types(plan)
293 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
294
295 let param_values =
296 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
297
298 let plan = plan
299 .clone()
300 .replace_params_with_values(¶m_values)
301 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
302 let optimised = self
303 .session_context
304 .state()
305 .optimize(&plan)
306 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
307
308 let dataframe = {
309 let timeout = client::get_statement_timeout(client);
310 if let Some(timeout_duration) = timeout {
311 tokio::time::timeout(
312 timeout_duration,
313 self.session_context.execute_logical_plan(optimised),
314 )
315 .await
316 .map_err(|_| {
317 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
318 "ERROR".to_string(),
319 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
321 )))
322 })?
323 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
324 } else {
325 self.session_context
326 .execute_logical_plan(optimised)
327 .await
328 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
329 }
330 };
331
332 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
333 let resp = map_rows_affected_for_insert(&dataframe).await?;
334
335 Ok(resp)
336 } else {
337 let format_options =
339 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
340 let resp = df::encode_dataframe(
341 dataframe,
342 &portal.result_column_format,
343 Some(format_options),
344 )
345 .await?;
346 Ok(Response::Query(resp))
347 }
348 } else {
349 Ok(Response::EmptyQuery)
350 }
351 }
352}
353
354async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
355 let result = df
358 .clone()
359 .collect()
360 .await
361 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
362
363 let rows_affected = result
365 .first()
366 .and_then(|batch| batch.column_by_name("count"))
367 .and_then(|col| {
368 col.as_any()
369 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
370 })
371 .map_or(0, |array| array.value(0) as usize);
372
373 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
375 Ok(Response::Execution(tag))
376}
377
378pub struct Parser {
379 session_context: Arc<SessionContext>,
380 sql_parser: PostgresCompatibilityParser,
381 query_hooks: Vec<Arc<dyn QueryHook>>,
382}
383
384#[async_trait]
385impl QueryParser for Parser {
386 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
387
388 async fn parse_sql<C>(
389 &self,
390 client: &C,
391 sql: &str,
392 _types: &[Option<Type>],
393 ) -> PgWireResult<Self::Statement>
394 where
395 C: ClientInfo + Unpin + Send + Sync,
396 {
397 log::debug!("Received parse extended query: {sql}");
398 let mut statements = self
399 .sql_parser
400 .parse(sql)
401 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402 if statements.is_empty() {
403 return Ok((sql.to_string(), None));
404 }
405
406 let statement = statements.remove(0);
407 let query = statement.to_string();
408
409 let context = &self.session_context;
410 let state = context.state();
411
412 for hook in &self.query_hooks {
413 if let Some(logical_plan) = hook
414 .handle_extended_parse_query(&statement, context, client)
415 .await
416 {
417 return Ok((query, Some((statement, logical_plan?))));
418 }
419 }
420
421 let logical_plan = state
422 .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
423 .await
424 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
425 Ok((query, Some((statement, logical_plan))))
426 }
427
428 fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
429 if let (_, Some((_, plan))) = stmt {
430 let params = planner::get_inferred_parameter_types(plan)
431 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
432
433 let mut param_types = Vec::with_capacity(params.len());
434 for param_type in ordered_param_types(¶ms).iter() {
435 if let Some(datatype) = param_type {
436 let pgtype = into_pg_type(datatype)?;
437 param_types.push(pgtype);
438 } else {
439 param_types.push(Type::UNKNOWN);
440 }
441 }
442
443 Ok(param_types)
444 } else {
445 Ok(vec![])
446 }
447 }
448
449 fn get_result_schema(
450 &self,
451 stmt: &Self::Statement,
452 column_format: Option<&Format>,
453 ) -> PgWireResult<Vec<FieldInfo>> {
454 if let (_, Some((_, plan))) = stmt {
455 if !matches!(plan, LogicalPlan::Ddl(_) | LogicalPlan::Dml(_)) {
456 let schema = plan.schema();
457 let fields = arrow_schema_to_pg_fields(
458 schema.as_arrow(),
459 column_format.unwrap_or(&Format::UnifiedText),
460 None,
461 )?;
462
463 Ok(fields)
464 } else {
465 Ok(vec![])
466 }
467 } else {
468 Ok(vec![])
469 }
470 }
471}
472
473fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
474 let mut types = types.iter().collect::<Vec<_>>();
477 types.sort_by(|a, b| a.0.cmp(b.0));
478 types.into_iter().map(|pt| pt.1.as_ref()).collect()
479}
480
481#[cfg(test)]
482mod tests {
483 use datafusion::prelude::SessionContext;
484
485 use super::*;
486 use crate::testing::MockClient;
487
488 use crate::hooks::HookClient;
489
490 struct TestHook;
491
492 #[async_trait]
493 impl QueryHook for TestHook {
494 async fn handle_simple_query(
495 &self,
496 statement: &sqlparser::ast::Statement,
497 _ctx: &SessionContext,
498 _client: &mut dyn HookClient,
499 ) -> Option<PgWireResult<Response>> {
500 if statement.to_string().contains("magic") {
501 Some(Ok(Response::EmptyQuery))
502 } else {
503 None
504 }
505 }
506
507 async fn handle_extended_parse_query(
508 &self,
509 _statement: &sqlparser::ast::Statement,
510 _session_context: &SessionContext,
511 _client: &(dyn ClientInfo + Send + Sync),
512 ) -> Option<PgWireResult<LogicalPlan>> {
513 None
514 }
515
516 async fn handle_extended_query(
517 &self,
518 _statement: &sqlparser::ast::Statement,
519 _logical_plan: &LogicalPlan,
520 _params: &ParamValues,
521 _session_context: &SessionContext,
522 _client: &mut dyn HookClient,
523 ) -> Option<PgWireResult<Response>> {
524 None
525 }
526 }
527
528 #[tokio::test]
529 async fn test_query_hooks() {
530 let hook = TestHook;
531 let ctx = SessionContext::new();
532 let mut client = MockClient::new();
533
534 let parser = PostgresCompatibilityParser::new();
536 let statements = parser.parse("SELECT magic").unwrap();
537 let stmt = &statements[0];
538
539 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
541 assert!(result.is_some());
542
543 let statements = parser.parse("SELECT 1").unwrap();
545 let stmt = &statements[0];
546
547 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
549 assert!(result.is_none());
550 }
551
552 #[tokio::test]
553 async fn test_multiple_statements_with_hook_continue() {
554 let session_context = Arc::new(SessionContext::new());
558
559 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
560 let service = DfSessionService::new_with_hooks(session_context, hooks);
561
562 let mut client = MockClient::new();
563
564 let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
566
567 let results =
568 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
569 .await
570 .unwrap();
571
572 assert_eq!(results.len(), 4, "Expected 4 responses");
573
574 assert!(matches!(results[0], Response::EmptyQuery));
575 assert!(matches!(results[1], Response::Query(_)));
576 assert!(matches!(results[2], Response::EmptyQuery));
577 assert!(matches!(results[3], Response::Query(_)));
578 }
579
580 #[tokio::test]
581 async fn test_set_sends_parameter_status_via_sink() {
582 use pgwire::messages::PgWireBackendMessage;
583
584 let service = crate::testing::setup_handlers();
585 let mut client = MockClient::new();
586
587 let test_cases = vec![
588 ("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
589 (
590 "SET intervalstyle = 'postgres'",
591 "IntervalStyle",
592 "postgres",
593 ),
594 ("SET bytea_output = 'hex'", "bytea_output", "hex"),
595 (
596 "SET application_name = 'myapp'",
597 "application_name",
598 "myapp",
599 ),
600 ("SET search_path = 'public'", "search_path", "public"),
601 ("SET extra_float_digits = '2'", "extra_float_digits", "2"),
602 (
603 "SET TIME ZONE 'America/New_York'",
604 "TimeZone",
605 "America/New_York",
606 ),
607 ];
608
609 for (sql, expected_key, expected_value) in test_cases {
610 client.sent_messages.clear();
611
612 let responses =
613 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
614 .await
615 .unwrap();
616
617 assert!(
618 matches!(responses[0], Response::Execution(_)),
619 "Expected SET tag for {sql}"
620 );
621
622 let ps_msgs: Vec<_> = client
623 .sent_messages()
624 .iter()
625 .filter_map(|m| match m {
626 PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
627 _ => None,
628 })
629 .collect();
630
631 assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
632 assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
633 assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
634 }
635 }
636
637 #[tokio::test]
638 async fn test_set_statement_timeout_no_parameter_status() {
639 use pgwire::messages::PgWireBackendMessage;
640
641 let service = crate::testing::setup_handlers();
642 let mut client = MockClient::new();
643
644 <DfSessionService as SimpleQueryHandler>::do_query(
645 &service,
646 &mut client,
647 "SET statement_timeout TO '5000ms'",
648 )
649 .await
650 .unwrap();
651
652 let has_ps = client
653 .sent_messages()
654 .iter()
655 .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
656
657 assert!(!has_ps, "statement_timeout should not send ParameterStatus");
658 }
659
660 fn assert_execution_tag(response: &Response, expected: &str) {
661 match response {
662 Response::Execution(tag) => {
663 let cc = pgwire::messages::response::CommandComplete::from(tag.clone());
664 assert_eq!(cc.tag, expected, "Unexpected execution tag");
665 }
666 other => panic!("Expected Execution response, got: {other:?}"),
667 }
668 }
669
670 async fn assert_query_response_empty(response: &mut Response) {
671 use futures::StreamExt;
672
673 let Response::Query(qr) = response else {
674 panic!("Expected Query response, got: {response:?}");
675 };
676
677 let mut count = 0;
678 while qr.data_rows().next().await.is_some() {
679 count += 1;
680 }
681 assert_eq!(count, 0, "Expected no rows from exhausted cursor");
682 }
683
684 #[tokio::test]
685 async fn test_declare_fetch_close_cursor() {
686 let service = crate::testing::setup_handlers();
687 let mut client = MockClient::new();
688
689 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
690 &service,
691 &mut client,
692 "DECLARE test_cursor CURSOR FOR SELECT 1 AS col",
693 )
694 .await
695 .unwrap();
696
697 assert_eq!(responses.len(), 1);
698 assert_execution_tag(&responses[0], "DECLARE CURSOR");
699
700 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
701 &service,
702 &mut client,
703 "FETCH NEXT FROM test_cursor",
704 )
705 .await
706 .unwrap();
707
708 assert_eq!(responses.len(), 1);
709 assert!(
710 matches!(&responses[0], Response::Query(_)),
711 "Expected Query response for FETCH"
712 );
713
714 let mut responses = <DfSessionService as SimpleQueryHandler>::do_query(
715 &service,
716 &mut client,
717 "FETCH NEXT FROM test_cursor",
718 )
719 .await
720 .unwrap();
721
722 assert_eq!(responses.len(), 1);
723 assert_query_response_empty(&mut responses[0]).await;
724
725 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
726 &service,
727 &mut client,
728 "CLOSE test_cursor",
729 )
730 .await
731 .unwrap();
732
733 assert_eq!(responses.len(), 1);
734 assert_execution_tag(&responses[0], "CLOSE CURSOR");
735 }
736
737 #[tokio::test]
738 async fn test_fetch_nonexistent_cursor() {
739 let service = crate::testing::setup_handlers();
740 let mut client = MockClient::new();
741
742 let result = <DfSessionService as SimpleQueryHandler>::do_query(
743 &service,
744 &mut client,
745 "FETCH NEXT FROM nonexistent",
746 )
747 .await;
748
749 assert!(result.is_err());
750 }
751
752 #[tokio::test]
753 async fn test_close_all_portals() {
754 let service = crate::testing::setup_handlers();
755 let mut client = MockClient::new();
756
757 <DfSessionService as SimpleQueryHandler>::do_query(
758 &service,
759 &mut client,
760 "DECLARE c1 CURSOR FOR SELECT 1",
761 )
762 .await
763 .unwrap();
764
765 <DfSessionService as SimpleQueryHandler>::do_query(
766 &service,
767 &mut client,
768 "DECLARE c2 CURSOR FOR SELECT 2",
769 )
770 .await
771 .unwrap();
772
773 let responses =
774 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, "CLOSE ALL")
775 .await
776 .unwrap();
777
778 assert!(matches!(&responses[0], Response::Execution(_)),);
779
780 let result = <DfSessionService as SimpleQueryHandler>::do_query(
781 &service,
782 &mut client,
783 "FETCH NEXT FROM c1",
784 )
785 .await;
786 assert!(result.is_err(), "c1 should be closed");
787 }
788
789 #[tokio::test]
790 async fn test_fetch_forward_n() {
791 let service = crate::testing::setup_handlers();
792 let mut client = MockClient::new();
793
794 <DfSessionService as SimpleQueryHandler>::do_query(
795 &service,
796 &mut client,
797 "CREATE TABLE nums AS SELECT 1 AS n UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5",
798 )
799 .await
800 .unwrap();
801
802 <DfSessionService as SimpleQueryHandler>::do_query(
803 &service,
804 &mut client,
805 "DECLARE mycur CURSOR FOR SELECT n FROM nums ORDER BY n",
806 )
807 .await
808 .unwrap();
809
810 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
811 &service,
812 &mut client,
813 "FETCH FORWARD 3 FROM mycur",
814 )
815 .await
816 .unwrap();
817
818 assert!(
819 matches!(&responses[0], Response::Query(_)),
820 "Expected Query response for FORWARD 3"
821 );
822
823 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
824 &service,
825 &mut client,
826 "FETCH FORWARD ALL FROM mycur",
827 )
828 .await
829 .unwrap();
830
831 let resp_desc = match &responses[0] {
832 Response::Query(_) => "Query".to_string(),
833 Response::Execution(tag) => {
834 let cc = pgwire::messages::response::CommandComplete::from(tag.clone());
835 format!("Execution({})", cc.tag)
836 }
837 other => format!("{:?}", other),
838 };
839 assert!(
840 matches!(&responses[0], Response::Query(_)),
841 "Expected Query response for remaining rows, got: {resp_desc}"
842 );
843
844 let mut responses = <DfSessionService as SimpleQueryHandler>::do_query(
845 &service,
846 &mut client,
847 "FETCH NEXT FROM mycur",
848 )
849 .await
850 .unwrap();
851
852 assert_query_response_empty(&mut responses[0]).await;
853 }
854
855 #[tokio::test]
856 async fn test_scroll_cursor_error() {
857 let service = crate::testing::setup_handlers();
858 let mut client = MockClient::new();
859
860 <DfSessionService as SimpleQueryHandler>::do_query(
861 &service,
862 &mut client,
863 "DECLARE mycur CURSOR FOR SELECT 1",
864 )
865 .await
866 .unwrap();
867
868 let result = <DfSessionService as SimpleQueryHandler>::do_query(
869 &service,
870 &mut client,
871 "FETCH PRIOR FROM mycur",
872 )
873 .await;
874
875 assert!(result.is_err(), "PRIOR should fail on forward-only cursor");
876 }
877
878 #[tokio::test]
879 async fn test_move_cursor() {
880 let service = crate::testing::setup_handlers();
881 let mut client = MockClient::new();
882
883 <DfSessionService as SimpleQueryHandler>::do_query(
884 &service,
885 &mut client,
886 "DECLARE mycur CURSOR FOR SELECT generate_series(1, 5) AS n",
887 )
888 .await
889 .unwrap();
890
891 let responses = <DfSessionService as SimpleQueryHandler>::do_query(
892 &service,
893 &mut client,
894 "FETCH FORWARD 3 FROM mycur",
895 )
896 .await
897 .unwrap();
898
899 assert!(matches!(&responses[0], Response::Query(_)));
900 }
901}