datafusion_postgres/
handlers.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::DataType;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::*;
7use pgwire::api::auth::noop::NoopStartupHandler;
8use pgwire::api::copy::NoopCopyHandler;
9use pgwire::api::portal::{Format, Portal};
10use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
11use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response};
12use pgwire::api::stmt::QueryParser;
13use pgwire::api::stmt::StoredStatement;
14use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
15use pgwire::error::{PgWireError, PgWireResult};
16
17use crate::datatypes::{self, into_pg_type};
18
19pub struct HandlerFactory(pub Arc<DfSessionService>);
20
21impl PgWireHandlerFactory for HandlerFactory {
22 type StartupHandler = NoopStartupHandler;
23 type SimpleQueryHandler = DfSessionService;
24 type ExtendedQueryHandler = DfSessionService;
25 type CopyHandler = NoopCopyHandler;
26
27 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
28 self.0.clone()
29 }
30
31 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
32 self.0.clone()
33 }
34
35 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
36 Arc::new(NoopStartupHandler)
37 }
38
39 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
40 Arc::new(NoopCopyHandler)
41 }
42}
43
44pub struct DfSessionService {
45 session_context: Arc<SessionContext>,
46 parser: Arc<Parser>,
47}
48
49impl DfSessionService {
50 pub fn new(session_context: SessionContext) -> DfSessionService {
51 let session_context = Arc::new(session_context);
52 let parser = Arc::new(Parser {
53 session_context: session_context.clone(),
54 });
55 DfSessionService {
56 session_context,
57 parser,
58 }
59 }
60}
61
62#[async_trait]
63impl SimpleQueryHandler for DfSessionService {
64 async fn do_query<'a, C>(
65 &self,
66 _client: &mut C,
67 query: &'a str,
68 ) -> PgWireResult<Vec<Response<'a>>>
69 where
70 C: ClientInfo + Unpin + Send + Sync,
71 {
72 let ctx = &self.session_context;
73 let df = ctx
74 .sql(query)
75 .await
76 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
77
78 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
79 Ok(vec![Response::Query(resp)])
80 }
81}
82
83pub struct Parser {
84 session_context: Arc<SessionContext>,
85}
86
87#[async_trait]
88impl QueryParser for Parser {
89 type Statement = LogicalPlan;
90
91 async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
92 let context = &self.session_context;
93 let state = context.state();
94
95 let logical_plan = state
96 .create_logical_plan(sql)
97 .await
98 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
99 let optimised = state
100 .optimize(&logical_plan)
101 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
102
103 Ok(optimised)
104 }
105}
106
107#[async_trait]
108impl ExtendedQueryHandler for DfSessionService {
109 type Statement = LogicalPlan;
110
111 type QueryParser = Parser;
112
113 fn query_parser(&self) -> Arc<Self::QueryParser> {
114 self.parser.clone()
115 }
116
117 async fn do_describe_statement<C>(
118 &self,
119 _client: &mut C,
120 target: &StoredStatement<Self::Statement>,
121 ) -> PgWireResult<DescribeStatementResponse>
122 where
123 C: ClientInfo + Unpin + Send + Sync,
124 {
125 let plan = &target.statement;
126
127 let schema = plan.schema();
128 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
129 let params = plan
130 .get_parameter_types()
131 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
132
133 let mut param_types = Vec::with_capacity(params.len());
134 for param_type in params.into_values() {
135 if let Some(datatype) = param_type {
136 let pgtype = into_pg_type(&datatype)?;
137 param_types.push(pgtype);
138 } else {
139 param_types.push(Type::UNKNOWN);
140 }
141 }
142
143 Ok(DescribeStatementResponse::new(param_types, fields))
144 }
145
146 async fn do_describe_portal<C>(
147 &self,
148 _client: &mut C,
149 target: &Portal<Self::Statement>,
150 ) -> PgWireResult<DescribePortalResponse>
151 where
152 C: ClientInfo + Unpin + Send + Sync,
153 {
154 let plan = &target.statement.statement;
155 let format = &target.result_column_format;
156 let schema = plan.schema();
157 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
158
159 Ok(DescribePortalResponse::new(fields))
160 }
161
162 async fn do_query<'a, C>(
163 &self,
164 _client: &mut C,
165 portal: &'a Portal<Self::Statement>,
166 _max_rows: usize,
167 ) -> PgWireResult<Response<'a>>
168 where
169 C: ClientInfo + Unpin + Send + Sync,
170 {
171 let plan = &portal.statement.statement;
172
173 let param_values = datatypes::deserialize_parameters(
174 portal,
175 &plan
176 .get_parameter_types()
177 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
178 .values()
179 .map(|v| v.as_ref())
180 .collect::<Vec<Option<&DataType>>>(),
181 )?;
182
183 let plan = plan
184 .clone()
185 .replace_params_with_values(¶m_values)
186 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
187
188 let dataframe = self
189 .session_context
190 .execute_logical_plan(plan)
191 .await
192 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
193
194 let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
195 Ok(Response::Query(resp))
196 }
197}