1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::DataType;
6use datafusion::logical_expr::LogicalPlan;
7use datafusion::prelude::*;
8use pgwire::api::auth::noop::NoopStartupHandler;
9use pgwire::api::copy::NoopCopyHandler;
10use pgwire::api::portal::{Format, Portal};
11use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12use pgwire::api::results::{
13 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
14 Response, Tag,
15};
16use pgwire::api::stmt::QueryParser;
17use pgwire::api::stmt::StoredStatement;
18use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
19use tokio::sync::Mutex;
20
21use crate::datatypes;
22use crate::information_schema::{columns_df, schemata_df, tables_df};
23use pgwire::error::{PgWireError, PgWireResult};
24
25pub struct HandlerFactory(pub Arc<DfSessionService>);
26
27impl NoopStartupHandler for DfSessionService {}
28
29impl PgWireServerHandlers for HandlerFactory {
30 type StartupHandler = DfSessionService;
31 type SimpleQueryHandler = DfSessionService;
32 type ExtendedQueryHandler = DfSessionService;
33 type CopyHandler = NoopCopyHandler;
34 type ErrorHandler = NoopErrorHandler;
35
36 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
37 self.0.clone()
38 }
39
40 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
41 self.0.clone()
42 }
43
44 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
45 self.0.clone()
46 }
47
48 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
49 Arc::new(NoopCopyHandler)
50 }
51
52 fn error_handler(&self) -> Arc<Self::ErrorHandler> {
53 Arc::new(NoopErrorHandler)
54 }
55}
56
57pub struct DfSessionService {
58 session_context: Arc<SessionContext>,
59 parser: Arc<Parser>,
60 timezone: Arc<Mutex<String>>,
61}
62
63impl DfSessionService {
64 pub fn new(session_context: Arc<SessionContext>) -> DfSessionService {
65 let parser = Arc::new(Parser {
66 session_context: session_context.clone(),
67 });
68 DfSessionService {
69 session_context,
70 parser,
71 timezone: Arc::new(Mutex::new("UTC".to_string())),
72 }
73 }
74
75 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
76 let fields = vec![FieldInfo::new(
77 name.to_string(),
78 None,
79 None,
80 Type::VARCHAR,
81 FieldFormat::Text,
82 )];
83
84 let row = {
85 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
86 encoder.encode_field(&Some(value))?;
87 encoder.finish()
88 };
89
90 let row_stream = futures::stream::once(async move { row });
91 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
92 }
93
94 async fn mock_pg_namespace<'a>(&self) -> PgWireResult<QueryResponse<'a>> {
96 let fields = Arc::new(vec![FieldInfo::new(
97 "nspname".to_string(),
98 None,
99 None,
100 Type::VARCHAR,
101 FieldFormat::Text,
102 )]);
103
104 let fields_ref = fields.clone();
105 let rows = self
106 .session_context
107 .catalog_names()
108 .into_iter()
109 .map(move |name| {
110 let mut encoder = pgwire::api::results::DataRowEncoder::new(fields_ref.clone());
111 encoder.encode_field(&Some(&name))?; encoder.finish()
113 });
114
115 let row_stream = futures::stream::iter(rows);
116 Ok(QueryResponse::new(fields.clone(), Box::pin(row_stream)))
117 }
118
119 async fn try_respond_set_time_zone<'a>(
120 &self,
121 query_lower: &str,
122 ) -> PgWireResult<Option<Response<'a>>> {
123 if query_lower.starts_with("set time zone") {
124 let parts: Vec<&str> = query_lower.split_whitespace().collect();
125 if parts.len() >= 4 {
126 let tz = parts[3].trim_matches('"');
127 let mut timezone = self.timezone.lock().await;
128 *timezone = tz.to_string();
129 Ok(Some(Response::Execution(Tag::new("SET"))))
130 } else {
131 Err(PgWireError::UserError(Box::new(
132 pgwire::error::ErrorInfo::new(
133 "ERROR".to_string(),
134 "42601".to_string(),
135 "Invalid SET TIME ZONE syntax".to_string(),
136 ),
137 )))
138 }
139 } else {
140 Ok(None)
141 }
142 }
143
144 async fn try_respond_show_statements<'a>(
145 &self,
146 query_lower: &str,
147 ) -> PgWireResult<Option<Response<'a>>> {
148 if query_lower.starts_with("show ") {
149 match query_lower.strip_suffix(";").unwrap_or(query_lower) {
150 "show time zone" => {
151 let timezone = self.timezone.lock().await.clone();
152 let resp = Self::mock_show_response("TimeZone", &timezone)?;
153 Ok(Some(Response::Query(resp)))
154 }
155 "show server_version" => {
156 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
157 Ok(Some(Response::Query(resp)))
158 }
159 "show transaction_isolation" => {
160 let resp =
161 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
162 Ok(Some(Response::Query(resp)))
163 }
164 "show catalogs" => {
165 let catalogs = self.session_context.catalog_names();
166 let value = catalogs.join(", ");
167 let resp = Self::mock_show_response("Catalogs", &value)?;
168 Ok(Some(Response::Query(resp)))
169 }
170 "show search_path" => {
171 let default_catalog = "datafusion";
172 let resp = Self::mock_show_response("search_path", default_catalog)?;
173 Ok(Some(Response::Query(resp)))
174 }
175 _ => Err(PgWireError::UserError(Box::new(
176 pgwire::error::ErrorInfo::new(
177 "ERROR".to_string(),
178 "42704".to_string(),
179 format!("Unrecognized SHOW command: {}", query_lower),
180 ),
181 ))),
182 }
183 } else {
184 Ok(None)
185 }
186 }
187
188 async fn try_respond_information_schema<'a>(
189 &self,
190 query_lower: &str,
191 ) -> PgWireResult<Option<Response<'a>>> {
192 if query_lower.contains("information_schema.schemata") {
193 let df = schemata_df(&self.session_context)
194 .await
195 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
196 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
197 return Ok(Some(Response::Query(resp)));
198 } else if query_lower.contains("information_schema.tables") {
199 let df = tables_df(&self.session_context)
200 .await
201 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
202 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
203 return Ok(Some(Response::Query(resp)));
204 } else if query_lower.contains("information_schema.columns") {
205 let df = columns_df(&self.session_context)
206 .await
207 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
208 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
209 return Ok(Some(Response::Query(resp)));
210 }
211
212 if query_lower.contains("pg_catalog.pg_namespace") {
214 let resp = self.mock_pg_namespace().await?;
215 return Ok(Some(Response::Query(resp)));
216 }
217
218 Ok(None)
219 }
220}
221
222#[async_trait]
223impl SimpleQueryHandler for DfSessionService {
224 async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
225 where
226 C: ClientInfo + Unpin + Send + Sync,
227 {
228 let query_lower = query.to_lowercase().trim().to_string();
229 log::debug!("Received query: {}", query); if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? {
232 return Ok(vec![resp]);
233 }
234
235 if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
236 return Ok(vec![resp]);
237 }
238
239 if let Some(resp) = self.try_respond_information_schema(&query_lower).await? {
240 return Ok(vec![resp]);
241 }
242
243 let df = self
244 .session_context
245 .sql(query)
246 .await
247 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
248
249 if query_lower.starts_with("insert into") {
250 let result = df
253 .clone()
254 .collect()
255 .await
256 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
257
258 let rows_affected = result
260 .first()
261 .and_then(|batch| batch.column_by_name("count"))
262 .and_then(|col| {
263 col.as_any()
264 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
265 })
266 .map_or(0, |array| array.value(0) as usize);
267
268 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
270 Ok(vec![Response::Execution(tag)])
271 } else {
272 let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
274 Ok(vec![Response::Query(resp)])
275 }
276 }
277}
278
279#[async_trait]
280impl ExtendedQueryHandler for DfSessionService {
281 type Statement = LogicalPlan;
282 type QueryParser = Parser;
283
284 fn query_parser(&self) -> Arc<Self::QueryParser> {
285 self.parser.clone()
286 }
287
288 async fn do_describe_statement<C>(
289 &self,
290 _client: &mut C,
291 target: &StoredStatement<Self::Statement>,
292 ) -> PgWireResult<DescribeStatementResponse>
293 where
294 C: ClientInfo + Unpin + Send + Sync,
295 {
296 let plan = &target.statement;
297 let schema = plan.schema();
298 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
299 let params = plan
300 .get_parameter_types()
301 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
302
303 let mut param_types = Vec::with_capacity(params.len());
304 for param_type in ordered_param_types(¶ms).iter() {
305 if let Some(datatype) = param_type {
307 let pgtype = datatypes::into_pg_type(datatype)?;
308 param_types.push(pgtype);
309 } else {
310 param_types.push(Type::UNKNOWN);
311 }
312 }
313
314 Ok(DescribeStatementResponse::new(param_types, fields))
315 }
316
317 async fn do_describe_portal<C>(
318 &self,
319 _client: &mut C,
320 target: &Portal<Self::Statement>,
321 ) -> PgWireResult<DescribePortalResponse>
322 where
323 C: ClientInfo + Unpin + Send + Sync,
324 {
325 let plan = &target.statement.statement;
326 let format = &target.result_column_format;
327 let schema = plan.schema();
328 let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
329
330 Ok(DescribePortalResponse::new(fields))
331 }
332
333 async fn do_query<'a, C>(
334 &self,
335 _client: &mut C,
336 portal: &Portal<Self::Statement>,
337 _max_rows: usize,
338 ) -> PgWireResult<Response<'a>>
339 where
340 C: ClientInfo + Unpin + Send + Sync,
341 {
342 let query = portal
343 .statement
344 .statement
345 .to_string()
346 .to_lowercase()
347 .trim()
348 .to_string();
349 log::debug!("Received extended query: {}", query); if let Some(resp) = self.try_respond_show_statements(&query).await? {
352 return Ok(resp);
353 }
354
355 if let Some(resp) = self.try_respond_information_schema(&query).await? {
356 return Ok(resp);
357 }
358
359 let plan = &portal.statement.statement;
360 let param_types = plan
361 .get_parameter_types()
362 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
363 let param_values =
364 datatypes::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
366 .clone()
367 .replace_params_with_values(¶m_values)
368 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = self
370 .session_context
371 .execute_logical_plan(plan)
372 .await
373 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
374 let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
375 Ok(Response::Query(resp))
376 }
377}
378
379pub struct Parser {
380 session_context: Arc<SessionContext>,
381}
382
383#[async_trait]
384impl QueryParser for Parser {
385 type Statement = LogicalPlan;
386
387 async fn parse_sql<C>(
388 &self,
389 _client: &C,
390 sql: &str,
391 _types: &[Type],
392 ) -> PgWireResult<Self::Statement> {
393 let context = &self.session_context;
394 let state = context.state();
395 let logical_plan = state
396 .create_logical_plan(sql)
397 .await
398 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
399 let optimised = state
400 .optimize(&logical_plan)
401 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402 Ok(optimised)
403 }
404}
405
406fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
407 let mut types = types.iter().collect::<Vec<_>>();
410 types.sort_by(|a, b| a.0.cmp(b.0));
411 types.into_iter().map(|pt| pt.1.as_ref()).collect()
412}