Skip to main content

datafusion_postgres/hooks/
set_show.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::error::DataFusionError;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::SessionContext;
9use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
10use log::{info, warn};
11use pgwire::api::auth::DefaultServerParameterProvider;
12use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13use pgwire::api::ClientInfo;
14use pgwire::error::{PgWireError, PgWireResult};
15use pgwire::messages::startup::ParameterStatus;
16use pgwire::messages::PgWireBackendMessage;
17use pgwire::types::format::FormatOptions;
18use postgres_types::Type;
19
20use crate::client;
21use crate::hooks::HookClient;
22use crate::QueryHook;
23
24#[derive(Debug)]
25pub struct SetShowHook;
26
27#[async_trait]
28impl QueryHook for SetShowHook {
29    /// called in simple query handler to return response directly
30    async fn handle_simple_query(
31        &self,
32        statement: &Statement,
33        session_context: &SessionContext,
34        client: &mut dyn HookClient,
35    ) -> Option<PgWireResult<Response>> {
36        match statement {
37            Statement::Set { .. } => {
38                try_respond_set_statements(client, statement, session_context).await
39            }
40            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
41                try_respond_show_statements(client, statement, session_context).await
42            }
43            _ => None,
44        }
45    }
46
47    async fn handle_extended_parse_query(
48        &self,
49        stmt: &Statement,
50        _session_context: &SessionContext,
51        _client: &(dyn ClientInfo + Send + Sync),
52    ) -> Option<PgWireResult<LogicalPlan>> {
53        match stmt {
54            Statement::Set { .. } => {
55                let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
56                let result = show_schema
57                    .to_dfschema()
58                    .map(|df_schema| {
59                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
60                            produce_one_row: true,
61                            schema: Arc::new(df_schema),
62                        })
63                    })
64                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
65                Some(result)
66            }
67            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
68                let show_schema =
69                    Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
70                let result = show_schema
71                    .to_dfschema()
72                    .map(|df_schema| {
73                        LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
74                            produce_one_row: true,
75                            schema: Arc::new(df_schema),
76                        })
77                    })
78                    .map_err(|e| PgWireError::ApiError(Box::new(e)));
79                Some(result)
80            }
81            _ => None,
82        }
83    }
84
85    async fn handle_extended_query(
86        &self,
87        statement: &Statement,
88        _logical_plan: &LogicalPlan,
89        _params: &ParamValues,
90        session_context: &SessionContext,
91        client: &mut dyn HookClient,
92    ) -> Option<PgWireResult<Response>> {
93        match statement {
94            Statement::Set { .. } => {
95                try_respond_set_statements(client, statement, session_context).await
96            }
97            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
98                try_respond_show_statements(client, statement, session_context).await
99            }
100            _ => None,
101        }
102    }
103}
104
105fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
106    let fields = vec![FieldInfo::new(
107        name.to_string(),
108        None,
109        None,
110        Type::VARCHAR,
111        FieldFormat::Text,
112    )];
113
114    let row = {
115        let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
116        encoder.encode_field(&Some(value))?;
117        Ok(encoder.take_row())
118    };
119
120    let row_stream = futures::stream::once(async move { row });
121    Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122}
123
124async fn try_respond_set_statements(
125    client: &mut dyn HookClient,
126    statement: &Statement,
127    session_context: &SessionContext,
128) -> Option<PgWireResult<Response>> {
129    let Statement::Set(set_statement) = statement else {
130        return None;
131    };
132
133    match &set_statement {
134        Set::SingleAssignment {
135            scope: None,
136            hivevar: false,
137            variable,
138            values,
139        } => {
140            let var = variable.to_string().to_lowercase();
141            if var == "statement_timeout" {
142                let value = values[0].to_string();
143                let timeout_str = value.trim_matches('"').trim_matches('\'');
144
145                let timeout = if timeout_str == "0" || timeout_str.is_empty() {
146                    None
147                } else {
148                    // Parse timeout value (supports ms, s, min formats)
149                    let timeout_ms = if timeout_str.ends_with("ms") {
150                        timeout_str.trim_end_matches("ms").parse::<u64>()
151                    } else if timeout_str.ends_with("s") {
152                        timeout_str
153                            .trim_end_matches("s")
154                            .parse::<u64>()
155                            .map(|s| s * 1000)
156                    } else if timeout_str.ends_with("min") {
157                        timeout_str
158                            .trim_end_matches("min")
159                            .parse::<u64>()
160                            .map(|m| m * 60 * 1000)
161                    } else {
162                        // Default to milliseconds
163                        timeout_str.parse::<u64>()
164                    };
165
166                    match timeout_ms {
167                        Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
168                        _ => None,
169                    }
170                };
171
172                client::set_statement_timeout(client, timeout);
173                return Some(Ok(Response::Execution(Tag::new("SET"))));
174            } else if matches!(
175                var.as_str(),
176                "datestyle"
177                    | "bytea_output"
178                    | "intervalstyle"
179                    | "application_name"
180                    | "extra_float_digits"
181                    | "search_path"
182            ) && !values.is_empty()
183            {
184                // postgres configuration variables
185                let value = values[0].clone();
186                if let Expr::Value(value) = value {
187                    let val_str = value.into_string().unwrap_or_else(|| "".to_string());
188                    client.metadata_mut().insert(var.clone(), val_str);
189                    if let Some((name, value)) = parameter_status_for_var(&var, &*client) {
190                        if let Err(e) = client
191                            .send_message(PgWireBackendMessage::ParameterStatus(
192                                ParameterStatus::new(name, value),
193                            ))
194                            .await
195                        {
196                            return Some(Err(e));
197                        }
198                    }
199                    return Some(Ok(Response::Execution(Tag::new("SET"))));
200                }
201            }
202        }
203        Set::SetTimeZone {
204            local: false,
205            value,
206        } => {
207            let tz = value.to_string();
208            let tz = tz.trim_matches('"').trim_matches('\'');
209            client::set_timezone(client, Some(tz));
210            // execution options for timezone
211            session_context
212                .state()
213                .config_mut()
214                .options_mut()
215                .execution
216                .time_zone = Some(tz.to_string());
217            let tz_value = client::get_timezone(client).unwrap_or("UTC").to_string();
218            if let Err(e) = client
219                .send_message(PgWireBackendMessage::ParameterStatus(ParameterStatus::new(
220                    "TimeZone".to_string(),
221                    tz_value,
222                )))
223                .await
224            {
225                return Some(Err(e));
226            }
227            return Some(Ok(Response::Execution(Tag::new("SET"))));
228        }
229        _ => {}
230    }
231
232    // fallback to datafusion and ignore all errors
233    if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
234        warn!(
235            "SET statement {statement} is not supported by datafusion, error {e}, statement ignored",
236        );
237    }
238
239    // Always return SET success
240    Some(Ok(Response::Execution(Tag::new("SET"))))
241}
242
243fn parameter_status_for_var(
244    var: &str,
245    client: &(impl ClientInfo + ?Sized),
246) -> Option<(String, String)> {
247    let display_name = match var {
248        "datestyle" => "DateStyle",
249        "intervalstyle" => "IntervalStyle",
250        "bytea_output" => "bytea_output",
251        "application_name" => "application_name",
252        "extra_float_digits" => "extra_float_digits",
253        "search_path" => "search_path",
254        _ => return None,
255    };
256    let value = client.metadata().get(var)?.clone();
257    Some((display_name.to_string(), value))
258}
259
260async fn execute_set_statement(
261    session_context: &SessionContext,
262    statement: Statement,
263) -> Result<(), DataFusionError> {
264    let state = session_context.state();
265    let logical_plan = state
266        .statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
267            statement,
268        )))
269        .await
270        .and_then(|logical_plan| state.optimize(&logical_plan))?;
271
272    session_context
273        .execute_logical_plan(logical_plan)
274        .await
275        .map(|_| ())
276}
277
278async fn try_respond_show_statements(
279    client: &dyn HookClient,
280    statement: &Statement,
281    session_context: &SessionContext,
282) -> Option<PgWireResult<Response>> {
283    let Statement::ShowVariable { variable } = statement else {
284        return None;
285    };
286
287    let variables = variable
288        .iter()
289        .map(|v| v.value.to_lowercase())
290        .collect::<Vec<_>>();
291    let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
292
293    match variables_ref.as_slice() {
294        ["time", "zone"] => {
295            let timezone = client::get_timezone(client).unwrap_or("UTC");
296            Some(mock_show_response("TimeZone", timezone).map(Response::Query))
297        }
298        ["server_version"] => {
299            let version = format!(
300                "datafusion {} on {} {}",
301                session_context.state().version(),
302                env!("CARGO_PKG_NAME"),
303                env!("CARGO_PKG_VERSION")
304            );
305            Some(mock_show_response("server_version", &version).map(Response::Query))
306        }
307        ["transaction_isolation"] => Some(
308            mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
309        ),
310        ["catalogs"] => {
311            let catalogs = session_context.catalog_names();
312            let value = catalogs.join(", ");
313            Some(mock_show_response("Catalogs", &value).map(Response::Query))
314        }
315        ["statement_timeout"] => {
316            let timeout = client::get_statement_timeout(client);
317            let timeout_str = match timeout {
318                Some(duration) => format!("{}ms", duration.as_millis()),
319                None => "0".to_string(),
320            };
321            Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
322        }
323        ["transaction", "isolation", "level"] => {
324            Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
325        }
326        _ => {
327            let val = client
328                .metadata()
329                .get(&variables[0])
330                .map(|v| v.to_string())
331                .or_else(|| match variables[0].as_str() {
332                    "bytea_output" => Some(FormatOptions::default().bytea_output),
333                    "datestyle" => Some(FormatOptions::default().date_style),
334                    "intervalstyle" => Some(FormatOptions::default().interval_style),
335                    "extra_float_digits" => {
336                        Some(FormatOptions::default().extra_float_digits.to_string())
337                    }
338                    "application_name" => Some(
339                        DefaultServerParameterProvider::default()
340                            .application_name
341                            .unwrap_or("".to_owned()),
342                    ),
343                    "search_path" => Some(DefaultServerParameterProvider::default().search_path),
344                    _ => None,
345                });
346            if let Some(val) = val {
347                Some(mock_show_response(&variables[0], &val).map(Response::Query))
348            } else {
349                info!("Unsupported show statement: {statement}");
350                Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
351            }
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::time::Duration;
359
360    use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
361
362    use super::*;
363    use crate::testing::MockClient;
364
365    #[tokio::test]
366    async fn test_statement_timeout_set_and_show() {
367        let session_context = SessionContext::new();
368        let mut client = MockClient::new();
369
370        // Test setting timeout to 5000ms
371        let statement = Parser::new(&PostgreSqlDialect {})
372            .try_with_sql("set statement_timeout to '5000ms'")
373            .unwrap()
374            .parse_statement()
375            .unwrap();
376        let set_response =
377            try_respond_set_statements(&mut client, &statement, &session_context).await;
378
379        assert!(set_response.is_some());
380        assert!(set_response.unwrap().is_ok());
381
382        // Verify the timeout was set in client metadata
383        let timeout = client::get_statement_timeout(&client);
384        assert_eq!(timeout, Some(Duration::from_millis(5000)));
385
386        // Test SHOW statement_timeout
387        let statement = Parser::new(&PostgreSqlDialect {})
388            .try_with_sql("show statement_timeout")
389            .unwrap()
390            .parse_statement()
391            .unwrap();
392        let show_response =
393            try_respond_show_statements(&client, &statement, &session_context).await;
394
395        assert!(show_response.is_some());
396        assert!(show_response.unwrap().is_ok());
397    }
398
399    #[tokio::test]
400    async fn test_bytea_output_set_and_show() {
401        let session_context = SessionContext::new();
402        let mut client = MockClient::new();
403
404        // Test setting bytea_output to hex
405        let statement = Parser::new(&PostgreSqlDialect {})
406            .try_with_sql("set bytea_output = 'hex'")
407            .unwrap()
408            .parse_statement()
409            .unwrap();
410        let set_response =
411            try_respond_set_statements(&mut client, &statement, &session_context).await;
412
413        assert!(set_response.is_some());
414        assert!(set_response.unwrap().is_ok());
415
416        // Verify the value was set in client metadata
417        let bytea_output = client.metadata().get("bytea_output").unwrap();
418        assert_eq!(bytea_output, "hex");
419
420        // Test SHOW bytea_output
421        let statement = Parser::new(&PostgreSqlDialect {})
422            .try_with_sql("show bytea_output")
423            .unwrap()
424            .parse_statement()
425            .unwrap();
426        let show_response =
427            try_respond_show_statements(&client, &statement, &session_context).await;
428
429        assert!(show_response.is_some());
430        assert!(show_response.unwrap().is_ok());
431    }
432
433    #[tokio::test]
434    async fn test_date_style_set_and_show() {
435        let session_context = SessionContext::new();
436        let mut client = MockClient::new();
437
438        // Test setting dateStyle
439        let statement = Parser::new(&PostgreSqlDialect {})
440            .try_with_sql("set dateStyle = 'ISO, DMY'")
441            .unwrap()
442            .parse_statement()
443            .unwrap();
444        let set_response =
445            try_respond_set_statements(&mut client, &statement, &session_context).await;
446
447        assert!(set_response.is_some());
448        assert!(set_response.unwrap().is_ok());
449
450        // Verify the value was set in client metadata
451        let bytea_output = client.metadata().get("datestyle").unwrap();
452        assert_eq!(bytea_output, "ISO, DMY");
453
454        // Test SHOW dateStyle
455        let statement = Parser::new(&PostgreSqlDialect {})
456            .try_with_sql("show dateStyle")
457            .unwrap()
458            .parse_statement()
459            .unwrap();
460        let show_response =
461            try_respond_show_statements(&client, &statement, &session_context).await;
462
463        assert!(show_response.is_some());
464        assert!(show_response.unwrap().is_ok());
465    }
466
467    #[tokio::test]
468    async fn test_statement_timeout_disable() {
469        let session_context = SessionContext::new();
470        let mut client = MockClient::new();
471
472        // Set timeout first
473        let statement = Parser::new(&PostgreSqlDialect {})
474            .try_with_sql("set statement_timeout to '1000ms'")
475            .unwrap()
476            .parse_statement()
477            .unwrap();
478        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
479        assert!(resp.is_some());
480        assert!(resp.unwrap().is_ok());
481
482        // Disable timeout with 0
483        let statement = Parser::new(&PostgreSqlDialect {})
484            .try_with_sql("set statement_timeout to '0'")
485            .unwrap()
486            .parse_statement()
487            .unwrap();
488        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
489        assert!(resp.is_some());
490        assert!(resp.unwrap().is_ok());
491
492        let timeout = client::get_statement_timeout(&client);
493        assert_eq!(timeout, None);
494    }
495
496    #[tokio::test]
497    async fn test_parameter_status_sent_for_all_set_vars() {
498        use pgwire::messages::PgWireBackendMessage;
499
500        let test_cases = vec![
501            ("set bytea_output = 'escape'", "bytea_output", "escape"),
502            (
503                "set intervalstyle = 'postgres'",
504                "IntervalStyle",
505                "postgres",
506            ),
507            (
508                "set application_name = 'myapp'",
509                "application_name",
510                "myapp",
511            ),
512            ("set search_path = 'public'", "search_path", "public"),
513            ("set extra_float_digits = '2'", "extra_float_digits", "2"),
514            ("set datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
515            (
516                "set time zone 'America/New_York'",
517                "TimeZone",
518                "America/New_York",
519            ),
520        ];
521
522        for (sql, expected_key, expected_value) in test_cases {
523            let session_context = SessionContext::new();
524            let mut client = MockClient::new();
525            let statement = Parser::new(&PostgreSqlDialect {})
526                .try_with_sql(sql)
527                .unwrap()
528                .parse_statement()
529                .unwrap();
530
531            let result =
532                try_respond_set_statements(&mut client, &statement, &session_context).await;
533            assert!(result.is_some(), "Expected Some for {sql}");
534            assert!(result.unwrap().is_ok(), "Expected Ok for {sql}");
535
536            let ps_msgs: Vec<_> = client
537                .sent_messages()
538                .iter()
539                .filter_map(|m| match m {
540                    PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
541                    _ => None,
542                })
543                .collect();
544
545            assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
546            assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
547            assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
548        }
549    }
550
551    #[tokio::test]
552    async fn test_no_parameter_status_for_statement_timeout() {
553        use pgwire::messages::PgWireBackendMessage;
554
555        let session_context = SessionContext::new();
556        let mut client = MockClient::new();
557
558        let statement = Parser::new(&PostgreSqlDialect {})
559            .try_with_sql("set statement_timeout to '5000ms'")
560            .unwrap()
561            .parse_statement()
562            .unwrap();
563
564        let result = try_respond_set_statements(&mut client, &statement, &session_context).await;
565        assert!(result.is_some());
566        assert!(result.unwrap().is_ok());
567
568        let has_ps = client
569            .sent_messages()
570            .iter()
571            .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
572
573        assert!(!has_ps, "statement_timeout should not send ParameterStatus");
574    }
575
576    #[tokio::test]
577    async fn test_supported_show_statements_returned_columns() {
578        let session_context = SessionContext::new();
579        let client = MockClient::new();
580
581        let tests = [
582            ("show time zone", "TimeZone"),
583            ("show server_version", "server_version"),
584            ("show transaction_isolation", "transaction_isolation"),
585            ("show catalogs", "Catalogs"),
586            ("show search_path", "search_path"),
587            ("show statement_timeout", "statement_timeout"),
588            ("show transaction isolation level", "transaction_isolation"),
589        ];
590
591        for (query, expected_response_col) in tests {
592            let statement = Parser::new(&PostgreSqlDialect {})
593                .try_with_sql(&query)
594                .unwrap()
595                .parse_statement()
596                .unwrap();
597            let show_response =
598                try_respond_show_statements(&client, &statement, &session_context).await;
599
600            let Some(Ok(Response::Query(show_response))) = show_response else {
601                panic!("unexpected show response");
602            };
603
604            assert_eq!(show_response.command_tag(), "SELECT");
605
606            let row_schema = show_response.row_schema();
607            assert_eq!(row_schema.len(), 1);
608            assert_eq!(row_schema[0].name(), expected_response_col);
609        }
610    }
611}