use crate::sqlx;
use std::future::Future;
use cratestack_core::{CoolError, TransactionIsolation};
const MAX_RETRIES_DEFAULT: u32 = 3;
const PG_SERIALIZATION_FAILURE_SQLSTATE: &str = "40001";
const PG_DEADLOCK_DETECTED_SQLSTATE: &str = "40P01";
pub async fn run_in_isolated_tx<F, Fut, T>(
pool: &sqlx::PgPool,
isolation: TransactionIsolation,
body: F,
) -> Result<T, CoolError>
where
F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
{
run_in_isolated_tx_with_retries(pool, isolation, MAX_RETRIES_DEFAULT, body).await
}
pub async fn run_in_isolated_tx_with_retries<F, Fut, T>(
pool: &sqlx::PgPool,
isolation: TransactionIsolation,
max_retries: u32,
mut body: F,
) -> Result<T, CoolError>
where
F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
{
let mut attempts = 0u32;
loop {
attempts += 1;
let mut tx = pool
.begin()
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
let set_stmt = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation.as_sql());
sqlx::query(&set_stmt)
.execute(&mut *tx)
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
match body(tx).await {
Ok((value, tx)) => match tx.commit().await {
Ok(()) => return Ok(value),
Err(commit_error) => {
let promoted = CoolError::Database(commit_error.to_string());
if attempts <= max_retries && is_retriable(&promoted) {
tokio::task::yield_now().await;
continue;
}
return Err(promoted);
}
},
Err(error) => {
if attempts <= max_retries && is_retriable(&error) {
tokio::task::yield_now().await;
continue;
}
return Err(error);
}
}
}
}
fn is_retriable(error: &CoolError) -> bool {
let detail = error.detail().unwrap_or_default();
detail.contains(PG_SERIALIZATION_FAILURE_SQLSTATE)
|| detail.contains(PG_DEADLOCK_DETECTED_SQLSTATE)
|| detail.contains("could not serialize access")
|| detail.contains("deadlock detected")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_all_isolation_levels() {
assert_eq!(
TransactionIsolation::parse("serializable").unwrap(),
TransactionIsolation::Serializable,
);
assert_eq!(
TransactionIsolation::parse("Repeatable_Read").unwrap(),
TransactionIsolation::RepeatableRead,
);
assert_eq!(
TransactionIsolation::parse("read committed").unwrap(),
TransactionIsolation::ReadCommitted,
);
assert!(TransactionIsolation::parse("snapshot").is_err());
}
#[test]
fn sql_strings_match_pg_grammar() {
assert_eq!(TransactionIsolation::Serializable.as_sql(), "SERIALIZABLE");
assert_eq!(
TransactionIsolation::RepeatableRead.as_sql(),
"REPEATABLE READ",
);
assert_eq!(
TransactionIsolation::ReadCommitted.as_sql(),
"READ COMMITTED",
);
}
#[test]
fn retriable_on_serialization_failure_sqlstate() {
let err = CoolError::Database(
"Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
message: \"could not serialize access due to concurrent update\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
#[test]
fn retriable_on_deadlock_sqlstate() {
let err = CoolError::Database(
"Database(PgDatabaseError { code: \"40P01\", \
message: \"deadlock detected\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
#[test]
fn not_retriable_on_unique_violation() {
let err = CoolError::Database(
"duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
);
assert!(!is_retriable(&err));
}
#[test]
fn retriable_when_serialization_failure_is_raised_at_commit_time() {
let err = CoolError::Database(
"Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
message: \"could not serialize access due to read/write dependencies among transactions\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
}