datafusion_postgres/
handlers.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::DataType;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::*;
7use pgwire::api::auth::noop::NoopStartupHandler;
8use pgwire::api::copy::NoopCopyHandler;
9use pgwire::api::portal::{Format, Portal};
10use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
11use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response};
12use pgwire::api::stmt::QueryParser;
13use pgwire::api::stmt::StoredStatement;
14use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
15use pgwire::error::{PgWireError, PgWireResult};
16
17use crate::datatypes::{self, into_pg_type};
18
19pub struct HandlerFactory(pub Arc<DfSessionService>);
20
21impl PgWireHandlerFactory for HandlerFactory {
22    type StartupHandler = NoopStartupHandler;
23    type SimpleQueryHandler = DfSessionService;
24    type ExtendedQueryHandler = DfSessionService;
25    type CopyHandler = NoopCopyHandler;
26
27    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
28        self.0.clone()
29    }
30
31    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
32        self.0.clone()
33    }
34
35    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
36        Arc::new(NoopStartupHandler)
37    }
38
39    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
40        Arc::new(NoopCopyHandler)
41    }
42}
43
44pub struct DfSessionService {
45    session_context: Arc<SessionContext>,
46    parser: Arc<Parser>,
47}
48
49impl DfSessionService {
50    pub fn new(session_context: SessionContext) -> DfSessionService {
51        let session_context = Arc::new(session_context);
52        let parser = Arc::new(Parser {
53            session_context: session_context.clone(),
54        });
55        DfSessionService {
56            session_context,
57            parser,
58        }
59    }
60}
61
62#[async_trait]
63impl SimpleQueryHandler for DfSessionService {
64    async fn do_query<'a, C>(
65        &self,
66        _client: &mut C,
67        query: &'a str,
68    ) -> PgWireResult<Vec<Response<'a>>>
69    where
70        C: ClientInfo + Unpin + Send + Sync,
71    {
72        let ctx = &self.session_context;
73        let df = ctx
74            .sql(query)
75            .await
76            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
77
78        let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
79        Ok(vec![Response::Query(resp)])
80    }
81}
82
83pub struct Parser {
84    session_context: Arc<SessionContext>,
85}
86
87#[async_trait]
88impl QueryParser for Parser {
89    type Statement = LogicalPlan;
90
91    async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
92        let context = &self.session_context;
93        let state = context.state();
94
95        let logical_plan = state
96            .create_logical_plan(sql)
97            .await
98            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
99        let optimised = state
100            .optimize(&logical_plan)
101            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
102
103        Ok(optimised)
104    }
105}
106
107#[async_trait]
108impl ExtendedQueryHandler for DfSessionService {
109    type Statement = LogicalPlan;
110
111    type QueryParser = Parser;
112
113    fn query_parser(&self) -> Arc<Self::QueryParser> {
114        self.parser.clone()
115    }
116
117    async fn do_describe_statement<C>(
118        &self,
119        _client: &mut C,
120        target: &StoredStatement<Self::Statement>,
121    ) -> PgWireResult<DescribeStatementResponse>
122    where
123        C: ClientInfo + Unpin + Send + Sync,
124    {
125        let plan = &target.statement;
126
127        let schema = plan.schema();
128        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
129        let params = plan
130            .get_parameter_types()
131            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
132
133        let mut param_types = Vec::with_capacity(params.len());
134        for param_type in params.into_values() {
135            if let Some(datatype) = param_type {
136                let pgtype = into_pg_type(&datatype)?;
137                param_types.push(pgtype);
138            } else {
139                param_types.push(Type::UNKNOWN);
140            }
141        }
142
143        Ok(DescribeStatementResponse::new(param_types, fields))
144    }
145
146    async fn do_describe_portal<C>(
147        &self,
148        _client: &mut C,
149        target: &Portal<Self::Statement>,
150    ) -> PgWireResult<DescribePortalResponse>
151    where
152        C: ClientInfo + Unpin + Send + Sync,
153    {
154        let plan = &target.statement.statement;
155        let format = &target.result_column_format;
156        let schema = plan.schema();
157        let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
158
159        Ok(DescribePortalResponse::new(fields))
160    }
161
162    async fn do_query<'a, C>(
163        &self,
164        _client: &mut C,
165        portal: &'a Portal<Self::Statement>,
166        _max_rows: usize,
167    ) -> PgWireResult<Response<'a>>
168    where
169        C: ClientInfo + Unpin + Send + Sync,
170    {
171        let plan = &portal.statement.statement;
172
173        let param_values = datatypes::deserialize_parameters(
174            portal,
175            &plan
176                .get_parameter_types()
177                .map_err(|e| PgWireError::ApiError(Box::new(e)))?
178                .values()
179                .map(|v| v.as_ref())
180                .collect::<Vec<Option<&DataType>>>(),
181        )?;
182
183        let plan = plan
184            .clone()
185            .replace_params_with_values(&param_values)
186            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
187
188        let dataframe = self
189            .session_context
190            .execute_logical_plan(plan)
191            .await
192            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
193
194        let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
195        Ok(Response::Query(resp))
196    }
197}