1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14 Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use pgwire::error::{PgWireError, PgWireResult};
20use tokio::sync::Mutex;
21
22use crate::datatypes;
23use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30 type StartupHandler = DfSessionService;
31 type SimpleQueryHandler = DfSessionService;
32 type ExtendedQueryHandler = DfSessionService;
33 type CopyHandler = NoopCopyHandler;
34 type ErrorHandler = NoopErrorHandler;
35
36 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37 self.0.clone()
38 }
39
40 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41 self.0.clone()
42 }
43
44 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45 self.0.clone()
46 }
47
48 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49 Arc::new(NoopCopyHandler)
50 }
51
52 fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53 Arc::new(NoopErrorHandler)
54 }
55}
56
57pub struct DfSessionService {
58 session_context: Arc<SessionContext>,
59 parser: Arc<Parser>,
60 timezone: Arc<Mutex<String>>,
61}
62
63impl DfSessionService {
64 pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
65 let parser = Arc::new(Parser {
66 session_context: session_context.clone(),
67 });
68 DfSessionService {
69 session_context,
70 parser,
71 timezone: Arc::new(Mutex::new("UTC".to_string())),
72 }
73 }
74
75 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
76 let fields = vec![FieldInfo::new(
77 name.to_string(),
78 None,
79 None,
80 Type::VARCHAR,
81 FieldFormat::Text,
82 )];
83
84 let row = {
85 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
86 encoder.encode_field(&Some(value))?;
87 encoder.finish()
88 };
89
90 let row_stream = futures::stream::once(async move { row });
91 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
92 }
93
94 async fn try_respond_set_statements<'a>(
95 &self,
96 query_lower: &str,
97 ) -> PgWireResult<Option<Response<'a>>> {
98 if query_lower.starts_with("set") {
99 if query_lower.starts_with("set time zone") {
100 let parts: Vec<&str> = query_lower.split_whitespace().collect();
101 if parts.len() >= 4 {
102 let tz = parts[3].trim_matches('"');
103 let mut timezone = self.timezone.lock().await;
104 *timezone = tz.to_string();
105 Ok(Some(Response::Execution(Tag::new("SET"))))
106 } else {
107 Err(PgWireError::UserError(Box::new(
108 pgwire::error::ErrorInfo::new(
109 "ERROR".to_string(),
110 "42601".to_string(),
111 "Invalid SET TIME ZONE syntax".to_string(),
112 ),
113 )))
114 }
115 } else {
116 Ok(Some(Response::Execution(Tag::new("SET"))))
118 }
119 } else {
120 Ok(None)
121 }
122 }
123
124 async fn try_respond_show_statements<'a>(
125 &self,
126 query_lower: &str,
127 ) -> PgWireResult<Option<Response<'a>>> {
128 if query_lower.starts_with("show ") {
129 match query_lower.strip_suffix(";").unwrap_or(query_lower) {
130 "show time zone" => {
131 let timezone = self.timezone.lock().await.clone();
132 let resp = Self::mock_show_response("TimeZone", &timezone)?;
133 Ok(Some(Response::Query(resp)))
134 }
135 "show server_version" => {
136 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
137 Ok(Some(Response::Query(resp)))
138 }
139 "show transaction_isolation" => {
140 let resp =
141 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
142 Ok(Some(Response::Query(resp)))
143 }
144 "show catalogs" => {
145 let catalogs = self.session_context.catalog_names();
146 let value = catalogs.join(", ");
147 let resp = Self::mock_show_response("Catalogs", &value)?;
148 Ok(Some(Response::Query(resp)))
149 }
150 "show search_path" => {
151 let default_catalog = "datafusion";
152 let resp = Self::mock_show_response("search_path", default_catalog)?;
153 Ok(Some(Response::Query(resp)))
154 }
155 _ => Err(PgWireError::UserError(Box::new(
156 pgwire::error::ErrorInfo::new(
157 "ERROR".to_string(),
158 "42704".to_string(),
159 format!("Unrecognized SHOW command: {}", query_lower),
160 ),
161 ))),
162 }
163 } else {
164 Ok(None)
165 }
166 }
167}
168
169#[async_trait]
170impl SimpleQueryHandler for DfSessionService {
171 async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
172 where
173 C: ClientInfo + Unpin + Send + Sync,
174 {
175 let query_lower = query.to_lowercase().trim().to_string();
176 log::debug!("Received query: {}", query); if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
179 return Ok(vec![resp]);
180 }
181
182 if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
183 return Ok(vec![resp]);
184 }
185
186 let df = self
187 .session_context
188 .sql(query)
189 .await
190 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
191
192 if query_lower.starts_with("insert into") {
193 let result = df
196 .clone()
197 .collect()
198 .await
199 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
200
201 let rows_affected = result
203 .first()
204 .and_then(|batch| batch.column_by_name("count"))
205 .and_then(|col| {
206 col.as_any()
207 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
208 })
209 .map_or(0, |array| array.value(0) as usize);
210
211 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
213 Ok(vec![Response::Execution(tag)])
214 } else {
215 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
217 Ok(vec![Response::Query(resp)])
218 }
219 }
220}
221
222#[async_trait]
223impl ExtendedQueryHandler for DfSessionService {
224 type Statement = (String, LogicalPlan);
225 type QueryParser = Parser;
226
227 fn query_parser(&self) -> Arc<Self::QueryParser> {
228 self.parser.clone()
229 }
230
231 async fn do_describe_statement<C>(
232 &self,
233 _client: &mut C,
234 target: &StoredStatement<Self::Statement>,
235 ) -> PgWireResult<DescribeStatementResponse>
236 where
237 C: ClientInfo + Unpin + Send + Sync,
238 {
239 let (_, plan) = &target.statement;
240 let schema = plan.schema();
241 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
242 let params = plan
243 .get_parameter_types()
244 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
245
246 let mut param_types = Vec::with_capacity(params.len());
247 for param_type in ordered_param_types(¶ms).iter() {
248 if let Some(datatype) = param_type {
250 let pgtype = into_pg_type(datatype)?;
251 param_types.push(pgtype);
252 } else {
253 param_types.push(Type::UNKNOWN);
254 }
255 }
256
257 Ok(DescribeStatementResponse::new(param_types, fields))
258 }
259
260 async fn do_describe_portal<C>(
261 &self,
262 _client: &mut C,
263 target: &Portal<Self::Statement>,
264 ) -> PgWireResult<DescribePortalResponse>
265 where
266 C: ClientInfo + Unpin + Send + Sync,
267 {
268 let (_, plan) = &target.statement.statement;
269 let format = &target.result_column_format;
270 let schema = plan.schema();
271 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
272
273 Ok(DescribePortalResponse::new(fields))
274 }
275
276 async fn do_query<'a, C>(
277 &self,
278 _client: &mut C,
279 portal: &Portal<Self::Statement>,
280 _max_rows: usize,
281 ) -> PgWireResult<Response<'a>>
282 where
283 C: ClientInfo + Unpin + Send + Sync,
284 {
285 let query = portal
286 .statement
287 .statement
288 .0
289 .to_lowercase()
290 .trim()
291 .to_string();
292 log::debug!("Received execute extended query: {}", query); if let Some(resp) = self.try_respond_set_statements(&query).await? {
295 return Ok(resp);
296 }
297
298 if let Some(resp) = self.try_respond_show_statements(&query).await? {
299 return Ok(resp);
300 }
301
302 let (_, plan) = &portal.statement.statement;
303
304 let param_types = plan
305 .get_parameter_types()
306 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
307 let param_values =
308 datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
310 .clone()
311 .replace_params_with_values(¶m_values)
312 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = self
314 .session_context
315 .execute_logical_plan(plan)
316 .await
317 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
318 let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
319 Ok(Response::Query(resp))
320 }
321}
322
323pub struct Parser {
324 session_context: Arc<SessionContext>,
325}
326
327#[async_trait]
328impl QueryParser for Parser {
329 type Statement = (String, LogicalPlan);
330
331 async fn parse_sql<C>(
332 &self,
333 _client: &C,
334 sql: &str,
335 _types: &[Type],
336 ) -> PgWireResult<Self::Statement> {
337 log::debug!("Received parse extended query: {}", sql); let context = &self.session_context;
339 let state = context.state();
340 let logical_plan = state
341 .create_logical_plan(sql)
342 .await
343 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
344 let optimised = state
345 .optimize(&logical_plan)
346 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
347 Ok((sql.to_string(), optimised))
348 }
349}
350
351fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
352 let mut types = types.iter().collect::<Vec<_>>();
355 types.sort_by(|a, b| a.0.cmp(b.0));
356 types.into_iter().map(|pt| pt.1.as_ref()).collect()
357}