1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::{ParamValues, ToDFSchema};
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use datafusion::sql::sqlparser;
12use log::info;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::auth::StartupHandler;
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{
18 DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25use pgwire::types::format::FormatOptions;
26
27use crate::auth::AuthManager;
28use crate::client;
29use crate::hooks::set_show::SetShowHook;
30use crate::hooks::QueryHook;
31use arrow_pg::datatypes::df;
32use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
33use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
34use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
35
36pub struct SimpleStartupHandler;
39
40#[async_trait::async_trait]
41impl NoopStartupHandler for SimpleStartupHandler {}
42
43pub struct HandlerFactory {
44 pub session_service: Arc<DfSessionService>,
45}
46
47impl HandlerFactory {
48 pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
49 let session_service =
50 Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
51 HandlerFactory { session_service }
52 }
53
54 pub fn new_with_hooks(
55 session_context: Arc<SessionContext>,
56 auth_manager: Arc<AuthManager>,
57 query_hooks: Vec<Arc<dyn QueryHook>>,
58 ) -> Self {
59 let session_service = Arc::new(DfSessionService::new_with_hooks(
60 session_context,
61 auth_manager.clone(),
62 query_hooks,
63 ));
64 HandlerFactory { session_service }
65 }
66}
67
68impl PgWireServerHandlers for HandlerFactory {
69 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
70 self.session_service.clone()
71 }
72
73 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
74 self.session_service.clone()
75 }
76
77 fn startup_handler(&self) -> Arc<impl StartupHandler> {
78 Arc::new(SimpleStartupHandler)
79 }
80
81 fn error_handler(&self) -> Arc<impl ErrorHandler> {
82 Arc::new(LoggingErrorHandler)
83 }
84}
85
86struct LoggingErrorHandler;
87
88impl ErrorHandler for LoggingErrorHandler {
89 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
90 where
91 C: ClientInfo,
92 {
93 info!("Sending error: {error}")
94 }
95}
96
97pub struct DfSessionService {
99 session_context: Arc<SessionContext>,
100 parser: Arc<Parser>,
101 auth_manager: Arc<AuthManager>,
102 query_hooks: Vec<Arc<dyn QueryHook>>,
103}
104
105impl DfSessionService {
106 pub fn new(
107 session_context: Arc<SessionContext>,
108 auth_manager: Arc<AuthManager>,
109 ) -> DfSessionService {
110 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
111 Self::new_with_hooks(session_context, auth_manager, hooks)
112 }
113
114 pub fn new_with_hooks(
115 session_context: Arc<SessionContext>,
116 auth_manager: Arc<AuthManager>,
117 query_hooks: Vec<Arc<dyn QueryHook>>,
118 ) -> DfSessionService {
119 let parser = Arc::new(Parser {
120 session_context: session_context.clone(),
121 sql_parser: PostgresCompatibilityParser::new(),
122 query_hooks: query_hooks.clone(),
123 });
124 DfSessionService {
125 session_context,
126 parser,
127 auth_manager,
128 query_hooks,
129 }
130 }
131
132 async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
134 where
135 C: ClientInfo,
136 {
137 let username = client
139 .metadata()
140 .get("user")
141 .map(|s| s.as_str())
142 .unwrap_or("anonymous");
143
144 let query_lower = query.to_lowercase();
146 let query_trimmed = query_lower.trim();
147
148 let (required_permission, resource) = if query_trimmed.starts_with("select") {
149 (Permission::Select, self.extract_table_from_query(query))
150 } else if query_trimmed.starts_with("insert") {
151 (Permission::Insert, self.extract_table_from_query(query))
152 } else if query_trimmed.starts_with("update") {
153 (Permission::Update, self.extract_table_from_query(query))
154 } else if query_trimmed.starts_with("delete") {
155 (Permission::Delete, self.extract_table_from_query(query))
156 } else if query_trimmed.starts_with("create table")
157 || query_trimmed.starts_with("create view")
158 {
159 (Permission::Create, ResourceType::All)
160 } else if query_trimmed.starts_with("drop") {
161 (Permission::Drop, self.extract_table_from_query(query))
162 } else if query_trimmed.starts_with("alter") {
163 (Permission::Alter, self.extract_table_from_query(query))
164 } else {
165 return Ok(());
167 };
168
169 let has_permission = self
171 .auth_manager
172 .check_permission(username, required_permission, resource)
173 .await;
174
175 if !has_permission {
176 return Err(PgWireError::UserError(Box::new(
177 pgwire::error::ErrorInfo::new(
178 "ERROR".to_string(),
179 "42501".to_string(), format!("permission denied for user \"{username}\""),
181 ),
182 )));
183 }
184
185 Ok(())
186 }
187
188 fn extract_table_from_query(&self, query: &str) -> ResourceType {
190 let words: Vec<&str> = query.split_whitespace().collect();
191
192 for (i, word) in words.iter().enumerate() {
194 let word_lower = word.to_lowercase();
195 if (word_lower == "from" || word_lower == "into" || word_lower == "table")
196 && i + 1 < words.len()
197 {
198 let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
199 return ResourceType::Table(table_name.to_string());
200 }
201 }
202
203 ResourceType::All
205 }
206
207 async fn try_respond_transaction_statements<C>(
208 &self,
209 client: &C,
210 query_lower: &str,
211 ) -> PgWireResult<Option<Response>>
212 where
213 C: ClientInfo,
214 {
215 match query_lower.trim() {
218 "begin" | "begin transaction" | "begin work" | "start transaction" => {
219 match client.transaction_status() {
220 TransactionStatus::Idle => {
221 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
222 }
223 TransactionStatus::Transaction => {
224 log::warn!("BEGIN command ignored: already in transaction block");
227 Ok(Some(Response::Execution(Tag::new("BEGIN"))))
228 }
229 TransactionStatus::Error => {
230 Err(PgWireError::UserError(Box::new(
232 pgwire::error::ErrorInfo::new(
233 "ERROR".to_string(),
234 "25P01".to_string(),
235 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
236 ),
237 )))
238 }
239 }
240 }
241 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
242 match client.transaction_status() {
243 TransactionStatus::Idle | TransactionStatus::Transaction => {
244 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
245 }
246 TransactionStatus::Error => {
247 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
248 }
249 }
250 }
251 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
252 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
253 }
254 _ => Ok(None),
255 }
256 }
257}
258
259#[async_trait]
260impl SimpleQueryHandler for DfSessionService {
261 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
262 where
263 C: ClientInfo + Unpin + Send + Sync,
264 {
265 log::debug!("Received query: {query}"); let query_lower = query.to_lowercase().trim().to_string();
269 if let Some(resp) = self
270 .try_respond_transaction_statements(client, &query_lower)
271 .await?
272 {
273 return Ok(vec![resp]);
274 }
275
276 let statements = self
277 .parser
278 .sql_parser
279 .parse(query)
280 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
281
282 if statements.is_empty() {
284 return Ok(vec![Response::EmptyQuery]);
285 }
286
287 let mut results = vec![];
288 'stmt: for statement in statements {
289 let query = statement.to_string();
291 let query_lower = query.to_lowercase().trim().to_string();
292
293 if !query_lower.starts_with("set")
295 && !query_lower.starts_with("begin")
296 && !query_lower.starts_with("commit")
297 && !query_lower.starts_with("rollback")
298 && !query_lower.starts_with("start")
299 && !query_lower.starts_with("end")
300 && !query_lower.starts_with("abort")
301 && !query_lower.starts_with("show")
302 {
303 self.check_query_permission(client, &query).await?;
304 }
305
306 for hook in &self.query_hooks {
308 if let Some(result) = hook
309 .handle_simple_query(&statement, &self.session_context, client)
310 .await
311 {
312 results.push(result?);
313 continue 'stmt;
314 }
315 }
316
317 if client.transaction_status() == TransactionStatus::Error {
320 return Err(PgWireError::UserError(Box::new(
321 pgwire::error::ErrorInfo::new(
322 "ERROR".to_string(),
323 "25P01".to_string(),
324 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
325 ),
326 )));
327 }
328
329 let df_result = {
330 let timeout = client::get_statement_timeout(client);
331 if let Some(timeout_duration) = timeout {
332 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
333 .await
334 .map_err(|_| {
335 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
336 "ERROR".to_string(),
337 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
339 )))
340 })?
341 } else {
342 self.session_context.sql(&query).await
343 }
344 };
345
346 let df = match df_result {
348 Ok(df) => df,
349 Err(e) => {
350 return Err(PgWireError::ApiError(Box::new(e)));
351 }
352 };
353
354 if query_lower.starts_with("insert into") {
355 let resp = map_rows_affected_for_insert(&df).await?;
356 results.push(resp);
357 } else {
358 let format_options =
360 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
361 let resp =
362 df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
363 results.push(Response::Query(resp));
364 }
365 }
366 Ok(results)
367 }
368}
369
370#[async_trait]
371impl ExtendedQueryHandler for DfSessionService {
372 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
373 type QueryParser = Parser;
374
375 fn query_parser(&self) -> Arc<Self::QueryParser> {
376 self.parser.clone()
377 }
378
379 async fn do_describe_statement<C>(
380 &self,
381 _client: &mut C,
382 target: &StoredStatement<Self::Statement>,
383 ) -> PgWireResult<DescribeStatementResponse>
384 where
385 C: ClientInfo + Unpin + Send + Sync,
386 {
387 if let (_, Some((_, plan))) = &target.statement {
388 let schema = plan.schema();
389 let fields =
390 arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary, None)?;
391 let params = plan
392 .get_parameter_types()
393 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
394
395 let mut param_types = Vec::with_capacity(params.len());
396 for param_type in ordered_param_types(¶ms).iter() {
397 if let Some(datatype) = param_type {
399 let pgtype = into_pg_type(datatype)?;
400 param_types.push(pgtype);
401 } else {
402 param_types.push(Type::UNKNOWN);
403 }
404 }
405
406 Ok(DescribeStatementResponse::new(param_types, fields))
407 } else {
408 Ok(DescribeStatementResponse::no_data())
409 }
410 }
411
412 async fn do_describe_portal<C>(
413 &self,
414 _client: &mut C,
415 target: &Portal<Self::Statement>,
416 ) -> PgWireResult<DescribePortalResponse>
417 where
418 C: ClientInfo + Unpin + Send + Sync,
419 {
420 if let (_, Some((_, plan))) = &target.statement.statement {
421 let format = &target.result_column_format;
422 let schema = plan.schema();
423 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format, None)?;
424
425 Ok(DescribePortalResponse::new(fields))
426 } else {
427 Ok(DescribePortalResponse::no_data())
428 }
429 }
430
431 async fn do_query<C>(
432 &self,
433 client: &mut C,
434 portal: &Portal<Self::Statement>,
435 _max_rows: usize,
436 ) -> PgWireResult<Response>
437 where
438 C: ClientInfo + Unpin + Send + Sync,
439 {
440 let query = portal
441 .statement
442 .statement
443 .0
444 .to_lowercase()
445 .trim()
446 .to_string();
447 log::debug!("Received execute extended query: {query}"); if !self.query_hooks.is_empty() {
451 if let (_, Some((statement, plan))) = &portal.statement.statement {
452 let param_types = plan
454 .get_parameter_types()
455 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
456
457 let param_values: ParamValues =
458 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
459
460 for hook in &self.query_hooks {
461 if let Some(result) = hook
462 .handle_extended_query(
463 statement,
464 plan,
465 ¶m_values,
466 &self.session_context,
467 client,
468 )
469 .await
470 {
471 return result;
472 }
473 }
474 }
475 }
476
477 if !query.starts_with("set") && !query.starts_with("show") {
479 self.check_query_permission(client, &portal.statement.statement.0)
480 .await?;
481 }
482
483 if let Some(resp) = self
484 .try_respond_transaction_statements(client, &query)
485 .await?
486 {
487 return Ok(resp);
488 }
489
490 if client.transaction_status() == TransactionStatus::Error {
493 return Err(PgWireError::UserError(Box::new(
494 pgwire::error::ErrorInfo::new(
495 "ERROR".to_string(),
496 "25P01".to_string(),
497 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
498 ),
499 )));
500 }
501
502 if let (_, Some((_, plan))) = &portal.statement.statement {
503 let param_types = plan
504 .get_parameter_types()
505 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
506
507 let param_values =
508 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
511 .clone()
512 .replace_params_with_values(¶m_values)
513 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let optimised = self
516 .session_context
517 .state()
518 .optimize(&plan)
519 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
520
521 let dataframe = {
522 let timeout = client::get_statement_timeout(client);
523 if let Some(timeout_duration) = timeout {
524 tokio::time::timeout(
525 timeout_duration,
526 self.session_context.execute_logical_plan(optimised),
527 )
528 .await
529 .map_err(|_| {
530 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
531 "ERROR".to_string(),
532 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
534 )))
535 })?
536 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
537 } else {
538 self.session_context
539 .execute_logical_plan(optimised)
540 .await
541 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
542 }
543 };
544
545 if query.starts_with("insert into") {
546 let resp = map_rows_affected_for_insert(&dataframe).await?;
547
548 Ok(resp)
549 } else {
550 let format_options =
552 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
553 let resp = df::encode_dataframe(
554 dataframe,
555 &portal.result_column_format,
556 Some(format_options),
557 )
558 .await?;
559 Ok(Response::Query(resp))
560 }
561 } else {
562 Ok(Response::EmptyQuery)
563 }
564 }
565}
566
567async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
568 let result = df
571 .clone()
572 .collect()
573 .await
574 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
575
576 let rows_affected = result
578 .first()
579 .and_then(|batch| batch.column_by_name("count"))
580 .and_then(|col| {
581 col.as_any()
582 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
583 })
584 .map_or(0, |array| array.value(0) as usize);
585
586 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
588 Ok(Response::Execution(tag))
589}
590
591pub struct Parser {
592 session_context: Arc<SessionContext>,
593 sql_parser: PostgresCompatibilityParser,
594 query_hooks: Vec<Arc<dyn QueryHook>>,
595}
596
597impl Parser {
598 fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
599 let sql_lower = sql.to_lowercase();
601 let sql_trimmed = sql_lower.trim();
602
603 if matches!(
604 sql_trimmed,
605 "" | "begin"
606 | "begin transaction"
607 | "begin work"
608 | "start transaction"
609 | "commit"
610 | "commit transaction"
611 | "commit work"
612 | "end"
613 | "end transaction"
614 | "rollback"
615 | "rollback transaction"
616 | "rollback work"
617 | "abort"
618 ) {
619 let dummy_schema = datafusion::common::DFSchema::empty();
621 return Ok(Some(LogicalPlan::EmptyRelation(
622 datafusion::logical_expr::EmptyRelation {
623 produce_one_row: false,
624 schema: Arc::new(dummy_schema),
625 },
626 )));
627 }
628
629 if sql_trimmed.starts_with("show") {
631 let show_schema =
632 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
633 let df_schema = show_schema.to_dfschema()?;
634 return Ok(Some(LogicalPlan::EmptyRelation(
635 datafusion::logical_expr::EmptyRelation {
636 produce_one_row: true,
637 schema: Arc::new(df_schema),
638 },
639 )));
640 }
641
642 Ok(None)
643 }
644}
645
646#[async_trait]
647impl QueryParser for Parser {
648 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
649
650 async fn parse_sql<C>(
651 &self,
652 client: &C,
653 sql: &str,
654 _types: &[Option<Type>],
655 ) -> PgWireResult<Self::Statement>
656 where
657 C: ClientInfo + Unpin + Send + Sync,
658 {
659 log::debug!("Received parse extended query: {sql}"); let mut statements = self
662 .sql_parser
663 .parse(sql)
664 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
665 if statements.is_empty() {
666 return Ok((sql.to_string(), None));
667 }
668
669 let statement = statements.remove(0);
670
671 if let Some(plan) = self
673 .try_shortcut_parse_plan(sql)
674 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
675 {
676 return Ok((sql.to_string(), Some((statement, plan))));
677 }
678
679 let query = statement.to_string();
680
681 let context = &self.session_context;
682 let state = context.state();
683
684 for hook in &self.query_hooks {
685 if let Some(logical_plan) = hook
686 .handle_extended_parse_query(&statement, context, client)
687 .await
688 {
689 return Ok((query, Some((statement, logical_plan?))));
690 }
691 }
692
693 let logical_plan = state
694 .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
695 .await
696 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
697 Ok((query, Some((statement, logical_plan))))
698 }
699}
700
701fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
702 let mut types = types.iter().collect::<Vec<_>>();
705 types.sort_by(|a, b| a.0.cmp(b.0));
706 types.into_iter().map(|pt| pt.1.as_ref()).collect()
707}
708
709#[cfg(test)]
710mod tests {
711 use datafusion::prelude::SessionContext;
712
713 use super::*;
714 use crate::testing::MockClient;
715
716 struct TestHook;
717
718 #[async_trait]
719 impl QueryHook for TestHook {
720 async fn handle_simple_query(
721 &self,
722 statement: &sqlparser::ast::Statement,
723 _ctx: &SessionContext,
724 _client: &mut (dyn ClientInfo + Sync + Send),
725 ) -> Option<PgWireResult<Response>> {
726 if statement.to_string().contains("magic") {
727 Some(Ok(Response::EmptyQuery))
728 } else {
729 None
730 }
731 }
732
733 async fn handle_extended_parse_query(
734 &self,
735 _statement: &sqlparser::ast::Statement,
736 _session_context: &SessionContext,
737 _client: &(dyn ClientInfo + Send + Sync),
738 ) -> Option<PgWireResult<LogicalPlan>> {
739 None
740 }
741
742 async fn handle_extended_query(
743 &self,
744 _statement: &sqlparser::ast::Statement,
745 _logical_plan: &LogicalPlan,
746 _params: &ParamValues,
747 _session_context: &SessionContext,
748 _client: &mut (dyn ClientInfo + Send + Sync),
749 ) -> Option<PgWireResult<Response>> {
750 None
751 }
752 }
753
754 #[tokio::test]
755 async fn test_query_hooks() {
756 let hook = TestHook;
757 let ctx = SessionContext::new();
758 let mut client = MockClient::new();
759
760 let parser = PostgresCompatibilityParser::new();
762 let statements = parser.parse("SELECT magic").unwrap();
763 let stmt = &statements[0];
764
765 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
767 assert!(result.is_some());
768
769 let statements = parser.parse("SELECT 1").unwrap();
771 let stmt = &statements[0];
772
773 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
775 assert!(result.is_none());
776 }
777
778 #[tokio::test]
779 async fn test_multiple_statements_with_hook_continue() {
780 let session_context = Arc::new(SessionContext::new());
784 let auth_manager = Arc::new(AuthManager::new());
785
786 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
787 let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
788
789 let mut client = MockClient::new();
790
791 let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
793
794 let results =
795 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
796 .await
797 .unwrap();
798
799 assert_eq!(results.len(), 4, "Expected 4 responses");
800
801 assert!(matches!(results[0], Response::EmptyQuery));
802 assert!(matches!(results[1], Response::Query(_)));
803 assert!(matches!(results[2], Response::EmptyQuery));
804 assert!(matches!(results[3], Response::Query(_)));
805 }
806}