use std::ops::Deref;
use std::sync::Mutex;
use std::sync::atomic::Ordering;
use deadpool::managed::Object;
use crate::job::Job;
use crate::pool::manager::JobManager;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TxState {
NotStarted,
Started,
Closed,
}
pub struct Reserved {
obj: Object<JobManager>,
rollback_on_drop: bool,
tx_state: Mutex<TxState>,
}
impl Reserved {
pub(crate) fn new(obj: Object<JobManager>) -> Self {
obj.inner.in_flight.store(u32::MAX, Ordering::Relaxed);
Self {
obj,
rollback_on_drop: false,
tx_state: Mutex::new(TxState::NotStarted),
}
}
#[must_use]
pub fn rollback_on_drop(mut self) -> Self {
self.rollback_on_drop = true;
self
}
pub async fn execute(&self, sql: &str) -> crate::Result<crate::query::Rows> {
let job: &Job = &self.obj;
let result = Job::execute(job, sql).await;
if result.is_ok() {
Self::observe_sql(&self.tx_state, sql);
}
result
}
pub async fn execute_with(
&self,
sql: &str,
params: &[serde_json::Value],
) -> crate::Result<crate::query::Rows> {
let job: &Job = &self.obj;
let result = Job::execute_with(job, sql, params).await;
if result.is_ok() {
Self::observe_sql(&self.tx_state, sql);
}
result
}
pub async fn begin(&self) -> crate::Result<crate::query::Rows> {
self.execute("BEGIN").await
}
pub async fn commit(&self) -> crate::Result<crate::query::Rows> {
self.execute("COMMIT").await
}
pub async fn rollback(&self) -> crate::Result<crate::query::Rows> {
self.execute("ROLLBACK").await
}
fn observe_sql(state: &Mutex<TxState>, sql: &str) {
let head = sql.split_whitespace().next().unwrap_or("");
let mut tx = state.lock().expect("tx_state mutex poisoned");
if head.eq_ignore_ascii_case("BEGIN") {
*tx = TxState::Started;
} else if head.eq_ignore_ascii_case("COMMIT") || head.eq_ignore_ascii_case("ROLLBACK") {
*tx = TxState::Closed;
}
}
}
impl Deref for Reserved {
type Target = Job;
fn deref(&self) -> &Job {
&self.obj
}
}
impl Drop for Reserved {
fn drop(&mut self) {
let in_tx = matches!(
*self.tx_state.lock().expect("tx_state mutex poisoned"),
TxState::Started
);
let rolled_back = self.rollback_on_drop && in_tx;
if rolled_back {
let handle = self.obj.inner.handle.clone();
let id = self.obj.inner.ids.next();
crate::job_helpers::spawn_best_effort(async move {
let req = crate::protocol::Request::Sql {
id: id.clone(),
sql: "ROLLBACK".into(),
rows: None,
parameters: None,
};
let _ = handle.send(req).await;
});
#[cfg(feature = "metrics")]
metrics::counter!(crate::observability::POOL_RESERVED_ROLLBACK_TOTAL).increment(1);
}
#[cfg(feature = "tracing")]
tracing::trace!(rolled_back, in_tx, "Reserved dropped");
self.obj.inner.in_flight.store(0, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
fn state(s: &Mutex<TxState>) -> TxState {
*s.lock().unwrap()
}
#[test]
fn observe_sql_transitions() {
let s = Mutex::new(TxState::NotStarted);
Reserved::observe_sql(&s, "BEGIN");
assert_eq!(state(&s), TxState::Started);
Reserved::observe_sql(&s, "UPDATE T SET C = 1");
assert_eq!(state(&s), TxState::Started, "DML should not change state");
Reserved::observe_sql(&s, "COMMIT");
assert_eq!(state(&s), TxState::Closed);
Reserved::observe_sql(&s, "begin"); assert_eq!(state(&s), TxState::Started);
Reserved::observe_sql(&s, " rollback "); assert_eq!(state(&s), TxState::Closed);
}
#[test]
fn observe_sql_ignores_dml_and_select() {
let s = Mutex::new(TxState::NotStarted);
Reserved::observe_sql(&s, "SELECT * FROM SYSIBM.SYSDUMMY1");
assert_eq!(state(&s), TxState::NotStarted);
Reserved::observe_sql(&s, "INSERT INTO T VALUES (1)");
assert_eq!(state(&s), TxState::NotStarted);
Reserved::observe_sql(&s, "DELETE FROM T WHERE ID = 1");
assert_eq!(state(&s), TxState::NotStarted);
}
#[test]
fn observe_sql_rollback_closes_started_state() {
let s = Mutex::new(TxState::NotStarted);
Reserved::observe_sql(&s, "BEGIN");
assert_eq!(state(&s), TxState::Started);
Reserved::observe_sql(&s, "ROLLBACK");
assert_eq!(state(&s), TxState::Closed);
}
#[test]
fn typed_helpers_keywords_drive_correct_state_transitions() {
let s = Mutex::new(TxState::NotStarted);
Reserved::observe_sql(&s, "BEGIN");
assert_eq!(
state(&s),
TxState::Started,
"begin() keyword should transition to Started"
);
Reserved::observe_sql(&s, "COMMIT");
assert_eq!(
state(&s),
TxState::Closed,
"commit() keyword should transition to Closed"
);
Reserved::observe_sql(&s, "BEGIN");
assert_eq!(state(&s), TxState::Started);
Reserved::observe_sql(&s, "ROLLBACK");
assert_eq!(
state(&s),
TxState::Closed,
"rollback() keyword should transition to Closed"
);
}
#[test]
fn observe_sql_empty_string_no_panic() {
let s = Mutex::new(TxState::NotStarted);
Reserved::observe_sql(&s, "");
assert_eq!(state(&s), TxState::NotStarted);
Reserved::observe_sql(&s, " ");
assert_eq!(state(&s), TxState::NotStarted);
}
}