1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::ToDFSchema;
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use log::{info, warn};
12use pgwire::api::auth::noop::NoopStartupHandler;
13use pgwire::api::auth::StartupHandler;
14use pgwire::api::portal::{Format, Portal};
15use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
16use pgwire::api::results::{
17 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
18 Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25use tokio::sync::Mutex;
26
27use crate::auth::AuthManager;
28use arrow_pg::datatypes::df;
29use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
30use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
31use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
32
33const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms";
35
36pub struct SimpleStartupHandler;
39
40#[async_trait::async_trait]
41impl NoopStartupHandler for SimpleStartupHandler {}
42
43pub struct HandlerFactory {
44 pub session_service: Arc<DfSessionService>,
45}
46
47impl HandlerFactory {
48 pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
49 let session_service =
50 Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
51 HandlerFactory { session_service }
52 }
53}
54
55impl PgWireServerHandlers for HandlerFactory {
56 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
57 self.session_service.clone()
58 }
59
60 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
61 self.session_service.clone()
62 }
63
64 fn startup_handler(&self) -> Arc<impl StartupHandler> {
65 Arc::new(SimpleStartupHandler)
66 }
67
68 fn error_handler(&self) -> Arc<impl ErrorHandler> {
69 Arc::new(LoggingErrorHandler)
70 }
71}
72
73struct LoggingErrorHandler;
74
75impl ErrorHandler for LoggingErrorHandler {
76 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
77 where
78 C: ClientInfo,
79 {
80 info!("Sending error: {error}")
81 }
82}
83
84pub struct DfSessionService {
86 session_context: Arc<SessionContext>,
87 parser: Arc<Parser>,
88 timezone: Arc<Mutex<String>>,
89 auth_manager: Arc<AuthManager>,
90}
91
92impl DfSessionService {
93 pub fn new(
94 session_context: Arc<SessionContext>,
95 auth_manager: Arc<AuthManager>,
96 ) -> DfSessionService {
97 let parser = Arc::new(Parser {
98 session_context: session_context.clone(),
99 sql_parser: PostgresCompatibilityParser::new(),
100 });
101 DfSessionService {
102 session_context,
103 parser,
104 timezone: Arc::new(Mutex::new("UTC".to_string())),
105 auth_manager,
106 }
107 }
108
109 fn get_statement_timeout<C>(client: &C) -> Option<std::time::Duration>
111 where
112 C: ClientInfo,
113 {
114 client
115 .metadata()
116 .get(METADATA_STATEMENT_TIMEOUT)
117 .and_then(|s| s.parse::<u64>().ok())
118 .map(std::time::Duration::from_millis)
119 }
120
121 fn set_statement_timeout<C>(client: &mut C, timeout: Option<std::time::Duration>)
123 where
124 C: ClientInfo,
125 {
126 let metadata = client.metadata_mut();
127 if let Some(duration) = timeout {
128 metadata.insert(
129 METADATA_STATEMENT_TIMEOUT.to_string(),
130 duration.as_millis().to_string(),
131 );
132 } else {
133 metadata.remove(METADATA_STATEMENT_TIMEOUT);
134 }
135 }
136
137 async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
139 where
140 C: ClientInfo,
141 {
142 let username = client
144 .metadata()
145 .get("user")
146 .map(|s| s.as_str())
147 .unwrap_or("anonymous");
148
149 let query_lower = query.to_lowercase();
151 let query_trimmed = query_lower.trim();
152
153 let (required_permission, resource) = if query_trimmed.starts_with("select") {
154 (Permission::Select, self.extract_table_from_query(query))
155 } else if query_trimmed.starts_with("insert") {
156 (Permission::Insert, self.extract_table_from_query(query))
157 } else if query_trimmed.starts_with("update") {
158 (Permission::Update, self.extract_table_from_query(query))
159 } else if query_trimmed.starts_with("delete") {
160 (Permission::Delete, self.extract_table_from_query(query))
161 } else if query_trimmed.starts_with("create table")
162 || query_trimmed.starts_with("create view")
163 {
164 (Permission::Create, ResourceType::All)
165 } else if query_trimmed.starts_with("drop") {
166 (Permission::Drop, self.extract_table_from_query(query))
167 } else if query_trimmed.starts_with("alter") {
168 (Permission::Alter, self.extract_table_from_query(query))
169 } else {
170 return Ok(());
172 };
173
174 let has_permission = self
176 .auth_manager
177 .check_permission(username, required_permission, resource)
178 .await;
179
180 if !has_permission {
181 return Err(PgWireError::UserError(Box::new(
182 pgwire::error::ErrorInfo::new(
183 "ERROR".to_string(),
184 "42501".to_string(), format!("permission denied for user \"{username}\""),
186 ),
187 )));
188 }
189
190 Ok(())
191 }
192
193 fn extract_table_from_query(&self, query: &str) -> ResourceType {
195 let words: Vec<&str> = query.split_whitespace().collect();
196
197 for (i, word) in words.iter().enumerate() {
199 let word_lower = word.to_lowercase();
200 if (word_lower == "from" || word_lower == "into" || word_lower == "table")
201 && i + 1 < words.len()
202 {
203 let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
204 return ResourceType::Table(table_name.to_string());
205 }
206 }
207
208 ResourceType::All
210 }
211
212 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
213 let fields = vec![FieldInfo::new(
214 name.to_string(),
215 None,
216 None,
217 Type::VARCHAR,
218 FieldFormat::Text,
219 )];
220
221 let row = {
222 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
223 encoder.encode_field(&Some(value))?;
224 encoder.finish()
225 };
226
227 let row_stream = futures::stream::once(async move { row });
228 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
229 }
230
231 async fn try_respond_set_statements<'a, C>(
232 &self,
233 client: &mut C,
234 query_lower: &str,
235 ) -> PgWireResult<Option<Response<'a>>>
236 where
237 C: ClientInfo,
238 {
239 if query_lower.starts_with("set") {
240 if query_lower.starts_with("set time zone") {
241 let parts: Vec<&str> = query_lower.split_whitespace().collect();
242 if parts.len() >= 4 {
243 let tz = parts[3].trim_matches('"');
244 let mut timezone = self.timezone.lock().await;
245 *timezone = tz.to_string();
246 Ok(Some(Response::Execution(Tag::new("SET"))))
247 } else {
248 Err(PgWireError::UserError(Box::new(
249 pgwire::error::ErrorInfo::new(
250 "ERROR".to_string(),
251 "42601".to_string(),
252 "Invalid SET TIME ZONE syntax".to_string(),
253 ),
254 )))
255 }
256 } else if query_lower.starts_with("set statement_timeout") {
257 let parts: Vec<&str> = query_lower.split_whitespace().collect();
258 if parts.len() >= 3 {
259 let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
260
261 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
262 None
263 } else {
264 let timeout_ms = if timeout_str.ends_with("ms") {
266 timeout_str.trim_end_matches("ms").parse::<u64>()
267 } else if timeout_str.ends_with("s") {
268 timeout_str
269 .trim_end_matches("s")
270 .parse::<u64>()
271 .map(|s| s * 1000)
272 } else if timeout_str.ends_with("min") {
273 timeout_str
274 .trim_end_matches("min")
275 .parse::<u64>()
276 .map(|m| m * 60 * 1000)
277 } else {
278 timeout_str.parse::<u64>()
280 };
281
282 match timeout_ms {
283 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
284 _ => None,
285 }
286 };
287
288 Self::set_statement_timeout(client, timeout);
289 Ok(Some(Response::Execution(Tag::new("SET"))))
290 } else {
291 Err(PgWireError::UserError(Box::new(
292 pgwire::error::ErrorInfo::new(
293 "ERROR".to_string(),
294 "42601".to_string(),
295 "Invalid SET statement_timeout syntax".to_string(),
296 ),
297 )))
298 }
299 } else {
300 if let Err(e) = self.session_context.sql(query_lower).await {
302 warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
303 }
304
305 Ok(Some(Response::Execution(Tag::new("SET"))))
307 }
308 } else {
309 Ok(None)
310 }
311 }
312
313 async fn try_respond_transaction_statements<'a, C>(
314 &self,
315 client: &C,
316 query_lower: &str,
317 ) -> PgWireResult<Option<Response<'a>>>
318 where
319 C: ClientInfo,
320 {
321 match query_lower.trim() {
324 "begin" | "begin transaction" | "begin work" | "start transaction" => {
325 match client.transaction_status() {
326 TransactionStatus::Idle => {
327 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
328 }
329 TransactionStatus::Transaction => {
330 log::warn!("BEGIN command ignored: already in transaction block");
333 Ok(Some(Response::Execution(Tag::new("BEGIN"))))
334 }
335 TransactionStatus::Error => {
336 Err(PgWireError::UserError(Box::new(
338 pgwire::error::ErrorInfo::new(
339 "ERROR".to_string(),
340 "25P01".to_string(),
341 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
342 ),
343 )))
344 }
345 }
346 }
347 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
348 match client.transaction_status() {
349 TransactionStatus::Idle | TransactionStatus::Transaction => {
350 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
351 }
352 TransactionStatus::Error => {
353 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
354 }
355 }
356 }
357 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
358 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
359 }
360 _ => Ok(None),
361 }
362 }
363
364 async fn try_respond_show_statements<'a, C>(
365 &self,
366 client: &C,
367 query_lower: &str,
368 ) -> PgWireResult<Option<Response<'a>>>
369 where
370 C: ClientInfo,
371 {
372 if query_lower.starts_with("show ") {
373 match query_lower.strip_suffix(";").unwrap_or(query_lower) {
374 "show time zone" => {
375 let timezone = self.timezone.lock().await.clone();
376 let resp = Self::mock_show_response("TimeZone", &timezone)?;
377 Ok(Some(Response::Query(resp)))
378 }
379 "show server_version" => {
380 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
381 Ok(Some(Response::Query(resp)))
382 }
383 "show transaction_isolation" => {
384 let resp =
385 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
386 Ok(Some(Response::Query(resp)))
387 }
388 "show catalogs" => {
389 let catalogs = self.session_context.catalog_names();
390 let value = catalogs.join(", ");
391 let resp = Self::mock_show_response("Catalogs", &value)?;
392 Ok(Some(Response::Query(resp)))
393 }
394 "show search_path" => {
395 let default_schema = "public";
396 let resp = Self::mock_show_response("search_path", default_schema)?;
397 Ok(Some(Response::Query(resp)))
398 }
399 "show statement_timeout" => {
400 let timeout = Self::get_statement_timeout(client);
401 let timeout_str = match timeout {
402 Some(duration) => format!("{}ms", duration.as_millis()),
403 None => "0".to_string(),
404 };
405 let resp = Self::mock_show_response("statement_timeout", &timeout_str)?;
406 Ok(Some(Response::Query(resp)))
407 }
408 "show transaction isolation level" => {
409 let resp = Self::mock_show_response("transaction_isolation", "read_committed")?;
410 Ok(Some(Response::Query(resp)))
411 }
412 _ => {
413 info!("Unsupported show statement: {query_lower}");
414 let resp = Self::mock_show_response("unsupported_show_statement", "")?;
415 Ok(Some(Response::Query(resp)))
416 }
417 }
418 } else {
419 Ok(None)
420 }
421 }
422}
423
424#[async_trait]
425impl SimpleQueryHandler for DfSessionService {
426 async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
427 where
428 C: ClientInfo + Unpin + Send + Sync,
429 {
430 log::debug!("Received query: {query}"); let query_lower = query.to_lowercase().trim().to_string();
434 if let Some(resp) = self
435 .try_respond_transaction_statements(client, &query_lower)
436 .await?
437 {
438 return Ok(vec![resp]);
439 }
440
441 let mut statements = self
442 .parser
443 .sql_parser
444 .parse(query)
445 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
446
447 let statement = statements.remove(0);
449
450 let query = statement.to_string();
452 let query_lower = query.to_lowercase().trim().to_string();
453
454 if !query_lower.starts_with("set")
456 && !query_lower.starts_with("begin")
457 && !query_lower.starts_with("commit")
458 && !query_lower.starts_with("rollback")
459 && !query_lower.starts_with("start")
460 && !query_lower.starts_with("end")
461 && !query_lower.starts_with("abort")
462 && !query_lower.starts_with("show")
463 {
464 self.check_query_permission(client, &query).await?;
465 }
466
467 if let Some(resp) = self
468 .try_respond_set_statements(client, &query_lower)
469 .await?
470 {
471 return Ok(vec![resp]);
472 }
473
474 if let Some(resp) = self
475 .try_respond_show_statements(client, &query_lower)
476 .await?
477 {
478 return Ok(vec![resp]);
479 }
480
481 if client.transaction_status() == TransactionStatus::Error {
484 return Err(PgWireError::UserError(Box::new(
485 pgwire::error::ErrorInfo::new(
486 "ERROR".to_string(),
487 "25P01".to_string(),
488 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
489 ),
490 )));
491 }
492
493 let df_result = {
494 let timeout = Self::get_statement_timeout(client);
495 if let Some(timeout_duration) = timeout {
496 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
497 .await
498 .map_err(|_| {
499 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
500 "ERROR".to_string(),
501 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
503 )))
504 })?
505 } else {
506 self.session_context.sql(&query).await
507 }
508 };
509
510 let df = match df_result {
512 Ok(df) => df,
513 Err(e) => {
514 return Err(PgWireError::ApiError(Box::new(e)));
515 }
516 };
517
518 if query_lower.starts_with("insert into") {
519 let result = df
522 .clone()
523 .collect()
524 .await
525 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
526
527 let rows_affected = result
529 .first()
530 .and_then(|batch| batch.column_by_name("count"))
531 .and_then(|col| {
532 col.as_any()
533 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
534 })
535 .map_or(0, |array| array.value(0) as usize);
536
537 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
539 Ok(vec![Response::Execution(tag)])
540 } else {
541 let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
543 Ok(vec![Response::Query(resp)])
544 }
545 }
546}
547
548#[async_trait]
549impl ExtendedQueryHandler for DfSessionService {
550 type Statement = (String, LogicalPlan);
551 type QueryParser = Parser;
552
553 fn query_parser(&self) -> Arc<Self::QueryParser> {
554 self.parser.clone()
555 }
556
557 async fn do_describe_statement<C>(
558 &self,
559 _client: &mut C,
560 target: &StoredStatement<Self::Statement>,
561 ) -> PgWireResult<DescribeStatementResponse>
562 where
563 C: ClientInfo + Unpin + Send + Sync,
564 {
565 let (_, plan) = &target.statement;
566 let schema = plan.schema();
567 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
568 let params = plan
569 .get_parameter_types()
570 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
571
572 let mut param_types = Vec::with_capacity(params.len());
573 for param_type in ordered_param_types(¶ms).iter() {
574 if let Some(datatype) = param_type {
576 let pgtype = into_pg_type(datatype)?;
577 param_types.push(pgtype);
578 } else {
579 param_types.push(Type::UNKNOWN);
580 }
581 }
582
583 Ok(DescribeStatementResponse::new(param_types, fields))
584 }
585
586 async fn do_describe_portal<C>(
587 &self,
588 _client: &mut C,
589 target: &Portal<Self::Statement>,
590 ) -> PgWireResult<DescribePortalResponse>
591 where
592 C: ClientInfo + Unpin + Send + Sync,
593 {
594 let (_, plan) = &target.statement.statement;
595 let format = &target.result_column_format;
596 let schema = plan.schema();
597 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
598
599 Ok(DescribePortalResponse::new(fields))
600 }
601
602 async fn do_query<'a, C>(
603 &self,
604 client: &mut C,
605 portal: &Portal<Self::Statement>,
606 _max_rows: usize,
607 ) -> PgWireResult<Response<'a>>
608 where
609 C: ClientInfo + Unpin + Send + Sync,
610 {
611 let query = portal
612 .statement
613 .statement
614 .0
615 .to_lowercase()
616 .trim()
617 .to_string();
618 log::debug!("Received execute extended query: {query}"); if !query.starts_with("set") && !query.starts_with("show") {
622 self.check_query_permission(client, &portal.statement.statement.0)
623 .await?;
624 }
625
626 if let Some(resp) = self.try_respond_set_statements(client, &query).await? {
627 return Ok(resp);
628 }
629
630 if let Some(resp) = self
631 .try_respond_transaction_statements(client, &query)
632 .await?
633 {
634 return Ok(resp);
635 }
636
637 if let Some(resp) = self.try_respond_show_statements(client, &query).await? {
638 return Ok(resp);
639 }
640
641 if client.transaction_status() == TransactionStatus::Error {
644 return Err(PgWireError::UserError(Box::new(
645 pgwire::error::ErrorInfo::new(
646 "ERROR".to_string(),
647 "25P01".to_string(),
648 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
649 ),
650 )));
651 }
652
653 let (_, plan) = &portal.statement.statement;
654
655 let param_types = plan
656 .get_parameter_types()
657 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
658
659 let param_values = df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
662 .clone()
663 .replace_params_with_values(¶m_values)
664 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let optimised = self
667 .session_context
668 .state()
669 .optimize(&plan)
670 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
671
672 let dataframe = {
673 let timeout = Self::get_statement_timeout(client);
674 if let Some(timeout_duration) = timeout {
675 tokio::time::timeout(
676 timeout_duration,
677 self.session_context.execute_logical_plan(optimised),
678 )
679 .await
680 .map_err(|_| {
681 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
682 "ERROR".to_string(),
683 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
685 )))
686 })?
687 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
688 } else {
689 self.session_context
690 .execute_logical_plan(optimised)
691 .await
692 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
693 }
694 };
695 let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
696 Ok(Response::Query(resp))
697 }
698}
699
700pub struct Parser {
701 session_context: Arc<SessionContext>,
702 sql_parser: PostgresCompatibilityParser,
703}
704
705impl Parser {
706 fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
707 let sql_lower = sql.to_lowercase();
709 let sql_trimmed = sql_lower.trim();
710
711 if matches!(
712 sql_trimmed,
713 "" | "begin"
714 | "begin transaction"
715 | "begin work"
716 | "start transaction"
717 | "commit"
718 | "commit transaction"
719 | "commit work"
720 | "end"
721 | "end transaction"
722 | "rollback"
723 | "rollback transaction"
724 | "rollback work"
725 | "abort"
726 ) {
727 let dummy_schema = datafusion::common::DFSchema::empty();
729 return Ok(Some(LogicalPlan::EmptyRelation(
730 datafusion::logical_expr::EmptyRelation {
731 produce_one_row: false,
732 schema: Arc::new(dummy_schema),
733 },
734 )));
735 }
736
737 if sql_trimmed.starts_with("show") {
739 let show_schema =
741 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
742 let df_schema = show_schema.to_dfschema()?;
743 return Ok(Some(LogicalPlan::EmptyRelation(
744 datafusion::logical_expr::EmptyRelation {
745 produce_one_row: true,
746 schema: Arc::new(df_schema),
747 },
748 )));
749 }
750
751 Ok(None)
752 }
753}
754
755#[async_trait]
756impl QueryParser for Parser {
757 type Statement = (String, LogicalPlan);
758
759 async fn parse_sql<C>(
760 &self,
761 _client: &C,
762 sql: &str,
763 _types: &[Type],
764 ) -> PgWireResult<Self::Statement> {
765 log::debug!("Received parse extended query: {sql}"); if let Some(plan) = self
769 .try_shortcut_parse_plan(sql)
770 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
771 {
772 return Ok((sql.to_string(), plan));
773 }
774
775 let mut statements = self
776 .sql_parser
777 .parse(sql)
778 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
779 let statement = statements.remove(0);
780
781 let query = statement.to_string();
782
783 let context = &self.session_context;
784 let state = context.state();
785 let logical_plan = state
786 .statement_to_plan(Statement::Statement(Box::new(statement)))
787 .await
788 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
789 Ok((query, logical_plan))
790 }
791}
792
793fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
794 let mut types = types.iter().collect::<Vec<_>>();
797 types.sort_by(|a, b| a.0.cmp(b.0));
798 types.into_iter().map(|pt| pt.1.as_ref()).collect()
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use crate::auth::AuthManager;
805 use datafusion::prelude::SessionContext;
806 use std::collections::HashMap;
807 use std::time::Duration;
808
809 struct MockClient {
810 metadata: HashMap<String, String>,
811 }
812
813 impl MockClient {
814 fn new() -> Self {
815 Self {
816 metadata: HashMap::new(),
817 }
818 }
819 }
820
821 impl ClientInfo for MockClient {
822 fn socket_addr(&self) -> std::net::SocketAddr {
823 "127.0.0.1:5432".parse().unwrap()
824 }
825
826 fn is_secure(&self) -> bool {
827 false
828 }
829
830 fn protocol_version(&self) -> pgwire::messages::ProtocolVersion {
831 pgwire::messages::ProtocolVersion::PROTOCOL3_0
832 }
833
834 fn set_protocol_version(&mut self, _version: pgwire::messages::ProtocolVersion) {}
835
836 fn pid_and_secret_key(&self) -> (i32, pgwire::messages::startup::SecretKey) {
837 (0, pgwire::messages::startup::SecretKey::I32(0))
838 }
839
840 fn set_pid_and_secret_key(
841 &mut self,
842 _pid: i32,
843 _secret_key: pgwire::messages::startup::SecretKey,
844 ) {
845 }
846
847 fn state(&self) -> pgwire::api::PgWireConnectionState {
848 pgwire::api::PgWireConnectionState::ReadyForQuery
849 }
850
851 fn set_state(&mut self, _new_state: pgwire::api::PgWireConnectionState) {}
852
853 fn transaction_status(&self) -> pgwire::messages::response::TransactionStatus {
854 pgwire::messages::response::TransactionStatus::Idle
855 }
856
857 fn set_transaction_status(
858 &mut self,
859 _new_status: pgwire::messages::response::TransactionStatus,
860 ) {
861 }
862
863 fn metadata(&self) -> &HashMap<String, String> {
864 &self.metadata
865 }
866
867 fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
868 &mut self.metadata
869 }
870
871 fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
872 None
873 }
874 }
875
876 #[tokio::test]
877 async fn test_statement_timeout_set_and_show() {
878 let session_context = Arc::new(SessionContext::new());
879 let auth_manager = Arc::new(AuthManager::new());
880 let service = DfSessionService::new(session_context, auth_manager);
881 let mut client = MockClient::new();
882
883 let set_response = service
885 .try_respond_set_statements(&mut client, "set statement_timeout '5000ms'")
886 .await
887 .unwrap();
888 assert!(set_response.is_some());
889
890 let timeout = DfSessionService::get_statement_timeout(&client);
892 assert_eq!(timeout, Some(Duration::from_millis(5000)));
893
894 let show_response = service
896 .try_respond_show_statements(&client, "show statement_timeout")
897 .await
898 .unwrap();
899 assert!(show_response.is_some());
900 }
901
902 #[tokio::test]
903 async fn test_statement_timeout_disable() {
904 let session_context = Arc::new(SessionContext::new());
905 let auth_manager = Arc::new(AuthManager::new());
906 let service = DfSessionService::new(session_context, auth_manager);
907 let mut client = MockClient::new();
908
909 service
911 .try_respond_set_statements(&mut client, "set statement_timeout '1000ms'")
912 .await
913 .unwrap();
914
915 service
917 .try_respond_set_statements(&mut client, "set statement_timeout '0'")
918 .await
919 .unwrap();
920
921 let timeout = DfSessionService::get_statement_timeout(&client);
922 assert_eq!(timeout, None);
923 }
924}