1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::ClientInfo;
12use pgwire::api::auth::DefaultServerParameterProvider;
13use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::messages::PgWireBackendMessage;
16use pgwire::messages::startup::ParameterStatus;
17use pgwire::types::format::FormatOptions;
18use postgres_types::Type;
19
20use crate::QueryHook;
21use crate::client;
22use crate::hooks::HookClient;
23
24#[derive(Debug)]
25pub struct SetShowHook;
26
27#[async_trait]
28impl QueryHook for SetShowHook {
29 async fn handle_simple_query(
31 &self,
32 statement: &Statement,
33 session_context: &SessionContext,
34 client: &mut dyn HookClient,
35 ) -> Option<PgWireResult<Response>> {
36 match statement {
37 Statement::Set { .. } => {
38 try_respond_set_statements(client, statement, session_context).await
39 }
40 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
41 try_respond_show_statements(client, statement, session_context).await
42 }
43 _ => None,
44 }
45 }
46
47 async fn handle_extended_parse_query(
48 &self,
49 stmt: &Statement,
50 _session_context: &SessionContext,
51 _client: &(dyn ClientInfo + Send + Sync),
52 ) -> Option<PgWireResult<LogicalPlan>> {
53 match stmt {
54 Statement::Set { .. } => {
55 let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
56 let result = show_schema
57 .to_dfschema()
58 .map(|df_schema| {
59 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
60 produce_one_row: true,
61 schema: Arc::new(df_schema),
62 })
63 })
64 .map_err(|e| PgWireError::ApiError(Box::new(e)));
65 Some(result)
66 }
67 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
68 let show_schema =
69 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
70 let result = show_schema
71 .to_dfschema()
72 .map(|df_schema| {
73 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
74 produce_one_row: true,
75 schema: Arc::new(df_schema),
76 })
77 })
78 .map_err(|e| PgWireError::ApiError(Box::new(e)));
79 Some(result)
80 }
81 _ => None,
82 }
83 }
84
85 async fn handle_extended_query(
86 &self,
87 statement: &Statement,
88 _logical_plan: &LogicalPlan,
89 _params: &ParamValues,
90 session_context: &SessionContext,
91 client: &mut dyn HookClient,
92 ) -> Option<PgWireResult<Response>> {
93 match statement {
94 Statement::Set { .. } => {
95 try_respond_set_statements(client, statement, session_context).await
96 }
97 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
98 try_respond_show_statements(client, statement, session_context).await
99 }
100 _ => None,
101 }
102 }
103}
104
105fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
106 let fields = vec![FieldInfo::new(
107 name.to_string(),
108 None,
109 None,
110 Type::VARCHAR,
111 FieldFormat::Text,
112 )];
113
114 let row = {
115 let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
116 encoder.encode_field(&Some(value))?;
117 Ok(encoder.take_row())
118 };
119
120 let row_stream = futures::stream::once(async move { row });
121 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122}
123
124async fn try_respond_set_statements(
125 client: &mut dyn HookClient,
126 statement: &Statement,
127 session_context: &SessionContext,
128) -> Option<PgWireResult<Response>> {
129 let Statement::Set(set_statement) = statement else {
130 return None;
131 };
132
133 match &set_statement {
134 Set::SingleAssignment {
135 scope: None,
136 hivevar: false,
137 variable,
138 values,
139 } => {
140 let var = variable.to_string().to_lowercase();
141 if var == "statement_timeout" {
142 let value = values[0].to_string();
143 let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146 None
147 } else {
148 let timeout_ms = if timeout_str.ends_with("ms") {
150 timeout_str.trim_end_matches("ms").parse::<u64>()
151 } else if timeout_str.ends_with("s") {
152 timeout_str
153 .trim_end_matches("s")
154 .parse::<u64>()
155 .map(|s| s * 1000)
156 } else if timeout_str.ends_with("min") {
157 timeout_str
158 .trim_end_matches("min")
159 .parse::<u64>()
160 .map(|m| m * 60 * 1000)
161 } else {
162 timeout_str.parse::<u64>()
164 };
165
166 match timeout_ms {
167 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168 _ => None,
169 }
170 };
171
172 client::set_statement_timeout(client, timeout);
173 return Some(Ok(Response::Execution(Tag::new("SET"))));
174 } else if matches!(
175 var.as_str(),
176 "datestyle"
177 | "bytea_output"
178 | "intervalstyle"
179 | "application_name"
180 | "extra_float_digits"
181 | "search_path"
182 ) && !values.is_empty()
183 {
184 let value = values[0].clone();
186 if let Expr::Value(value) = value {
187 let val_str = value.into_string().unwrap_or_else(|| "".to_string());
188 client.metadata_mut().insert(var.clone(), val_str);
189 if let Some((name, value)) = parameter_status_for_var(&var, &*client)
190 && let Err(e) = client
191 .send_message(PgWireBackendMessage::ParameterStatus(
192 ParameterStatus::new(name, value),
193 ))
194 .await
195 {
196 return Some(Err(e));
197 }
198 return Some(Ok(Response::Execution(Tag::new("SET"))));
199 }
200 }
201 }
202 Set::SetTimeZone {
203 local: false,
204 value,
205 } => {
206 let tz = value.to_string();
207 let tz = tz.trim_matches('"').trim_matches('\'');
208 client::set_timezone(client, Some(tz));
209 session_context
211 .state()
212 .config_mut()
213 .options_mut()
214 .execution
215 .time_zone = Some(tz.to_string());
216 let tz_value = client::get_timezone(client).unwrap_or("UTC").to_string();
217 if let Err(e) = client
218 .send_message(PgWireBackendMessage::ParameterStatus(ParameterStatus::new(
219 "TimeZone".to_string(),
220 tz_value,
221 )))
222 .await
223 {
224 return Some(Err(e));
225 }
226 return Some(Ok(Response::Execution(Tag::new("SET"))));
227 }
228 _ => {}
229 }
230
231 if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
233 warn!(
234 "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
235 );
236 }
237
238 Some(Ok(Response::Execution(Tag::new("SET"))))
240}
241
242fn parameter_status_for_var(
243 var: &str,
244 client: &(impl ClientInfo + ?Sized),
245) -> Option<(String, String)> {
246 let display_name = match var {
247 "datestyle" => "DateStyle",
248 "intervalstyle" => "IntervalStyle",
249 "bytea_output" => "bytea_output",
250 "application_name" => "application_name",
251 "extra_float_digits" => "extra_float_digits",
252 "search_path" => "search_path",
253 _ => return None,
254 };
255 let value = client.metadata().get(var)?.clone();
256 Some((display_name.to_string(), value))
257}
258
259async fn execute_set_statement(
260 session_context: &SessionContext,
261 statement: Statement,
262) -> Result<(), DataFusionError> {
263 let state = session_context.state();
264 let logical_plan = state
265 .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
266 statement,
267 )))
268 .await
269 .and_then(|logical_plan| state.optimize(&logical_plan))?;
270
271 session_context
272 .execute_logical_plan(logical_plan)
273 .await
274 .map(|_| ())
275}
276
277async fn try_respond_show_statements(
278 client: &dyn HookClient,
279 statement: &Statement,
280 session_context: &SessionContext,
281) -> Option<PgWireResult<Response>> {
282 let Statement::ShowVariable { variable } = statement else {
283 return None;
284 };
285
286 let variables = variable
287 .iter()
288 .map(|v| v.value.to_lowercase())
289 .collect::<Vec<_>>();
290 let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
291
292 match variables_ref.as_slice() {
293 ["time", "zone"] => {
294 let timezone = client::get_timezone(client).unwrap_or("UTC");
295 Some(mock_show_response("TimeZone", timezone).map(Response::Query))
296 }
297 ["server_version"] => {
298 let version = format!(
299 "datafusion {} on {} {}",
300 session_context.state().version(),
301 env!("CARGO_PKG_NAME"),
302 env!("CARGO_PKG_VERSION")
303 );
304 Some(mock_show_response("server_version", &version).map(Response::Query))
305 }
306 ["transaction_isolation"] => Some(
307 mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
308 ),
309 ["catalogs"] => {
310 let catalogs = session_context.catalog_names();
311 let value = catalogs.join(", ");
312 Some(mock_show_response("Catalogs", &value).map(Response::Query))
313 }
314 ["statement_timeout"] => {
315 let timeout = client::get_statement_timeout(client);
316 let timeout_str = match timeout {
317 Some(duration) => format!("{}ms", duration.as_millis()),
318 None => "0".to_string(),
319 };
320 Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
321 }
322 ["transaction", "isolation", "level"] => {
323 Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
324 }
325 _ => {
326 let val = client
327 .metadata()
328 .get(&variables[0])
329 .map(|v| v.to_string())
330 .or_else(|| match variables[0].as_str() {
331 "bytea_output" => Some(FormatOptions::default().bytea_output),
332 "datestyle" => Some(FormatOptions::default().date_style),
333 "intervalstyle" => Some(FormatOptions::default().interval_style),
334 "extra_float_digits" => {
335 Some(FormatOptions::default().extra_float_digits.to_string())
336 }
337 "application_name" => Some(
338 DefaultServerParameterProvider::default()
339 .application_name
340 .unwrap_or("".to_owned()),
341 ),
342 "search_path" => Some(DefaultServerParameterProvider::default().search_path),
343 _ => None,
344 });
345 if let Some(val) = val {
346 Some(mock_show_response(&variables[0], &val).map(Response::Query))
347 } else {
348 info!("Unsupported show statement: {statement}");
349 Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use std::time::Duration;
358
359 use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
360
361 use super::*;
362 use crate::testing::MockClient;
363
364 #[tokio::test]
365 async fn test_statement_timeout_set_and_show() {
366 let session_context = SessionContext::new();
367 let mut client = MockClient::new();
368
369 let statement = Parser::new(&PostgreSqlDialect {})
371 .try_with_sql("set statement_timeout to '5000ms'")
372 .unwrap()
373 .parse_statement()
374 .unwrap();
375 let set_response =
376 try_respond_set_statements(&mut client, &statement, &session_context).await;
377
378 assert!(set_response.is_some());
379 assert!(set_response.unwrap().is_ok());
380
381 let timeout = client::get_statement_timeout(&client);
383 assert_eq!(timeout, Some(Duration::from_millis(5000)));
384
385 let statement = Parser::new(&PostgreSqlDialect {})
387 .try_with_sql("show statement_timeout")
388 .unwrap()
389 .parse_statement()
390 .unwrap();
391 let show_response =
392 try_respond_show_statements(&client, &statement, &session_context).await;
393
394 assert!(show_response.is_some());
395 assert!(show_response.unwrap().is_ok());
396 }
397
398 #[tokio::test]
399 async fn test_bytea_output_set_and_show() {
400 let session_context = SessionContext::new();
401 let mut client = MockClient::new();
402
403 let statement = Parser::new(&PostgreSqlDialect {})
405 .try_with_sql("set bytea_output = 'hex'")
406 .unwrap()
407 .parse_statement()
408 .unwrap();
409 let set_response =
410 try_respond_set_statements(&mut client, &statement, &session_context).await;
411
412 assert!(set_response.is_some());
413 assert!(set_response.unwrap().is_ok());
414
415 let bytea_output = client.metadata().get("bytea_output").unwrap();
417 assert_eq!(bytea_output, "hex");
418
419 let statement = Parser::new(&PostgreSqlDialect {})
421 .try_with_sql("show bytea_output")
422 .unwrap()
423 .parse_statement()
424 .unwrap();
425 let show_response =
426 try_respond_show_statements(&client, &statement, &session_context).await;
427
428 assert!(show_response.is_some());
429 assert!(show_response.unwrap().is_ok());
430 }
431
432 #[tokio::test]
433 async fn test_date_style_set_and_show() {
434 let session_context = SessionContext::new();
435 let mut client = MockClient::new();
436
437 let statement = Parser::new(&PostgreSqlDialect {})
439 .try_with_sql("set dateStyle = 'ISO, DMY'")
440 .unwrap()
441 .parse_statement()
442 .unwrap();
443 let set_response =
444 try_respond_set_statements(&mut client, &statement, &session_context).await;
445
446 assert!(set_response.is_some());
447 assert!(set_response.unwrap().is_ok());
448
449 let bytea_output = client.metadata().get("datestyle").unwrap();
451 assert_eq!(bytea_output, "ISO, DMY");
452
453 let statement = Parser::new(&PostgreSqlDialect {})
455 .try_with_sql("show dateStyle")
456 .unwrap()
457 .parse_statement()
458 .unwrap();
459 let show_response =
460 try_respond_show_statements(&client, &statement, &session_context).await;
461
462 assert!(show_response.is_some());
463 assert!(show_response.unwrap().is_ok());
464 }
465
466 #[tokio::test]
467 async fn test_statement_timeout_disable() {
468 let session_context = SessionContext::new();
469 let mut client = MockClient::new();
470
471 let statement = Parser::new(&PostgreSqlDialect {})
473 .try_with_sql("set statement_timeout to '1000ms'")
474 .unwrap()
475 .parse_statement()
476 .unwrap();
477 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
478 assert!(resp.is_some());
479 assert!(resp.unwrap().is_ok());
480
481 let statement = Parser::new(&PostgreSqlDialect {})
483 .try_with_sql("set statement_timeout to '0'")
484 .unwrap()
485 .parse_statement()
486 .unwrap();
487 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
488 assert!(resp.is_some());
489 assert!(resp.unwrap().is_ok());
490
491 let timeout = client::get_statement_timeout(&client);
492 assert_eq!(timeout, None);
493 }
494
495 #[tokio::test]
496 async fn test_parameter_status_sent_for_all_set_vars() {
497 use pgwire::messages::PgWireBackendMessage;
498
499 let test_cases = vec![
500 ("set bytea_output = 'escape'", "bytea_output", "escape"),
501 (
502 "set intervalstyle = 'postgres'",
503 "IntervalStyle",
504 "postgres",
505 ),
506 (
507 "set application_name = 'myapp'",
508 "application_name",
509 "myapp",
510 ),
511 ("set search_path = 'public'", "search_path", "public"),
512 ("set extra_float_digits = '2'", "extra_float_digits", "2"),
513 ("set datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
514 (
515 "set time zone 'America/New_York'",
516 "TimeZone",
517 "America/New_York",
518 ),
519 ];
520
521 for (sql, expected_key, expected_value) in test_cases {
522 let session_context = SessionContext::new();
523 let mut client = MockClient::new();
524 let statement = Parser::new(&PostgreSqlDialect {})
525 .try_with_sql(sql)
526 .unwrap()
527 .parse_statement()
528 .unwrap();
529
530 let result =
531 try_respond_set_statements(&mut client, &statement, &session_context).await;
532 assert!(result.is_some(), "Expected Some for {sql}");
533 assert!(result.unwrap().is_ok(), "Expected Ok for {sql}");
534
535 let ps_msgs: Vec<_> = client
536 .sent_messages()
537 .iter()
538 .filter_map(|m| match m {
539 PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
540 _ => None,
541 })
542 .collect();
543
544 assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
545 assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
546 assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
547 }
548 }
549
550 #[tokio::test]
551 async fn test_no_parameter_status_for_statement_timeout() {
552 use pgwire::messages::PgWireBackendMessage;
553
554 let session_context = SessionContext::new();
555 let mut client = MockClient::new();
556
557 let statement = Parser::new(&PostgreSqlDialect {})
558 .try_with_sql("set statement_timeout to '5000ms'")
559 .unwrap()
560 .parse_statement()
561 .unwrap();
562
563 let result = try_respond_set_statements(&mut client, &statement, &session_context).await;
564 assert!(result.is_some());
565 assert!(result.unwrap().is_ok());
566
567 let has_ps = client
568 .sent_messages()
569 .iter()
570 .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
571
572 assert!(!has_ps, "statement_timeout should not send ParameterStatus");
573 }
574
575 #[tokio::test]
576 async fn test_supported_show_statements_returned_columns() {
577 let session_context = SessionContext::new();
578 let client = MockClient::new();
579
580 let tests = [
581 ("show time zone", "TimeZone"),
582 ("show server_version", "server_version"),
583 ("show transaction_isolation", "transaction_isolation"),
584 ("show catalogs", "Catalogs"),
585 ("show search_path", "search_path"),
586 ("show statement_timeout", "statement_timeout"),
587 ("show transaction isolation level", "transaction_isolation"),
588 ];
589
590 for (query, expected_response_col) in tests {
591 let statement = Parser::new(&PostgreSqlDialect {})
592 .try_with_sql(&query)
593 .unwrap()
594 .parse_statement()
595 .unwrap();
596 let show_response =
597 try_respond_show_statements(&client, &statement, &session_context).await;
598
599 let Some(Ok(Response::Query(show_response))) = show_response else {
600 panic!("unexpected show response");
601 };
602
603 assert_eq!(show_response.command_tag(), "SELECT");
604
605 let row_schema = show_response.row_schema();
606 assert_eq!(row_schema.len(), 1);
607 assert_eq!(row_schema[0].name(), expected_response_col);
608 }
609 }
610}