Skip to main content

datafusion_postgres/hooks/
set_show.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::ClientInfo;
12use pgwire::api::auth::DefaultServerParameterProvider;
13use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::messages::PgWireBackendMessage;
16use pgwire::messages::startup::ParameterStatus;
17use pgwire::types::format::FormatOptions;
18use postgres_types::Type;
19
20use crate::QueryHook;
21use crate::client;
22use crate::hooks::HookClient;
23
24#[derive(Debug)]
25pub struct SetShowHook;
26
27#[async_trait]
28impl QueryHook for SetShowHook {
29    /// called in simple query handler to return response directly
30    async fn handle_simple_query(
31        &self,
32        statement: &Statement,
33        session_context: &SessionContext,
34        client: &mut dyn HookClient,
35    ) -> Option<PgWireResult<Response>> {
36        match statement {
37            Statement::Set { .. } => {
38                try_respond_set_statements(client, statement, session_context).await
39            }
40            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
41                try_respond_show_statements(client, statement, session_context).await
42            }
43            _ => None,
44        }
45    }
46
47    async fn handle_extended_parse_query(
48        &self,
49        stmt: &Statement,
50        _session_context: &SessionContext,
51        _client: &(dyn ClientInfo + Send + Sync),
52    ) -> Option<PgWireResult<LogicalPlan>> {
53        match stmt {
54            Statement::Set { .. } => {
55                let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
56                let result = show_schema
57                    .to_dfschema()
58                    .map(|df_schema| {
59                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
60                            produce_one_row: true,
61                            schema: Arc::new(df_schema),
62                        })
63                    })
64                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
65                Some(result)
66            }
67            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
68                let show_schema =
69                    Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
70                let result = show_schema
71                    .to_dfschema()
72                    .map(|df_schema| {
73                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
74                            produce_one_row: true,
75                            schema: Arc::new(df_schema),
76                        })
77                    })
78                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
79                Some(result)
80            }
81            _ => None,
82        }
83    }
84
85    async fn handle_extended_query(
86        &self,
87        statement: &Statement,
88        _logical_plan: &LogicalPlan,
89        _params: &ParamValues,
90        session_context: &SessionContext,
91        client: &mut dyn HookClient,
92    ) -> Option<PgWireResult<Response>> {
93        match statement {
94            Statement::Set { .. } => {
95                try_respond_set_statements(client, statement, session_context).await
96            }
97            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
98                try_respond_show_statements(client, statement, session_context).await
99            }
100            _ => None,
101        }
102    }
103}
104
105fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
106    let fields = vec![FieldInfo::new(
107        name.to_string(),
108        None,
109        None,
110        Type::VARCHAR,
111        FieldFormat::Text,
112    )];
113
114    let row = {
115        let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
116        encoder.encode_field(&Some(value))?;
117        Ok(encoder.take_row())
118    };
119
120    let row_stream = futures::stream::once(async move { row });
121    Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122}
123
124async fn try_respond_set_statements(
125    client: &mut dyn HookClient,
126    statement: &Statement,
127    session_context: &SessionContext,
128) -> Option<PgWireResult<Response>> {
129    let Statement::Set(set_statement) = statement else {
130        return None;
131    };
132
133    match &set_statement {
134        Set::SingleAssignment {
135            scope: None,
136            hivevar: false,
137            variable,
138            values,
139        } => {
140            let var = variable.to_string().to_lowercase();
141            if var == "statement_timeout" {
142                let value = values[0].to_string();
143                let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145                let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146                    None
147                } else {
148                    // Parse timeout value (supports ms, s, min formats)
149                    let timeout_ms = if timeout_str.ends_with("ms") {
150                        timeout_str.trim_end_matches("ms").parse::<u64>()
151                    } else if timeout_str.ends_with("s") {
152                        timeout_str
153                            .trim_end_matches("s")
154                            .parse::<u64>()
155                            .map(|s| s * 1000)
156                    } else if timeout_str.ends_with("min") {
157                        timeout_str
158                            .trim_end_matches("min")
159                            .parse::<u64>()
160                            .map(|m| m * 60 * 1000)
161                    } else {
162                        // Default to milliseconds
163                        timeout_str.parse::<u64>()
164                    };
165
166                    match timeout_ms {
167                        Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168                        _ => None,
169                    }
170                };
171
172                client::set_statement_timeout(client, timeout);
173                return Some(Ok(Response::Execution(Tag::new("SET"))));
174            } else if matches!(
175                var.as_str(),
176                "datestyle"
177                    | "bytea_output"
178                    | "intervalstyle"
179                    | "application_name"
180                    | "extra_float_digits"
181                    | "search_path"
182            ) && !values.is_empty()
183            {
184                // postgres configuration variables
185                let value = values[0].clone();
186                if let Expr::Value(value) = value {
187                    let val_str = value.into_string().unwrap_or_else(|| "".to_string());
188                    client.metadata_mut().insert(var.clone(), val_str);
189                    if let Some((name, value)) = parameter_status_for_var(&var, &*client)
190                        && let Err(e) = client
191                            .send_message(PgWireBackendMessage::ParameterStatus(
192                                ParameterStatus::new(name, value),
193                            ))
194                            .await
195                    {
196                        return Some(Err(e));
197                    }
198                    return Some(Ok(Response::Execution(Tag::new("SET"))));
199                }
200            }
201        }
202        Set::SetTimeZone {
203            local: false,
204            value,
205        } => {
206            let tz = value.to_string();
207            let tz = tz.trim_matches('"').trim_matches('\'');
208            client::set_timezone(client, Some(tz));
209            // execution options for timezone
210            session_context
211                .state()
212                .config_mut()
213                .options_mut()
214                .execution
215                .time_zone = Some(tz.to_string());
216            let tz_value = client::get_timezone(client).unwrap_or("UTC").to_string();
217            if let Err(e) = client
218                .send_message(PgWireBackendMessage::ParameterStatus(ParameterStatus::new(
219                    "TimeZone".to_string(),
220                    tz_value,
221                )))
222                .await
223            {
224                return Some(Err(e));
225            }
226            return Some(Ok(Response::Execution(Tag::new("SET"))));
227        }
228        _ => {}
229    }
230
231    // fallback to datafusion and ignore all errors
232    if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
233        warn!(
234            "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
235        );
236    }
237
238    // Always return SET success
239    Some(Ok(Response::Execution(Tag::new("SET"))))
240}
241
242fn parameter_status_for_var(
243    var: &str,
244    client: &(impl ClientInfo + ?Sized),
245) -> Option<(String, String)> {
246    let display_name = match var {
247        "datestyle" => "DateStyle",
248        "intervalstyle" => "IntervalStyle",
249        "bytea_output" => "bytea_output",
250        "application_name" => "application_name",
251        "extra_float_digits" => "extra_float_digits",
252        "search_path" => "search_path",
253        _ => return None,
254    };
255    let value = client.metadata().get(var)?.clone();
256    Some((display_name.to_string(), value))
257}
258
259async fn execute_set_statement(
260    session_context: &SessionContext,
261    statement: Statement,
262) -> Result<(), DataFusionError> {
263    let state = session_context.state();
264    let logical_plan = state
265        .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
266            statement,
267        )))
268        .await
269        .and_then(|logical_plan| state.optimize(&logical_plan))?;
270
271    session_context
272        .execute_logical_plan(logical_plan)
273        .await
274        .map(|_| ())
275}
276
277async fn try_respond_show_statements(
278    client: &dyn HookClient,
279    statement: &Statement,
280    session_context: &SessionContext,
281) -> Option<PgWireResult<Response>> {
282    let Statement::ShowVariable { variable } = statement else {
283        return None;
284    };
285
286    let variables = variable
287        .iter()
288        .map(|v| v.value.to_lowercase())
289        .collect::<Vec<_>>();
290    let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
291
292    match variables_ref.as_slice() {
293        ["time", "zone"] => {
294            let timezone = client::get_timezone(client).unwrap_or("UTC");
295            Some(mock_show_response("TimeZone", timezone).map(Response::Query))
296        }
297        ["server_version"] => {
298            let version = format!(
299                "datafusion {} on {} {}",
300                session_context.state().version(),
301                env!("CARGO_PKG_NAME"),
302                env!("CARGO_PKG_VERSION")
303            );
304            Some(mock_show_response("server_version", &version).map(Response::Query))
305        }
306        ["transaction_isolation"] => Some(
307            mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
308        ),
309        ["catalogs"] => {
310            let catalogs = session_context.catalog_names();
311            let value = catalogs.join(", ");
312            Some(mock_show_response("Catalogs", &value).map(Response::Query))
313        }
314        ["statement_timeout"] => {
315            let timeout = client::get_statement_timeout(client);
316            let timeout_str = match timeout {
317                Some(duration) => format!("{}ms", duration.as_millis()),
318                None => "0".to_string(),
319            };
320            Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
321        }
322        ["transaction", "isolation", "level"] => {
323            Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
324        }
325        _ => {
326            let val = client
327                .metadata()
328                .get(&variables[0])
329                .map(|v| v.to_string())
330                .or_else(|| match variables[0].as_str() {
331                    "bytea_output" => Some(FormatOptions::default().bytea_output),
332                    "datestyle" => Some(FormatOptions::default().date_style),
333                    "intervalstyle" => Some(FormatOptions::default().interval_style),
334                    "extra_float_digits" => {
335                        Some(FormatOptions::default().extra_float_digits.to_string())
336                    }
337                    "application_name" => Some(
338                        DefaultServerParameterProvider::default()
339                            .application_name
340                            .unwrap_or("".to_owned()),
341                    ),
342                    "search_path" => Some(DefaultServerParameterProvider::default().search_path),
343                    _ => None,
344                });
345            if let Some(val) = val {
346                Some(mock_show_response(&variables[0], &val).map(Response::Query))
347            } else {
348                info!("Unsupported show statement: {statement}");
349                Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
350            }
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use std::time::Duration;
358
359    use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
360
361    use super::*;
362    use crate::testing::MockClient;
363
364    #[tokio::test]
365    async fn test_statement_timeout_set_and_show() {
366        let session_context = SessionContext::new();
367        let mut client = MockClient::new();
368
369        // Test setting timeout to 5000ms
370        let statement = Parser::new(&PostgreSqlDialect {})
371            .try_with_sql("set statement_timeout to '5000ms'")
372            .unwrap()
373            .parse_statement()
374            .unwrap();
375        let set_response =
376            try_respond_set_statements(&mut client, &statement, &session_context).await;
377
378        assert!(set_response.is_some());
379        assert!(set_response.unwrap().is_ok());
380
381        // Verify the timeout was set in client metadata
382        let timeout = client::get_statement_timeout(&client);
383        assert_eq!(timeout, Some(Duration::from_millis(5000)));
384
385        // Test SHOW statement_timeout
386        let statement = Parser::new(&PostgreSqlDialect {})
387            .try_with_sql("show statement_timeout")
388            .unwrap()
389            .parse_statement()
390            .unwrap();
391        let show_response =
392            try_respond_show_statements(&client, &statement, &session_context).await;
393
394        assert!(show_response.is_some());
395        assert!(show_response.unwrap().is_ok());
396    }
397
398    #[tokio::test]
399    async fn test_bytea_output_set_and_show() {
400        let session_context = SessionContext::new();
401        let mut client = MockClient::new();
402
403        // Test setting bytea_output to hex
404        let statement = Parser::new(&PostgreSqlDialect {})
405            .try_with_sql("set bytea_output = 'hex'")
406            .unwrap()
407            .parse_statement()
408            .unwrap();
409        let set_response =
410            try_respond_set_statements(&mut client, &statement, &session_context).await;
411
412        assert!(set_response.is_some());
413        assert!(set_response.unwrap().is_ok());
414
415        // Verify the value was set in client metadata
416        let bytea_output = client.metadata().get("bytea_output").unwrap();
417        assert_eq!(bytea_output, "hex");
418
419        // Test SHOW bytea_output
420        let statement = Parser::new(&PostgreSqlDialect {})
421            .try_with_sql("show bytea_output")
422            .unwrap()
423            .parse_statement()
424            .unwrap();
425        let show_response =
426            try_respond_show_statements(&client, &statement, &session_context).await;
427
428        assert!(show_response.is_some());
429        assert!(show_response.unwrap().is_ok());
430    }
431
432    #[tokio::test]
433    async fn test_date_style_set_and_show() {
434        let session_context = SessionContext::new();
435        let mut client = MockClient::new();
436
437        // Test setting dateStyle
438        let statement = Parser::new(&PostgreSqlDialect {})
439            .try_with_sql("set dateStyle = 'ISO, DMY'")
440            .unwrap()
441            .parse_statement()
442            .unwrap();
443        let set_response =
444            try_respond_set_statements(&mut client, &statement, &session_context).await;
445
446        assert!(set_response.is_some());
447        assert!(set_response.unwrap().is_ok());
448
449        // Verify the value was set in client metadata
450        let bytea_output = client.metadata().get("datestyle").unwrap();
451        assert_eq!(bytea_output, "ISO, DMY");
452
453        // Test SHOW dateStyle
454        let statement = Parser::new(&PostgreSqlDialect {})
455            .try_with_sql("show dateStyle")
456            .unwrap()
457            .parse_statement()
458            .unwrap();
459        let show_response =
460            try_respond_show_statements(&client, &statement, &session_context).await;
461
462        assert!(show_response.is_some());
463        assert!(show_response.unwrap().is_ok());
464    }
465
466    #[tokio::test]
467    async fn test_statement_timeout_disable() {
468        let session_context = SessionContext::new();
469        let mut client = MockClient::new();
470
471        // Set timeout first
472        let statement = Parser::new(&PostgreSqlDialect {})
473            .try_with_sql("set statement_timeout to '1000ms'")
474            .unwrap()
475            .parse_statement()
476            .unwrap();
477        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
478        assert!(resp.is_some());
479        assert!(resp.unwrap().is_ok());
480
481        // Disable timeout with 0
482        let statement = Parser::new(&PostgreSqlDialect {})
483            .try_with_sql("set statement_timeout to '0'")
484            .unwrap()
485            .parse_statement()
486            .unwrap();
487        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
488        assert!(resp.is_some());
489        assert!(resp.unwrap().is_ok());
490
491        let timeout = client::get_statement_timeout(&client);
492        assert_eq!(timeout, None);
493    }
494
495    #[tokio::test]
496    async fn test_parameter_status_sent_for_all_set_vars() {
497        use pgwire::messages::PgWireBackendMessage;
498
499        let test_cases = vec![
500            ("set bytea_output = 'escape'", "bytea_output", "escape"),
501            (
502                "set intervalstyle = 'postgres'",
503                "IntervalStyle",
504                "postgres",
505            ),
506            (
507                "set application_name = 'myapp'",
508                "application_name",
509                "myapp",
510            ),
511            ("set search_path = 'public'", "search_path", "public"),
512            ("set extra_float_digits = '2'", "extra_float_digits", "2"),
513            ("set datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
514            (
515                "set time zone 'America/New_York'",
516                "TimeZone",
517                "America/New_York",
518            ),
519        ];
520
521        for (sql, expected_key, expected_value) in test_cases {
522            let session_context = SessionContext::new();
523            let mut client = MockClient::new();
524            let statement = Parser::new(&PostgreSqlDialect {})
525                .try_with_sql(sql)
526                .unwrap()
527                .parse_statement()
528                .unwrap();
529
530            let result =
531                try_respond_set_statements(&mut client, &statement, &session_context).await;
532            assert!(result.is_some(), "Expected Some for {sql}");
533            assert!(result.unwrap().is_ok(), "Expected Ok for {sql}");
534
535            let ps_msgs: Vec<_> = client
536                .sent_messages()
537                .iter()
538                .filter_map(|m| match m {
539                    PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
540                    _ => None,
541                })
542                .collect();
543
544            assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
545            assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
546            assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
547        }
548    }
549
550    #[tokio::test]
551    async fn test_no_parameter_status_for_statement_timeout() {
552        use pgwire::messages::PgWireBackendMessage;
553
554        let session_context = SessionContext::new();
555        let mut client = MockClient::new();
556
557        let statement = Parser::new(&PostgreSqlDialect {})
558            .try_with_sql("set statement_timeout to '5000ms'")
559            .unwrap()
560            .parse_statement()
561            .unwrap();
562
563        let result = try_respond_set_statements(&mut client, &statement, &session_context).await;
564        assert!(result.is_some());
565        assert!(result.unwrap().is_ok());
566
567        let has_ps = client
568            .sent_messages()
569            .iter()
570            .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
571
572        assert!(!has_ps, "statement_timeout should not send ParameterStatus");
573    }
574
575    #[tokio::test]
576    async fn test_supported_show_statements_returned_columns() {
577        let session_context = SessionContext::new();
578        let client = MockClient::new();
579
580        let tests = [
581            ("show time zone", "TimeZone"),
582            ("show server_version", "server_version"),
583            ("show transaction_isolation", "transaction_isolation"),
584            ("show catalogs", "Catalogs"),
585            ("show search_path", "search_path"),
586            ("show statement_timeout", "statement_timeout"),
587            ("show transaction isolation level", "transaction_isolation"),
588        ];
589
590        for (query, expected_response_col) in tests {
591            let statement = Parser::new(&PostgreSqlDialect {})
592                .try_with_sql(&query)
593                .unwrap()
594                .parse_statement()
595                .unwrap();
596            let show_response =
597                try_respond_show_statements(&client, &statement, &session_context).await;
598
599            let Some(Ok(Response::Query(show_response))) = show_response else {
600                panic!("unexpected show response");
601            };
602
603            assert_eq!(show_response.command_tag(), "SELECT");
604
605            let row_schema = show_response.row_schema();
606            assert_eq!(row_schema.len(), 1);
607            assert_eq!(row_schema[0].name(), expected_response_col);
608        }
609    }
610}