datafusion_postgres/hooks/
set_show.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
12use pgwire::api::ClientInfo;
13use pgwire::error::{PgWireError, PgWireResult};
14use postgres_types::Type;
15
16use crate::client;
17use crate::QueryHook;
18
19#[derive(Debug)]
20pub struct SetShowHook;
21
22#[async_trait]
23impl QueryHook for SetShowHook {
24    /// called in simple query handler to return response directly
25    async fn handle_simple_query(
26        &self,
27        statement: &Statement,
28        session_context: &SessionContext,
29        client: &mut (dyn ClientInfo + Send + Sync),
30    ) -> Option<PgWireResult<Response>> {
31        match statement {
32            Statement::Set { .. } => {
33                try_respond_set_statements(client, statement, session_context).await
34            }
35            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
36                try_respond_show_statements(client, statement, session_context).await
37            }
38            _ => None,
39        }
40    }
41
42    async fn handle_extended_parse_query(
43        &self,
44        stmt: &Statement,
45        _session_context: &SessionContext,
46        _client: &(dyn ClientInfo + Send + Sync),
47    ) -> Option<PgWireResult<LogicalPlan>> {
48        let sql_lower = stmt.to_string().to_lowercase();
49        let sql_trimmed = sql_lower.trim();
50
51        if sql_trimmed.starts_with("show") {
52            let show_schema =
53                Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
54            let result = show_schema
55                .to_dfschema()
56                .map(|df_schema| {
57                    LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
58                        produce_one_row: true,
59                        schema: Arc::new(df_schema),
60                    })
61                })
62                .map_err(|e| PgWireError::ApiError(Box::new(e)));
63            Some(result)
64        } else if sql_trimmed.starts_with("set") {
65            let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
66            let result = show_schema
67                .to_dfschema()
68                .map(|df_schema| {
69                    LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
70                        produce_one_row: true,
71                        schema: Arc::new(df_schema),
72                    })
73                })
74                .map_err(|e| PgWireError::ApiError(Box::new(e)));
75            Some(result)
76        } else {
77            None
78        }
79    }
80
81    async fn handle_extended_query(
82        &self,
83        statement: &Statement,
84        _logical_plan: &LogicalPlan,
85        _params: &ParamValues,
86        session_context: &SessionContext,
87        client: &mut (dyn ClientInfo + Send + Sync),
88    ) -> Option<PgWireResult<Response>> {
89        match statement {
90            Statement::Set { .. } => {
91                try_respond_set_statements(client, statement, session_context).await
92            }
93            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
94                try_respond_show_statements(client, statement, session_context).await
95            }
96            _ => None,
97        }
98    }
99}
100
101fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
102    let fields = vec![FieldInfo::new(
103        name.to_string(),
104        None,
105        None,
106        Type::VARCHAR,
107        FieldFormat::Text,
108    )];
109
110    let row = {
111        let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
112        encoder.encode_field(&Some(value))?;
113        encoder.finish()
114    };
115
116    let row_stream = futures::stream::once(async move { row });
117    Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
118}
119
120async fn try_respond_set_statements<C>(
121    client: &mut C,
122    statement: &Statement,
123    session_context: &SessionContext,
124) -> Option<PgWireResult<Response>>
125where
126    C: ClientInfo + Send + Sync + ?Sized,
127{
128    let Statement::Set(set_statement) = statement else {
129        return None;
130    };
131
132    match &set_statement {
133        Set::SingleAssignment {
134            scope: None,
135            hivevar: false,
136            variable,
137            values,
138        } => {
139            let var = variable.to_string().to_lowercase();
140            if var == "statement_timeout" {
141                let value = values[0].to_string();
142                let timeout_str = value.trim_matches('"').trim_matches('\'');
143
144                let timeout = if timeout_str == "0" || timeout_str.is_empty() {
145                    None
146                } else {
147                    // Parse timeout value (supports ms, s, min formats)
148                    let timeout_ms = if timeout_str.ends_with("ms") {
149                        timeout_str.trim_end_matches("ms").parse::<u64>()
150                    } else if timeout_str.ends_with("s") {
151                        timeout_str
152                            .trim_end_matches("s")
153                            .parse::<u64>()
154                            .map(|s| s * 1000)
155                    } else if timeout_str.ends_with("min") {
156                        timeout_str
157                            .trim_end_matches("min")
158                            .parse::<u64>()
159                            .map(|m| m * 60 * 1000)
160                    } else {
161                        // Default to milliseconds
162                        timeout_str.parse::<u64>()
163                    };
164
165                    match timeout_ms {
166                        Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
167                        _ => None,
168                    }
169                };
170
171                client::set_statement_timeout(client, timeout);
172                return Some(Ok(Response::Execution(Tag::new("SET"))));
173            } else if matches!(
174                var.as_str(),
175                "datestyle"
176                    | "bytea_output"
177                    | "intervalstyle"
178                    | "application_name"
179                    | "extra_float_digits"
180                    | "search_path"
181            ) && !values.is_empty()
182            {
183                // postgres configuration variables
184                let value = values[0].clone();
185                if let Expr::Value(value) = value {
186                    client
187                        .metadata_mut()
188                        .insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
189                    return Some(Ok(Response::Execution(Tag::new("SET"))));
190                }
191            }
192        }
193        Set::SetTimeZone {
194            local: false,
195            value,
196        } => {
197            let tz = value.to_string();
198            let tz = tz.trim_matches('"').trim_matches('\'');
199            client::set_timezone(client, Some(tz));
200            // execution options for timezone
201            session_context
202                .state()
203                .config_mut()
204                .options_mut()
205                .execution
206                .time_zone = tz.to_string();
207            return Some(Ok(Response::Execution(Tag::new("SET"))));
208        }
209        _ => {}
210    }
211
212    // fallback to datafusion and ignore all errors
213    if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
214        warn!(
215            "SET statement {} is not supported by datafusion, error {e}, statement ignored",
216            statement
217        );
218    }
219
220    // Always return SET success
221    Some(Ok(Response::Execution(Tag::new("SET"))))
222}
223
224async fn execute_set_statement(
225    session_context: &SessionContext,
226    statement: Statement,
227) -> Result<(), DataFusionError> {
228    let state = session_context.state();
229    let logical_plan = state
230        .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
231            statement,
232        )))
233        .await
234        .and_then(|logical_plan| state.optimize(&logical_plan))?;
235
236    session_context
237        .execute_logical_plan(logical_plan)
238        .await
239        .map(|_| ())
240}
241
242async fn try_respond_show_statements<C>(
243    client: &C,
244    statement: &Statement,
245    session_context: &SessionContext,
246) -> Option<PgWireResult<Response>>
247where
248    C: ClientInfo + ?Sized,
249{
250    let Statement::ShowVariable { variable } = statement else {
251        return None;
252    };
253
254    let variables = variable
255        .iter()
256        .map(|v| v.value.to_lowercase())
257        .collect::<Vec<_>>();
258    let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
259
260    match variables_ref.as_slice() {
261        ["time", "zone"] => {
262            let timezone = client::get_timezone(client).unwrap_or("UTC");
263            Some(mock_show_response("TimeZone", timezone).map(Response::Query))
264        }
265        ["server_version"] => {
266            let version = format!(
267                "datafusion {} on {} {}",
268                session_context.state().version(),
269                env!("CARGO_PKG_NAME"),
270                env!("CARGO_PKG_VERSION")
271            );
272            Some(mock_show_response("server_version", &version).map(Response::Query))
273        }
274        ["transaction_isolation"] => Some(
275            mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
276        ),
277        ["catalogs"] => {
278            let catalogs = session_context.catalog_names();
279            let value = catalogs.join(", ");
280            Some(mock_show_response("Catalogs", &value).map(Response::Query))
281        }
282        ["statement_timeout"] => {
283            let timeout = client::get_statement_timeout(client);
284            let timeout_str = match timeout {
285                Some(duration) => format!("{}ms", duration.as_millis()),
286                None => "0".to_string(),
287            };
288            Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
289        }
290        ["transaction", "isolation", "level"] => {
291            Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
292        }
293        ["bytea_output"]
294        | ["datestyle"]
295        | ["intervalstyle"]
296        | ["application_name"]
297        | ["extra_float_digits"]
298        | ["search_path"] => {
299            let val = client
300                .metadata()
301                .get(&variables[0])
302                .map(|v| v.as_str())
303                .unwrap_or("");
304            Some(mock_show_response(&variables[0], val).map(Response::Query))
305        }
306        _ => {
307            info!("Unsupported show statement: {}", statement);
308            Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use std::time::Duration;
316
317    use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
318
319    use super::*;
320    use crate::testing::MockClient;
321
322    #[tokio::test]
323    async fn test_statement_timeout_set_and_show() {
324        let session_context = SessionContext::new();
325        let mut client = MockClient::new();
326
327        // Test setting timeout to 5000ms
328        let statement = Parser::new(&PostgreSqlDialect {})
329            .try_with_sql("set statement_timeout to '5000ms'")
330            .unwrap()
331            .parse_statement()
332            .unwrap();
333        let set_response =
334            try_respond_set_statements(&mut client, &statement, &session_context).await;
335
336        assert!(set_response.is_some());
337        assert!(set_response.unwrap().is_ok());
338
339        // Verify the timeout was set in client metadata
340        let timeout = client::get_statement_timeout(&client);
341        assert_eq!(timeout, Some(Duration::from_millis(5000)));
342
343        // Test SHOW statement_timeout
344        let statement = Parser::new(&PostgreSqlDialect {})
345            .try_with_sql("show statement_timeout")
346            .unwrap()
347            .parse_statement()
348            .unwrap();
349        let show_response =
350            try_respond_show_statements(&client, &statement, &session_context).await;
351
352        assert!(show_response.is_some());
353        assert!(show_response.unwrap().is_ok());
354    }
355
356    #[tokio::test]
357    async fn test_bytea_output_set_and_show() {
358        let session_context = SessionContext::new();
359        let mut client = MockClient::new();
360
361        // Test setting timeout to 5000ms
362        let statement = Parser::new(&PostgreSqlDialect {})
363            .try_with_sql("set bytea_output = 'hex'")
364            .unwrap()
365            .parse_statement()
366            .unwrap();
367        let set_response =
368            try_respond_set_statements(&mut client, &statement, &session_context).await;
369
370        assert!(set_response.is_some());
371        assert!(set_response.unwrap().is_ok());
372
373        // Verify the timeout was set in client metadata
374        let bytea_output = client.metadata().get("bytea_output").unwrap();
375        assert_eq!(bytea_output, "hex");
376
377        // Test SHOW statement_timeout
378        let statement = Parser::new(&PostgreSqlDialect {})
379            .try_with_sql("show bytea_output")
380            .unwrap()
381            .parse_statement()
382            .unwrap();
383        let show_response =
384            try_respond_show_statements(&client, &statement, &session_context).await;
385
386        assert!(show_response.is_some());
387        assert!(show_response.unwrap().is_ok());
388    }
389
390    #[tokio::test]
391    async fn test_date_style_set_and_show() {
392        let session_context = SessionContext::new();
393        let mut client = MockClient::new();
394
395        // Test setting timeout to 5000ms
396        let statement = Parser::new(&PostgreSqlDialect {})
397            .try_with_sql("set dateStyle = 'ISO, DMY'")
398            .unwrap()
399            .parse_statement()
400            .unwrap();
401        let set_response =
402            try_respond_set_statements(&mut client, &statement, &session_context).await;
403
404        assert!(set_response.is_some());
405        assert!(set_response.unwrap().is_ok());
406
407        // Verify the timeout was set in client metadata
408        let bytea_output = client.metadata().get("datestyle").unwrap();
409        assert_eq!(bytea_output, "ISO, DMY");
410
411        // Test SHOW statement_timeout
412        let statement = Parser::new(&PostgreSqlDialect {})
413            .try_with_sql("show dateStyle")
414            .unwrap()
415            .parse_statement()
416            .unwrap();
417        let show_response =
418            try_respond_show_statements(&client, &statement, &session_context).await;
419
420        assert!(show_response.is_some());
421        assert!(show_response.unwrap().is_ok());
422    }
423
424    #[tokio::test]
425    async fn test_statement_timeout_disable() {
426        let session_context = SessionContext::new();
427        let mut client = MockClient::new();
428
429        // Set timeout first
430        let statement = Parser::new(&PostgreSqlDialect {})
431            .try_with_sql("set statement_timeout to '1000ms'")
432            .unwrap()
433            .parse_statement()
434            .unwrap();
435        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
436        assert!(resp.is_some());
437        assert!(resp.unwrap().is_ok());
438
439        // Disable timeout with 0
440        let statement = Parser::new(&PostgreSqlDialect {})
441            .try_with_sql("set statement_timeout to '0'")
442            .unwrap()
443            .parse_statement()
444            .unwrap();
445        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
446        assert!(resp.is_some());
447        assert!(resp.unwrap().is_ok());
448
449        let timeout = client::get_statement_timeout(&client);
450        assert_eq!(timeout, None);
451    }
452
453    #[tokio::test]
454    async fn test_supported_show_statements_returned_columns() {
455        let session_context = SessionContext::new();
456        let client = MockClient::new();
457
458        let tests = [
459            ("show time zone", "TimeZone"),
460            ("show server_version", "server_version"),
461            ("show transaction_isolation", "transaction_isolation"),
462            ("show catalogs", "Catalogs"),
463            ("show search_path", "search_path"),
464            ("show statement_timeout", "statement_timeout"),
465            ("show transaction isolation level", "transaction_isolation"),
466        ];
467
468        for (query, expected_response_col) in tests {
469            let statement = Parser::new(&PostgreSqlDialect {})
470                .try_with_sql(&query)
471                .unwrap()
472                .parse_statement()
473                .unwrap();
474            let show_response =
475                try_respond_show_statements(&client, &statement, &session_context).await;
476
477            let Some(Ok(Response::Query(show_response))) = show_response else {
478                panic!("unexpected show response");
479            };
480
481            assert_eq!(show_response.command_tag(), "SELECT");
482
483            let row_schema = show_response.row_schema();
484            assert_eq!(row_schema.len(), 1);
485            assert_eq!(row_schema[0].name(), expected_response_col);
486        }
487    }
488}