datafusion_postgres/hooks/
set_show.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::auth::DefaultServerParameterProvider;
12use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13use pgwire::api::ClientInfo;
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::types::format::FormatOptions;
16use postgres_types::Type;
17
18use crate::client;
19use crate::QueryHook;
20
21#[derive(Debug)]
22pub struct SetShowHook;
23
24#[async_trait]
25impl QueryHook for SetShowHook {
26 async fn handle_simple_query(
28 &self,
29 statement: &Statement,
30 session_context: &SessionContext,
31 client: &mut (dyn ClientInfo + Send + Sync),
32 ) -> Option<PgWireResult<Response>> {
33 match statement {
34 Statement::Set { .. } => {
35 try_respond_set_statements(client, statement, session_context).await
36 }
37 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
38 try_respond_show_statements(client, statement, session_context).await
39 }
40 _ => None,
41 }
42 }
43
44 async fn handle_extended_parse_query(
45 &self,
46 stmt: &Statement,
47 _session_context: &SessionContext,
48 _client: &(dyn ClientInfo + Send + Sync),
49 ) -> Option<PgWireResult<LogicalPlan>> {
50 match stmt {
51 Statement::Set { .. } => {
52 let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
53 let result = show_schema
54 .to_dfschema()
55 .map(|df_schema| {
56 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
57 produce_one_row: true,
58 schema: Arc::new(df_schema),
59 })
60 })
61 .map_err(|e| PgWireError::ApiError(Box::new(e)));
62 Some(result)
63 }
64 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
65 let show_schema =
66 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
67 let result = show_schema
68 .to_dfschema()
69 .map(|df_schema| {
70 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
71 produce_one_row: true,
72 schema: Arc::new(df_schema),
73 })
74 })
75 .map_err(|e| PgWireError::ApiError(Box::new(e)));
76 Some(result)
77 }
78 _ => None,
79 }
80 }
81
82 async fn handle_extended_query(
83 &self,
84 statement: &Statement,
85 _logical_plan: &LogicalPlan,
86 _params: &ParamValues,
87 session_context: &SessionContext,
88 client: &mut (dyn ClientInfo + Send + Sync),
89 ) -> Option<PgWireResult<Response>> {
90 match statement {
91 Statement::Set { .. } => {
92 try_respond_set_statements(client, statement, session_context).await
93 }
94 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
95 try_respond_show_statements(client, statement, session_context).await
96 }
97 _ => None,
98 }
99 }
100}
101
102fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
103 let fields = vec![FieldInfo::new(
104 name.to_string(),
105 None,
106 None,
107 Type::VARCHAR,
108 FieldFormat::Text,
109 )];
110
111 let row = {
112 let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
113 encoder.encode_field(&Some(value))?;
114 Ok(encoder.take_row())
115 };
116
117 let row_stream = futures::stream::once(async move { row });
118 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
119}
120
121async fn try_respond_set_statements<C>(
122 client: &mut C,
123 statement: &Statement,
124 session_context: &SessionContext,
125) -> Option<PgWireResult<Response>>
126where
127 C: ClientInfo + Send + Sync + ?Sized,
128{
129 let Statement::Set(set_statement) = statement else {
130 return None;
131 };
132
133 match &set_statement {
134 Set::SingleAssignment {
135 scope: None,
136 hivevar: false,
137 variable,
138 values,
139 } => {
140 let var = variable.to_string().to_lowercase();
141 if var == "statement_timeout" {
142 let value = values[0].to_string();
143 let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146 None
147 } else {
148 let timeout_ms = if timeout_str.ends_with("ms") {
150 timeout_str.trim_end_matches("ms").parse::<u64>()
151 } else if timeout_str.ends_with("s") {
152 timeout_str
153 .trim_end_matches("s")
154 .parse::<u64>()
155 .map(|s| s * 1000)
156 } else if timeout_str.ends_with("min") {
157 timeout_str
158 .trim_end_matches("min")
159 .parse::<u64>()
160 .map(|m| m * 60 * 1000)
161 } else {
162 timeout_str.parse::<u64>()
164 };
165
166 match timeout_ms {
167 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168 _ => None,
169 }
170 };
171
172 client::set_statement_timeout(client, timeout);
173 return Some(Ok(Response::Execution(Tag::new("SET"))));
174 } else if matches!(
175 var.as_str(),
176 "datestyle"
177 | "bytea_output"
178 | "intervalstyle"
179 | "application_name"
180 | "extra_float_digits"
181 | "search_path"
182 ) && !values.is_empty()
183 {
184 let value = values[0].clone();
186 if let Expr::Value(value) = value {
187 client
188 .metadata_mut()
189 .insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
190 return Some(Ok(Response::Execution(Tag::new("SET"))));
191 }
192 }
193 }
194 Set::SetTimeZone {
195 local: false,
196 value,
197 } => {
198 let tz = value.to_string();
199 let tz = tz.trim_matches('"').trim_matches('\'');
200 client::set_timezone(client, Some(tz));
201 session_context
203 .state()
204 .config_mut()
205 .options_mut()
206 .execution
207 .time_zone = Some(tz.to_string());
208 return Some(Ok(Response::Execution(Tag::new("SET"))));
209 }
210 _ => {}
211 }
212
213 if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
215 warn!(
216 "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
217 );
218 }
219
220 Some(Ok(Response::Execution(Tag::new("SET"))))
222}
223
224async fn execute_set_statement(
225 session_context: &SessionContext,
226 statement: Statement,
227) -> Result<(), DataFusionError> {
228 let state = session_context.state();
229 let logical_plan = state
230 .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
231 statement,
232 )))
233 .await
234 .and_then(|logical_plan| state.optimize(&logical_plan))?;
235
236 session_context
237 .execute_logical_plan(logical_plan)
238 .await
239 .map(|_| ())
240}
241
242async fn try_respond_show_statements<C>(
243 client: &C,
244 statement: &Statement,
245 session_context: &SessionContext,
246) -> Option<PgWireResult<Response>>
247where
248 C: ClientInfo + ?Sized,
249{
250 let Statement::ShowVariable { variable } = statement else {
251 return None;
252 };
253
254 let variables = variable
255 .iter()
256 .map(|v| v.value.to_lowercase())
257 .collect::<Vec<_>>();
258 let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
259
260 match variables_ref.as_slice() {
261 ["time", "zone"] => {
262 let timezone = client::get_timezone(client).unwrap_or("UTC");
263 Some(mock_show_response("TimeZone", timezone).map(Response::Query))
264 }
265 ["server_version"] => {
266 let version = format!(
267 "datafusion {} on {} {}",
268 session_context.state().version(),
269 env!("CARGO_PKG_NAME"),
270 env!("CARGO_PKG_VERSION")
271 );
272 Some(mock_show_response("server_version", &version).map(Response::Query))
273 }
274 ["transaction_isolation"] => Some(
275 mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
276 ),
277 ["catalogs"] => {
278 let catalogs = session_context.catalog_names();
279 let value = catalogs.join(", ");
280 Some(mock_show_response("Catalogs", &value).map(Response::Query))
281 }
282 ["statement_timeout"] => {
283 let timeout = client::get_statement_timeout(client);
284 let timeout_str = match timeout {
285 Some(duration) => format!("{}ms", duration.as_millis()),
286 None => "0".to_string(),
287 };
288 Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
289 }
290 ["transaction", "isolation", "level"] => {
291 Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
292 }
293 _ => {
294 let val = client
295 .metadata()
296 .get(&variables[0])
297 .map(|v| v.to_string())
298 .or_else(|| match variables[0].as_str() {
299 "bytea_output" => Some(FormatOptions::default().bytea_output),
300 "datestyle" => Some(FormatOptions::default().date_style),
301 "intervalstyle" => Some(FormatOptions::default().interval_style),
302 "extra_float_digits" => {
303 Some(FormatOptions::default().extra_float_digits.to_string())
304 }
305 "application_name" => Some(
306 DefaultServerParameterProvider::default()
307 .application_name
308 .unwrap_or("".to_owned()),
309 ),
310 "search_path" => Some(DefaultServerParameterProvider::default().search_path),
311 _ => None,
312 });
313 if let Some(val) = val {
314 Some(mock_show_response(&variables[0], &val).map(Response::Query))
315 } else {
316 info!("Unsupported show statement: {statement}");
317 Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
318 }
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::time::Duration;
326
327 use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
328
329 use super::*;
330 use crate::testing::MockClient;
331
332 #[tokio::test]
333 async fn test_statement_timeout_set_and_show() {
334 let session_context = SessionContext::new();
335 let mut client = MockClient::new();
336
337 let statement = Parser::new(&PostgreSqlDialect {})
339 .try_with_sql("set statement_timeout to '5000ms'")
340 .unwrap()
341 .parse_statement()
342 .unwrap();
343 let set_response =
344 try_respond_set_statements(&mut client, &statement, &session_context).await;
345
346 assert!(set_response.is_some());
347 assert!(set_response.unwrap().is_ok());
348
349 let timeout = client::get_statement_timeout(&client);
351 assert_eq!(timeout, Some(Duration::from_millis(5000)));
352
353 let statement = Parser::new(&PostgreSqlDialect {})
355 .try_with_sql("show statement_timeout")
356 .unwrap()
357 .parse_statement()
358 .unwrap();
359 let show_response =
360 try_respond_show_statements(&client, &statement, &session_context).await;
361
362 assert!(show_response.is_some());
363 assert!(show_response.unwrap().is_ok());
364 }
365
366 #[tokio::test]
367 async fn test_bytea_output_set_and_show() {
368 let session_context = SessionContext::new();
369 let mut client = MockClient::new();
370
371 let statement = Parser::new(&PostgreSqlDialect {})
373 .try_with_sql("set bytea_output = 'hex'")
374 .unwrap()
375 .parse_statement()
376 .unwrap();
377 let set_response =
378 try_respond_set_statements(&mut client, &statement, &session_context).await;
379
380 assert!(set_response.is_some());
381 assert!(set_response.unwrap().is_ok());
382
383 let bytea_output = client.metadata().get("bytea_output").unwrap();
385 assert_eq!(bytea_output, "hex");
386
387 let statement = Parser::new(&PostgreSqlDialect {})
389 .try_with_sql("show bytea_output")
390 .unwrap()
391 .parse_statement()
392 .unwrap();
393 let show_response =
394 try_respond_show_statements(&client, &statement, &session_context).await;
395
396 assert!(show_response.is_some());
397 assert!(show_response.unwrap().is_ok());
398 }
399
400 #[tokio::test]
401 async fn test_date_style_set_and_show() {
402 let session_context = SessionContext::new();
403 let mut client = MockClient::new();
404
405 let statement = Parser::new(&PostgreSqlDialect {})
407 .try_with_sql("set dateStyle = 'ISO, DMY'")
408 .unwrap()
409 .parse_statement()
410 .unwrap();
411 let set_response =
412 try_respond_set_statements(&mut client, &statement, &session_context).await;
413
414 assert!(set_response.is_some());
415 assert!(set_response.unwrap().is_ok());
416
417 let bytea_output = client.metadata().get("datestyle").unwrap();
419 assert_eq!(bytea_output, "ISO, DMY");
420
421 let statement = Parser::new(&PostgreSqlDialect {})
423 .try_with_sql("show dateStyle")
424 .unwrap()
425 .parse_statement()
426 .unwrap();
427 let show_response =
428 try_respond_show_statements(&client, &statement, &session_context).await;
429
430 assert!(show_response.is_some());
431 assert!(show_response.unwrap().is_ok());
432 }
433
434 #[tokio::test]
435 async fn test_statement_timeout_disable() {
436 let session_context = SessionContext::new();
437 let mut client = MockClient::new();
438
439 let statement = Parser::new(&PostgreSqlDialect {})
441 .try_with_sql("set statement_timeout to '1000ms'")
442 .unwrap()
443 .parse_statement()
444 .unwrap();
445 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
446 assert!(resp.is_some());
447 assert!(resp.unwrap().is_ok());
448
449 let statement = Parser::new(&PostgreSqlDialect {})
451 .try_with_sql("set statement_timeout to '0'")
452 .unwrap()
453 .parse_statement()
454 .unwrap();
455 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
456 assert!(resp.is_some());
457 assert!(resp.unwrap().is_ok());
458
459 let timeout = client::get_statement_timeout(&client);
460 assert_eq!(timeout, None);
461 }
462
463 #[tokio::test]
464 async fn test_supported_show_statements_returned_columns() {
465 let session_context = SessionContext::new();
466 let client = MockClient::new();
467
468 let tests = [
469 ("show time zone", "TimeZone"),
470 ("show server_version", "server_version"),
471 ("show transaction_isolation", "transaction_isolation"),
472 ("show catalogs", "Catalogs"),
473 ("show search_path", "search_path"),
474 ("show statement_timeout", "statement_timeout"),
475 ("show transaction isolation level", "transaction_isolation"),
476 ];
477
478 for (query, expected_response_col) in tests {
479 let statement = Parser::new(&PostgreSqlDialect {})
480 .try_with_sql(&query)
481 .unwrap()
482 .parse_statement()
483 .unwrap();
484 let show_response =
485 try_respond_show_statements(&client, &statement, &session_context).await;
486
487 let Some(Ok(Response::Query(show_response))) = show_response else {
488 panic!("unexpected show response");
489 };
490
491 assert_eq!(show_response.command_tag(), "SELECT");
492
493 let row_schema = show_response.row_schema();
494 assert_eq!(row_schema.len(), 1);
495 assert_eq!(row_schema[0].name(), expected_response_col);
496 }
497 }
498}