1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::auth::DefaultServerParameterProvider;
12use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13use pgwire::api::ClientInfo;
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::messages::startup::ParameterStatus;
16use pgwire::messages::PgWireBackendMessage;
17use pgwire::types::format::FormatOptions;
18use postgres_types::Type;
19
20use crate::client;
21use crate::hooks::HookClient;
22use crate::QueryHook;
23
24#[derive(Debug)]
25pub struct SetShowHook;
26
27#[async_trait]
28impl QueryHook for SetShowHook {
29 async fn handle_simple_query(
31 &self,
32 statement: &Statement,
33 session_context: &SessionContext,
34 client: &mut dyn HookClient,
35 ) -> Option<PgWireResult<Response>> {
36 match statement {
37 Statement::Set { .. } => {
38 try_respond_set_statements(client, statement, session_context).await
39 }
40 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
41 try_respond_show_statements(client, statement, session_context).await
42 }
43 _ => None,
44 }
45 }
46
47 async fn handle_extended_parse_query(
48 &self,
49 stmt: &Statement,
50 _session_context: &SessionContext,
51 _client: &(dyn ClientInfo + Send + Sync),
52 ) -> Option<PgWireResult<LogicalPlan>> {
53 match stmt {
54 Statement::Set { .. } => {
55 let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
56 let result = show_schema
57 .to_dfschema()
58 .map(|df_schema| {
59 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
60 produce_one_row: true,
61 schema: Arc::new(df_schema),
62 })
63 })
64 .map_err(|e| PgWireError::ApiError(Box::new(e)));
65 Some(result)
66 }
67 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
68 let show_schema =
69 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
70 let result = show_schema
71 .to_dfschema()
72 .map(|df_schema| {
73 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
74 produce_one_row: true,
75 schema: Arc::new(df_schema),
76 })
77 })
78 .map_err(|e| PgWireError::ApiError(Box::new(e)));
79 Some(result)
80 }
81 _ => None,
82 }
83 }
84
85 async fn handle_extended_query(
86 &self,
87 statement: &Statement,
88 _logical_plan: &LogicalPlan,
89 _params: &ParamValues,
90 session_context: &SessionContext,
91 client: &mut dyn HookClient,
92 ) -> Option<PgWireResult<Response>> {
93 match statement {
94 Statement::Set { .. } => {
95 try_respond_set_statements(client, statement, session_context).await
96 }
97 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
98 try_respond_show_statements(client, statement, session_context).await
99 }
100 _ => None,
101 }
102 }
103}
104
105fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
106 let fields = vec![FieldInfo::new(
107 name.to_string(),
108 None,
109 None,
110 Type::VARCHAR,
111 FieldFormat::Text,
112 )];
113
114 let row = {
115 let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
116 encoder.encode_field(&Some(value))?;
117 Ok(encoder.take_row())
118 };
119
120 let row_stream = futures::stream::once(async move { row });
121 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122}
123
124async fn try_respond_set_statements(
125 client: &mut dyn HookClient,
126 statement: &Statement,
127 session_context: &SessionContext,
128) -> Option<PgWireResult<Response>> {
129 let Statement::Set(set_statement) = statement else {
130 return None;
131 };
132
133 match &set_statement {
134 Set::SingleAssignment {
135 scope: None,
136 hivevar: false,
137 variable,
138 values,
139 } => {
140 let var = variable.to_string().to_lowercase();
141 if var == "statement_timeout" {
142 let value = values[0].to_string();
143 let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146 None
147 } else {
148 let timeout_ms = if timeout_str.ends_with("ms") {
150 timeout_str.trim_end_matches("ms").parse::<u64>()
151 } else if timeout_str.ends_with("s") {
152 timeout_str
153 .trim_end_matches("s")
154 .parse::<u64>()
155 .map(|s| s * 1000)
156 } else if timeout_str.ends_with("min") {
157 timeout_str
158 .trim_end_matches("min")
159 .parse::<u64>()
160 .map(|m| m * 60 * 1000)
161 } else {
162 timeout_str.parse::<u64>()
164 };
165
166 match timeout_ms {
167 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168 _ => None,
169 }
170 };
171
172 client::set_statement_timeout(client, timeout);
173 return Some(Ok(Response::Execution(Tag::new("SET"))));
174 } else if matches!(
175 var.as_str(),
176 "datestyle"
177 | "bytea_output"
178 | "intervalstyle"
179 | "application_name"
180 | "extra_float_digits"
181 | "search_path"
182 ) && !values.is_empty()
183 {
184 let value = values[0].clone();
186 if let Expr::Value(value) = value {
187 let val_str = value.into_string().unwrap_or_else(|| "".to_string());
188 client.metadata_mut().insert(var.clone(), val_str);
189 if let Some((name, value)) = parameter_status_for_var(&var, &*client) {
190 if let Err(e) = client
191 .send_message(PgWireBackendMessage::ParameterStatus(
192 ParameterStatus::new(name, value),
193 ))
194 .await
195 {
196 return Some(Err(e));
197 }
198 }
199 return Some(Ok(Response::Execution(Tag::new("SET"))));
200 }
201 }
202 }
203 Set::SetTimeZone {
204 local: false,
205 value,
206 } => {
207 let tz = value.to_string();
208 let tz = tz.trim_matches('"').trim_matches('\'');
209 client::set_timezone(client, Some(tz));
210 session_context
212 .state()
213 .config_mut()
214 .options_mut()
215 .execution
216 .time_zone = Some(tz.to_string());
217 let tz_value = client::get_timezone(client).unwrap_or("UTC").to_string();
218 if let Err(e) = client
219 .send_message(PgWireBackendMessage::ParameterStatus(ParameterStatus::new(
220 "TimeZone".to_string(),
221 tz_value,
222 )))
223 .await
224 {
225 return Some(Err(e));
226 }
227 return Some(Ok(Response::Execution(Tag::new("SET"))));
228 }
229 _ => {}
230 }
231
232 if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
234 warn!(
235 "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
236 );
237 }
238
239 Some(Ok(Response::Execution(Tag::new("SET"))))
241}
242
243fn parameter_status_for_var(
244 var: &str,
245 client: &(impl ClientInfo + ?Sized),
246) -> Option<(String, String)> {
247 let display_name = match var {
248 "datestyle" => "DateStyle",
249 "intervalstyle" => "IntervalStyle",
250 "bytea_output" => "bytea_output",
251 "application_name" => "application_name",
252 "extra_float_digits" => "extra_float_digits",
253 "search_path" => "search_path",
254 _ => return None,
255 };
256 let value = client.metadata().get(var)?.clone();
257 Some((display_name.to_string(), value))
258}
259
260async fn execute_set_statement(
261 session_context: &SessionContext,
262 statement: Statement,
263) -> Result<(), DataFusionError> {
264 let state = session_context.state();
265 let logical_plan = state
266 .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
267 statement,
268 )))
269 .await
270 .and_then(|logical_plan| state.optimize(&logical_plan))?;
271
272 session_context
273 .execute_logical_plan(logical_plan)
274 .await
275 .map(|_| ())
276}
277
278async fn try_respond_show_statements(
279 client: &dyn HookClient,
280 statement: &Statement,
281 session_context: &SessionContext,
282) -> Option<PgWireResult<Response>> {
283 let Statement::ShowVariable { variable } = statement else {
284 return None;
285 };
286
287 let variables = variable
288 .iter()
289 .map(|v| v.value.to_lowercase())
290 .collect::<Vec<_>>();
291 let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
292
293 match variables_ref.as_slice() {
294 ["time", "zone"] => {
295 let timezone = client::get_timezone(client).unwrap_or("UTC");
296 Some(mock_show_response("TimeZone", timezone).map(Response::Query))
297 }
298 ["server_version"] => {
299 let version = format!(
300 "datafusion {} on {} {}",
301 session_context.state().version(),
302 env!("CARGO_PKG_NAME"),
303 env!("CARGO_PKG_VERSION")
304 );
305 Some(mock_show_response("server_version", &version).map(Response::Query))
306 }
307 ["transaction_isolation"] => Some(
308 mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
309 ),
310 ["catalogs"] => {
311 let catalogs = session_context.catalog_names();
312 let value = catalogs.join(", ");
313 Some(mock_show_response("Catalogs", &value).map(Response::Query))
314 }
315 ["statement_timeout"] => {
316 let timeout = client::get_statement_timeout(client);
317 let timeout_str = match timeout {
318 Some(duration) => format!("{}ms", duration.as_millis()),
319 None => "0".to_string(),
320 };
321 Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
322 }
323 ["transaction", "isolation", "level"] => {
324 Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
325 }
326 _ => {
327 let val = client
328 .metadata()
329 .get(&variables[0])
330 .map(|v| v.to_string())
331 .or_else(|| match variables[0].as_str() {
332 "bytea_output" => Some(FormatOptions::default().bytea_output),
333 "datestyle" => Some(FormatOptions::default().date_style),
334 "intervalstyle" => Some(FormatOptions::default().interval_style),
335 "extra_float_digits" => {
336 Some(FormatOptions::default().extra_float_digits.to_string())
337 }
338 "application_name" => Some(
339 DefaultServerParameterProvider::default()
340 .application_name
341 .unwrap_or("".to_owned()),
342 ),
343 "search_path" => Some(DefaultServerParameterProvider::default().search_path),
344 _ => None,
345 });
346 if let Some(val) = val {
347 Some(mock_show_response(&variables[0], &val).map(Response::Query))
348 } else {
349 info!("Unsupported show statement: {statement}");
350 Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
351 }
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::time::Duration;
359
360 use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
361
362 use super::*;
363 use crate::testing::MockClient;
364
365 #[tokio::test]
366 async fn test_statement_timeout_set_and_show() {
367 let session_context = SessionContext::new();
368 let mut client = MockClient::new();
369
370 let statement = Parser::new(&PostgreSqlDialect {})
372 .try_with_sql("set statement_timeout to '5000ms'")
373 .unwrap()
374 .parse_statement()
375 .unwrap();
376 let set_response =
377 try_respond_set_statements(&mut client, &statement, &session_context).await;
378
379 assert!(set_response.is_some());
380 assert!(set_response.unwrap().is_ok());
381
382 let timeout = client::get_statement_timeout(&client);
384 assert_eq!(timeout, Some(Duration::from_millis(5000)));
385
386 let statement = Parser::new(&PostgreSqlDialect {})
388 .try_with_sql("show statement_timeout")
389 .unwrap()
390 .parse_statement()
391 .unwrap();
392 let show_response =
393 try_respond_show_statements(&client, &statement, &session_context).await;
394
395 assert!(show_response.is_some());
396 assert!(show_response.unwrap().is_ok());
397 }
398
399 #[tokio::test]
400 async fn test_bytea_output_set_and_show() {
401 let session_context = SessionContext::new();
402 let mut client = MockClient::new();
403
404 let statement = Parser::new(&PostgreSqlDialect {})
406 .try_with_sql("set bytea_output = 'hex'")
407 .unwrap()
408 .parse_statement()
409 .unwrap();
410 let set_response =
411 try_respond_set_statements(&mut client, &statement, &session_context).await;
412
413 assert!(set_response.is_some());
414 assert!(set_response.unwrap().is_ok());
415
416 let bytea_output = client.metadata().get("bytea_output").unwrap();
418 assert_eq!(bytea_output, "hex");
419
420 let statement = Parser::new(&PostgreSqlDialect {})
422 .try_with_sql("show bytea_output")
423 .unwrap()
424 .parse_statement()
425 .unwrap();
426 let show_response =
427 try_respond_show_statements(&client, &statement, &session_context).await;
428
429 assert!(show_response.is_some());
430 assert!(show_response.unwrap().is_ok());
431 }
432
433 #[tokio::test]
434 async fn test_date_style_set_and_show() {
435 let session_context = SessionContext::new();
436 let mut client = MockClient::new();
437
438 let statement = Parser::new(&PostgreSqlDialect {})
440 .try_with_sql("set dateStyle = 'ISO, DMY'")
441 .unwrap()
442 .parse_statement()
443 .unwrap();
444 let set_response =
445 try_respond_set_statements(&mut client, &statement, &session_context).await;
446
447 assert!(set_response.is_some());
448 assert!(set_response.unwrap().is_ok());
449
450 let bytea_output = client.metadata().get("datestyle").unwrap();
452 assert_eq!(bytea_output, "ISO, DMY");
453
454 let statement = Parser::new(&PostgreSqlDialect {})
456 .try_with_sql("show dateStyle")
457 .unwrap()
458 .parse_statement()
459 .unwrap();
460 let show_response =
461 try_respond_show_statements(&client, &statement, &session_context).await;
462
463 assert!(show_response.is_some());
464 assert!(show_response.unwrap().is_ok());
465 }
466
467 #[tokio::test]
468 async fn test_statement_timeout_disable() {
469 let session_context = SessionContext::new();
470 let mut client = MockClient::new();
471
472 let statement = Parser::new(&PostgreSqlDialect {})
474 .try_with_sql("set statement_timeout to '1000ms'")
475 .unwrap()
476 .parse_statement()
477 .unwrap();
478 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
479 assert!(resp.is_some());
480 assert!(resp.unwrap().is_ok());
481
482 let statement = Parser::new(&PostgreSqlDialect {})
484 .try_with_sql("set statement_timeout to '0'")
485 .unwrap()
486 .parse_statement()
487 .unwrap();
488 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
489 assert!(resp.is_some());
490 assert!(resp.unwrap().is_ok());
491
492 let timeout = client::get_statement_timeout(&client);
493 assert_eq!(timeout, None);
494 }
495
496 #[tokio::test]
497 async fn test_parameter_status_sent_for_all_set_vars() {
498 use pgwire::messages::PgWireBackendMessage;
499
500 let test_cases = vec![
501 ("set bytea_output = 'escape'", "bytea_output", "escape"),
502 (
503 "set intervalstyle = 'postgres'",
504 "IntervalStyle",
505 "postgres",
506 ),
507 (
508 "set application_name = 'myapp'",
509 "application_name",
510 "myapp",
511 ),
512 ("set search_path = 'public'", "search_path", "public"),
513 ("set extra_float_digits = '2'", "extra_float_digits", "2"),
514 ("set datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
515 (
516 "set time zone 'America/New_York'",
517 "TimeZone",
518 "America/New_York",
519 ),
520 ];
521
522 for (sql, expected_key, expected_value) in test_cases {
523 let session_context = SessionContext::new();
524 let mut client = MockClient::new();
525 let statement = Parser::new(&PostgreSqlDialect {})
526 .try_with_sql(sql)
527 .unwrap()
528 .parse_statement()
529 .unwrap();
530
531 let result =
532 try_respond_set_statements(&mut client, &statement, &session_context).await;
533 assert!(result.is_some(), "Expected Some for {sql}");
534 assert!(result.unwrap().is_ok(), "Expected Ok for {sql}");
535
536 let ps_msgs: Vec<_> = client
537 .sent_messages()
538 .iter()
539 .filter_map(|m| match m {
540 PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
541 _ => None,
542 })
543 .collect();
544
545 assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
546 assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
547 assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
548 }
549 }
550
551 #[tokio::test]
552 async fn test_no_parameter_status_for_statement_timeout() {
553 use pgwire::messages::PgWireBackendMessage;
554
555 let session_context = SessionContext::new();
556 let mut client = MockClient::new();
557
558 let statement = Parser::new(&PostgreSqlDialect {})
559 .try_with_sql("set statement_timeout to '5000ms'")
560 .unwrap()
561 .parse_statement()
562 .unwrap();
563
564 let result = try_respond_set_statements(&mut client, &statement, &session_context).await;
565 assert!(result.is_some());
566 assert!(result.unwrap().is_ok());
567
568 let has_ps = client
569 .sent_messages()
570 .iter()
571 .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
572
573 assert!(!has_ps, "statement_timeout should not send ParameterStatus");
574 }
575
576 #[tokio::test]
577 async fn test_supported_show_statements_returned_columns() {
578 let session_context = SessionContext::new();
579 let client = MockClient::new();
580
581 let tests = [
582 ("show time zone", "TimeZone"),
583 ("show server_version", "server_version"),
584 ("show transaction_isolation", "transaction_isolation"),
585 ("show catalogs", "Catalogs"),
586 ("show search_path", "search_path"),
587 ("show statement_timeout", "statement_timeout"),
588 ("show transaction isolation level", "transaction_isolation"),
589 ];
590
591 for (query, expected_response_col) in tests {
592 let statement = Parser::new(&PostgreSqlDialect {})
593 .try_with_sql(&query)
594 .unwrap()
595 .parse_statement()
596 .unwrap();
597 let show_response =
598 try_respond_show_statements(&client, &statement, &session_context).await;
599
600 let Some(Ok(Response::Query(show_response))) = show_response else {
601 panic!("unexpected show response");
602 };
603
604 assert_eq!(show_response.command_tag(), "SELECT");
605
606 let row_schema = show_response.row_schema();
607 assert_eq!(row_schema.len(), 1);
608 assert_eq!(row_schema[0].name(), expected_response_col);
609 }
610 }
611}