datafusion_postgres/hooks/
set_show.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::{DataType, Field, Schema};
5use datafusion::common::{ParamValues, ToDFSchema};
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::SessionContext;
8use datafusion::sql::sqlparser::ast::{Set, Statement};
9use log::{info, warn};
10use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
11use pgwire::api::ClientInfo;
12use pgwire::error::{PgWireError, PgWireResult};
13use postgres_types::Type;
14
15use crate::client;
16use crate::QueryHook;
17
18#[derive(Debug)]
19pub struct SetShowHook;
20
21#[async_trait]
22impl QueryHook for SetShowHook {
23    /// called in simple query handler to return response directly
24    async fn handle_simple_query(
25        &self,
26        statement: &Statement,
27        session_context: &SessionContext,
28        client: &mut (dyn ClientInfo + Send + Sync),
29    ) -> Option<PgWireResult<Response>> {
30        match statement {
31            Statement::Set { .. } => {
32                try_respond_set_statements(client, statement, session_context).await
33            }
34            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
35                try_respond_show_statements(client, statement, session_context).await
36            }
37            _ => None,
38        }
39    }
40
41    async fn handle_extended_parse_query(
42        &self,
43        stmt: &Statement,
44        _session_context: &SessionContext,
45        _client: &(dyn ClientInfo + Send + Sync),
46    ) -> Option<PgWireResult<LogicalPlan>> {
47        let sql_lower = stmt.to_string().to_lowercase();
48        let sql_trimmed = sql_lower.trim();
49
50        if sql_trimmed.starts_with("show") {
51            let show_schema =
52                Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
53            let result = show_schema
54                .to_dfschema()
55                .map(|df_schema| {
56                    LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
57                        produce_one_row: true,
58                        schema: Arc::new(df_schema),
59                    })
60                })
61                .map_err(|e| PgWireError::ApiError(Box::new(e)));
62            Some(result)
63        } else if sql_trimmed.starts_with("set") {
64            let show_schema = Arc::new(Schema::new(Vec::<Field>::new()));
65            let result = show_schema
66                .to_dfschema()
67                .map(|df_schema| {
68                    LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
69                        produce_one_row: true,
70                        schema: Arc::new(df_schema),
71                    })
72                })
73                .map_err(|e| PgWireError::ApiError(Box::new(e)));
74            Some(result)
75        } else {
76            None
77        }
78    }
79
80    async fn handle_extended_query(
81        &self,
82        statement: &Statement,
83        _logical_plan: &LogicalPlan,
84        _params: &ParamValues,
85        session_context: &SessionContext,
86        client: &mut (dyn ClientInfo + Send + Sync),
87    ) -> Option<PgWireResult<Response>> {
88        match statement {
89            Statement::Set { .. } => {
90                try_respond_set_statements(client, statement, session_context).await
91            }
92            Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
93                try_respond_show_statements(client, statement, session_context).await
94            }
95            _ => None,
96        }
97    }
98}
99
100fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
101    let fields = vec![FieldInfo::new(
102        name.to_string(),
103        None,
104        None,
105        Type::VARCHAR,
106        FieldFormat::Text,
107    )];
108
109    let row = {
110        let mut encoder = DataRowEncoder::new(Arc::new(fields.clone()));
111        encoder.encode_field(&Some(value))?;
112        encoder.finish()
113    };
114
115    let row_stream = futures::stream::once(async move { row });
116    Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
117}
118
119async fn try_respond_set_statements<C>(
120    client: &mut C,
121    statement: &Statement,
122    session_context: &SessionContext,
123) -> Option<PgWireResult<Response>>
124where
125    C: ClientInfo + Send + Sync + ?Sized,
126{
127    let Statement::Set(set_statement) = statement else {
128        return None;
129    };
130
131    match &set_statement {
132        Set::SingleAssignment {
133            scope: None,
134            hivevar: false,
135            variable,
136            values,
137        } if &variable.to_string() == "statement_timeout" => {
138            let value = values[0].to_string();
139            let timeout_str = value.trim_matches('"').trim_matches('\'');
140
141            let timeout = if timeout_str == "0" || timeout_str.is_empty() {
142                None
143            } else {
144                // Parse timeout value (supports ms, s, min formats)
145                let timeout_ms = if timeout_str.ends_with("ms") {
146                    timeout_str.trim_end_matches("ms").parse::<u64>()
147                } else if timeout_str.ends_with("s") {
148                    timeout_str
149                        .trim_end_matches("s")
150                        .parse::<u64>()
151                        .map(|s| s * 1000)
152                } else if timeout_str.ends_with("min") {
153                    timeout_str
154                        .trim_end_matches("min")
155                        .parse::<u64>()
156                        .map(|m| m * 60 * 1000)
157                } else {
158                    // Default to milliseconds
159                    timeout_str.parse::<u64>()
160                };
161
162                match timeout_ms {
163                    Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
164                    _ => None,
165                }
166            };
167
168            client::set_statement_timeout(client, timeout);
169            Some(Ok(Response::Execution(Tag::new("SET"))))
170        }
171        Set::SetTimeZone {
172            local: false,
173            value,
174        } => {
175            let tz = value.to_string();
176            let tz = tz.trim_matches('"').trim_matches('\'');
177            client::set_timezone(client, Some(tz));
178            Some(Ok(Response::Execution(Tag::new("SET"))))
179        }
180        _ => {
181            // pass SET query to datafusion
182            let query = statement.to_string();
183            if let Err(e) = session_context.sql(&query).await {
184                warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
185            }
186
187            // Always return SET success
188            Some(Ok(Response::Execution(Tag::new("SET"))))
189        }
190    }
191}
192
193async fn try_respond_show_statements<C>(
194    client: &C,
195    statement: &Statement,
196    session_context: &SessionContext,
197) -> Option<PgWireResult<Response>>
198where
199    C: ClientInfo + ?Sized,
200{
201    let Statement::ShowVariable { variable } = statement else {
202        return None;
203    };
204
205    let variables = variable
206        .iter()
207        .map(|v| &v.value as &str)
208        .collect::<Vec<_>>();
209
210    match &variables as &[&str] {
211        ["time", "zone"] => {
212            let timezone = client::get_timezone(client).unwrap_or("UTC");
213            Some(mock_show_response("TimeZone", timezone).map(Response::Query))
214        }
215        ["server_version"] => {
216            Some(mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query))
217        }
218        ["transaction_isolation"] => Some(
219            mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
220        ),
221        ["catalogs"] => {
222            let catalogs = session_context.catalog_names();
223            let value = catalogs.join(", ");
224            Some(mock_show_response("Catalogs", &value).map(Response::Query))
225        }
226        ["search_path"] => {
227            let default_schema = "public";
228            Some(mock_show_response("search_path", default_schema).map(Response::Query))
229        }
230        ["statement_timeout"] => {
231            let timeout = client::get_statement_timeout(client);
232            let timeout_str = match timeout {
233                Some(duration) => format!("{}ms", duration.as_millis()),
234                None => "0".to_string(),
235            };
236            Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
237        }
238        ["transaction", "isolation", "level"] => {
239            Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
240        }
241        _ => {
242            info!("Unsupported show statement: {}", statement);
243            Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use std::time::Duration;
251
252    use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
253
254    use super::*;
255    use crate::testing::MockClient;
256
257    #[tokio::test]
258    async fn test_statement_timeout_set_and_show() {
259        let session_context = SessionContext::new();
260        let mut client = MockClient::new();
261
262        // Test setting timeout to 5000ms
263        let statement = Parser::new(&PostgreSqlDialect {})
264            .try_with_sql("set statement_timeout to '5000ms'")
265            .unwrap()
266            .parse_statement()
267            .unwrap();
268        let set_response =
269            try_respond_set_statements(&mut client, &statement, &session_context).await;
270
271        assert!(set_response.is_some());
272        assert!(set_response.unwrap().is_ok());
273
274        // Verify the timeout was set in client metadata
275        let timeout = client::get_statement_timeout(&client);
276        assert_eq!(timeout, Some(Duration::from_millis(5000)));
277
278        // Test SHOW statement_timeout
279        let statement = Parser::new(&PostgreSqlDialect {})
280            .try_with_sql("show statement_timeout")
281            .unwrap()
282            .parse_statement()
283            .unwrap();
284        let show_response =
285            try_respond_show_statements(&client, &statement, &session_context).await;
286
287        assert!(show_response.is_some());
288        assert!(show_response.unwrap().is_ok());
289    }
290
291    #[tokio::test]
292    async fn test_statement_timeout_disable() {
293        let session_context = SessionContext::new();
294        let mut client = MockClient::new();
295
296        // Set timeout first
297        let statement = Parser::new(&PostgreSqlDialect {})
298            .try_with_sql("set statement_timeout to '1000ms'")
299            .unwrap()
300            .parse_statement()
301            .unwrap();
302        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
303        assert!(resp.is_some());
304        assert!(resp.unwrap().is_ok());
305
306        // Disable timeout with 0
307        let statement = Parser::new(&PostgreSqlDialect {})
308            .try_with_sql("set statement_timeout to '0'")
309            .unwrap()
310            .parse_statement()
311            .unwrap();
312        let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
313        assert!(resp.is_some());
314        assert!(resp.unwrap().is_ok());
315
316        let timeout = client::get_statement_timeout(&client);
317        assert_eq!(timeout, None);
318    }
319
320    #[tokio::test]
321    async fn test_supported_show_statements_returned_columns() {
322        let session_context = SessionContext::new();
323        let client = MockClient::new();
324
325        let tests = [
326            ("show time zone", "TimeZone"),
327            ("show server_version", "server_version"),
328            ("show transaction_isolation", "transaction_isolation"),
329            ("show catalogs", "Catalogs"),
330            ("show search_path", "search_path"),
331            ("show statement_timeout", "statement_timeout"),
332            ("show transaction isolation level", "transaction_isolation"),
333        ];
334
335        for (query, expected_response_col) in tests {
336            let statement = Parser::new(&PostgreSqlDialect {})
337                .try_with_sql(&query)
338                .unwrap()
339                .parse_statement()
340                .unwrap();
341            let show_response =
342                try_respond_show_statements(&client, &statement, &session_context).await;
343
344            let Some(Ok(Response::Query(show_response))) = show_response else {
345                panic!("unexpected show response");
346            };
347
348            assert_eq!(show_response.command_tag(), "SELECT");
349
350            let row_schema = show_response.row_schema();
351            assert_eq!(row_schema.len(), 1);
352            assert_eq!(row_schema[0].name(), expected_response_col);
353        }
354    }
355}