datafusion_postgres/hooks/
set_show.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::auth::DefaultServerParameterProvider;
12use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13use pgwire::api::ClientInfo;
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::types::format::FormatOptions;
16use postgres_types::Type;
17
18use crate::client;
19use crate::QueryHook;
20
21#[derive(Debug)]
22pub struct SetShowHook;
23
24#[async_trait]
25impl QueryHook for SetShowHook {
26    /// called in simple query handler to return response directly
27    async fn handle_simple_query(
28        &self,
29        statement: &Statement,
30        session_context: &SessionContext,
31        client: &mut (dyn ClientInfo + Send + Sync),
32    ) -> Option<PgWireResult<Response>> {
33        match statement {
34            Statement::Set { .. } => {
35                try_respond_set_statements(client, statement, session_context).await
36            }
37            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
38                try_respond_show_statements(client, statement, session_context).await
39            }
40            _ => None,
41        }
42    }
43
44    async fn handle_extended_parse_query(
45        &self,
46        stmt: &Statement,
47        _session_context: &SessionContext,
48        _client: &(dyn ClientInfo + Send + Sync),
49    ) -> Option<PgWireResult<LogicalPlan>> {
50        match stmt {
51            Statement::Set { .. } => {
52                let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
53                let result = show_schema
54                    .to_dfschema()
55                    .map(|df_schema| {
56                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
57                            produce_one_row: true,
58                            schema: Arc::new(df_schema),
59                        })
60                    })
61                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
62                Some(result)
63            }
64            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
65                let show_schema =
66                    Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
67                let result = show_schema
68                    .to_dfschema()
69                    .map(|df_schema| {
70                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
71                            produce_one_row: true,
72                            schema: Arc::new(df_schema),
73                        })
74                    })
75                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
76                Some(result)
77            }
78            _ => None,
79        }
80    }
81
82    async fn handle_extended_query(
83        &self,
84        statement: &Statement,
85        _logical_plan: &LogicalPlan,
86        _params: &ParamValues,
87        session_context: &SessionContext,
88        client: &mut (dyn ClientInfo + Send + Sync),
89    ) -> Option<PgWireResult<Response>> {
90        match statement {
91            Statement::Set { .. } => {
92                try_respond_set_statements(client, statement, session_context).await
93            }
94            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
95                try_respond_show_statements(client, statement, session_context).await
96            }
97            _ => None,
98        }
99    }
100}
101
102fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
103    let fields = vec![FieldInfo::new(
104        name.to_string(),
105        None,
106        None,
107        Type::VARCHAR,
108        FieldFormat::Text,
109    )];
110
111    let row = {
112        let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
113        encoder.encode_field(&Some(value))?;
114        Ok(encoder.take_row())
115    };
116
117    let row_stream = futures::stream::once(async move { row });
118    Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
119}
120
121async fn try_respond_set_statements<C>(
122    client: &mut C,
123    statement: &Statement,
124    session_context: &SessionContext,
125) -> Option<PgWireResult<Response>>
126where
127    C: ClientInfo + Send + Sync + ?Sized,
128{
129    let Statement::Set(set_statement) = statement else {
130        return None;
131    };
132
133    match &set_statement {
134        Set::SingleAssignment {
135            scope: None,
136            hivevar: false,
137            variable,
138            values,
139        } => {
140            let var = variable.to_string().to_lowercase();
141            if var == "statement_timeout" {
142                let value = values[0].to_string();
143                let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145                let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146                    None
147                } else {
148                    // Parse timeout value (supports ms, s, min formats)
149                    let timeout_ms = if timeout_str.ends_with("ms") {
150                        timeout_str.trim_end_matches("ms").parse::<u64>()
151                    } else if timeout_str.ends_with("s") {
152                        timeout_str
153                            .trim_end_matches("s")
154                            .parse::<u64>()
155                            .map(|s| s * 1000)
156                    } else if timeout_str.ends_with("min") {
157                        timeout_str
158                            .trim_end_matches("min")
159                            .parse::<u64>()
160                            .map(|m| m * 60 * 1000)
161                    } else {
162                        // Default to milliseconds
163                        timeout_str.parse::<u64>()
164                    };
165
166                    match timeout_ms {
167                        Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168                        _ => None,
169                    }
170                };
171
172                client::set_statement_timeout(client, timeout);
173                return Some(Ok(Response::Execution(Tag::new("SET"))));
174            } else if matches!(
175                var.as_str(),
176                "datestyle"
177                    | "bytea_output"
178                    | "intervalstyle"
179                    | "application_name"
180                    | "extra_float_digits"
181                    | "search_path"
182            ) && !values.is_empty()
183            {
184                // postgres configuration variables
185                let value = values[0].clone();
186                if let Expr::Value(value) = value {
187                    client
188                        .metadata_mut()
189                        .insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
190                    return Some(Ok(Response::Execution(Tag::new("SET"))));
191                }
192            }
193        }
194        Set::SetTimeZone {
195            local: false,
196            value,
197        } => {
198            let tz = value.to_string();
199            let tz = tz.trim_matches('"').trim_matches('\'');
200            client::set_timezone(client, Some(tz));
201            // execution options for timezone
202            session_context
203                .state()
204                .config_mut()
205                .options_mut()
206                .execution
207                .time_zone = Some(tz.to_string());
208            return Some(Ok(Response::Execution(Tag::new("SET"))));
209        }
210        _ => {}
211    }
212
213    // fallback to datafusion and ignore all errors
214    if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
215        warn!(
216            "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
217        );
218    }
219
220    // Always return SET success
221    Some(Ok(Response::Execution(Tag::new("SET"))))
222}
223
224async fn execute_set_statement(
225    session_context: &SessionContext,
226    statement: Statement,
227) -> Result<(), DataFusionError> {
228    let state = session_context.state();
229    let logical_plan = state
230        .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
231            statement,
232        )))
233        .await
234        .and_then(|logical_plan| state.optimize(&logical_plan))?;
235
236    session_context
237        .execute_logical_plan(logical_plan)
238        .await
239        .map(|_| ())
240}
241
242async fn try_respond_show_statements<C>(
243    client: &C,
244    statement: &Statement,
245    session_context: &SessionContext,
246) -> Option<PgWireResult<Response>>
247where
248    C: ClientInfo + ?Sized,
249{
250    let Statement::ShowVariable { variable } = statement else {
251        return None;
252    };
253
254    let variables = variable
255        .iter()
256        .map(|v| v.value.to_lowercase())
257        .collect::<Vec<_>>();
258    let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
259
260    match variables_ref.as_slice() {
261        ["time", "zone"] => {
262            let timezone = client::get_timezone(client).unwrap_or("UTC");
263            Some(mock_show_response("TimeZone", timezone).map(Response::Query))
264        }
265        ["server_version"] => {
266            let version = format!(
267                "datafusion {} on {} {}",
268                session_context.state().version(),
269                env!("CARGO_PKG_NAME"),
270                env!("CARGO_PKG_VERSION")
271            );
272            Some(mock_show_response("server_version", &version).map(Response::Query))
273        }
274        ["transaction_isolation"] => Some(
275            mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
276        ),
277        ["catalogs"] => {
278            let catalogs = session_context.catalog_names();
279            let value = catalogs.join(", ");
280            Some(mock_show_response("Catalogs", &value).map(Response::Query))
281        }
282        ["statement_timeout"] => {
283            let timeout = client::get_statement_timeout(client);
284            let timeout_str = match timeout {
285                Some(duration) => format!("{}ms", duration.as_millis()),
286                None => "0".to_string(),
287            };
288            Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
289        }
290        ["transaction", "isolation", "level"] => {
291            Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
292        }
293        _ => {
294            let val = client
295                .metadata()
296                .get(&variables[0])
297                .map(|v| v.to_string())
298                .or_else(|| match variables[0].as_str() {
299                    "bytea_output" => Some(FormatOptions::default().bytea_output),
300                    "datestyle" => Some(FormatOptions::default().date_style),
301                    "intervalstyle" => Some(FormatOptions::default().interval_style),
302                    "extra_float_digits" => {
303                        Some(FormatOptions::default().extra_float_digits.to_string())
304                    }
305                    "application_name" => Some(
306                        DefaultServerParameterProvider::default()
307                            .application_name
308                            .unwrap_or("".to_owned()),
309                    ),
310                    "search_path" => Some(DefaultServerParameterProvider::default().search_path),
311                    _ => None,
312                });
313            if let Some(val) = val {
314                Some(mock_show_response(&variables[0], &val).map(Response::Query))
315            } else {
316                info!("Unsupported show statement: {statement}");
317                Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
318            }
319        }
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use std::time::Duration;
326
327    use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
328
329    use super::*;
330    use crate::testing::MockClient;
331
332    #[tokio::test]
333    async fn test_statement_timeout_set_and_show() {
334        let session_context = SessionContext::new();
335        let mut client = MockClient::new();
336
337        // Test setting timeout to 5000ms
338        let statement = Parser::new(&PostgreSqlDialect {})
339            .try_with_sql("set statement_timeout to '5000ms'")
340            .unwrap()
341            .parse_statement()
342            .unwrap();
343        let set_response =
344            try_respond_set_statements(&mut client, &statement, &session_context).await;
345
346        assert!(set_response.is_some());
347        assert!(set_response.unwrap().is_ok());
348
349        // Verify the timeout was set in client metadata
350        let timeout = client::get_statement_timeout(&client);
351        assert_eq!(timeout, Some(Duration::from_millis(5000)));
352
353        // Test SHOW statement_timeout
354        let statement = Parser::new(&PostgreSqlDialect {})
355            .try_with_sql("show statement_timeout")
356            .unwrap()
357            .parse_statement()
358            .unwrap();
359        let show_response =
360            try_respond_show_statements(&client, &statement, &session_context).await;
361
362        assert!(show_response.is_some());
363        assert!(show_response.unwrap().is_ok());
364    }
365
366    #[tokio::test]
367    async fn test_bytea_output_set_and_show() {
368        let session_context = SessionContext::new();
369        let mut client = MockClient::new();
370
371        // Test setting timeout to 5000ms
372        let statement = Parser::new(&PostgreSqlDialect {})
373            .try_with_sql("set bytea_output = 'hex'")
374            .unwrap()
375            .parse_statement()
376            .unwrap();
377        let set_response =
378            try_respond_set_statements(&mut client, &statement, &session_context).await;
379
380        assert!(set_response.is_some());
381        assert!(set_response.unwrap().is_ok());
382
383        // Verify the timeout was set in client metadata
384        let bytea_output = client.metadata().get("bytea_output").unwrap();
385        assert_eq!(bytea_output, "hex");
386
387        // Test SHOW statement_timeout
388        let statement = Parser::new(&PostgreSqlDialect {})
389            .try_with_sql("show bytea_output")
390            .unwrap()
391            .parse_statement()
392            .unwrap();
393        let show_response =
394            try_respond_show_statements(&client, &statement, &session_context).await;
395
396        assert!(show_response.is_some());
397        assert!(show_response.unwrap().is_ok());
398    }
399
400    #[tokio::test]
401    async fn test_date_style_set_and_show() {
402        let session_context = SessionContext::new();
403        let mut client = MockClient::new();
404
405        // Test setting timeout to 5000ms
406        let statement = Parser::new(&PostgreSqlDialect {})
407            .try_with_sql("set dateStyle = 'ISO, DMY'")
408            .unwrap()
409            .parse_statement()
410            .unwrap();
411        let set_response =
412            try_respond_set_statements(&mut client, &statement, &session_context).await;
413
414        assert!(set_response.is_some());
415        assert!(set_response.unwrap().is_ok());
416
417        // Verify the timeout was set in client metadata
418        let bytea_output = client.metadata().get("datestyle").unwrap();
419        assert_eq!(bytea_output, "ISO, DMY");
420
421        // Test SHOW statement_timeout
422        let statement = Parser::new(&PostgreSqlDialect {})
423            .try_with_sql("show dateStyle")
424            .unwrap()
425            .parse_statement()
426            .unwrap();
427        let show_response =
428            try_respond_show_statements(&client, &statement, &session_context).await;
429
430        assert!(show_response.is_some());
431        assert!(show_response.unwrap().is_ok());
432    }
433
434    #[tokio::test]
435    async fn test_statement_timeout_disable() {
436        let session_context = SessionContext::new();
437        let mut client = MockClient::new();
438
439        // Set timeout first
440        let statement = Parser::new(&PostgreSqlDialect {})
441            .try_with_sql("set statement_timeout to '1000ms'")
442            .unwrap()
443            .parse_statement()
444            .unwrap();
445        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
446        assert!(resp.is_some());
447        assert!(resp.unwrap().is_ok());
448
449        // Disable timeout with 0
450        let statement = Parser::new(&PostgreSqlDialect {})
451            .try_with_sql("set statement_timeout to '0'")
452            .unwrap()
453            .parse_statement()
454            .unwrap();
455        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
456        assert!(resp.is_some());
457        assert!(resp.unwrap().is_ok());
458
459        let timeout = client::get_statement_timeout(&client);
460        assert_eq!(timeout, None);
461    }
462
463    #[tokio::test]
464    async fn test_supported_show_statements_returned_columns() {
465        let session_context = SessionContext::new();
466        let client = MockClient::new();
467
468        let tests = [
469            ("show time zone", "TimeZone"),
470            ("show server_version", "server_version"),
471            ("show transaction_isolation", "transaction_isolation"),
472            ("show catalogs", "Catalogs"),
473            ("show search_path", "search_path"),
474            ("show statement_timeout", "statement_timeout"),
475            ("show transaction isolation level", "transaction_isolation"),
476        ];
477
478        for (query, expected_response_col) in tests {
479            let statement = Parser::new(&PostgreSqlDialect {})
480                .try_with_sql(&query)
481                .unwrap()
482                .parse_statement()
483                .unwrap();
484            let show_response =
485                try_respond_show_statements(&client, &statement, &session_context).await;
486
487            let Some(Ok(Response::Query(show_response))) = show_response else {
488                panic!("unexpected show response");
489            };
490
491            assert_eq!(show_response.command_tag(), "SELECT");
492
493            let row_schema = show_response.row_schema();
494            assert_eq!(row_schema.len(), 1);
495            assert_eq!(row_schema[0].name(), expected_response_col);
496        }
497    }
498}