use std::sync::Arc;
use async_trait::async_trait;
use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use pgwire::api::results::{Response, Tag};
use pgwire::api::ClientInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::TransactionStatus;
use crate::hooks::HookClient;
use crate::QueryHook;
#[derive(Debug)]
pub struct TransactionStatementHook;
#[async_trait]
impl QueryHook for TransactionStatementHook {
async fn handle_simple_query(
&self,
statement: &Statement,
_session_context: &SessionContext,
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
let resp = try_respond_transaction_statements(client, statement)
.await
.transpose();
if let Some(result) = resp {
return Some(result);
}
if client.transaction_status() == TransactionStatus::Error {
return Some(Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
))));
}
None
}
async fn handle_extended_parse_query(
&self,
stmt: &Statement,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
if matches!(
stmt,
Statement::StartTransaction { .. }
| Statement::Commit { .. }
| Statement::Rollback { .. }
) {
let dummy_schema = datafusion::common::DFSchema::empty();
return Some(Ok(LogicalPlan::EmptyRelation(
datafusion::logical_expr::EmptyRelation {
produce_one_row: false,
schema: Arc::new(dummy_schema),
},
)));
}
None
}
async fn handle_extended_query(
&self,
statement: &Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
session_context: &SessionContext,
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
self.handle_simple_query(statement, session_context, client)
.await
}
}
async fn try_respond_transaction_statements<C>(
client: &C,
stmt: &Statement,
) -> PgWireResult<Option<Response>>
where
C: ClientInfo + Send + Sync + ?Sized,
{
match stmt {
Statement::StartTransaction { .. } => {
match client.transaction_status() {
TransactionStatus::Idle => Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))),
TransactionStatus::Transaction => {
log::warn!("BEGIN command ignored: already in transaction block");
Ok(Some(Response::Execution(Tag::new("BEGIN"))))
}
TransactionStatus::Error => {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
)))
}
}
}
Statement::Commit { .. } => match client.transaction_status() {
TransactionStatus::Idle | TransactionStatus::Transaction => {
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
}
TransactionStatus::Error => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
},
Statement::Rollback { .. } => Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))),
_ => Ok(None),
}
}