datafusion_postgres/hooks/
transactions.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use pgwire::api::results::{Response, Tag};
9use pgwire::api::ClientInfo;
10use pgwire::error::{PgWireError, PgWireResult};
11use pgwire::messages::response::TransactionStatus;
12
13use crate::hooks::HookClient;
14use crate::QueryHook;
15
16#[derive(Debug)]
21pub struct TransactionStatementHook;
22
23#[async_trait]
24impl QueryHook for TransactionStatementHook {
25 async fn handle_simple_query(
27 &self,
28 statement: &Statement,
29 _session_context: &SessionContext,
30 client: &mut dyn HookClient,
31 ) -> Option<PgWireResult<Response>> {
32 let resp = try_respond_transaction_statements(client, statement)
33 .await
34 .transpose();
35
36 if let Some(result) = resp {
37 return Some(result);
38 }
39
40 if client.transaction_status() == TransactionStatus::Error {
43 return Some(Err(PgWireError::UserError(Box::new(
44 pgwire::error::ErrorInfo::new(
45 "ERROR".to_string(),
46 "25P01".to_string(),
47 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
48 ),
49 ))));
50 }
51
52 None
53 }
54
55 async fn handle_extended_parse_query(
56 &self,
57 stmt: &Statement,
58 _session_context: &SessionContext,
59 _client: &(dyn ClientInfo + Send + Sync),
60 ) -> Option<PgWireResult<LogicalPlan>> {
61 if matches!(
63 stmt,
64 Statement::StartTransaction { .. }
65 | Statement::Commit { .. }
66 | Statement::Rollback { .. }
67 ) {
68 let dummy_schema = datafusion::common::DFSchema::empty();
70 return Some(Ok(LogicalPlan::EmptyRelation(
71 datafusion::logical_expr::EmptyRelation {
72 produce_one_row: false,
73 schema: Arc::new(dummy_schema),
74 },
75 )));
76 }
77 None
78 }
79
80 async fn handle_extended_query(
81 &self,
82 statement: &Statement,
83 _logical_plan: &LogicalPlan,
84 _params: &ParamValues,
85 session_context: &SessionContext,
86 client: &mut dyn HookClient,
87 ) -> Option<PgWireResult<Response>> {
88 self.handle_simple_query(statement, session_context, client)
89 .await
90 }
91}
92
93async fn try_respond_transaction_statements<C>(
94 client: &C,
95 stmt: &Statement,
96) -> PgWireResult<Option<Response>>
97where
98 C: ClientInfo + Send + Sync + ?Sized,
99{
100 match stmt {
101 Statement::StartTransaction { .. } => {
102 match client.transaction_status() {
103 TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))),
104 TransactionStatus::Transaction => {
105 log::warn!("BEGIN command ignored: already in transaction block");
108 Ok(Some(Response::Execution(Tag::new("BEGIN"))))
109 }
110 TransactionStatus::Error => {
111 Err(PgWireError::UserError(Box::new(
113 pgwire::error::ErrorInfo::new(
114 "ERROR".to_string(),
115 "25P01".to_string(),
116 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
117 ),
118 )))
119 }
120 }
121 }
122 Statement::Commit { .. } => match client.transaction_status() {
123 TransactionStatus::Idle | TransactionStatus::Transaction => {
124 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
125 }
126 TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
127 },
128 Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
129 _ => Ok(None),
130 }
131}