datafusion_postgres/hooks/
set_show.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
12use pgwire::api::ClientInfo;
13use pgwire::error::{PgWireError, PgWireResult};
14use postgres_types::Type;
15
16use crate::client;
17use crate::QueryHook;
18
19#[derive(Debug)]
20pub struct SetShowHook;
21
22#[async_trait]
23impl QueryHook for SetShowHook {
24 async fn handle_simple_query(
26 &self,
27 statement: &Statement,
28 session_context: &SessionContext,
29 client: &mut (dyn ClientInfo + Send + Sync),
30 ) -> Option<PgWireResult<Response>> {
31 match statement {
32 Statement::Set { .. } => {
33 try_respond_set_statements(client, statement, session_context).await
34 }
35 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
36 try_respond_show_statements(client, statement, session_context).await
37 }
38 _ => None,
39 }
40 }
41
42 async fn handle_extended_parse_query(
43 &self,
44 stmt: &Statement,
45 _session_context: &SessionContext,
46 _client: &(dyn ClientInfo + Send + Sync),
47 ) -> Option<PgWireResult<LogicalPlan>> {
48 let sql_lower = stmt.to_string().to_lowercase();
49 let sql_trimmed = sql_lower.trim();
50
51 if sql_trimmed.starts_with("show") {
52 let show_schema =
53 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
54 let result = show_schema
55 .to_dfschema()
56 .map(|df_schema| {
57 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
58 produce_one_row: true,
59 schema: Arc::new(df_schema),
60 })
61 })
62 .map_err(|e| PgWireError::ApiError(Box::new(e)));
63 Some(result)
64 } else if sql_trimmed.starts_with("set") {
65 let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
66 let result = show_schema
67 .to_dfschema()
68 .map(|df_schema| {
69 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
70 produce_one_row: true,
71 schema: Arc::new(df_schema),
72 })
73 })
74 .map_err(|e| PgWireError::ApiError(Box::new(e)));
75 Some(result)
76 } else {
77 None
78 }
79 }
80
81 async fn handle_extended_query(
82 &self,
83 statement: &Statement,
84 _logical_plan: &LogicalPlan,
85 _params: &ParamValues,
86 session_context: &SessionContext,
87 client: &mut (dyn ClientInfo + Send + Sync),
88 ) -> Option<PgWireResult<Response>> {
89 match statement {
90 Statement::Set { .. } => {
91 try_respond_set_statements(client, statement, session_context).await
92 }
93 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
94 try_respond_show_statements(client, statement, session_context).await
95 }
96 _ => None,
97 }
98 }
99}
100
101fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
102 let fields = vec![FieldInfo::new(
103 name.to_string(),
104 None,
105 None,
106 Type::VARCHAR,
107 FieldFormat::Text,
108 )];
109
110 let row = {
111 let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
112 encoder.encode_field(&Some(value))?;
113 encoder.finish()
114 };
115
116 let row_stream = futures::stream::once(async move { row });
117 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
118}
119
120async fn try_respond_set_statements<C>(
121 client: &mut C,
122 statement: &Statement,
123 session_context: &SessionContext,
124) -> Option<PgWireResult<Response>>
125where
126 C: ClientInfo + Send + Sync + ?Sized,
127{
128 let Statement::Set(set_statement) = statement else {
129 return None;
130 };
131
132 match &set_statement {
133 Set::SingleAssignment {
134 scope: None,
135 hivevar: false,
136 variable,
137 values,
138 } => {
139 let var = variable.to_string().to_lowercase();
140 if var == "statement_timeout" {
141 let value = values[0].to_string();
142 let timeout_str = value.trim_matches('"').trim_matches('\'');
143
144 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
145 None
146 } else {
147 let timeout_ms = if timeout_str.ends_with("ms") {
149 timeout_str.trim_end_matches("ms").parse::<u64>()
150 } else if timeout_str.ends_with("s") {
151 timeout_str
152 .trim_end_matches("s")
153 .parse::<u64>()
154 .map(|s| s * 1000)
155 } else if timeout_str.ends_with("min") {
156 timeout_str
157 .trim_end_matches("min")
158 .parse::<u64>()
159 .map(|m| m * 60 * 1000)
160 } else {
161 timeout_str.parse::<u64>()
163 };
164
165 match timeout_ms {
166 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
167 _ => None,
168 }
169 };
170
171 client::set_statement_timeout(client, timeout);
172 return Some(Ok(Response::Execution(Tag::new("SET"))));
173 } else if matches!(
174 var.as_str(),
175 "datestyle"
176 | "bytea_output"
177 | "intervalstyle"
178 | "application_name"
179 | "extra_float_digits"
180 | "search_path"
181 ) && !values.is_empty()
182 {
183 let value = values[0].clone();
185 if let Expr::Value(value) = value {
186 client
187 .metadata_mut()
188 .insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
189 return Some(Ok(Response::Execution(Tag::new("SET"))));
190 }
191 }
192 }
193 Set::SetTimeZone {
194 local: false,
195 value,
196 } => {
197 let tz = value.to_string();
198 let tz = tz.trim_matches('"').trim_matches('\'');
199 client::set_timezone(client, Some(tz));
200 session_context
202 .state()
203 .config_mut()
204 .options_mut()
205 .execution
206 .time_zone = tz.to_string();
207 return Some(Ok(Response::Execution(Tag::new("SET"))));
208 }
209 _ => {}
210 }
211
212 if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
214 warn!(
215 "SET statement {} is not supported by datafusion, error {e}, statement ignored",
216 statement
217 );
218 }
219
220 Some(Ok(Response::Execution(Tag::new("SET"))))
222}
223
224async fn execute_set_statement(
225 session_context: &SessionContext,
226 statement: Statement,
227) -> Result<(), DataFusionError> {
228 let state = session_context.state();
229 let logical_plan = state
230 .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
231 statement,
232 )))
233 .await
234 .and_then(|logical_plan| state.optimize(&logical_plan))?;
235
236 session_context
237 .execute_logical_plan(logical_plan)
238 .await
239 .map(|_| ())
240}
241
242async fn try_respond_show_statements<C>(
243 client: &C,
244 statement: &Statement,
245 session_context: &SessionContext,
246) -> Option<PgWireResult<Response>>
247where
248 C: ClientInfo + ?Sized,
249{
250 let Statement::ShowVariable { variable } = statement else {
251 return None;
252 };
253
254 let variables = variable
255 .iter()
256 .map(|v| v.value.to_lowercase())
257 .collect::<Vec<_>>();
258 let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
259
260 match variables_ref.as_slice() {
261 ["time", "zone"] => {
262 let timezone = client::get_timezone(client).unwrap_or("UTC");
263 Some(mock_show_response("TimeZone", timezone).map(Response::Query))
264 }
265 ["server_version"] => {
266 let version = format!(
267 "datafusion {} on {} {}",
268 session_context.state().version(),
269 env!("CARGO_PKG_NAME"),
270 env!("CARGO_PKG_VERSION")
271 );
272 Some(mock_show_response("server_version", &version).map(Response::Query))
273 }
274 ["transaction_isolation"] => Some(
275 mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
276 ),
277 ["catalogs"] => {
278 let catalogs = session_context.catalog_names();
279 let value = catalogs.join(", ");
280 Some(mock_show_response("Catalogs", &value).map(Response::Query))
281 }
282 ["statement_timeout"] => {
283 let timeout = client::get_statement_timeout(client);
284 let timeout_str = match timeout {
285 Some(duration) => format!("{}ms", duration.as_millis()),
286 None => "0".to_string(),
287 };
288 Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
289 }
290 ["transaction", "isolation", "level"] => {
291 Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
292 }
293 ["bytea_output"]
294 | ["datestyle"]
295 | ["intervalstyle"]
296 | ["application_name"]
297 | ["extra_float_digits"]
298 | ["search_path"] => {
299 let val = client
300 .metadata()
301 .get(&variables[0])
302 .map(|v| v.as_str())
303 .unwrap_or("");
304 Some(mock_show_response(&variables[0], val).map(Response::Query))
305 }
306 _ => {
307 info!("Unsupported show statement: {}", statement);
308 Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use std::time::Duration;
316
317 use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
318
319 use super::*;
320 use crate::testing::MockClient;
321
322 #[tokio::test]
323 async fn test_statement_timeout_set_and_show() {
324 let session_context = SessionContext::new();
325 let mut client = MockClient::new();
326
327 let statement = Parser::new(&PostgreSqlDialect {})
329 .try_with_sql("set statement_timeout to '5000ms'")
330 .unwrap()
331 .parse_statement()
332 .unwrap();
333 let set_response =
334 try_respond_set_statements(&mut client, &statement, &session_context).await;
335
336 assert!(set_response.is_some());
337 assert!(set_response.unwrap().is_ok());
338
339 let timeout = client::get_statement_timeout(&client);
341 assert_eq!(timeout, Some(Duration::from_millis(5000)));
342
343 let statement = Parser::new(&PostgreSqlDialect {})
345 .try_with_sql("show statement_timeout")
346 .unwrap()
347 .parse_statement()
348 .unwrap();
349 let show_response =
350 try_respond_show_statements(&client, &statement, &session_context).await;
351
352 assert!(show_response.is_some());
353 assert!(show_response.unwrap().is_ok());
354 }
355
356 #[tokio::test]
357 async fn test_bytea_output_set_and_show() {
358 let session_context = SessionContext::new();
359 let mut client = MockClient::new();
360
361 let statement = Parser::new(&PostgreSqlDialect {})
363 .try_with_sql("set bytea_output = 'hex'")
364 .unwrap()
365 .parse_statement()
366 .unwrap();
367 let set_response =
368 try_respond_set_statements(&mut client, &statement, &session_context).await;
369
370 assert!(set_response.is_some());
371 assert!(set_response.unwrap().is_ok());
372
373 let bytea_output = client.metadata().get("bytea_output").unwrap();
375 assert_eq!(bytea_output, "hex");
376
377 let statement = Parser::new(&PostgreSqlDialect {})
379 .try_with_sql("show bytea_output")
380 .unwrap()
381 .parse_statement()
382 .unwrap();
383 let show_response =
384 try_respond_show_statements(&client, &statement, &session_context).await;
385
386 assert!(show_response.is_some());
387 assert!(show_response.unwrap().is_ok());
388 }
389
390 #[tokio::test]
391 async fn test_date_style_set_and_show() {
392 let session_context = SessionContext::new();
393 let mut client = MockClient::new();
394
395 let statement = Parser::new(&PostgreSqlDialect {})
397 .try_with_sql("set dateStyle = 'ISO, DMY'")
398 .unwrap()
399 .parse_statement()
400 .unwrap();
401 let set_response =
402 try_respond_set_statements(&mut client, &statement, &session_context).await;
403
404 assert!(set_response.is_some());
405 assert!(set_response.unwrap().is_ok());
406
407 let bytea_output = client.metadata().get("datestyle").unwrap();
409 assert_eq!(bytea_output, "ISO, DMY");
410
411 let statement = Parser::new(&PostgreSqlDialect {})
413 .try_with_sql("show dateStyle")
414 .unwrap()
415 .parse_statement()
416 .unwrap();
417 let show_response =
418 try_respond_show_statements(&client, &statement, &session_context).await;
419
420 assert!(show_response.is_some());
421 assert!(show_response.unwrap().is_ok());
422 }
423
424 #[tokio::test]
425 async fn test_statement_timeout_disable() {
426 let session_context = SessionContext::new();
427 let mut client = MockClient::new();
428
429 let statement = Parser::new(&PostgreSqlDialect {})
431 .try_with_sql("set statement_timeout to '1000ms'")
432 .unwrap()
433 .parse_statement()
434 .unwrap();
435 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
436 assert!(resp.is_some());
437 assert!(resp.unwrap().is_ok());
438
439 let statement = Parser::new(&PostgreSqlDialect {})
441 .try_with_sql("set statement_timeout to '0'")
442 .unwrap()
443 .parse_statement()
444 .unwrap();
445 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
446 assert!(resp.is_some());
447 assert!(resp.unwrap().is_ok());
448
449 let timeout = client::get_statement_timeout(&client);
450 assert_eq!(timeout, None);
451 }
452
453 #[tokio::test]
454 async fn test_supported_show_statements_returned_columns() {
455 let session_context = SessionContext::new();
456 let client = MockClient::new();
457
458 let tests = [
459 ("show time zone", "TimeZone"),
460 ("show server_version", "server_version"),
461 ("show transaction_isolation", "transaction_isolation"),
462 ("show catalogs", "Catalogs"),
463 ("show search_path", "search_path"),
464 ("show statement_timeout", "statement_timeout"),
465 ("show transaction isolation level", "transaction_isolation"),
466 ];
467
468 for (query, expected_response_col) in tests {
469 let statement = Parser::new(&PostgreSqlDialect {})
470 .try_with_sql(&query)
471 .unwrap()
472 .parse_statement()
473 .unwrap();
474 let show_response =
475 try_respond_show_statements(&client, &statement, &session_context).await;
476
477 let Some(Ok(Response::Query(show_response))) = show_response else {
478 panic!("unexpected show response");
479 };
480
481 assert_eq!(show_response.command_tag(), "SELECT");
482
483 let row_schema = show_response.row_schema();
484 assert_eq!(row_schema.len(), 1);
485 assert_eq!(row_schema[0].name(), expected_response_col);
486 }
487 }
488}