datafusion_postgres/hooks/
transactions.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use pgwire::api::results::{Response, Tag};
9use pgwire::api::ClientInfo;
10use pgwire::error::{PgWireError, PgWireResult};
11use pgwire::messages::response::TransactionStatus;
12
13use crate::QueryHook;
14
15/// Hook for processing transaction related statements
16///
17/// Note that this hook doesn't create actual transactions. It just responds
18/// with reasonable return values.
19#[derive(Debug)]
20pub struct TransactionStatementHook;
21
22#[async_trait]
23impl QueryHook for TransactionStatementHook {
24    /// called in simple query handler to return response directly
25    async fn handle_simple_query(
26        &self,
27        statement: &Statement,
28        _session_context: &SessionContext,
29        client: &mut (dyn ClientInfo + Send + Sync),
30    ) -> Option<PgWireResult<Response>> {
31        let resp = try_respond_transaction_statements(client, statement)
32            .await
33            .transpose();
34
35        if resp.is_some() {
36            return resp;
37        }
38
39        // Check if we're in a failed transaction and block non-transaction
40        // commands
41        if client.transaction_status() == TransactionStatus::Error {
42            return Some(Err(PgWireError::UserError(Box::new(
43                pgwire::error::ErrorInfo::new(
44                    "ERROR".to_string(),
45                    "25P01".to_string(),
46                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
47                ),
48            ))));
49        }
50
51        None
52    }
53
54    async fn handle_extended_parse_query(
55        &self,
56        stmt: &Statement,
57        _session_context: &SessionContext,
58        _client: &(dyn ClientInfo + Send + Sync),
59    ) -> Option<PgWireResult<LogicalPlan>> {
60        // We don't generate logical plan for these statements
61        if matches!(
62            stmt,
63            Statement::StartTransaction { .. }
64                | Statement::Commit { .. }
65                | Statement::Rollback { .. }
66        ) {
67            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
68            let dummy_schema = datafusion::common::DFSchema::empty();
69            return Some(Ok(LogicalPlan::EmptyRelation(
70                datafusion::logical_expr::EmptyRelation {
71                    produce_one_row: false,
72                    schema: Arc::new(dummy_schema),
73                },
74            )));
75        }
76        None
77    }
78
79    async fn handle_extended_query(
80        &self,
81        statement: &Statement,
82        _logical_plan: &LogicalPlan,
83        _params: &ParamValues,
84        session_context: &SessionContext,
85        client: &mut (dyn ClientInfo + Send + Sync),
86    ) -> Option<PgWireResult<Response>> {
87        self.handle_simple_query(statement, session_context, client)
88            .await
89    }
90}
91
92async fn try_respond_transaction_statements<C>(
93    client: &C,
94    stmt: &Statement,
95) -> PgWireResult<Option<Response>>
96where
97    C: ClientInfo + Send + Sync + ?Sized,
98{
99    match stmt {
100        Statement::StartTransaction { .. } => {
101            match client.transaction_status() {
102                TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))),
103                TransactionStatus::Transaction => {
104                    // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
105                    // This matches PostgreSQL's handling of nested transaction blocks
106                    log::warn!("BEGIN command ignored: already in transaction block");
107                    Ok(Some(Response::Execution(Tag::new("BEGIN"))))
108                }
109                TransactionStatus::Error => {
110                    // Can't start new transaction from failed state
111                    Err(PgWireError::UserError(Box::new(
112                            pgwire::error::ErrorInfo::new(
113                                "ERROR".to_string(),
114                                "25P01".to_string(),
115                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
116                            ),
117                        )))
118                }
119            }
120        }
121        Statement::Commit { .. } => match client.transaction_status() {
122            TransactionStatus::Idle | TransactionStatus::Transaction => {
123                Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
124            }
125            TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
126        },
127        Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
128        _ => Ok(None),
129    }
130}