1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::auth::{AuthManager, Permission, ResourceType};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::DataType;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use pgwire::api::auth::noop::NoopStartupHandler;
10use pgwire::api::auth::StartupHandler;
11use pgwire::api::portal::{Format, Portal};
12use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
13use pgwire::api::results::{
14 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
15 Response, Tag,
16};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::stmt::StoredStatement;
19use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20use pgwire::error::{PgWireError, PgWireResult};
21use tokio::sync::Mutex;
22
23use arrow_pg::datatypes::df;
24use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum TransactionState {
28 None,
29 Active,
30 Failed,
31}
32
33pub struct SimpleStartupHandler;
36
37#[async_trait::async_trait]
38impl NoopStartupHandler for SimpleStartupHandler {}
39
40pub struct HandlerFactory {
41 pub session_service: Arc<DfSessionService>,
42}
43
44impl HandlerFactory {
45 pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
46 let session_service =
47 Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
48 HandlerFactory { session_service }
49 }
50}
51
52impl PgWireServerHandlers for HandlerFactory {
53 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
54 self.session_service.clone()
55 }
56
57 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
58 self.session_service.clone()
59 }
60
61 fn startup_handler(&self) -> Arc<impl StartupHandler> {
62 Arc::new(SimpleStartupHandler)
63 }
64}
65
66pub struct DfSessionService {
68 session_context: Arc<SessionContext>,
69 parser: Arc<Parser>,
70 timezone: Arc<Mutex<String>>,
71 transaction_state: Arc<Mutex<TransactionState>>,
72 auth_manager: Arc<AuthManager>,
73}
74
75impl DfSessionService {
76 pub fn new(
77 session_context: Arc<SessionContext>,
78 auth_manager: Arc<AuthManager>,
79 ) -> DfSessionService {
80 let parser = Arc::new(Parser {
81 session_context: session_context.clone(),
82 });
83 DfSessionService {
84 session_context,
85 parser,
86 timezone: Arc::new(Mutex::new("UTC".to_string())),
87 transaction_state: Arc::new(Mutex::new(TransactionState::None)),
88 auth_manager,
89 }
90 }
91
92 async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
94 where
95 C: ClientInfo,
96 {
97 let username = client
99 .metadata()
100 .get("user")
101 .map(|s| s.as_str())
102 .unwrap_or("anonymous");
103
104 let query_lower = query.to_lowercase();
106 let query_trimmed = query_lower.trim();
107
108 let (required_permission, resource) = if query_trimmed.starts_with("select") {
109 (Permission::Select, self.extract_table_from_query(query))
110 } else if query_trimmed.starts_with("insert") {
111 (Permission::Insert, self.extract_table_from_query(query))
112 } else if query_trimmed.starts_with("update") {
113 (Permission::Update, self.extract_table_from_query(query))
114 } else if query_trimmed.starts_with("delete") {
115 (Permission::Delete, self.extract_table_from_query(query))
116 } else if query_trimmed.starts_with("create table")
117 || query_trimmed.starts_with("create view")
118 {
119 (Permission::Create, ResourceType::All)
120 } else if query_trimmed.starts_with("drop") {
121 (Permission::Drop, self.extract_table_from_query(query))
122 } else if query_trimmed.starts_with("alter") {
123 (Permission::Alter, self.extract_table_from_query(query))
124 } else {
125 return Ok(());
127 };
128
129 let has_permission = self
131 .auth_manager
132 .check_permission(username, required_permission, resource)
133 .await;
134
135 if !has_permission {
136 return Err(PgWireError::UserError(Box::new(
137 pgwire::error::ErrorInfo::new(
138 "ERROR".to_string(),
139 "42501".to_string(), format!("permission denied for user \"{username}\""),
141 ),
142 )));
143 }
144
145 Ok(())
146 }
147
148 fn extract_table_from_query(&self, query: &str) -> ResourceType {
150 let words: Vec<&str> = query.split_whitespace().collect();
151
152 for (i, word) in words.iter().enumerate() {
154 let word_lower = word.to_lowercase();
155 if (word_lower == "from" || word_lower == "into" || word_lower == "table")
156 && i + 1 < words.len()
157 {
158 let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
159 return ResourceType::Table(table_name.to_string());
160 }
161 }
162
163 ResourceType::All
165 }
166
167 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
168 let fields = vec![FieldInfo::new(
169 name.to_string(),
170 None,
171 None,
172 Type::VARCHAR,
173 FieldFormat::Text,
174 )];
175
176 let row = {
177 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
178 encoder.encode_field(&Some(value))?;
179 encoder.finish()
180 };
181
182 let row_stream = futures::stream::once(async move { row });
183 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
184 }
185
186 async fn try_respond_set_statements<'a>(
187 &self,
188 query_lower: &str,
189 ) -> PgWireResult<Option<Response<'a>>> {
190 if query_lower.starts_with("set") {
191 if query_lower.starts_with("set time zone") {
192 let parts: Vec<&str> = query_lower.split_whitespace().collect();
193 if parts.len() >= 4 {
194 let tz = parts[3].trim_matches('"');
195 let mut timezone = self.timezone.lock().await;
196 *timezone = tz.to_string();
197 Ok(Some(Response::Execution(Tag::new("SET"))))
198 } else {
199 Err(PgWireError::UserError(Box::new(
200 pgwire::error::ErrorInfo::new(
201 "ERROR".to_string(),
202 "42601".to_string(),
203 "Invalid SET TIME ZONE syntax".to_string(),
204 ),
205 )))
206 }
207 } else {
208 let df = self
210 .session_context
211 .sql(query_lower)
212 .await
213 .map_err(|err| PgWireError::ApiError(Box::new(err)))?;
214
215 let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
216 Ok(Some(Response::Query(resp)))
217 }
218 } else {
219 Ok(None)
220 }
221 }
222
223 async fn try_respond_transaction_statements<'a>(
224 &self,
225 query_lower: &str,
226 ) -> PgWireResult<Option<Response<'a>>> {
227 match query_lower.trim() {
230 "begin" | "begin transaction" | "begin work" | "start transaction" => {
231 let mut state = self.transaction_state.lock().await;
232 match *state {
233 TransactionState::None => {
234 *state = TransactionState::Active;
235 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
236 }
237 TransactionState::Active => {
238 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
241 }
242 TransactionState::Failed => {
243 Err(PgWireError::UserError(Box::new(
245 pgwire::error::ErrorInfo::new(
246 "ERROR".to_string(),
247 "25P01".to_string(),
248 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
249 ),
250 )))
251 }
252 }
253 }
254 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
255 let mut state = self.transaction_state.lock().await;
256 match *state {
257 TransactionState::Active => {
258 *state = TransactionState::None;
259 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
260 }
261 TransactionState::None => {
262 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
264 }
265 TransactionState::Failed => {
266 *state = TransactionState::None;
268 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
269 }
270 }
271 }
272 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
273 let mut state = self.transaction_state.lock().await;
274 *state = TransactionState::None;
275 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
276 }
277 _ => Ok(None),
278 }
279 }
280
281 async fn try_respond_show_statements<'a>(
282 &self,
283 query_lower: &str,
284 ) -> PgWireResult<Option<Response<'a>>> {
285 if query_lower.starts_with("show ") {
286 match query_lower.strip_suffix(";").unwrap_or(query_lower) {
287 "show time zone" => {
288 let timezone = self.timezone.lock().await.clone();
289 let resp = Self::mock_show_response("TimeZone", &timezone)?;
290 Ok(Some(Response::Query(resp)))
291 }
292 "show server_version" => {
293 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
294 Ok(Some(Response::Query(resp)))
295 }
296 "show transaction_isolation" => {
297 let resp =
298 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
299 Ok(Some(Response::Query(resp)))
300 }
301 "show catalogs" => {
302 let catalogs = self.session_context.catalog_names();
303 let value = catalogs.join(", ");
304 let resp = Self::mock_show_response("Catalogs", &value)?;
305 Ok(Some(Response::Query(resp)))
306 }
307 "show search_path" => {
308 let default_catalog = "datafusion";
309 let resp = Self::mock_show_response("search_path", default_catalog)?;
310 Ok(Some(Response::Query(resp)))
311 }
312 _ => Err(PgWireError::UserError(Box::new(
313 pgwire::error::ErrorInfo::new(
314 "ERROR".to_string(),
315 "42704".to_string(),
316 format!("Unrecognized SHOW command: {query_lower}"),
317 ),
318 ))),
319 }
320 } else {
321 Ok(None)
322 }
323 }
324}
325
326#[async_trait]
327impl SimpleQueryHandler for DfSessionService {
328 async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
329 where
330 C: ClientInfo + Unpin + Send + Sync,
331 {
332 let query_lower = query.to_lowercase().trim().to_string();
333 log::debug!("Received query: {}", query); if !query_lower.starts_with("set")
337 && !query_lower.starts_with("begin")
338 && !query_lower.starts_with("commit")
339 && !query_lower.starts_with("rollback")
340 && !query_lower.starts_with("start")
341 && !query_lower.starts_with("end")
342 && !query_lower.starts_with("abort")
343 && !query_lower.starts_with("show")
344 {
345 self.check_query_permission(client, query).await?;
346 }
347
348 if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
349 return Ok(vec![resp]);
350 }
351
352 if let Some(resp) = self
353 .try_respond_transaction_statements(&query_lower)
354 .await?
355 {
356 return Ok(vec![resp]);
357 }
358
359 if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
360 return Ok(vec![resp]);
361 }
362
363 {
365 let state = self.transaction_state.lock().await;
366 if *state == TransactionState::Failed {
367 return Err(PgWireError::UserError(Box::new(
368 pgwire::error::ErrorInfo::new(
369 "ERROR".to_string(),
370 "25P01".to_string(),
371 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
372 ),
373 )));
374 }
375 }
376
377 let df_result = self.session_context.sql(query).await;
378
379 let df = match df_result {
381 Ok(df) => df,
382 Err(e) => {
383 {
385 let mut state = self.transaction_state.lock().await;
386 if *state == TransactionState::Active {
387 *state = TransactionState::Failed;
388 }
389 }
390 return Err(PgWireError::ApiError(Box::new(e)));
391 }
392 };
393
394 if query_lower.starts_with("insert into") {
395 let result = df
398 .clone()
399 .collect()
400 .await
401 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402
403 let rows_affected = result
405 .first()
406 .and_then(|batch| batch.column_by_name("count"))
407 .and_then(|col| {
408 col.as_any()
409 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
410 })
411 .map_or(0, |array| array.value(0) as usize);
412
413 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
415 Ok(vec![Response::Execution(tag)])
416 } else {
417 let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
419 Ok(vec![Response::Query(resp)])
420 }
421 }
422}
423
424#[async_trait]
425impl ExtendedQueryHandler for DfSessionService {
426 type Statement = (String, LogicalPlan);
427 type QueryParser = Parser;
428
429 fn query_parser(&self) -> Arc<Self::QueryParser> {
430 self.parser.clone()
431 }
432
433 async fn do_describe_statement<C>(
434 &self,
435 _client: &mut C,
436 target: &StoredStatement<Self::Statement>,
437 ) -> PgWireResult<DescribeStatementResponse>
438 where
439 C: ClientInfo + Unpin + Send + Sync,
440 {
441 let (_, plan) = &target.statement;
442 let schema = plan.schema();
443 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
444 let params = plan
445 .get_parameter_types()
446 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
447
448 let mut param_types = Vec::with_capacity(params.len());
449 for param_type in ordered_param_types(¶ms).iter() {
450 if let Some(datatype) = param_type {
452 let pgtype = into_pg_type(datatype)?;
453 param_types.push(pgtype);
454 } else {
455 param_types.push(Type::UNKNOWN);
456 }
457 }
458
459 Ok(DescribeStatementResponse::new(param_types, fields))
460 }
461
462 async fn do_describe_portal<C>(
463 &self,
464 _client: &mut C,
465 target: &Portal<Self::Statement>,
466 ) -> PgWireResult<DescribePortalResponse>
467 where
468 C: ClientInfo + Unpin + Send + Sync,
469 {
470 let (_, plan) = &target.statement.statement;
471 let format = &target.result_column_format;
472 let schema = plan.schema();
473 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
474
475 Ok(DescribePortalResponse::new(fields))
476 }
477
478 async fn do_query<'a, C>(
479 &self,
480 client: &mut C,
481 portal: &Portal<Self::Statement>,
482 _max_rows: usize,
483 ) -> PgWireResult<Response<'a>>
484 where
485 C: ClientInfo + Unpin + Send + Sync,
486 {
487 let query = portal
488 .statement
489 .statement
490 .0
491 .to_lowercase()
492 .trim()
493 .to_string();
494 log::debug!("Received execute extended query: {}", query); if !query.starts_with("set") && !query.starts_with("show") {
498 self.check_query_permission(client, &portal.statement.statement.0)
499 .await?;
500 }
501
502 if let Some(resp) = self.try_respond_set_statements(&query).await? {
503 return Ok(resp);
504 }
505
506 if let Some(resp) = self.try_respond_show_statements(&query).await? {
507 return Ok(resp);
508 }
509
510 let (_, plan) = &portal.statement.statement;
511
512 let param_types = plan
513 .get_parameter_types()
514 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
515 let param_values = df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
517 .clone()
518 .replace_params_with_values(¶m_values)
519 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = self
521 .session_context
522 .execute_logical_plan(plan)
523 .await
524 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
525 let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
526 Ok(Response::Query(resp))
527 }
528}
529
530pub struct Parser {
531 session_context: Arc<SessionContext>,
532}
533
534#[async_trait]
535impl QueryParser for Parser {
536 type Statement = (String, LogicalPlan);
537
538 async fn parse_sql<C>(
539 &self,
540 _client: &C,
541 sql: &str,
542 _types: &[Type],
543 ) -> PgWireResult<Self::Statement> {
544 log::debug!("Received parse extended query: {}", sql); let context = &self.session_context;
546 let state = context.state();
547 let logical_plan = state
548 .create_logical_plan(sql)
549 .await
550 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
551 let optimised = state
552 .optimize(&logical_plan)
553 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
554 Ok((sql.to_string(), optimised))
555 }
556}
557
558fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
559 let mut types = types.iter().collect::<Vec<_>>();
562 types.sort_by(|a, b| a.0.cmp(b.0));
563 types.into_iter().map(|pt| pt.1.as_ref()).collect()
564}