datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14    Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use tokio::sync::Mutex;
20
21use crate::datatypes;
22use crate::information_schema::{columns_df, schemata_df, tables_df};
23use pgwire::error::{PgWireError, PgWireResult};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30    type StartupHandler = DfSessionService;
31    type SimpleQueryHandler = DfSessionService;
32    type ExtendedQueryHandler = DfSessionService;
33    type CopyHandler = NoopCopyHandler;
34    type ErrorHandler = NoopErrorHandler;
35
36    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37        self.0.clone()
38    }
39
40    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41        self.0.clone()
42    }
43
44    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45        self.0.clone()
46    }
47
48    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49        Arc::new(NoopCopyHandler)
50    }
51
52    fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53        Arc::new(NoopErrorHandler)
54    }
55}
56
57pub struct DfSessionService {
58    session_context: Arc<SessionContext>,
59    parser: Arc<Parser>,
60    timezone: Arc<Mutex<String>>,
61}
62
63impl DfSessionService {
64    pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
65        let parser = Arc::new(Parser {
66            session_context: session_context.clone(),
67        });
68        DfSessionService {
69            session_context,
70            parser,
71            timezone: Arc::new(Mutex::new("UTC".to_string())),
72        }
73    }
74
75    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
76        let fields = vec![FieldInfo::new(
77            name.to_string(),
78            None,
79            None,
80            Type::VARCHAR,
81            FieldFormat::Text,
82        )];
83
84        let row = {
85            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
86            encoder.encode_field(&Some(value))?;
87            encoder.finish()
88        };
89
90        let row_stream = futures::stream::once(async move { row });
91        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
92    }
93
94    // Mock pg_namespace response
95    async fn mock_pg_namespace<'a>(&self) -> PgWireResult<QueryResponse<'a>> {
96        let fields = Arc::new(vec![FieldInfo::new(
97            "nspname".to_string(),
98            None,
99            None,
100            Type::VARCHAR,
101            FieldFormat::Text,
102        )]);
103
104        let fields_ref = fields.clone();
105        let rows = self
106            .session_context
107            .catalog_names()
108            .into_iter()
109            .map(move |name| {
110                let mut encoder = pgwire::api::results::DataRowEncoder::new(fields_ref.clone());
111                encoder.encode_field(&Some(&name))?; // Return catalog_name as a schema
112                encoder.finish()
113            });
114
115        let row_stream = futures::stream::iter(rows);
116        Ok(QueryResponse::new(fields.clone(), Box::pin(row_stream)))
117    }
118
119    async fn try_respond_set_time_zone<'a>(
120        &self,
121        query_lower: &str,
122    ) -> PgWireResult<Option<Response<'a>>> {
123        if query_lower.starts_with("set time zone") {
124            let parts: Vec<&str> = query_lower.split_whitespace().collect();
125            if parts.len() >= 4 {
126                let tz = parts[3].trim_matches('"');
127                let mut timezone = self.timezone.lock().await;
128                *timezone = tz.to_string();
129                Ok(Some(Response::Execution(Tag::new("SET"))))
130            } else {
131                Err(PgWireError::UserError(Box::new(
132                    pgwire::error::ErrorInfo::new(
133                        "ERROR".to_string(),
134                        "42601".to_string(),
135                        "Invalid SET TIME ZONE syntax".to_string(),
136                    ),
137                )))
138            }
139        } else {
140            Ok(None)
141        }
142    }
143
144    async fn try_respond_show_statements<'a>(
145        &self,
146        query_lower: &str,
147    ) -> PgWireResult<Option<Response<'a>>> {
148        if query_lower.starts_with("show ") {
149            match query_lower.strip_suffix(";").unwrap_or(query_lower) {
150                "show time zone" => {
151                    let timezone = self.timezone.lock().await.clone();
152                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
153                    Ok(Some(Response::Query(resp)))
154                }
155                "show server_version" => {
156                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
157                    Ok(Some(Response::Query(resp)))
158                }
159                "show transaction_isolation" => {
160                    let resp =
161                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
162                    Ok(Some(Response::Query(resp)))
163                }
164                "show catalogs" => {
165                    let catalogs = self.session_context.catalog_names();
166                    let value = catalogs.join(", ");
167                    let resp = Self::mock_show_response("Catalogs", &value)?;
168                    Ok(Some(Response::Query(resp)))
169                }
170                "show search_path" => {
171                    let default_catalog = "datafusion";
172                    let resp = Self::mock_show_response("search_path", default_catalog)?;
173                    Ok(Some(Response::Query(resp)))
174                }
175                _ => Err(PgWireError::UserError(Box::new(
176                    pgwire::error::ErrorInfo::new(
177                        "ERROR".to_string(),
178                        "42704".to_string(),
179                        format!("Unrecognized SHOW command: {}", query_lower),
180                    ),
181                ))),
182            }
183        } else {
184            Ok(None)
185        }
186    }
187
188    async fn try_respond_information_schema<'a>(
189        &self,
190        query_lower: &str,
191    ) -> PgWireResult<Option<Response<'a>>> {
192        if query_lower.contains("information_schema.schemata") {
193            let df = schemata_df(&self.session_context)
194                .await
195                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
196            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
197            return Ok(Some(Response::Query(resp)));
198        } else if query_lower.contains("information_schema.tables") {
199            let df = tables_df(&self.session_context)
200                .await
201                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
202            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
203            return Ok(Some(Response::Query(resp)));
204        } else if query_lower.contains("information_schema.columns") {
205            let df = columns_df(&self.session_context)
206                .await
207                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
208            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
209            return Ok(Some(Response::Query(resp)));
210        }
211
212        // Handle pg_catalog.pg_namespace for pgcli compatibility
213        if query_lower.contains("pg_catalog.pg_namespace") {
214            let resp = self.mock_pg_namespace().await?;
215            return Ok(Some(Response::Query(resp)));
216        }
217
218        Ok(None)
219    }
220}
221
222#[async_trait]
223impl SimpleQueryHandler for DfSessionService {
224    async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
225    where
226        C: ClientInfo + Unpin + Send + Sync,
227    {
228        let query_lower = query.to_lowercase().trim().to_string();
229        log::debug!("Received query: {}", query); // Log the query for debugging
230
231        if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? {
232            return Ok(vec![resp]);
233        }
234
235        if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
236            return Ok(vec![resp]);
237        }
238
239        if let Some(resp) = self.try_respond_information_schema(&query_lower).await? {
240            return Ok(vec![resp]);
241        }
242
243        let df = self
244            .session_context
245            .sql(query)
246            .await
247            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
248
249        if query_lower.starts_with("insert into") {
250            // For INSERT queries, we need to execute the query to get the row count
251            // and return an Execution response with the proper tag
252            let result = df
253                .clone()
254                .collect()
255                .await
256                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
257
258            // Extract count field from the first batch
259            let rows_affected = result
260                .first()
261                .and_then(|batch| batch.column_by_name("count"))
262                .and_then(|col| {
263                    col.as_any()
264                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
265                })
266                .map_or(0, |array| array.value(0) as usize);
267
268            // Create INSERT tag with the affected row count
269            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
270            Ok(vec![Response::Execution(tag)])
271        } else {
272            // For non-INSERT queries, return a regular Query response
273            let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
274            Ok(vec![Response::Query(resp)])
275        }
276    }
277}
278
279#[async_trait]
280impl ExtendedQueryHandler for DfSessionService {
281    type Statement = LogicalPlan;
282    type QueryParser = Parser;
283
284    fn query_parser(&self) -> Arc<Self::QueryParser> {
285        self.parser.clone()
286    }
287
288    async fn do_describe_statement<C>(
289        &self,
290        _client: &mut C,
291        target: &StoredStatement<Self::Statement>,
292    ) -> PgWireResult<DescribeStatementResponse>
293    where
294        C: ClientInfo + Unpin + Send + Sync,
295    {
296        let plan = &target.statement;
297        let schema = plan.schema();
298        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
299        let params = plan
300            .get_parameter_types()
301            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
302
303        let mut param_types = Vec::with_capacity(params.len());
304        for param_type in ordered_param_types(&params).iter() {
305            // Fixed: Use &params
306            if let Some(datatype) = param_type {
307                let pgtype = datatypes::into_pg_type(datatype)?;
308                param_types.push(pgtype);
309            } else {
310                param_types.push(Type::UNKNOWN);
311            }
312        }
313
314        Ok(DescribeStatementResponse::new(param_types, fields))
315    }
316
317    async fn do_describe_portal<C>(
318        &self,
319        _client: &mut C,
320        target: &Portal<Self::Statement>,
321    ) -> PgWireResult<DescribePortalResponse>
322    where
323        C: ClientInfo + Unpin + Send + Sync,
324    {
325        let plan = &target.statement.statement;
326        let format = &target.result_column_format;
327        let schema = plan.schema();
328        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
329
330        Ok(DescribePortalResponse::new(fields))
331    }
332
333    async fn do_query<'a, C>(
334        &self,
335        _client: &mut C,
336        portal: &Portal<Self::Statement>,
337        _max_rows: usize,
338    ) -> PgWireResult<Response<'a>>
339    where
340        C: ClientInfo + Unpin + Send + Sync,
341    {
342        let query = portal
343            .statement
344            .statement
345            .to_string()
346            .to_lowercase()
347            .trim()
348            .to_string();
349        log::debug!("Received extended query: {}", query); // Log for debugging
350
351        if let Some(resp) = self.try_respond_show_statements(&query).await? {
352            return Ok(resp);
353        }
354
355        if let Some(resp) = self.try_respond_information_schema(&query).await? {
356            return Ok(resp);
357        }
358
359        let plan = &portal.statement.statement;
360        let param_types = plan
361            .get_parameter_types()
362            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
363        let param_values =
364            datatypes::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
365        let plan = plan
366            .clone()
367            .replace_params_with_values(&param_values)
368            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
369        let dataframe = self
370            .session_context
371            .execute_logical_plan(plan)
372            .await
373            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
374        let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
375        Ok(Response::Query(resp))
376    }
377}
378
379pub struct Parser {
380    session_context: Arc<SessionContext>,
381}
382
383#[async_trait]
384impl QueryParser for Parser {
385    type Statement = LogicalPlan;
386
387    async fn parse_sql<C>(
388        &self,
389        _client: &C,
390        sql: &str,
391        _types: &[Type],
392    ) -> PgWireResult<Self::Statement> {
393        let context = &self.session_context;
394        let state = context.state();
395        let logical_plan = state
396            .create_logical_plan(sql)
397            .await
398            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
399        let optimised = state
400            .optimize(&logical_plan)
401            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402        Ok(optimised)
403    }
404}
405
406fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
407    // Datafusion stores the parameters as a map.  In our case, the keys will be
408    // `$1`, `$2` etc.  The values will be the parameter types.
409    let mut types = types.iter().collect::<Vec<_>>();
410    types.sort_by(|a, b| a.0.cmp(b.0));
411    types.into_iter().map(|pt| pt.1.as_ref()).collect()
412}