1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::common::ParamValues;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use datafusion::sql::parser::Statement;
10use datafusion::sql::sqlparser;
11use log::info;
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{FieldInfo, Response, Tag};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use pgwire::types::format::FormatOptions;
21
22use crate::client;
23use crate::hooks::set_show::SetShowHook;
24use crate::hooks::transactions::TransactionStatementHook;
25use crate::hooks::QueryHook;
26use arrow_pg::datatypes::df;
27use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
28use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
29
30pub struct SimpleStartupHandler;
32
33#[async_trait::async_trait]
34impl NoopStartupHandler for SimpleStartupHandler {}
35
36pub struct HandlerFactory {
37 pub session_service: Arc<DfSessionService>,
38}
39
40impl HandlerFactory {
41 pub fn new(session_context: Arc<SessionContext>) -> Self {
42 let session_service = Arc::new(DfSessionService::new(session_context));
43 HandlerFactory { session_service }
44 }
45
46 pub fn new_with_hooks(
47 session_context: Arc<SessionContext>,
48 query_hooks: Vec<Arc<dyn QueryHook>>,
49 ) -> Self {
50 let session_service = Arc::new(DfSessionService::new_with_hooks(
51 session_context,
52 query_hooks,
53 ));
54 HandlerFactory { session_service }
55 }
56}
57
58impl PgWireServerHandlers for HandlerFactory {
59 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
60 self.session_service.clone()
61 }
62
63 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
64 self.session_service.clone()
65 }
66
67 fn startup_handler(&self) -> Arc<impl StartupHandler> {
68 Arc::new(SimpleStartupHandler)
69 }
70
71 fn error_handler(&self) -> Arc<impl ErrorHandler> {
72 Arc::new(LoggingErrorHandler)
73 }
74}
75
76struct LoggingErrorHandler;
77
78impl ErrorHandler for LoggingErrorHandler {
79 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
80 where
81 C: ClientInfo,
82 {
83 info!("Sending error: {error}")
84 }
85}
86
87pub struct DfSessionService {
89 session_context: Arc<SessionContext>,
90 parser: Arc<Parser>,
91 query_hooks: Vec<Arc<dyn QueryHook>>,
92}
93
94impl DfSessionService {
95 pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
96 let hooks: Vec<Arc<dyn QueryHook>> =
97 vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)];
98 Self::new_with_hooks(session_context, hooks)
99 }
100
101 pub fn new_with_hooks(
102 session_context: Arc<SessionContext>,
103 query_hooks: Vec<Arc<dyn QueryHook>>,
104 ) -> DfSessionService {
105 let parser = Arc::new(Parser {
106 session_context: session_context.clone(),
107 sql_parser: PostgresCompatibilityParser::new(),
108 query_hooks: query_hooks.clone(),
109 });
110 DfSessionService {
111 session_context,
112 parser,
113 query_hooks,
114 }
115 }
116}
117
118#[async_trait]
119impl SimpleQueryHandler for DfSessionService {
120 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
121 where
122 C: ClientInfo + Unpin + Send + Sync,
123 {
124 log::debug!("Received query: {query}"); let statements = self
127 .parser
128 .sql_parser
129 .parse(query)
130 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
131
132 if statements.is_empty() {
134 return Ok(vec![Response::EmptyQuery]);
135 }
136
137 let mut results = vec![];
138 'stmt: for statement in statements {
139 let query = statement.to_string();
140
141 for hook in &self.query_hooks {
143 if let Some(result) = hook
144 .handle_simple_query(&statement, &self.session_context, client)
145 .await
146 {
147 results.push(result?);
148 continue 'stmt;
149 }
150 }
151
152 let df_result = {
153 let timeout = client::get_statement_timeout(client);
154 if let Some(timeout_duration) = timeout {
155 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
156 .await
157 .map_err(|_| {
158 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
159 "ERROR".to_string(),
160 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
162 )))
163 })?
164 } else {
165 self.session_context.sql(&query).await
166 }
167 };
168
169 let df = match df_result {
171 Ok(df) => df,
172 Err(e) => {
173 return Err(PgWireError::ApiError(Box::new(e)));
174 }
175 };
176
177 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
178 let resp = map_rows_affected_for_insert(&df).await?;
179 results.push(resp);
180 } else {
181 let format_options =
183 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
184 let resp =
185 df::encode_dataframe(df, &Format::UnifiedText, Some(format_options)).await?;
186 results.push(Response::Query(resp));
187 }
188 }
189 Ok(results)
190 }
191}
192
193#[async_trait]
194impl ExtendedQueryHandler for DfSessionService {
195 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
196 type QueryParser = Parser;
197
198 fn query_parser(&self) -> Arc<Self::QueryParser> {
199 self.parser.clone()
200 }
201
202 async fn do_query<C>(
203 &self,
204 client: &mut C,
205 portal: &Portal<Self::Statement>,
206 _max_rows: usize,
207 ) -> PgWireResult<Response>
208 where
209 C: ClientInfo + Unpin + Send + Sync,
210 {
211 let query = &portal.statement.statement.0;
212 log::debug!("Received execute extended query: {query}"); if !self.query_hooks.is_empty() {
216 if let (_, Some((statement, plan))) = &portal.statement.statement {
217 let param_types = plan
219 .get_parameter_types()
220 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
221
222 let param_values: ParamValues =
223 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
224
225 for hook in &self.query_hooks {
226 if let Some(result) = hook
227 .handle_extended_query(
228 statement,
229 plan,
230 ¶m_values,
231 &self.session_context,
232 client,
233 )
234 .await
235 {
236 return result;
237 }
238 }
239 }
240 }
241
242 if let (_, Some((statement, plan))) = &portal.statement.statement {
243 let param_types = plan
244 .get_parameter_types()
245 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
246
247 let param_values =
248 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
251 .clone()
252 .replace_params_with_values(¶m_values)
253 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let optimised = self
256 .session_context
257 .state()
258 .optimize(&plan)
259 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
260
261 let dataframe = {
262 let timeout = client::get_statement_timeout(client);
263 if let Some(timeout_duration) = timeout {
264 tokio::time::timeout(
265 timeout_duration,
266 self.session_context.execute_logical_plan(optimised),
267 )
268 .await
269 .map_err(|_| {
270 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
271 "ERROR".to_string(),
272 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
274 )))
275 })?
276 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
277 } else {
278 self.session_context
279 .execute_logical_plan(optimised)
280 .await
281 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
282 }
283 };
284
285 if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
286 let resp = map_rows_affected_for_insert(&dataframe).await?;
287
288 Ok(resp)
289 } else {
290 let format_options =
292 Arc::new(FormatOptions::from_client_metadata(client.metadata()));
293 let resp = df::encode_dataframe(
294 dataframe,
295 &portal.result_column_format,
296 Some(format_options),
297 )
298 .await?;
299 Ok(Response::Query(resp))
300 }
301 } else {
302 Ok(Response::EmptyQuery)
303 }
304 }
305}
306
307async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
308 let result = df
311 .clone()
312 .collect()
313 .await
314 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
315
316 let rows_affected = result
318 .first()
319 .and_then(|batch| batch.column_by_name("count"))
320 .and_then(|col| {
321 col.as_any()
322 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
323 })
324 .map_or(0, |array| array.value(0) as usize);
325
326 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
328 Ok(Response::Execution(tag))
329}
330
331pub struct Parser {
332 session_context: Arc<SessionContext>,
333 sql_parser: PostgresCompatibilityParser,
334 query_hooks: Vec<Arc<dyn QueryHook>>,
335}
336
337#[async_trait]
338impl QueryParser for Parser {
339 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
340
341 async fn parse_sql<C>(
342 &self,
343 client: &C,
344 sql: &str,
345 _types: &[Option<Type>],
346 ) -> PgWireResult<Self::Statement>
347 where
348 C: ClientInfo + Unpin + Send + Sync,
349 {
350 log::debug!("Received parse extended query: {sql}"); let mut statements = self
353 .sql_parser
354 .parse(sql)
355 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
356 if statements.is_empty() {
357 return Ok((sql.to_string(), None));
358 }
359
360 let statement = statements.remove(0);
361 let query = statement.to_string();
362
363 let context = &self.session_context;
364 let state = context.state();
365
366 for hook in &self.query_hooks {
367 if let Some(logical_plan) = hook
368 .handle_extended_parse_query(&statement, context, client)
369 .await
370 {
371 return Ok((query, Some((statement, logical_plan?))));
372 }
373 }
374
375 let logical_plan = state
376 .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
377 .await
378 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
379 Ok((query, Some((statement, logical_plan))))
380 }
381
382 fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
383 if let (_, Some((_, plan))) = stmt {
384 let params = plan
385 .get_parameter_types()
386 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
387
388 let mut param_types = Vec::with_capacity(params.len());
389 for param_type in ordered_param_types(¶ms).iter() {
390 if let Some(datatype) = param_type {
392 let pgtype = into_pg_type(datatype)?;
393 param_types.push(pgtype);
394 } else {
395 param_types.push(Type::UNKNOWN);
396 }
397 }
398
399 Ok(param_types)
400 } else {
401 Ok(vec![])
402 }
403 }
404
405 fn get_result_schema(
406 &self,
407 stmt: &Self::Statement,
408 column_format: Option<&Format>,
409 ) -> PgWireResult<Vec<FieldInfo>> {
410 if let (_, Some((_, plan))) = stmt {
411 let schema = plan.schema();
412 let fields = arrow_schema_to_pg_fields(
413 schema.as_arrow(),
414 column_format.unwrap_or(&Format::UnifiedBinary),
415 None,
416 )?;
417
418 Ok(fields)
419 } else {
420 Ok(vec![])
421 }
422 }
423}
424
425fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
426 let mut types = types.iter().collect::<Vec<_>>();
429 types.sort_by(|a, b| a.0.cmp(b.0));
430 types.into_iter().map(|pt| pt.1.as_ref()).collect()
431}
432
433#[cfg(test)]
434mod tests {
435 use datafusion::prelude::SessionContext;
436
437 use super::*;
438 use crate::testing::MockClient;
439
440 struct TestHook;
441
442 #[async_trait]
443 impl QueryHook for TestHook {
444 async fn handle_simple_query(
445 &self,
446 statement: &sqlparser::ast::Statement,
447 _ctx: &SessionContext,
448 _client: &mut (dyn ClientInfo + Sync + Send),
449 ) -> Option<PgWireResult<Response>> {
450 if statement.to_string().contains("magic") {
451 Some(Ok(Response::EmptyQuery))
452 } else {
453 None
454 }
455 }
456
457 async fn handle_extended_parse_query(
458 &self,
459 _statement: &sqlparser::ast::Statement,
460 _session_context: &SessionContext,
461 _client: &(dyn ClientInfo + Send + Sync),
462 ) -> Option<PgWireResult<LogicalPlan>> {
463 None
464 }
465
466 async fn handle_extended_query(
467 &self,
468 _statement: &sqlparser::ast::Statement,
469 _logical_plan: &LogicalPlan,
470 _params: &ParamValues,
471 _session_context: &SessionContext,
472 _client: &mut (dyn ClientInfo + Send + Sync),
473 ) -> Option<PgWireResult<Response>> {
474 None
475 }
476 }
477
478 #[tokio::test]
479 async fn test_query_hooks() {
480 let hook = TestHook;
481 let ctx = SessionContext::new();
482 let mut client = MockClient::new();
483
484 let parser = PostgresCompatibilityParser::new();
486 let statements = parser.parse("SELECT magic").unwrap();
487 let stmt = &statements[0];
488
489 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
491 assert!(result.is_some());
492
493 let statements = parser.parse("SELECT 1").unwrap();
495 let stmt = &statements[0];
496
497 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
499 assert!(result.is_none());
500 }
501
502 #[tokio::test]
503 async fn test_multiple_statements_with_hook_continue() {
504 let session_context = Arc::new(SessionContext::new());
508
509 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
510 let service = DfSessionService::new_with_hooks(session_context, hooks);
511
512 let mut client = MockClient::new();
513
514 let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
516
517 let results =
518 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
519 .await
520 .unwrap();
521
522 assert_eq!(results.len(), 4, "Expected 4 responses");
523
524 assert!(matches!(results[0], Response::EmptyQuery));
525 assert!(matches!(results[1], Response::Query(_)));
526 assert!(matches!(results[2], Response::EmptyQuery));
527 assert!(matches!(results[3], Response::Query(_)));
528 }
529}