1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14 Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use tokio::sync::Mutex;
20
21use crate::datatypes;
22use crate::information_schema::{columns_df, schemata_df, tables_df};
23use pgwire::error::{PgWireError, PgWireResult};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30 type StartupHandler = DfSessionService;
31 type SimpleQueryHandler = DfSessionService;
32 type ExtendedQueryHandler = DfSessionService;
33 type CopyHandler = NoopCopyHandler;
34 type ErrorHandler = NoopErrorHandler;
35
36 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37 self.0.clone()
38 }
39
40 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41 self.0.clone()
42 }
43
44 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45 self.0.clone()
46 }
47
48 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49 Arc::new(NoopCopyHandler)
50 }
51
52 fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53 Arc::new(NoopErrorHandler)
54 }
55}
56
57pub struct DfSessionService {
58 session_context: Arc<SessionContext>,
59 parser: Arc<Parser>,
60 timezone: Arc<Mutex<String>>,
61 catalog_name: String,
62}
63
64impl DfSessionService {
65 pub fn new(session_context: SessionContext, catalog_name: Option<String>) -> DfSessionService {
66 let session_context = Arc::new(session_context);
67 let parser = Arc::new(Parser {
68 session_context: session_context.clone(),
69 });
70 let catalog_name = catalog_name.unwrap_or_else(|| {
71 session_context
72 .catalog_names()
73 .first()
74 .cloned()
75 .unwrap_or_else(|| "datafusion".to_string())
76 });
77 DfSessionService {
78 session_context,
79 parser,
80 timezone: Arc::new(Mutex::new("UTC".to_string())),
81 catalog_name,
82 }
83 }
84
85 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
86 let fields = vec![FieldInfo::new(
87 name.to_string(),
88 None,
89 None,
90 Type::VARCHAR,
91 FieldFormat::Text,
92 )];
93
94 let row = {
95 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
96 encoder.encode_field(&Some(value))?;
97 encoder.finish()
98 };
99
100 let row_stream = futures::stream::once(async move { row });
101 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
102 }
103
104 async fn mock_pg_namespace<'a>(&self) -> PgWireResult<QueryResponse<'a>> {
106 let fields = vec![FieldInfo::new(
107 "nspname".to_string(),
108 None,
109 None,
110 Type::VARCHAR,
111 FieldFormat::Text,
112 )];
113
114 let row = {
115 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
116 encoder.encode_field(&Some(&self.catalog_name))?; encoder.finish()
118 };
119
120 let row_stream = futures::stream::once(async move { row });
121 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
122 }
123
124 async fn try_respond_set_time_zone<'a>(
125 &self,
126 query_lower: &str,
127 ) -> PgWireResult<Option<Vec<Response<'a>>>> {
128 if query_lower.starts_with("set time zone") {
129 let parts: Vec<&str> = query_lower.split_whitespace().collect();
130 if parts.len() >= 4 {
131 let tz = parts[3].trim_matches('"');
132 let mut timezone = self.timezone.lock().await;
133 *timezone = tz.to_string();
134 Ok(Some(vec![Response::Execution(Tag::new("SET"))]))
135 } else {
136 Err(PgWireError::UserError(Box::new(
137 pgwire::error::ErrorInfo::new(
138 "ERROR".to_string(),
139 "42601".to_string(),
140 "Invalid SET TIME ZONE syntax".to_string(),
141 ),
142 )))
143 }
144 } else {
145 Ok(None)
146 }
147 }
148
149 async fn try_respond_show_statements<'a>(
150 &self,
151 query_lower: &str,
152 ) -> PgWireResult<Option<Vec<Response<'a>>>> {
153 if query_lower.starts_with("show ") {
154 match query_lower {
155 "show time zone" => {
156 let timezone = self.timezone.lock().await.clone();
157 let resp = Self::mock_show_response("TimeZone", &timezone)?;
158 Ok(Some(vec![Response::Query(resp)]))
159 }
160 "show server_version" => {
161 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
162 Ok(Some(vec![Response::Query(resp)]))
163 }
164 "show transaction_isolation" => {
165 let resp =
166 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
167 Ok(Some(vec![Response::Query(resp)]))
168 }
169 "show catalogs" => {
170 let catalogs = self.session_context.catalog_names();
171 let value = catalogs.join(", ");
172 let resp = Self::mock_show_response("Catalogs", &value)?;
173 Ok(Some(vec![Response::Query(resp)]))
174 }
175 "show search_path" => {
176 let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
177 Ok(Some(vec![Response::Query(resp)]))
178 }
179 _ => Err(PgWireError::UserError(Box::new(
180 pgwire::error::ErrorInfo::new(
181 "ERROR".to_string(),
182 "42704".to_string(),
183 format!("Unrecognized SHOW command: {}", query_lower),
184 ),
185 ))),
186 }
187 } else {
188 Ok(None)
189 }
190 }
191
192 async fn try_respond_information_schema<'a>(
193 &self,
194 query_lower: &str,
195 ) -> PgWireResult<Option<Vec<Response<'a>>>> {
196 if query_lower.contains("information_schema.schemata") {
197 let df = schemata_df(&self.session_context)
198 .await
199 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
200 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
201 return Ok(Some(vec![Response::Query(resp)]));
202 } else if query_lower.contains("information_schema.tables") {
203 let df = tables_df(&self.session_context)
204 .await
205 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
206 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
207 return Ok(Some(vec![Response::Query(resp)]));
208 } else if query_lower.contains("information_schema.columns") {
209 let df = columns_df(&self.session_context)
210 .await
211 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
212 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
213 return Ok(Some(vec![Response::Query(resp)]));
214 }
215
216 if query_lower.contains("pg_catalog.pg_namespace") {
218 let resp = self.mock_pg_namespace().await?;
219 return Ok(Some(vec![Response::Query(resp)]));
220 }
221
222 Ok(None)
223 }
224}
225
226#[async_trait]
227impl SimpleQueryHandler for DfSessionService {
228 async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
229 where
230 C: ClientInfo + Unpin + Send + Sync,
231 {
232 let query_lower = query.to_lowercase().trim().to_string();
233 log::debug!("Received query: {}", query); if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? {
236 return Ok(resp);
237 }
238
239 if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
240 return Ok(resp);
241 }
242
243 if let Some(resp) = self.try_respond_information_schema(&query_lower).await? {
244 return Ok(resp);
245 }
246
247 let df = self
248 .session_context
249 .sql(query)
250 .await
251 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
252
253 if query_lower.starts_with("insert into") {
254 let result = df
257 .clone()
258 .collect()
259 .await
260 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
261
262 let rows_affected = result
264 .first()
265 .and_then(|batch| batch.column_by_name("count"))
266 .and_then(|col| {
267 col.as_any()
268 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
269 })
270 .map_or(0, |array| array.value(0) as usize);
271
272 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
274 Ok(vec![Response::Execution(tag)])
275 } else {
276 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
278 Ok(vec![Response::Query(resp)])
279 }
280 }
281}
282
283#[async_trait]
284impl ExtendedQueryHandler for DfSessionService {
285 type Statement = LogicalPlan;
286 type QueryParser = Parser;
287
288 fn query_parser(&self) -> Arc<Self::QueryParser> {
289 self.parser.clone()
290 }
291
292 async fn do_describe_statement<C>(
293 &self,
294 _client: &mut C,
295 target: &StoredStatement<Self::Statement>,
296 ) -> PgWireResult<DescribeStatementResponse>
297 where
298 C: ClientInfo + Unpin + Send + Sync,
299 {
300 let plan = &target.statement;
301 let schema = plan.schema();
302 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
303 let params = plan
304 .get_parameter_types()
305 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
306
307 let mut param_types = Vec::with_capacity(params.len());
308 for param_type in ordered_param_types(¶ms).iter() {
309 if let Some(datatype) = param_type {
311 let pgtype = datatypes::into_pg_type(datatype)?;
312 param_types.push(pgtype);
313 } else {
314 param_types.push(Type::UNKNOWN);
315 }
316 }
317
318 Ok(DescribeStatementResponse::new(param_types, fields))
319 }
320
321 async fn do_describe_portal<C>(
322 &self,
323 _client: &mut C,
324 target: &Portal<Self::Statement>,
325 ) -> PgWireResult<DescribePortalResponse>
326 where
327 C: ClientInfo + Unpin + Send + Sync,
328 {
329 let plan = &target.statement.statement;
330 let format = &target.result_column_format;
331 let schema = plan.schema();
332 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
333
334 Ok(DescribePortalResponse::new(fields))
335 }
336
337 async fn do_query<'a, C>(
338 &self,
339 _client: &mut C,
340 portal: &Portal<Self::Statement>,
341 _max_rows: usize,
342 ) -> PgWireResult<Response<'a>>
343 where
344 C: ClientInfo + Unpin + Send + Sync,
345 {
346 let query = portal
347 .statement
348 .statement
349 .to_string()
350 .to_lowercase()
351 .trim()
352 .to_string();
353 log::debug!("Received extended query: {}", query); if query.starts_with("show ") {
356 match query.as_str() {
357 "show time zone" => {
358 let timezone = self.timezone.lock().await.clone();
359 let resp = Self::mock_show_response("TimeZone", &timezone)?;
360 return Ok(Response::Query(resp));
361 }
362 "show server_version" => {
363 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
364 return Ok(Response::Query(resp));
365 }
366 "show transaction_isolation" => {
367 let resp =
368 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
369 return Ok(Response::Query(resp));
370 }
371 "show catalogs" => {
372 let catalogs = self.session_context.catalog_names();
373 let value = catalogs.join(", ");
374 let resp = Self::mock_show_response("Catalogs", &value)?;
375 return Ok(Response::Query(resp));
376 }
377 "show search_path" => {
378 let resp = Self::mock_show_response("search_path", &self.catalog_name)?;
379 return Ok(Response::Query(resp));
380 }
381 _ => {
382 return Err(PgWireError::UserError(Box::new(
383 pgwire::error::ErrorInfo::new(
384 "ERROR".to_string(),
385 "42704".to_string(),
386 format!("Unrecognized SHOW command: {}", query),
387 ),
388 )));
389 }
390 }
391 }
392
393 if query.contains("information_schema.schemata") {
394 let df = schemata_df(&self.session_context)
395 .await
396 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
397 let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
398 return Ok(Response::Query(resp));
399 } else if query.contains("information_schema.tables") {
400 let df = tables_df(&self.session_context)
401 .await
402 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
403 let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
404 return Ok(Response::Query(resp));
405 } else if query.contains("information_schema.columns") {
406 let df = columns_df(&self.session_context)
407 .await
408 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
409 let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?;
410 return Ok(Response::Query(resp));
411 }
412
413 if query.contains("pg_catalog.pg_namespace") {
414 let resp = self.mock_pg_namespace().await?;
415 return Ok(Response::Query(resp));
416 }
417
418 let plan = &portal.statement.statement;
419 let param_types = plan
420 .get_parameter_types()
421 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
422 let param_values =
423 datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
425 .clone()
426 .replace_params_with_values(¶m_values)
427 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = self
429 .session_context
430 .execute_logical_plan(plan)
431 .await
432 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
433 let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
434 Ok(Response::Query(resp))
435 }
436}
437
438pub struct Parser {
439 session_context: Arc<SessionContext>,
440}
441
442#[async_trait]
443impl QueryParser for Parser {
444 type Statement = LogicalPlan;
445
446 async fn parse_sql<C>(
447 &self,
448 _client: &C,
449 sql: &str,
450 _types: &[Type],
451 ) -> PgWireResult<Self::Statement> {
452 let context = &self.session_context;
453 let state = context.state();
454 let logical_plan = state
455 .create_logical_plan(sql)
456 .await
457 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
458 let optimised = state
459 .optimize(&logical_plan)
460 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
461 Ok(optimised)
462 }
463}
464
465fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
466 let mut types = types.iter().collect::<Vec<_>>();
469 types.sort_by(|a, b| a.0.cmp(b.0));
470 types.into_iter().map(|pt| pt.1.as_ref()).collect()
471}