Skip to main content

datafusion_postgres/hooks/
transactions.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use pgwire::api::results::{Response, Tag};
9use pgwire::api::ClientInfo;
10use pgwire::error::{PgWireError, PgWireResult};
11use pgwire::messages::response::TransactionStatus;
12
13use crate::hooks::HookClient;
14use crate::QueryHook;
15
16/// Hook for processing transaction related statements
17///
18/// Note that this hook doesn't create actual transactions. It just responds
19/// with reasonable return values.
20#[derive(Debug)]
21pub struct TransactionStatementHook;
22
23#[async_trait]
24impl QueryHook for TransactionStatementHook {
25    /// called in simple query handler to return response directly
26    async fn handle_simple_query(
27        &self,
28        statement: &Statement,
29        _session_context: &SessionContext,
30        client: &mut dyn HookClient,
31    ) -> Option<PgWireResult<Response>> {
32        let resp = try_respond_transaction_statements(client, statement)
33            .await
34            .transpose();
35
36        if let Some(result) = resp {
37            return Some(result);
38        }
39
40        // Check if we're in a failed transaction and block non-transaction
41        // commands
42        if client.transaction_status() == TransactionStatus::Error {
43            return Some(Err(PgWireError::UserError(Box::new(
44                pgwire::error::ErrorInfo::new(
45                    "ERROR".to_string(),
46                    "25P01".to_string(),
47                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
48                ),
49            ))));
50        }
51
52        None
53    }
54
55    async fn handle_extended_parse_query(
56        &self,
57        stmt: &Statement,
58        _session_context: &SessionContext,
59        _client: &(dyn ClientInfo + Send + Sync),
60    ) -> Option<PgWireResult<LogicalPlan>> {
61        // We don't generate logical plan for these statements
62        if matches!(
63            stmt,
64            Statement::StartTransaction { .. }
65                | Statement::Commit { .. }
66                | Statement::Rollback { .. }
67        ) {
68            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
69            let dummy_schema = datafusion::common::DFSchema::empty();
70            return Some(Ok(LogicalPlan::EmptyRelation(
71                datafusion::logical_expr::EmptyRelation {
72                    produce_one_row: false,
73                    schema: Arc::new(dummy_schema),
74                },
75            )));
76        }
77        None
78    }
79
80    async fn handle_extended_query(
81        &self,
82        statement: &Statement,
83        _logical_plan: &LogicalPlan,
84        _params: &ParamValues,
85        session_context: &SessionContext,
86        client: &mut dyn HookClient,
87    ) -> Option<PgWireResult<Response>> {
88        self.handle_simple_query(statement, session_context, client)
89            .await
90    }
91}
92
93async fn try_respond_transaction_statements<C>(
94    client: &C,
95    stmt: &Statement,
96) -> PgWireResult<Option<Response>>
97where
98    C: ClientInfo + Send + Sync + ?Sized,
99{
100    match stmt {
101        Statement::StartTransaction { .. } => {
102            match client.transaction_status() {
103                TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))),
104                TransactionStatus::Transaction => {
105                    // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
106                    // This matches PostgreSQL's handling of nested transaction blocks
107                    log::warn!("BEGIN command ignored: already in transaction block");
108                    Ok(Some(Response::Execution(Tag::new("BEGIN"))))
109                }
110                TransactionStatus::Error => {
111                    // Can't start new transaction from failed state
112                    Err(PgWireError::UserError(Box::new(
113                            pgwire::error::ErrorInfo::new(
114                                "ERROR".to_string(),
115                                "25P01".to_string(),
116                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
117                            ),
118                        )))
119                }
120            }
121        }
122        Statement::Commit { .. } => match client.transaction_status() {
123            TransactionStatus::Idle | TransactionStatus::Transaction => {
124                Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
125            }
126            TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
127        },
128        Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
129        _ => Ok(None),
130    }
131}