datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::auth::{AuthManager, Permission, ResourceType};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::DataType;
7use datafusion::logical_expr::LogicalPlan;
8use datafusion::prelude::*;
9use pgwire::api::auth::noop::NoopStartupHandler;
10use pgwire::api::auth::StartupHandler;
11use pgwire::api::portal::{Format, Portal};
12use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
13use pgwire::api::results::{
14    DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
15    Response, Tag,
16};
17use pgwire::api::stmt::QueryParser;
18use pgwire::api::stmt::StoredStatement;
19use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
20use pgwire::error::{PgWireError, PgWireResult};
21use tokio::sync::Mutex;
22
23use arrow_pg::datatypes::df;
24use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum TransactionState {
28    None,
29    Active,
30    Failed,
31}
32
33/// Simple startup handler that does no authentication
34/// For production, use DfAuthSource with proper pgwire authentication handlers
35pub struct SimpleStartupHandler;
36
37#[async_trait::async_trait]
38impl NoopStartupHandler for SimpleStartupHandler {}
39
40pub struct HandlerFactory {
41    pub session_service: Arc<DfSessionService>,
42}
43
44impl HandlerFactory {
45    pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
46        let session_service =
47            Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
48        HandlerFactory { session_service }
49    }
50}
51
52impl PgWireServerHandlers for HandlerFactory {
53    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
54        self.session_service.clone()
55    }
56
57    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
58        self.session_service.clone()
59    }
60
61    fn startup_handler(&self) -> Arc<impl StartupHandler> {
62        Arc::new(SimpleStartupHandler)
63    }
64}
65
66/// The pgwire handler backed by a datafusion `SessionContext`
67pub struct DfSessionService {
68    session_context: Arc<SessionContext>,
69    parser: Arc<Parser>,
70    timezone: Arc<Mutex<String>>,
71    transaction_state: Arc<Mutex<TransactionState>>,
72    auth_manager: Arc<AuthManager>,
73}
74
75impl DfSessionService {
76    pub fn new(
77        session_context: Arc<SessionContext>,
78        auth_manager: Arc<AuthManager>,
79    ) -> DfSessionService {
80        let parser = Arc::new(Parser {
81            session_context: session_context.clone(),
82        });
83        DfSessionService {
84            session_context,
85            parser,
86            timezone: Arc::new(Mutex::new("UTC".to_string())),
87            transaction_state: Arc::new(Mutex::new(TransactionState::None)),
88            auth_manager,
89        }
90    }
91
92    /// Check if the current user has permission to execute a query
93    async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
94    where
95        C: ClientInfo,
96    {
97        // Get the username from client metadata
98        let username = client
99            .metadata()
100            .get("user")
101            .map(|s| s.as_str())
102            .unwrap_or("anonymous");
103
104        // Parse query to determine required permissions
105        let query_lower = query.to_lowercase();
106        let query_trimmed = query_lower.trim();
107
108        let (required_permission, resource) = if query_trimmed.starts_with("select") {
109            (Permission::Select, self.extract_table_from_query(query))
110        } else if query_trimmed.starts_with("insert") {
111            (Permission::Insert, self.extract_table_from_query(query))
112        } else if query_trimmed.starts_with("update") {
113            (Permission::Update, self.extract_table_from_query(query))
114        } else if query_trimmed.starts_with("delete") {
115            (Permission::Delete, self.extract_table_from_query(query))
116        } else if query_trimmed.starts_with("create table")
117            || query_trimmed.starts_with("create view")
118        {
119            (Permission::Create, ResourceType::All)
120        } else if query_trimmed.starts_with("drop") {
121            (Permission::Drop, self.extract_table_from_query(query))
122        } else if query_trimmed.starts_with("alter") {
123            (Permission::Alter, self.extract_table_from_query(query))
124        } else {
125            // For other queries (SHOW, EXPLAIN, etc.), allow all users
126            return Ok(());
127        };
128
129        // Check permission
130        let has_permission = self
131            .auth_manager
132            .check_permission(username, required_permission, resource)
133            .await;
134
135        if !has_permission {
136            return Err(PgWireError::UserError(Box::new(
137                pgwire::error::ErrorInfo::new(
138                    "ERROR".to_string(),
139                    "42501".to_string(), // insufficient_privilege
140                    format!("permission denied for user \"{username}\""),
141                ),
142            )));
143        }
144
145        Ok(())
146    }
147
148    /// Extract table name from query (simplified parsing)
149    fn extract_table_from_query(&self, query: &str) -> ResourceType {
150        let words: Vec<&str> = query.split_whitespace().collect();
151
152        // Simple heuristic to find table names
153        for (i, word) in words.iter().enumerate() {
154            let word_lower = word.to_lowercase();
155            if (word_lower == "from" || word_lower == "into" || word_lower == "table")
156                && i + 1 < words.len()
157            {
158                let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
159                return ResourceType::Table(table_name.to_string());
160            }
161        }
162
163        // If we can't determine the table, default to All
164        ResourceType::All
165    }
166
167    fn mock_show_response<'a>(name: &str, value: &str) -> PgWireResult<QueryResponse<'a>> {
168        let fields = vec![FieldInfo::new(
169            name.to_string(),
170            None,
171            None,
172            Type::VARCHAR,
173            FieldFormat::Text,
174        )];
175
176        let row = {
177            let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone()));
178            encoder.encode_field(&Some(value))?;
179            encoder.finish()
180        };
181
182        let row_stream = futures::stream::once(async move { row });
183        Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream)))
184    }
185
186    async fn try_respond_set_statements<'a>(
187        &self,
188        query_lower: &str,
189    ) -> PgWireResult<Option<Response<'a>>> {
190        if query_lower.starts_with("set") {
191            if query_lower.starts_with("set time zone") {
192                let parts: Vec<&str> = query_lower.split_whitespace().collect();
193                if parts.len() >= 4 {
194                    let tz = parts[3].trim_matches('"');
195                    let mut timezone = self.timezone.lock().await;
196                    *timezone = tz.to_string();
197                    Ok(Some(Response::Execution(Tag::new("SET"))))
198                } else {
199                    Err(PgWireError::UserError(Box::new(
200                        pgwire::error::ErrorInfo::new(
201                            "ERROR".to_string(),
202                            "42601".to_string(),
203                            "Invalid SET TIME ZONE syntax".to_string(),
204                        ),
205                    )))
206                }
207            } else {
208                // pass SET query to datafusion
209                let df = self
210                    .session_context
211                    .sql(query_lower)
212                    .await
213                    .map_err(|err| PgWireError::ApiError(Box::new(err)))?;
214
215                let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
216                Ok(Some(Response::Query(resp)))
217            }
218        } else {
219            Ok(None)
220        }
221    }
222
223    async fn try_respond_transaction_statements<'a>(
224        &self,
225        query_lower: &str,
226    ) -> PgWireResult<Option<Response<'a>>> {
227        // Transaction handling based on pgwire example:
228        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
229        match query_lower.trim() {
230            "begin" | "begin transaction" | "begin work" | "start transaction" => {
231                let mut state = self.transaction_state.lock().await;
232                match *state {
233                    TransactionState::None => {
234                        *state = TransactionState::Active;
235                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
236                    }
237                    TransactionState::Active => {
238                        // Already in transaction, PostgreSQL allows this but issues a warning
239                        // For simplicity, we'll just return BEGIN again
240                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
241                    }
242                    TransactionState::Failed => {
243                        // Can't start new transaction from failed state
244                        Err(PgWireError::UserError(Box::new(
245                            pgwire::error::ErrorInfo::new(
246                                "ERROR".to_string(),
247                                "25P01".to_string(),
248                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
249                            ),
250                        )))
251                    }
252                }
253            }
254            "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
255                let mut state = self.transaction_state.lock().await;
256                match *state {
257                    TransactionState::Active => {
258                        *state = TransactionState::None;
259                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
260                    }
261                    TransactionState::None => {
262                        // PostgreSQL allows COMMIT outside transaction with warning
263                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
264                    }
265                    TransactionState::Failed => {
266                        // COMMIT in failed transaction is treated as ROLLBACK
267                        *state = TransactionState::None;
268                        Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
269                    }
270                }
271            }
272            "rollback" | "rollback transaction" | "rollback work" | "abort" => {
273                let mut state = self.transaction_state.lock().await;
274                *state = TransactionState::None;
275                Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
276            }
277            _ => Ok(None),
278        }
279    }
280
281    async fn try_respond_show_statements<'a>(
282        &self,
283        query_lower: &str,
284    ) -> PgWireResult<Option<Response<'a>>> {
285        if query_lower.starts_with("show ") {
286            match query_lower.strip_suffix(";").unwrap_or(query_lower) {
287                "show time zone" => {
288                    let timezone = self.timezone.lock().await.clone();
289                    let resp = Self::mock_show_response("TimeZone", &timezone)?;
290                    Ok(Some(Response::Query(resp)))
291                }
292                "show server_version" => {
293                    let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?;
294                    Ok(Some(Response::Query(resp)))
295                }
296                "show transaction_isolation" => {
297                    let resp =
298                        Self::mock_show_response("transaction_isolation", "read uncommitted")?;
299                    Ok(Some(Response::Query(resp)))
300                }
301                "show catalogs" => {
302                    let catalogs = self.session_context.catalog_names();
303                    let value = catalogs.join(", ");
304                    let resp = Self::mock_show_response("Catalogs", &value)?;
305                    Ok(Some(Response::Query(resp)))
306                }
307                "show search_path" => {
308                    let default_catalog = "datafusion";
309                    let resp = Self::mock_show_response("search_path", default_catalog)?;
310                    Ok(Some(Response::Query(resp)))
311                }
312                _ => Err(PgWireError::UserError(Box::new(
313                    pgwire::error::ErrorInfo::new(
314                        "ERROR".to_string(),
315                        "42704".to_string(),
316                        format!("Unrecognized SHOW command: {query_lower}"),
317                    ),
318                ))),
319            }
320        } else {
321            Ok(None)
322        }
323    }
324}
325
326#[async_trait]
327impl SimpleQueryHandler for DfSessionService {
328    async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
329    where
330        C: ClientInfo + Unpin + Send + Sync,
331    {
332        let query_lower = query.to_lowercase().trim().to_string();
333        log::debug!("Received query: {}", query); // Log the query for debugging
334
335        // Check permissions for the query (skip for SET, transaction, and SHOW statements)
336        if !query_lower.starts_with("set")
337            && !query_lower.starts_with("begin")
338            && !query_lower.starts_with("commit")
339            && !query_lower.starts_with("rollback")
340            && !query_lower.starts_with("start")
341            && !query_lower.starts_with("end")
342            && !query_lower.starts_with("abort")
343            && !query_lower.starts_with("show")
344        {
345            self.check_query_permission(client, query).await?;
346        }
347
348        if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
349            return Ok(vec![resp]);
350        }
351
352        if let Some(resp) = self
353            .try_respond_transaction_statements(&query_lower)
354            .await?
355        {
356            return Ok(vec![resp]);
357        }
358
359        if let Some(resp) = self.try_respond_show_statements(&query_lower).await? {
360            return Ok(vec![resp]);
361        }
362
363        // Check if we're in a failed transaction and block non-transaction commands
364        {
365            let state = self.transaction_state.lock().await;
366            if *state == TransactionState::Failed {
367                return Err(PgWireError::UserError(Box::new(
368                    pgwire::error::ErrorInfo::new(
369                        "ERROR".to_string(),
370                        "25P01".to_string(),
371                        "current transaction is aborted, commands ignored until end of transaction block".to_string(),
372                    ),
373                )));
374            }
375        }
376
377        let df_result = self.session_context.sql(query).await;
378
379        // Handle query execution errors and transaction state
380        let df = match df_result {
381            Ok(df) => df,
382            Err(e) => {
383                // If we're in a transaction and a query fails, mark transaction as failed
384                {
385                    let mut state = self.transaction_state.lock().await;
386                    if *state == TransactionState::Active {
387                        *state = TransactionState::Failed;
388                    }
389                }
390                return Err(PgWireError::ApiError(Box::new(e)));
391            }
392        };
393
394        if query_lower.starts_with("insert into") {
395            // For INSERT queries, we need to execute the query to get the row count
396            // and return an Execution response with the proper tag
397            let result = df
398                .clone()
399                .collect()
400                .await
401                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
402
403            // Extract count field from the first batch
404            let rows_affected = result
405                .first()
406                .and_then(|batch| batch.column_by_name("count"))
407                .and_then(|col| {
408                    col.as_any()
409                        .downcast_ref::<datafusion::arrow::array::UInt64Array>()
410                })
411                .map_or(0, |array| array.value(0) as usize);
412
413            // Create INSERT tag with the affected row count
414            let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
415            Ok(vec![Response::Execution(tag)])
416        } else {
417            // For non-INSERT queries, return a regular Query response
418            let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
419            Ok(vec![Response::Query(resp)])
420        }
421    }
422}
423
424#[async_trait]
425impl ExtendedQueryHandler for DfSessionService {
426    type Statement = (String, LogicalPlan);
427    type QueryParser = Parser;
428
429    fn query_parser(&self) -> Arc<Self::QueryParser> {
430        self.parser.clone()
431    }
432
433    async fn do_describe_statement<C>(
434        &self,
435        _client: &mut C,
436        target: &StoredStatement<Self::Statement>,
437    ) -> PgWireResult<DescribeStatementResponse>
438    where
439        C: ClientInfo + Unpin + Send + Sync,
440    {
441        let (_, plan) = &target.statement;
442        let schema = plan.schema();
443        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
444        let params = plan
445            .get_parameter_types()
446            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
447
448        let mut param_types = Vec::with_capacity(params.len());
449        for param_type in ordered_param_types(&params).iter() {
450            // Fixed: Use &params
451            if let Some(datatype) = param_type {
452                let pgtype = into_pg_type(datatype)?;
453                param_types.push(pgtype);
454            } else {
455                param_types.push(Type::UNKNOWN);
456            }
457        }
458
459        Ok(DescribeStatementResponse::new(param_types, fields))
460    }
461
462    async fn do_describe_portal<C>(
463        &self,
464        _client: &mut C,
465        target: &Portal<Self::Statement>,
466    ) -> PgWireResult<DescribePortalResponse>
467    where
468        C: ClientInfo + Unpin + Send + Sync,
469    {
470        let (_, plan) = &target.statement.statement;
471        let format = &target.result_column_format;
472        let schema = plan.schema();
473        let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
474
475        Ok(DescribePortalResponse::new(fields))
476    }
477
478    async fn do_query<'a, C>(
479        &self,
480        client: &mut C,
481        portal: &Portal<Self::Statement>,
482        _max_rows: usize,
483    ) -> PgWireResult<Response<'a>>
484    where
485        C: ClientInfo + Unpin + Send + Sync,
486    {
487        let query = portal
488            .statement
489            .statement
490            .0
491            .to_lowercase()
492            .trim()
493            .to_string();
494        log::debug!("Received execute extended query: {}", query); // Log for debugging
495
496        // Check permissions for the query (skip for SET and SHOW statements)
497        if !query.starts_with("set") && !query.starts_with("show") {
498            self.check_query_permission(client, &portal.statement.statement.0)
499                .await?;
500        }
501
502        if let Some(resp) = self.try_respond_set_statements(&query).await? {
503            return Ok(resp);
504        }
505
506        if let Some(resp) = self.try_respond_show_statements(&query).await? {
507            return Ok(resp);
508        }
509
510        let (_, plan) = &portal.statement.statement;
511
512        let param_types = plan
513            .get_parameter_types()
514            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
515        let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
516        let plan = plan
517            .clone()
518            .replace_params_with_values(&param_values)
519            .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
520        let dataframe = self
521            .session_context
522            .execute_logical_plan(plan)
523            .await
524            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
525        let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
526        Ok(Response::Query(resp))
527    }
528}
529
530pub struct Parser {
531    session_context: Arc<SessionContext>,
532}
533
534#[async_trait]
535impl QueryParser for Parser {
536    type Statement = (String, LogicalPlan);
537
538    async fn parse_sql<C>(
539        &self,
540        _client: &C,
541        sql: &str,
542        _types: &[Type],
543    ) -> PgWireResult<Self::Statement> {
544        log::debug!("Received parse extended query: {}", sql); // Log for debugging
545        let context = &self.session_context;
546        let state = context.state();
547        let logical_plan = state
548            .create_logical_plan(sql)
549            .await
550            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
551        let optimised = state
552            .optimize(&logical_plan)
553            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
554        Ok((sql.to_string(), optimised))
555    }
556}
557
558fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
559    // Datafusion stores the parameters as a map.  In our case, the keys will be
560    // `$1`, `$2` etc.  The values will be the parameter types.
561    let mut types = types.iter().collect::<Vec<_>>();
562    types.sort_by(|a, b| a.0.cmp(b.0));
563    types.into_iter().map(|pt| pt.1.as_ref()).collect()
564}