1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{FieldInfo, Response, Tag};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use pgwire::messages::PgWireBackendMessage;
21use pgwire::types::format::FormatOptions;
22
23use crate::hooks::set_show::SetShowHook;
24use crate::hooks::transactions::TransactionStatementHook;
25use crate::hooks::QueryHook;
26use crate::{client, planner};
27use arrow_pg::datatypes::df;
28use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
29use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
30
31pub struct SimpleStartupHandler;
33
34#[async_trait::async_trait]
35impl NoopStartupHandler for SimpleStartupHandler {}
36
37pub struct HandlerFactory {
38 pub session_service: Arc<DfSessionService>,
39}
40
41impl HandlerFactory {
42 pub fn new(session_context: Arc<SessionContext>) -> Self {
43 let session_service = Arc::new(DfSessionService::new(session_context));
44 HandlerFactory { session_service }
45 }
46
47 pub fn new_with_hooks(
48 session_context: Arc<SessionContext>,
49 query_hooks: Vec<Arc<dyn QueryHook>>,
50 ) -> Self {
51 let session_service = Arc::new(DfSessionService::new_with_hooks(
52 session_context,
53 query_hooks,
54 ));
55 HandlerFactory { session_service }
56 }
57}
58
59impl PgWireServerHandlers for HandlerFactory {
60 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
61 self.session_service.clone()
62 }
63
64 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
65 self.session_service.clone()
66 }
67
68 fn startup_handler(&self) -> Arc<impl StartupHandler> {
69 Arc::new(SimpleStartupHandler)
70 }
71
72 fn error_handler(&self) -> Arc<impl ErrorHandler> {
73 Arc::new(LoggingErrorHandler)
74 }
75}
76
77struct LoggingErrorHandler;
78
79impl ErrorHandler for LoggingErrorHandler {
80 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
81 where
82 C: ClientInfo,
83 {
84 info!("Sending error: {error}")
85 }
86}
87
88pub struct DfSessionService {
90 session_context: Arc<SessionContext>,
91 parser: Arc<Parser>,
92 query_hooks: Vec<Arc<dyn QueryHook>>,
93}
94
95impl DfSessionService {
96 pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
97 let hooks: Vec<Arc<dyn QueryHook>> =
98 vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
99 Self::new_with_hooks(session_context, hooks)
100 }
101
102 pub fn new_with_hooks(
103 session_context: Arc<SessionContext>,
104 query_hooks: Vec<Arc<dyn QueryHook>>,
105 ) -> DfSessionService {
106 let parser = Arc::new(Parser {
107 session_context: session_context.clone(),
108 sql_parser: PostgresCompatibilityParser::new(),
109 query_hooks: query_hooks.clone(),
110 });
111 DfSessionService {
112 session_context,
113 parser,
114 query_hooks,
115 }
116 }
117}
118
119#[async_trait]
120impl SimpleQueryHandler for DfSessionService {
121 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
122 where
123 C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
124 C::Error: std::fmt::Debug,
125 PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
126 {
127 log::debug!("Received query: {query}");
128 let statements = self
129 .parser
130 .sql_parser
131 .parse(query)
132 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
133
134 if statements.is_empty() {
136 return Ok(vec![Response::EmptyQuery]);
137 }
138
139 let mut results = vec![];
140 'stmt: for statement in statements {
141 for hook in &self.query_hooks {
143 if let Some(result) = hook
144 .handle_simple_query(&statement, &self.session_context, client)
145 .await
146 {
147 results.push(result?);
148 continue 'stmt;
149 }
150 }
151
152 let df_result = {
153 let query = statement.to_string();
154
155 let timeout = client::get_statement_timeout(client);
156 if let Some(timeout_duration) = timeout {
157 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
158 .await
159 .map_err(|_| {
160 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
161 "ERROR".to_string(),
162 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
164 )))
165 })?
166 } else {
167 self.session_context.sql(&query).await
168 }
169 };
170
171 let df = match df_result {
173 Ok(df) => df,
174 Err(e) => {
175 return Err(PgWireError::ApiError(Box::new(e)));
176 }
177 };
178
179 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
180 let resp = map_rows_affected_for_insert(&df).await?;
181 results.push(resp);
182 } else {
183 let format_options =
185 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
186 let resp =
187 df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
188 results.push(Response::Query(resp));
189 }
190 }
191 Ok(results)
192 }
193}
194
195#[async_trait]
196impl ExtendedQueryHandler for DfSessionService {
197 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
198 type QueryParser = Parser;
199
200 fn query_parser(&self) -> Arc<Self::QueryParser> {
201 self.parser.clone()
202 }
203
204 async fn do_query<C>(
205 &self,
206 client: &mut C,
207 portal: &Portal<Self::Statement>,
208 _max_rows: usize,
209 ) -> PgWireResult<Response>
210 where
211 C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
212 C::Error: std::fmt::Debug,
213 PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
214 {
215 let query = &portal.statement.statement.0;
216 log::debug!("Received execute extended query: {query}");
217 if !self.query_hooks.is_empty() {
219 if let (_, Some((statement, plan))) = &portal.statement.statement {
220 let param_types = planner::get_inferred_parameter_types(plan)
222 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
223
224 let param_values: ParamValues =
225 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
226
227 for hook in &self.query_hooks {
228 if let Some(result) = hook
229 .handle_extended_query(
230 statement,
231 plan,
232 ¶m_values,
233 &self.session_context,
234 client,
235 )
236 .await
237 {
238 return result;
239 }
240 }
241 }
242 }
243
244 if let (_, Some((statement, plan))) = &portal.statement.statement {
245 let param_types = planner::get_inferred_parameter_types(plan)
246 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
247
248 let param_values =
249 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
250
251 let plan = plan
252 .clone()
253 .replace_params_with_values(¶m_values)
254 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
255 let optimised = self
256 .session_context
257 .state()
258 .optimize(&plan)
259 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
260
261 let dataframe = {
262 let timeout = client::get_statement_timeout(client);
263 if let Some(timeout_duration) = timeout {
264 tokio::time::timeout(
265 timeout_duration,
266 self.session_context.execute_logical_plan(optimised),
267 )
268 .await
269 .map_err(|_| {
270 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
271 "ERROR".to_string(),
272 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
274 )))
275 })?
276 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
277 } else {
278 self.session_context
279 .execute_logical_plan(optimised)
280 .await
281 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
282 }
283 };
284
285 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
286 let resp = map_rows_affected_for_insert(&dataframe).await?;
287
288 Ok(resp)
289 } else {
290 let format_options =
292 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
293 let resp = df::encode_dataframe(
294 dataframe,
295 &portal.result_column_format,
296 Some(format_options),
297 )
298 .await?;
299 Ok(Response::Query(resp))
300 }
301 } else {
302 Ok(Response::EmptyQuery)
303 }
304 }
305}
306
307async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
308 let result = df
311 .clone()
312 .collect()
313 .await
314 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
315
316 let rows_affected = result
318 .first()
319 .and_then(|batch| batch.column_by_name("count"))
320 .and_then(|col| {
321 col.as_any()
322 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
323 })
324 .map_or(0, |array| array.value(0) as usize);
325
326 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
328 Ok(Response::Execution(tag))
329}
330
331pub struct Parser {
332 session_context: Arc<SessionContext>,
333 sql_parser: PostgresCompatibilityParser,
334 query_hooks: Vec<Arc<dyn QueryHook>>,
335}
336
337#[async_trait]
338impl QueryParser for Parser {
339 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
340
341 async fn parse_sql<C>(
342 &self,
343 client: &C,
344 sql: &str,
345 _types: &[Option<Type>],
346 ) -> PgWireResult<Self::Statement>
347 where
348 C: ClientInfo + Unpin + Send + Sync,
349 {
350 log::debug!("Received parse extended query: {sql}");
351 let mut statements = self
352 .sql_parser
353 .parse(sql)
354 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
355 if statements.is_empty() {
356 return Ok((sql.to_string(), None));
357 }
358
359 let statement = statements.remove(0);
360 let query = statement.to_string();
361
362 let context = &self.session_context;
363 let state = context.state();
364
365 for hook in &self.query_hooks {
366 if let Some(logical_plan) = hook
367 .handle_extended_parse_query(&statement, context, client)
368 .await
369 {
370 return Ok((query, Some((statement, logical_plan?))));
371 }
372 }
373
374 let logical_plan = state
375 .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
376 .await
377 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
378 Ok((query, Some((statement, logical_plan))))
379 }
380
381 fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
382 if let (_, Some((_, plan))) = stmt {
383 let params = planner::get_inferred_parameter_types(plan)
384 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
385
386 let mut param_types = Vec::with_capacity(params.len());
387 for param_type in ordered_param_types(¶ms).iter() {
388 if let Some(datatype) = param_type {
389 let pgtype = into_pg_type(datatype)?;
390 param_types.push(pgtype);
391 } else {
392 param_types.push(Type::UNKNOWN);
393 }
394 }
395
396 Ok(param_types)
397 } else {
398 Ok(vec![])
399 }
400 }
401
402 fn get_result_schema(
403 &self,
404 stmt: &Self::Statement,
405 column_format: Option<&Format>,
406 ) -> PgWireResult<Vec<FieldInfo>> {
407 if let (_, Some((_, plan))) = stmt {
408 let schema = plan.schema();
409 let fields = arrow_schema_to_pg_fields(
410 schema.as_arrow(),
411 column_format.unwrap_or(&Format::UnifiedBinary),
412 None,
413 )?;
414
415 Ok(fields)
416 } else {
417 Ok(vec![])
418 }
419 }
420}
421
422fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
423 let mut types = types.iter().collect::<Vec<_>>();
426 types.sort_by(|a, b| a.0.cmp(b.0));
427 types.into_iter().map(|pt| pt.1.as_ref()).collect()
428}
429
430#[cfg(test)]
431mod tests {
432 use datafusion::prelude::SessionContext;
433
434 use super::*;
435 use crate::testing::MockClient;
436
437 use crate::hooks::HookClient;
438
439 struct TestHook;
440
441 #[async_trait]
442 impl QueryHook for TestHook {
443 async fn handle_simple_query(
444 &self,
445 statement: &sqlparser::ast::Statement,
446 _ctx: &SessionContext,
447 _client: &mut dyn HookClient,
448 ) -> Option<PgWireResult<Response>> {
449 if statement.to_string().contains("magic") {
450 Some(Ok(Response::EmptyQuery))
451 } else {
452 None
453 }
454 }
455
456 async fn handle_extended_parse_query(
457 &self,
458 _statement: &sqlparser::ast::Statement,
459 _session_context: &SessionContext,
460 _client: &(dyn ClientInfo + Send + Sync),
461 ) -> Option<PgWireResult<LogicalPlan>> {
462 None
463 }
464
465 async fn handle_extended_query(
466 &self,
467 _statement: &sqlparser::ast::Statement,
468 _logical_plan: &LogicalPlan,
469 _params: &ParamValues,
470 _session_context: &SessionContext,
471 _client: &mut dyn HookClient,
472 ) -> Option<PgWireResult<Response>> {
473 None
474 }
475 }
476
477 #[tokio::test]
478 async fn test_query_hooks() {
479 let hook = TestHook;
480 let ctx = SessionContext::new();
481 let mut client = MockClient::new();
482
483 let parser = PostgresCompatibilityParser::new();
485 let statements = parser.parse("SELECT magic").unwrap();
486 let stmt = &statements[0];
487
488 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
490 assert!(result.is_some());
491
492 let statements = parser.parse("SELECT 1").unwrap();
494 let stmt = &statements[0];
495
496 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
498 assert!(result.is_none());
499 }
500
501 #[tokio::test]
502 async fn test_multiple_statements_with_hook_continue() {
503 let session_context = Arc::new(SessionContext::new());
507
508 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
509 let service = DfSessionService::new_with_hooks(session_context, hooks);
510
511 let mut client = MockClient::new();
512
513 let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
515
516 let results =
517 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
518 .await
519 .unwrap();
520
521 assert_eq!(results.len(), 4, "Expected 4 responses");
522
523 assert!(matches!(results[0], Response::EmptyQuery));
524 assert!(matches!(results[1], Response::Query(_)));
525 assert!(matches!(results[2], Response::EmptyQuery));
526 assert!(matches!(results[3], Response::Query(_)));
527 }
528
529 #[tokio::test]
530 async fn test_set_sends_parameter_status_via_sink() {
531 use pgwire::messages::PgWireBackendMessage;
532
533 let service = crate::testing::setup_handlers();
534 let mut client = MockClient::new();
535
536 let test_cases = vec![
537 ("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
538 (
539 "SET intervalstyle = 'postgres'",
540 "IntervalStyle",
541 "postgres",
542 ),
543 ("SET bytea_output = 'hex'", "bytea_output", "hex"),
544 (
545 "SET application_name = 'myapp'",
546 "application_name",
547 "myapp",
548 ),
549 ("SET search_path = 'public'", "search_path", "public"),
550 ("SET extra_float_digits = '2'", "extra_float_digits", "2"),
551 (
552 "SET TIME ZONE 'America/New_York'",
553 "TimeZone",
554 "America/New_York",
555 ),
556 ];
557
558 for (sql, expected_key, expected_value) in test_cases {
559 client.sent_messages.clear();
560
561 let responses =
562 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
563 .await
564 .unwrap();
565
566 assert!(
567 matches!(responses[0], Response::Execution(_)),
568 "Expected SET tag for {sql}"
569 );
570
571 let ps_msgs: Vec<_> = client
572 .sent_messages()
573 .iter()
574 .filter_map(|m| match m {
575 PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
576 _ => None,
577 })
578 .collect();
579
580 assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
581 assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
582 assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
583 }
584 }
585
586 #[tokio::test]
587 async fn test_set_statement_timeout_no_parameter_status() {
588 use pgwire::messages::PgWireBackendMessage;
589
590 let service = crate::testing::setup_handlers();
591 let mut client = MockClient::new();
592
593 <DfSessionService as SimpleQueryHandler>::do_query(
594 &service,
595 &mut client,
596 "SET statement_timeout TO '5000ms'",
597 )
598 .await
599 .unwrap();
600
601 let has_ps = client
602 .sent_messages()
603 .iter()
604 .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
605
606 assert!(!has_ps, "statement_timeout should not send ParameterStatus");
607 }
608}