datafusion_postgres/hooks/
transactions.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use datafusion::common::ParamValues;
5use datafusion::logical_expr::LogicalPlan;
6use datafusion::prelude::SessionContext;
7use datafusion::sql::sqlparser::ast::Statement;
8use pgwire::api::results::{Response, Tag};
9use pgwire::api::ClientInfo;
10use pgwire::error::{PgWireError, PgWireResult};
11use pgwire::messages::response::TransactionStatus;
12
13use crate::QueryHook;
14
15#[derive(Debug)]
20pub struct TransactionStatementHook;
21
22#[async_trait]
23impl QueryHook for TransactionStatementHook {
24 async fn handle_simple_query(
26 &self,
27 statement: &Statement,
28 _session_context: &SessionContext,
29 client: &mut (dyn ClientInfo + Send + Sync),
30 ) -> Option<PgWireResult<Response>> {
31 let resp = try_respond_transaction_statements(client, statement)
32 .await
33 .transpose();
34
35 if resp.is_some() {
36 return resp;
37 }
38
39 if client.transaction_status() == TransactionStatus::Error {
42 return Some(Err(PgWireError::UserError(Box::new(
43 pgwire::error::ErrorInfo::new(
44 "ERROR".to_string(),
45 "25P01".to_string(),
46 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
47 ),
48 ))));
49 }
50
51 None
52 }
53
54 async fn handle_extended_parse_query(
55 &self,
56 stmt: &Statement,
57 _session_context: &SessionContext,
58 _client: &(dyn ClientInfo + Send + Sync),
59 ) -> Option<PgWireResult<LogicalPlan>> {
60 if matches!(
62 stmt,
63 Statement::StartTransaction { .. }
64 | Statement::Commit { .. }
65 | Statement::Rollback { .. }
66 ) {
67 let dummy_schema = datafusion::common::DFSchema::empty();
69 return Some(Ok(LogicalPlan::EmptyRelation(
70 datafusion::logical_expr::EmptyRelation {
71 produce_one_row: false,
72 schema: Arc::new(dummy_schema),
73 },
74 )));
75 }
76 None
77 }
78
79 async fn handle_extended_query(
80 &self,
81 statement: &Statement,
82 _logical_plan: &LogicalPlan,
83 _params: &ParamValues,
84 session_context: &SessionContext,
85 client: &mut (dyn ClientInfo + Send + Sync),
86 ) -> Option<PgWireResult<Response>> {
87 self.handle_simple_query(statement, session_context, client)
88 .await
89 }
90}
91
92async fn try_respond_transaction_statements<C>(
93 client: &C,
94 stmt: &Statement,
95) -> PgWireResult<Option<Response>>
96where
97 C: ClientInfo + Send + Sync + ?Sized,
98{
99 match stmt {
100 Statement::StartTransaction { .. } => {
101 match client.transaction_status() {
102 TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))),
103 TransactionStatus::Transaction => {
104 log::warn!("BEGIN command ignored: already in transaction block");
107 Ok(Some(Response::Execution(Tag::new("BEGIN"))))
108 }
109 TransactionStatus::Error => {
110 Err(PgWireError::UserError(Box::new(
112 pgwire::error::ErrorInfo::new(
113 "ERROR".to_string(),
114 "25P01".to_string(),
115 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
116 ),
117 )))
118 }
119 }
120 }
121 Statement::Commit { .. } => match client.transaction_status() {
122 TransactionStatus::Idle | TransactionStatus::Transaction => {
123 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
124 }
125 TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
126 },
127 Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
128 _ => Ok(None),
129 }
130}