1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::auth::{AuthManager, Permission, ResourceType};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::DataType;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use pgwire::api::auth::noop::NoopStartupHandler;
10use pgwire::api::auth::StartupHandler;
11use pgwire::api::portal::{Format, Portal};
12use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
13use pgwire::api::results::{
14 DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
15 Response, Tag,
16};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::stmt::StoredStatement;
19use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20use pgwire::error::{PgWireError, PgWireResult};
21use tokio::sync::Mutex;
22
23use arrow_pg::datatypes::df;
24use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum TransactionState {
28 None,
29 Active,
30 Failed,
31}
32
33pub struct SimpleStartupHandler;
36
37#[async_trait::async_trait]
38impl NoopStartupHandler for SimpleStartupHandler {}
39
40pub struct HandlerFactory {
41 pub session_service: Arc<DfSessionService>,
42}
43
44impl HandlerFactory {
45 pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
46 let session_service =
47 Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
48 HandlerFactory { session_service }
49 }
50}
51
52impl PgWireServerHandlers for HandlerFactory {
53 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
54 self.session_service.clone()
55 }
56
57 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
58 self.session_service.clone()
59 }
60
61 fn startup_handler(&self) -> Arc<impl StartupHandler> {
62 Arc::new(SimpleStartupHandler)
63 }
64}
65
66pub struct DfSessionService {
68 session_context: Arc<SessionContext>,
69 parser: Arc<Parser>,
70 timezone: Arc<Mutex<String>>,
71 transaction_state: Arc<Mutex<TransactionState>>,
72 auth_manager: Arc<AuthManager>,
73}
74
75impl DfSessionService {
76 pub fn new(
77 session_context: Arc<SessionContext>,
78 auth_manager: Arc<AuthManager>,
79 ) -> DfSessionService {
80 let parser = Arc::new(Parser {
81 session_context: session_context.clone(),
82 });
83 DfSessionService {
84 session_context,
85 parser,
86 timezone: Arc::new(Mutex::new("UTC".to_string())),
87 transaction_state: Arc::new(Mutex::new(TransactionState::None)),
88 auth_manager,
89 }
90 }
91
92 async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
94 where
95 C: ClientInfo,
96 {
97 let username = client
99 .metadata()
100 .get("user")
101 .map(|s| s.as_str())
102 .unwrap_or("anonymous");
103
104 let query_lower = query.to_lowercase();
106 let query_trimmed = query_lower.trim();
107
108 let (required_permission, resource) = if query_trimmed.starts_with("select") {
109 (Permission::Select, self.extract_table_from_query(query))
110 } else if query_trimmed.starts_with("insert") {
111 (Permission::Insert, self.extract_table_from_query(query))
112 } else if query_trimmed.starts_with("update") {
113 (Permission::Update, self.extract_table_from_query(query))
114 } else if query_trimmed.starts_with("delete") {
115 (Permission::Delete, self.extract_table_from_query(query))
116 } else if query_trimmed.starts_with("create table")
117 || query_trimmed.starts_with("create view")
118 {
119 (Permission::Create, ResourceType::All)
120 } else if query_trimmed.starts_with("drop") {
121 (Permission::Drop, self.extract_table_from_query(query))
122 } else if query_trimmed.starts_with("alter") {
123 (Permission::Alter, self.extract_table_from_query(query))
124 } else {
125 return Ok(());
127 };
128
129 let has_permission = self
131 .auth_manager
132 .check_permission(username, required_permission, resource)
133 .await;
134
135 if !has_permission {
136 return Err(PgWireError::UserError(Box::new(
137 pgwire::error::ErrorInfo::new(
138 "ERROR".to_string(),
139 "42501".to_string(), format!("permission denied for user \"{username}\""),
141 ),
142 )));
143 }
144
145 Ok(())
146 }
147
148 fn extract_table_from_query(&self, query: &str) -> ResourceType {
150 let words: Vec<&str> = query.split_whitespace().collect();
151
152 for (i, word) in words.iter().enumerate() {
154 let word_lower = word.to_lowercase();
155 if (word_lower == "from" || word_lower == "into" || word_lower == "table")
156 && i + 1 < words.len()
157 {
158 let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
159 return ResourceType::Table(table_name.to_string());
160 }
161 }
162
163 ResourceType::All
165 }
166
167 fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
168 let fields = vec![FieldInfo::new(
169 name.to_string(),
170 None,
171 None,
172 Type::VARCHAR,
173 FieldFormat::Text,
174 )];
175
176 let row = {
177 let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
178 encoder.encode_field(&Some(value))?;
179 encoder.finish()
180 };
181
182 let row_stream = futures::stream::once(async move { row });
183 Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
184 }
185
186 async fn try_respond_set_statements<'a>(
187 &self,
188 query_lower: &str,
189 ) -> PgWireResult<Option<Response<'a>>> {
190 if query_lower.starts_with("set") {
191 if query_lower.starts_with("set time zone") {
192 let parts: Vec<&str> = query_lower.split_whitespace().collect();
193 if parts.len() >= 4 {
194 let tz = parts[3].trim_matches('"');
195 let mut timezone = self.timezone.lock().await;
196 *timezone = tz.to_string();
197 Ok(Some(Response::Execution(Tag::new("SET"))))
198 } else {
199 Err(PgWireError::UserError(Box::new(
200 pgwire::error::ErrorInfo::new(
201 "ERROR".to_string(),
202 "42601".to_string(),
203 "Invalid SET TIME ZONE syntax".to_string(),
204 ),
205 )))
206 }
207 } else {
208 Ok(Some(Response::Execution(Tag::new("SET"))))
210 }
211 } else {
212 Ok(None)
213 }
214 }
215
216 async fn try_respond_transaction_statements<'a>(
217 &self,
218 query_lower: &str,
219 ) -> PgWireResult<Option<Response<'a>>> {
220 match query_lower.trim() {
223 "begin" | "begin transaction" | "begin work" | "start transaction" => {
224 let mut state = self.transaction_state.lock().await;
225 match *state {
226 TransactionState::None => {
227 *state = TransactionState::Active;
228 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
229 }
230 TransactionState::Active => {
231 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
234 }
235 TransactionState::Failed => {
236 Err(PgWireError::UserError(Box::new(
238 pgwire::error::ErrorInfo::new(
239 "ERROR".to_string(),
240 "25P01".to_string(),
241 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
242 ),
243 )))
244 }
245 }
246 }
247 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
248 let mut state = self.transaction_state.lock().await;
249 match *state {
250 TransactionState::Active => {
251 *state = TransactionState::None;
252 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
253 }
254 TransactionState::None => {
255 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
257 }
258 TransactionState::Failed => {
259 *state = TransactionState::None;
261 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
262 }
263 }
264 }
265 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
266 let mut state = self.transaction_state.lock().await;
267 *state = TransactionState::None;
268 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
269 }
270 _ => Ok(None),
271 }
272 }
273
274 async fn try_respond_show_statements<'a>(
275 &self,
276 query_lower: &str,
277 ) -> PgWireResult<Option<Response<'a>>> {
278 if query_lower.starts_with("show ") {
279 match query_lower.strip_suffix(";").unwrap_or(query_lower) {
280 "show time zone" => {
281 let timezone = self.timezone.lock().await.clone();
282 let resp = Self::mock_show_response("TimeZone", &timezone)?;
283 Ok(Some(Response::Query(resp)))
284 }
285 "show server_version" => {
286 let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
287 Ok(Some(Response::Query(resp)))
288 }
289 "show transaction_isolation" => {
290 let resp =
291 Self::mock_show_response("transaction_isolation", "read uncommitted")?;
292 Ok(Some(Response::Query(resp)))
293 }
294 "show catalogs" => {
295 let catalogs = self.session_context.catalog_names();
296 let value = catalogs.join(", ");
297 let resp = Self::mock_show_response("Catalogs", &value)?;
298 Ok(Some(Response::Query(resp)))
299 }
300 "show search_path" => {
301 let default_catalog = "datafusion";
302 let resp = Self::mock_show_response("search_path", default_catalog)?;
303 Ok(Some(Response::Query(resp)))
304 }
305 _ => Err(PgWireError::UserError(Box::new(
306 pgwire::error::ErrorInfo::new(
307 "ERROR".to_string(),
308 "42704".to_string(),
309 format!("Unrecognized SHOW command: {query_lower}"),
310 ),
311 ))),
312 }
313 } else {
314 Ok(None)
315 }
316 }
317}
318
319#[async_trait]
320impl SimpleQueryHandler for DfSessionService {
321 async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
322 where
323 C: ClientInfo + Unpin + Send + Sync,
324 {
325 let query_lower = query.to_lowercase().trim().to_string();
326 log::debug!("Received query: {}", query); if !query_lower.starts_with("set")
330 && !query_lower.starts_with("begin")
331 && !query_lower.starts_with("commit")
332 && !query_lower.starts_with("rollback")
333 && !query_lower.starts_with("start")
334 && !query_lower.starts_with("end")
335 && !query_lower.starts_with("abort")
336 && !query_lower.starts_with("show")
337 {
338 self.check_query_permission(client, query).await?;
339 }
340
341 if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
342 return Ok(vec![resp]);
343 }
344
345 if let Some(resp) = self
346 .try_respond_transaction_statements(&query_lower)
347 .await?
348 {
349 return Ok(vec![resp]);
350 }
351
352 if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
353 return Ok(vec![resp]);
354 }
355
356 {
358 let state = self.transaction_state.lock().await;
359 if *state == TransactionState::Failed {
360 return Err(PgWireError::UserError(Box::new(
361 pgwire::error::ErrorInfo::new(
362 "ERROR".to_string(),
363 "25P01".to_string(),
364 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
365 ),
366 )));
367 }
368 }
369
370 let df_result = self.session_context.sql(query).await;
371
372 let df = match df_result {
374 Ok(df) => df,
375 Err(e) => {
376 {
378 let mut state = self.transaction_state.lock().await;
379 if *state == TransactionState::Active {
380 *state = TransactionState::Failed;
381 }
382 }
383 return Err(PgWireError::ApiError(Box::new(e)));
384 }
385 };
386
387 if query_lower.starts_with("insert into") {
388 let result = df
391 .clone()
392 .collect()
393 .await
394 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
395
396 let rows_affected = result
398 .first()
399 .and_then(|batch| batch.column_by_name("count"))
400 .and_then(|col| {
401 col.as_any()
402 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
403 })
404 .map_or(0, |array| array.value(0) as usize);
405
406 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
408 Ok(vec![Response::Execution(tag)])
409 } else {
410 let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
412 Ok(vec![Response::Query(resp)])
413 }
414 }
415}
416
417#[async_trait]
418impl ExtendedQueryHandler for DfSessionService {
419 type Statement = (String, LogicalPlan);
420 type QueryParser = Parser;
421
422 fn query_parser(&self) -> Arc<Self::QueryParser> {
423 self.parser.clone()
424 }
425
426 async fn do_describe_statement<C>(
427 &self,
428 _client: &mut C,
429 target: &StoredStatement<Self::Statement>,
430 ) -> PgWireResult<DescribeStatementResponse>
431 where
432 C: ClientInfo + Unpin + Send + Sync,
433 {
434 let (_, plan) = &target.statement;
435 let schema = plan.schema();
436 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
437 let params = plan
438 .get_parameter_types()
439 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
440
441 let mut param_types = Vec::with_capacity(params.len());
442 for param_type in ordered_param_types(¶ms).iter() {
443 if let Some(datatype) = param_type {
445 let pgtype = into_pg_type(datatype)?;
446 param_types.push(pgtype);
447 } else {
448 param_types.push(Type::UNKNOWN);
449 }
450 }
451
452 Ok(DescribeStatementResponse::new(param_types, fields))
453 }
454
455 async fn do_describe_portal<C>(
456 &self,
457 _client: &mut C,
458 target: &Portal<Self::Statement>,
459 ) -> PgWireResult<DescribePortalResponse>
460 where
461 C: ClientInfo + Unpin + Send + Sync,
462 {
463 let (_, plan) = &target.statement.statement;
464 let format = &target.result_column_format;
465 let schema = plan.schema();
466 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
467
468 Ok(DescribePortalResponse::new(fields))
469 }
470
471 async fn do_query<'a, C>(
472 &self,
473 client: &mut C,
474 portal: &Portal<Self::Statement>,
475 _max_rows: usize,
476 ) -> PgWireResult<Response<'a>>
477 where
478 C: ClientInfo + Unpin + Send + Sync,
479 {
480 let query = portal
481 .statement
482 .statement
483 .0
484 .to_lowercase()
485 .trim()
486 .to_string();
487 log::debug!("Received execute extended query: {}", query); if !query.starts_with("set") && !query.starts_with("show") {
491 self.check_query_permission(client, &portal.statement.statement.0)
492 .await?;
493 }
494
495 if let Some(resp) = self.try_respond_set_statements(&query).await? {
496 return Ok(resp);
497 }
498
499 if let Some(resp) = self.try_respond_show_statements(&query).await? {
500 return Ok(resp);
501 }
502
503 let (_, plan) = &portal.statement.statement;
504
505 let param_types = plan
506 .get_parameter_types()
507 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
508 let param_values = df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
510 .clone()
511 .replace_params_with_values(¶m_values)
512 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = self
514 .session_context
515 .execute_logical_plan(plan)
516 .await
517 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
518 let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
519 Ok(Response::Query(resp))
520 }
521}
522
523pub struct Parser {
524 session_context: Arc<SessionContext>,
525}
526
527#[async_trait]
528impl QueryParser for Parser {
529 type Statement = (String, LogicalPlan);
530
531 async fn parse_sql<C>(
532 &self,
533 _client: &C,
534 sql: &str,
535 _types: &[Type],
536 ) -> PgWireResult<Self::Statement> {
537 log::debug!("Received parse extended query: {}", sql); let context = &self.session_context;
539 let state = context.state();
540 let logical_plan = state
541 .create_logical_plan(sql)
542 .await
543 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
544 let optimised = state
545 .optimize(&logical_plan)
546 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
547 Ok((sql.to_string(), optimised))
548 }
549}
550
551fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
552 let mut types = types.iter().collect::<Vec<_>>();
555 types.sort_by(|a, b| a.0.cmp(b.0));
556 types.into_iter().map(|pt| pt.1.as_ref()).collect()
557}