datafusion_postgres/hooks/
set_show.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::SessionContext;
8use datafusion::sql::sqlparser::ast::{Set, Statement};
9use log::{info, warn};
10use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
11use pgwire::api::ClientInfo;
12use pgwire::error::{PgWireError, PgWireResult};
13use postgres_types::Type;
14
15use crate::client;
16use crate::QueryHook;
17
18#[derive(Debug)]
19pub struct SetShowHook;
20
21#[async_trait]
22impl QueryHook for SetShowHook {
23 async fn handle_simple_query(
25 &self,
26 statement: &Statement,
27 session_context: &SessionContext,
28 client: &mut (dyn ClientInfo + Send + Sync),
29 ) -> Option<PgWireResult<Response>> {
30 match statement {
31 Statement::Set { .. } => {
32 try_respond_set_statements(client, statement, session_context).await
33 }
34 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
35 try_respond_show_statements(client, statement, session_context).await
36 }
37 _ => None,
38 }
39 }
40
41 async fn handle_extended_parse_query(
42 &self,
43 stmt: &Statement,
44 _session_context: &SessionContext,
45 _client: &(dyn ClientInfo + Send + Sync),
46 ) -> Option<PgWireResult<LogicalPlan>> {
47 let sql_lower = stmt.to_string().to_lowercase();
48 let sql_trimmed = sql_lower.trim();
49
50 if sql_trimmed.starts_with("show") {
51 let show_schema =
52 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
53 let result = show_schema
54 .to_dfschema()
55 .map(|df_schema| {
56 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
57 produce_one_row: true,
58 schema: Arc::new(df_schema),
59 })
60 })
61 .map_err(|e| PgWireError::ApiError(Box::new(e)));
62 Some(result)
63 } else if sql_trimmed.starts_with("set") {
64 let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
65 let result = show_schema
66 .to_dfschema()
67 .map(|df_schema| {
68 LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
69 produce_one_row: true,
70 schema: Arc::new(df_schema),
71 })
72 })
73 .map_err(|e| PgWireError::ApiError(Box::new(e)));
74 Some(result)
75 } else {
76 None
77 }
78 }
79
80 async fn handle_extended_query(
81 &self,
82 statement: &Statement,
83 _logical_plan: &LogicalPlan,
84 _params: &ParamValues,
85 session_context: &SessionContext,
86 client: &mut (dyn ClientInfo + Send + Sync),
87 ) -> Option<PgWireResult<Response>> {
88 match statement {
89 Statement::Set { .. } => {
90 try_respond_set_statements(client, statement, session_context).await
91 }
92 Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
93 try_respond_show_statements(client, statement, session_context).await
94 }
95 _ => None,
96 }
97 }
98}
99
100fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
101 let fields = vec![FieldInfo::new(
102 name.to_string(),
103 None,
104 None,
105 Type::VARCHAR,
106 FieldFormat::Text,
107 )];
108
109 let row = {
110 let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
111 encoder.encode_field(&Some(value))?;
112 encoder.finish()
113 };
114
115 let row_stream = futures::stream::once(async move { row });
116 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
117}
118
119async fn try_respond_set_statements<C>(
120 client: &mut C,
121 statement: &Statement,
122 session_context: &SessionContext,
123) -> Option<PgWireResult<Response>>
124where
125 C: ClientInfo + Send + Sync + ?Sized,
126{
127 let Statement::Set(set_statement) = statement else {
128 return None;
129 };
130
131 match &set_statement {
132 Set::SingleAssignment {
133 scope: None,
134 hivevar: false,
135 variable,
136 values,
137 } if &variable.to_string() == "statement_timeout" => {
138 let value = values[0].to_string();
139 let timeout_str = value.trim_matches('"').trim_matches('\'');
140
141 let timeout = if timeout_str == "0" || timeout_str.is_empty() {
142 None
143 } else {
144 let timeout_ms = if timeout_str.ends_with("ms") {
146 timeout_str.trim_end_matches("ms").parse::<u64>()
147 } else if timeout_str.ends_with("s") {
148 timeout_str
149 .trim_end_matches("s")
150 .parse::<u64>()
151 .map(|s| s * 1000)
152 } else if timeout_str.ends_with("min") {
153 timeout_str
154 .trim_end_matches("min")
155 .parse::<u64>()
156 .map(|m| m * 60 * 1000)
157 } else {
158 timeout_str.parse::<u64>()
160 };
161
162 match timeout_ms {
163 Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
164 _ => None,
165 }
166 };
167
168 client::set_statement_timeout(client, timeout);
169 Some(Ok(Response::Execution(Tag::new("SET"))))
170 }
171 Set::SetTimeZone {
172 local: false,
173 value,
174 } => {
175 let tz = value.to_string();
176 let tz = tz.trim_matches('"').trim_matches('\'');
177 client::set_timezone(client, Some(tz));
178 Some(Ok(Response::Execution(Tag::new("SET"))))
179 }
180 _ => {
181 let query = statement.to_string();
183 if let Err(e) = session_context.sql(&query).await {
184 warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
185 }
186
187 Some(Ok(Response::Execution(Tag::new("SET"))))
189 }
190 }
191}
192
193async fn try_respond_show_statements<C>(
194 client: &C,
195 statement: &Statement,
196 session_context: &SessionContext,
197) -> Option<PgWireResult<Response>>
198where
199 C: ClientInfo + ?Sized,
200{
201 let Statement::ShowVariable { variable } = statement else {
202 return None;
203 };
204
205 let variables = variable
206 .iter()
207 .map(|v| &v.value as &str)
208 .collect::<Vec<_>>();
209
210 match &variables as &[&str] {
211 ["time", "zone"] => {
212 let timezone = client::get_timezone(client).unwrap_or("UTC");
213 Some(mock_show_response("TimeZone", timezone).map(Response::Query))
214 }
215 ["server_version"] => {
216 Some(mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query))
217 }
218 ["transaction_isolation"] => Some(
219 mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
220 ),
221 ["catalogs"] => {
222 let catalogs = session_context.catalog_names();
223 let value = catalogs.join(", ");
224 Some(mock_show_response("Catalogs", &value).map(Response::Query))
225 }
226 ["search_path"] => {
227 let default_schema = "public";
228 Some(mock_show_response("search_path", default_schema).map(Response::Query))
229 }
230 ["statement_timeout"] => {
231 let timeout = client::get_statement_timeout(client);
232 let timeout_str = match timeout {
233 Some(duration) => format!("{}ms", duration.as_millis()),
234 None => "0".to_string(),
235 };
236 Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
237 }
238 ["transaction", "isolation", "level"] => {
239 Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
240 }
241 _ => {
242 info!("Unsupported show statement: {}", statement);
243 Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use std::time::Duration;
251
252 use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
253
254 use super::*;
255 use crate::testing::MockClient;
256
257 #[tokio::test]
258 async fn test_statement_timeout_set_and_show() {
259 let session_context = SessionContext::new();
260 let mut client = MockClient::new();
261
262 let statement = Parser::new(&PostgreSqlDialect {})
264 .try_with_sql("set statement_timeout to '5000ms'")
265 .unwrap()
266 .parse_statement()
267 .unwrap();
268 let set_response =
269 try_respond_set_statements(&mut client, &statement, &session_context).await;
270
271 assert!(set_response.is_some());
272 assert!(set_response.unwrap().is_ok());
273
274 let timeout = client::get_statement_timeout(&client);
276 assert_eq!(timeout, Some(Duration::from_millis(5000)));
277
278 let statement = Parser::new(&PostgreSqlDialect {})
280 .try_with_sql("show statement_timeout")
281 .unwrap()
282 .parse_statement()
283 .unwrap();
284 let show_response =
285 try_respond_show_statements(&client, &statement, &session_context).await;
286
287 assert!(show_response.is_some());
288 assert!(show_response.unwrap().is_ok());
289 }
290
291 #[tokio::test]
292 async fn test_statement_timeout_disable() {
293 let session_context = SessionContext::new();
294 let mut client = MockClient::new();
295
296 let statement = Parser::new(&PostgreSqlDialect {})
298 .try_with_sql("set statement_timeout to '1000ms'")
299 .unwrap()
300 .parse_statement()
301 .unwrap();
302 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
303 assert!(resp.is_some());
304 assert!(resp.unwrap().is_ok());
305
306 let statement = Parser::new(&PostgreSqlDialect {})
308 .try_with_sql("set statement_timeout to '0'")
309 .unwrap()
310 .parse_statement()
311 .unwrap();
312 let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
313 assert!(resp.is_some());
314 assert!(resp.unwrap().is_ok());
315
316 let timeout = client::get_statement_timeout(&client);
317 assert_eq!(timeout, None);
318 }
319
320 #[tokio::test]
321 async fn test_supported_show_statements_returned_columns() {
322 let session_context = SessionContext::new();
323 let client = MockClient::new();
324
325 let tests = [
326 ("show time zone", "TimeZone"),
327 ("show server_version", "server_version"),
328 ("show transaction_isolation", "transaction_isolation"),
329 ("show catalogs", "Catalogs"),
330 ("show search_path", "search_path"),
331 ("show statement_timeout", "statement_timeout"),
332 ("show transaction isolation level", "transaction_isolation"),
333 ];
334
335 for (query, expected_response_col) in tests {
336 let statement = Parser::new(&PostgreSqlDialect {})
337 .try_with_sql(&query)
338 .unwrap()
339 .parse_statement()
340 .unwrap();
341 let show_response =
342 try_respond_show_statements(&client, &statement, &session_context).await;
343
344 let Some(Ok(Response::Query(show_response))) = show_response else {
345 panic!("unexpected show response");
346 };
347
348 assert_eq!(show_response.command_tag(), "SELECT");
349
350 let row_schema = show_response.row_schema();
351 assert_eq!(row_schema.len(), 1);
352 assert_eq!(row_schema[0].name(), expected_response_col);
353 }
354 }
355}