datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14    Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use tokio::sync::Mutex;
20
21use crate::datatypes;
22use crate::information_schema::{columns_df, schemata_df, tables_df};
23use pgwire::error::{PgWireError, PgWireResult};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30    type StartupHandler = DfSessionService;
31    type SimpleQueryHandler = DfSessionService;
32    type ExtendedQueryHandler = DfSessionService;
33    type CopyHandler = NoopCopyHandler;
34    type ErrorHandler = NoopErrorHandler;
35
36    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37        self.0.clone()
38    }
39
40    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41        self.0.clone()
42    }
43
44    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45        self.0.clone()
46    }
47
48    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49        Arc::new(NoopCopyHandler)
50    }
51
52    fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53        Arc::new(NoopErrorHandler)
54    }
55}
56
57pub struct DfSessionService {
58    session_context: Arc<SessionContext>,
59    parser: Arc<Parser>,
60    timezone: Arc<Mutex<String>>,
61    catalog_name: String,
62}
63
64impl DfSessionService {
65    pub fn new(session_context: SessionContext, catalog_name: Option<String>) -> DfSessionService {
66        let session_context = Arc::new(session_context);
67        let parser = Arc::new(Parser {
68            session_context: session_context.clone(),
69        });
70        let catalog_name = catalog_name.unwrap_or_else(|| {
71            session_context
72                .catalog_names()
73                .first()
74                .cloned()
75                .unwrap_or_else(|| "datafusion".to_string())
76        });
77        DfSessionService {
78            session_context,
79            parser,
80            timezone: Arc::new(Mutex::new("UTC".to_string())),
81            catalog_name,
82        }
83    }
84
85    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
86        let fields = vec![FieldInfo::new(
87            name.to_string(),
88            None,
89            None,
90            Type::VARCHAR,
91            FieldFormat::Text,
92        )];
93
94        let row = {
95            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
96            encoder.encode_field(&Some(value))?;
97            encoder.finish()
98        };
99
100        let row_stream = futures::stream::once(async move { row });
101        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
102    }
103
104    // Mock pg_namespace response
105    async fn mock_pg_namespace<'a>(&self) -> PgWireResult<QueryResponse<'a>> {
106        let fields = vec![FieldInfo::new(
107            "nspname".to_string(),
108            None,
109            None,
110            Type::VARCHAR,
111            FieldFormat::Text,
112        )];
113
114        let row = {
115            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
116            encoder.encode_field(&Some(&self.catalog_name))?; // Return catalog_name as a schema
117            encoder.finish()
118        };
119
120        let row_stream = futures::stream::once(async move { row });
121        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122    }
123
124    async fn try_respond_set_time_zone<'a>(
125        &self,
126        query_lower: &str,
127    ) -> PgWireResult<Option<Vec<Response<'a>>>> {
128        if query_lower.starts_with("set time zone") {
129            let parts: Vec<&str> = query_lower.split_whitespace().collect();
130            if parts.len() >= 4 {
131                let tz = parts[3].trim_matches('"');
132                let mut timezone = self.timezone.lock().await;
133                *timezone = tz.to_string();
134                Ok(Some(vec![Response::Execution(Tag::new("SET"))]))
135            } else {
136                Err(PgWireError::UserError(Box::new(
137                    pgwire::error::ErrorInfo::new(
138                        "ERROR".to_string(),
139                        "42601".to_string(),
140                        "Invalid SET TIME ZONE syntax".to_string(),
141                    ),
142                )))
143            }
144        } else {
145            Ok(None)
146        }
147    }
148
149    async fn try_respond_show_statements<'a>(
150        &self,
151        query_lower: &str,
152    ) -> PgWireResult<Option<Vec<Response<'a>>>> {
153        if query_lower.starts_with("show ") {
154            match query_lower {
155                "show time zone" => {
156                    let timezone = self.timezone.lock().await.clone();
157                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
158                    Ok(Some(vec![Response::Query(resp)]))
159                }
160                "show server_version" => {
161                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
162                    Ok(Some(vec![Response::Query(resp)]))
163                }
164                "show transaction_isolation" => {
165                    let resp =
166                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
167                    Ok(Some(vec![Response::Query(resp)]))
168                }
169                "show catalogs" => {
170                    let catalogs = self.session_context.catalog_names();
171                    let value = catalogs.join(", ");
172                    let resp = Self::mock_show_response("Catalogs", &value)?;
173                    Ok(Some(vec![Response::Query(resp)]))
174                }
175                "show search_path" => {
176                    let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
177                    Ok(Some(vec![Response::Query(resp)]))
178                }
179                _ => Err(PgWireError::UserError(Box::new(
180                    pgwire::error::ErrorInfo::new(
181                        "ERROR".to_string(),
182                        "42704".to_string(),
183                        format!("Unrecognized SHOW command: {}", query_lower),
184                    ),
185                ))),
186            }
187        } else {
188            Ok(None)
189        }
190    }
191
192    async fn try_respond_information_schema<'a>(
193        &self,
194        query_lower: &str,
195    ) -> PgWireResult<Option<Vec<Response<'a>>>> {
196        if query_lower.contains("information_schema.schemata") {
197            let df = schemata_df(&self.session_context)
198                .await
199                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
200            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
201            return Ok(Some(vec![Response::Query(resp)]));
202        } else if query_lower.contains("information_schema.tables") {
203            let df = tables_df(&self.session_context)
204                .await
205                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
206            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
207            return Ok(Some(vec![Response::Query(resp)]));
208        } else if query_lower.contains("information_schema.columns") {
209            let df = columns_df(&self.session_context)
210                .await
211                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
212            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
213            return Ok(Some(vec![Response::Query(resp)]));
214        }
215
216        // Handle pg_catalog.pg_namespace for pgcli compatibility
217        if query_lower.contains("pg_catalog.pg_namespace") {
218            let resp = self.mock_pg_namespace().await?;
219            return Ok(Some(vec![Response::Query(resp)]));
220        }
221
222        Ok(None)
223    }
224}
225
226#[async_trait]
227impl SimpleQueryHandler for DfSessionService {
228    async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
229    where
230        C: ClientInfo + Unpin + Send + Sync,
231    {
232        let query_lower = query.to_lowercase().trim().to_string();
233        log::debug!("Received query: {}", query); // Log the query for debugging
234
235        if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? {
236            return Ok(resp);
237        }
238
239        if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
240            return Ok(resp);
241        }
242
243        if let Some(resp) = self.try_respond_information_schema(&query_lower).await? {
244            return Ok(resp);
245        }
246
247        let df = self
248            .session_context
249            .sql(query)
250            .await
251            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
252
253        if query_lower.starts_with("insert into") {
254            // For INSERT queries, we need to execute the query to get the row count
255            // and return an Execution response with the proper tag
256            let result = df
257                .clone()
258                .collect()
259                .await
260                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
261
262            // Extract count field from the first batch
263            let rows_affected = result
264                .first()
265                .and_then(|batch| batch.column_by_name("count"))
266                .and_then(|col| {
267                    col.as_any()
268                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
269                })
270                .map_or(0, |array| array.value(0) as usize);
271
272            // Create INSERT tag with the affected row count
273            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
274            Ok(vec![Response::Execution(tag)])
275        } else {
276            // For non-INSERT queries, return a regular Query response
277            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
278            Ok(vec![Response::Query(resp)])
279        }
280    }
281}
282
283#[async_trait]
284impl ExtendedQueryHandler for DfSessionService {
285    type Statement = LogicalPlan;
286    type QueryParser = Parser;
287
288    fn query_parser(&self) -> Arc<Self::QueryParser> {
289        self.parser.clone()
290    }
291
292    async fn do_describe_statement<C>(
293        &self,
294        _client: &mut C,
295        target: &StoredStatement<Self::Statement>,
296    ) -> PgWireResult<DescribeStatementResponse>
297    where
298        C: ClientInfo + Unpin + Send + Sync,
299    {
300        let plan = &target.statement;
301        let schema = plan.schema();
302        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
303        let params = plan
304            .get_parameter_types()
305            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
306
307        let mut param_types = Vec::with_capacity(params.len());
308        for param_type in ordered_param_types(&params).iter() {
309            // Fixed: Use &params
310            if let Some(datatype) = param_type {
311                let pgtype = datatypes::into_pg_type(datatype)?;
312                param_types.push(pgtype);
313            } else {
314                param_types.push(Type::UNKNOWN);
315            }
316        }
317
318        Ok(DescribeStatementResponse::new(param_types, fields))
319    }
320
321    async fn do_describe_portal<C>(
322        &self,
323        _client: &mut C,
324        target: &Portal<Self::Statement>,
325    ) -> PgWireResult<DescribePortalResponse>
326    where
327        C: ClientInfo + Unpin + Send + Sync,
328    {
329        let plan = &target.statement.statement;
330        let format = &target.result_column_format;
331        let schema = plan.schema();
332        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
333
334        Ok(DescribePortalResponse::new(fields))
335    }
336
337    async fn do_query<'a, C>(
338        &self,
339        _client: &mut C,
340        portal: &Portal<Self::Statement>,
341        _max_rows: usize,
342    ) -> PgWireResult<Response<'a>>
343    where
344        C: ClientInfo + Unpin + Send + Sync,
345    {
346        let query = portal
347            .statement
348            .statement
349            .to_string()
350            .to_lowercase()
351            .trim()
352            .to_string();
353        log::debug!("Received extended query: {}", query); // Log for debugging
354
355        if query.starts_with("show ") {
356            match query.as_str() {
357                "show time zone" => {
358                    let timezone = self.timezone.lock().await.clone();
359                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
360                    return Ok(Response::Query(resp));
361                }
362                "show server_version" => {
363                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
364                    return Ok(Response::Query(resp));
365                }
366                "show transaction_isolation" => {
367                    let resp =
368                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
369                    return Ok(Response::Query(resp));
370                }
371                "show catalogs" => {
372                    let catalogs = self.session_context.catalog_names();
373                    let value = catalogs.join(", ");
374                    let resp = Self::mock_show_response("Catalogs", &value)?;
375                    return Ok(Response::Query(resp));
376                }
377                "show search_path" => {
378                    let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
379                    return Ok(Response::Query(resp));
380                }
381                _ => {
382                    return Err(PgWireError::UserError(Box::new(
383                        pgwire::error::ErrorInfo::new(
384                            "ERROR".to_string(),
385                            "42704".to_string(),
386                            format!("Unrecognized SHOW command: {}", query),
387                        ),
388                    )));
389                }
390            }
391        }
392
393        if query.contains("information_schema.schemata") {
394            let df = schemata_df(&self.session_context)
395                .await
396                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
397            let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
398            return Ok(Response::Query(resp));
399        } else if query.contains("information_schema.tables") {
400            let df = tables_df(&self.session_context)
401                .await
402                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
403            let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
404            return Ok(Response::Query(resp));
405        } else if query.contains("information_schema.columns") {
406            let df = columns_df(&self.session_context)
407                .await
408                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
409            let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
410            return Ok(Response::Query(resp));
411        }
412
413        if query.contains("pg_catalog.pg_namespace") {
414            let resp = self.mock_pg_namespace().await?;
415            return Ok(Response::Query(resp));
416        }
417
418        let plan = &portal.statement.statement;
419        let param_types = plan
420            .get_parameter_types()
421            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
422        let param_values =
423            datatypes::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
424        let plan = plan
425            .clone()
426            .replace_params_with_values(&param_values)
427            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
428        let dataframe = self
429            .session_context
430            .execute_logical_plan(plan)
431            .await
432            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
433        let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
434        Ok(Response::Query(resp))
435    }
436}
437
438pub struct Parser {
439    session_context: Arc<SessionContext>,
440}
441
442#[async_trait]
443impl QueryParser for Parser {
444    type Statement = LogicalPlan;
445
446    async fn parse_sql<C>(
447        &self,
448        _client: &C,
449        sql: &str,
450        _types: &[Type],
451    ) -> PgWireResult<Self::Statement> {
452        let context = &self.session_context;
453        let state = context.state();
454        let logical_plan = state
455            .create_logical_plan(sql)
456            .await
457            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
458        let optimised = state
459            .optimize(&logical_plan)
460            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
461        Ok(optimised)
462    }
463}
464
465fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
466    // Datafusion stores the parameters as a map.  In our case, the keys will be
467    // `$1`, `$2` etc.  The values will be the parameter types.
468    let mut types = types.iter().collect::<Vec<_>>();
469    types.sort_by(|a, b| a.0.cmp(b.0));
470    types.into_iter().map(|pt| pt.1.as_ref()).collect()
471}