datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14    Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use tokio::sync::Mutex;
21
22use crate::datatypes;
23use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30    type StartupHandler = DfSessionService;
31    type SimpleQueryHandler = DfSessionService;
32    type ExtendedQueryHandler = DfSessionService;
33    type CopyHandler = NoopCopyHandler;
34    type ErrorHandler = NoopErrorHandler;
35
36    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37        self.0.clone()
38    }
39
40    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41        self.0.clone()
42    }
43
44    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45        self.0.clone()
46    }
47
48    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49        Arc::new(NoopCopyHandler)
50    }
51
52    fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53        Arc::new(NoopErrorHandler)
54    }
55}
56
57pub struct DfSessionService {
58    session_context: Arc<SessionContext>,
59    parser: Arc<Parser>,
60    timezone: Arc<Mutex<String>>,
61}
62
63impl DfSessionService {
64    pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
65        let parser = Arc::new(Parser {
66            session_context: session_context.clone(),
67        });
68        DfSessionService {
69            session_context,
70            parser,
71            timezone: Arc::new(Mutex::new("UTC".to_string())),
72        }
73    }
74
75    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
76        let fields = vec![FieldInfo::new(
77            name.to_string(),
78            None,
79            None,
80            Type::VARCHAR,
81            FieldFormat::Text,
82        )];
83
84        let row = {
85            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
86            encoder.encode_field(&Some(value))?;
87            encoder.finish()
88        };
89
90        let row_stream = futures::stream::once(async move { row });
91        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
92    }
93
94    async fn try_respond_set_statements<'a>(
95        &self,
96        query_lower: &str,
97    ) -> PgWireResult<Option<Response<'a>>> {
98        if query_lower.starts_with("set") {
99            if query_lower.starts_with("set time zone") {
100                let parts: Vec<&str> = query_lower.split_whitespace().collect();
101                if parts.len() >= 4 {
102                    let tz = parts[3].trim_matches('"');
103                    let mut timezone = self.timezone.lock().await;
104                    *timezone = tz.to_string();
105                    Ok(Some(Response::Execution(Tag::new("SET"))))
106                } else {
107                    Err(PgWireError::UserError(Box::new(
108                        pgwire::error::ErrorInfo::new(
109                            "ERROR".to_string(),
110                            "42601".to_string(),
111                            "Invalid SET TIME ZONE syntax".to_string(),
112                        ),
113                    )))
114                }
115            } else {
116                // noop: skip any unsupported set statements
117                Ok(Some(Response::Execution(Tag::new("SET"))))
118            }
119        } else {
120            Ok(None)
121        }
122    }
123
124    async fn try_respond_show_statements<'a>(
125        &self,
126        query_lower: &str,
127    ) -> PgWireResult<Option<Response<'a>>> {
128        if query_lower.starts_with("show ") {
129            match query_lower.strip_suffix(";").unwrap_or(query_lower) {
130                "show time zone" => {
131                    let timezone = self.timezone.lock().await.clone();
132                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
133                    Ok(Some(Response::Query(resp)))
134                }
135                "show server_version" => {
136                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
137                    Ok(Some(Response::Query(resp)))
138                }
139                "show transaction_isolation" => {
140                    let resp =
141                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
142                    Ok(Some(Response::Query(resp)))
143                }
144                "show catalogs" => {
145                    let catalogs = self.session_context.catalog_names();
146                    let value = catalogs.join(", ");
147                    let resp = Self::mock_show_response("Catalogs", &value)?;
148                    Ok(Some(Response::Query(resp)))
149                }
150                "show search_path" => {
151                    let default_catalog = "datafusion";
152                    let resp = Self::mock_show_response("search_path", default_catalog)?;
153                    Ok(Some(Response::Query(resp)))
154                }
155                _ => Err(PgWireError::UserError(Box::new(
156                    pgwire::error::ErrorInfo::new(
157                        "ERROR".to_string(),
158                        "42704".to_string(),
159                        format!("Unrecognized SHOW command: {}", query_lower),
160                    ),
161                ))),
162            }
163        } else {
164            Ok(None)
165        }
166    }
167}
168
169#[async_trait]
170impl SimpleQueryHandler for DfSessionService {
171    async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
172    where
173        C: ClientInfo + Unpin + Send + Sync,
174    {
175        let query_lower = query.to_lowercase().trim().to_string();
176        log::debug!("Received query: {}", query); // Log the query for debugging
177
178        if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
179            return Ok(vec![resp]);
180        }
181
182        if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
183            return Ok(vec![resp]);
184        }
185
186        let df = self
187            .session_context
188            .sql(query)
189            .await
190            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
191
192        if query_lower.starts_with("insert into") {
193            // For INSERT queries, we need to execute the query to get the row count
194            // and return an Execution response with the proper tag
195            let result = df
196                .clone()
197                .collect()
198                .await
199                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
200
201            // Extract count field from the first batch
202            let rows_affected = result
203                .first()
204                .and_then(|batch| batch.column_by_name("count"))
205                .and_then(|col| {
206                    col.as_any()
207                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
208                })
209                .map_or(0, |array| array.value(0) as usize);
210
211            // Create INSERT tag with the affected row count
212            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
213            Ok(vec![Response::Execution(tag)])
214        } else {
215            // For non-INSERT queries, return a regular Query response
216            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
217            Ok(vec![Response::Query(resp)])
218        }
219    }
220}
221
222#[async_trait]
223impl ExtendedQueryHandler for DfSessionService {
224    type Statement = (String, LogicalPlan);
225    type QueryParser = Parser;
226
227    fn query_parser(&self) -> Arc<Self::QueryParser> {
228        self.parser.clone()
229    }
230
231    async fn do_describe_statement<C>(
232        &self,
233        _client: &mut C,
234        target: &StoredStatement<Self::Statement>,
235    ) -> PgWireResult<DescribeStatementResponse>
236    where
237        C: ClientInfo + Unpin + Send + Sync,
238    {
239        let (_, plan) = &target.statement;
240        let schema = plan.schema();
241        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
242        let params = plan
243            .get_parameter_types()
244            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
245
246        let mut param_types = Vec::with_capacity(params.len());
247        for param_type in ordered_param_types(&params).iter() {
248            // Fixed: Use &params
249            if let Some(datatype) = param_type {
250                let pgtype = into_pg_type(datatype)?;
251                param_types.push(pgtype);
252            } else {
253                param_types.push(Type::UNKNOWN);
254            }
255        }
256
257        Ok(DescribeStatementResponse::new(param_types, fields))
258    }
259
260    async fn do_describe_portal<C>(
261        &self,
262        _client: &mut C,
263        target: &Portal<Self::Statement>,
264    ) -> PgWireResult<DescribePortalResponse>
265    where
266        C: ClientInfo + Unpin + Send + Sync,
267    {
268        let (_, plan) = &target.statement.statement;
269        let format = &target.result_column_format;
270        let schema = plan.schema();
271        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
272
273        Ok(DescribePortalResponse::new(fields))
274    }
275
276    async fn do_query<'a, C>(
277        &self,
278        _client: &mut C,
279        portal: &Portal<Self::Statement>,
280        _max_rows: usize,
281    ) -> PgWireResult<Response<'a>>
282    where
283        C: ClientInfo + Unpin + Send + Sync,
284    {
285        let query = portal
286            .statement
287            .statement
288            .0
289            .to_lowercase()
290            .trim()
291            .to_string();
292        log::debug!("Received execute extended query: {}", query); // Log for debugging
293
294        if let Some(resp) = self.try_respond_set_statements(&query).await? {
295            return Ok(resp);
296        }
297
298        if let Some(resp) = self.try_respond_show_statements(&query).await? {
299            return Ok(resp);
300        }
301
302        let (_, plan) = &portal.statement.statement;
303
304        let param_types = plan
305            .get_parameter_types()
306            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
307        let param_values =
308            datatypes::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
309        let plan = plan
310            .clone()
311            .replace_params_with_values(&param_values)
312            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
313        let dataframe = self
314            .session_context
315            .execute_logical_plan(plan)
316            .await
317            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
318        let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
319        Ok(Response::Query(resp))
320    }
321}
322
323pub struct Parser {
324    session_context: Arc<SessionContext>,
325}
326
327#[async_trait]
328impl QueryParser for Parser {
329    type Statement = (String, LogicalPlan);
330
331    async fn parse_sql<C>(
332        &self,
333        _client: &C,
334        sql: &str,
335        _types: &[Type],
336    ) -> PgWireResult<Self::Statement> {
337        log::debug!("Received parse extended query: {}", sql); // Log for debugging
338        let context = &self.session_context;
339        let state = context.state();
340        let logical_plan = state
341            .create_logical_plan(sql)
342            .await
343            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
344        let optimised = state
345            .optimize(&logical_plan)
346            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
347        Ok((sql.to_string(), optimised))
348    }
349}
350
351fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
352    // Datafusion stores the parameters as a map.  In our case, the keys will be
353    // `$1`, `$2` etc.  The values will be the parameter types.
354    let mut types = types.iter().collect::<Vec<_>>();
355    types.sort_by(|a, b| a.0.cmp(b.0));
356    types.into_iter().map(|pt| pt.1.as_ref()).collect()
357}