use crate::sqlx;
use std::future::Future;
use cratestack_core::{CoolError, TransactionIsolation};
use crate::error::cool_error_from_sqlx;
const MAX_RETRIES_DEFAULT: u32 = 3;
const PG_SERIALIZATION_FAILURE_SQLSTATE: &str = "40001";
const PG_DEADLOCK_DETECTED_SQLSTATE: &str = "40P01";
pub async fn run_in_isolated_tx<F, Fut, T>(
pool: &sqlx::PgPool,
isolation: TransactionIsolation,
body: F,
) -> Result<T, CoolError>
where
F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
{
run_in_isolated_tx_with_retries(pool, isolation, MAX_RETRIES_DEFAULT, body).await
}
pub async fn run_in_isolated_tx_with_retries<F, Fut, T>(
pool: &sqlx::PgPool,
isolation: TransactionIsolation,
max_retries: u32,
mut body: F,
) -> Result<T, CoolError>
where
F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
{
let mut attempts = 0u32;
loop {
attempts += 1;
let mut tx = pool.begin().await.map_err(cool_error_from_sqlx)?;
let set_stmt = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation.as_sql());
sqlx::query(&set_stmt)
.execute(&mut *tx)
.await
.map_err(cool_error_from_sqlx)?;
match body(tx).await {
Ok((value, tx)) => match tx.commit().await {
Ok(()) => return Ok(value),
Err(commit_error) => {
let promoted = cool_error_from_sqlx(commit_error);
if attempts <= max_retries && is_retriable(&promoted) {
tokio::task::yield_now().await;
continue;
}
return Err(promoted);
}
},
Err(error) => {
if attempts <= max_retries && is_retriable(&error) {
tokio::task::yield_now().await;
continue;
}
return Err(error);
}
}
}
}
fn is_retriable(error: &CoolError) -> bool {
if let Some(code) = error.db_sqlstate() {
if code == PG_SERIALIZATION_FAILURE_SQLSTATE || code == PG_DEADLOCK_DETECTED_SQLSTATE {
return true;
}
}
let detail = error.detail().unwrap_or_default();
detail.contains(PG_SERIALIZATION_FAILURE_SQLSTATE)
|| detail.contains(PG_DEADLOCK_DETECTED_SQLSTATE)
|| detail.contains("could not serialize access")
|| detail.contains("deadlock detected")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_all_isolation_levels() {
assert_eq!(
TransactionIsolation::parse("serializable").unwrap(),
TransactionIsolation::Serializable,
);
assert_eq!(
TransactionIsolation::parse("Repeatable_Read").unwrap(),
TransactionIsolation::RepeatableRead,
);
assert_eq!(
TransactionIsolation::parse("read committed").unwrap(),
TransactionIsolation::ReadCommitted,
);
assert!(TransactionIsolation::parse("snapshot").is_err());
}
#[test]
fn sql_strings_match_pg_grammar() {
assert_eq!(TransactionIsolation::Serializable.as_sql(), "SERIALIZABLE");
assert_eq!(
TransactionIsolation::RepeatableRead.as_sql(),
"REPEATABLE READ",
);
assert_eq!(
TransactionIsolation::ReadCommitted.as_sql(),
"READ COMMITTED",
);
}
#[test]
fn retriable_on_serialization_failure_sqlstate() {
let err = CoolError::Database(
"Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
message: \"could not serialize access due to concurrent update\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
#[test]
fn retriable_on_deadlock_sqlstate() {
let err = CoolError::Database(
"Database(PgDatabaseError { code: \"40P01\", \
message: \"deadlock detected\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
#[test]
fn not_retriable_on_unique_violation() {
let err = CoolError::Database(
"duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
);
assert!(!is_retriable(&err));
}
#[test]
fn retriable_when_serialization_failure_is_raised_at_commit_time() {
let err = CoolError::Database(
"Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
message: \"could not serialize access due to read/write dependencies among transactions\" })"
.to_owned(),
);
assert!(is_retriable(&err));
}
#[test]
fn retriable_typed_serialization_failure() {
use cratestack_core::DbErrorInfo;
let err = CoolError::DatabaseTyped(DbErrorInfo {
detail: "could not serialize access due to concurrent update".to_owned(),
sqlstate: Some("40001".to_owned()),
constraint: None,
});
assert!(
is_retriable(&err),
"DatabaseTyped with 40001 sqlstate must be retriable via the fast path",
);
}
#[test]
fn retriable_typed_deadlock() {
use cratestack_core::DbErrorInfo;
let err = CoolError::DatabaseTyped(DbErrorInfo {
detail: "deadlock detected".to_owned(),
sqlstate: Some("40P01".to_owned()),
constraint: None,
});
assert!(
is_retriable(&err),
"DatabaseTyped with 40P01 sqlstate must be retriable via the fast path",
);
}
#[test]
fn not_retriable_typed_unique_violation() {
use cratestack_core::DbErrorInfo;
let err = CoolError::DatabaseTyped(DbErrorInfo {
detail: "duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
sqlstate: Some("23505".to_owned()),
constraint: Some("accounts_pkey".to_owned()),
});
assert!(
!is_retriable(&err),
"unique_violation (23505) must not be retried",
);
}
#[test]
fn typed_variant_with_unknown_sqlstate_falls_through_to_detail_match() {
use cratestack_core::DbErrorInfo;
let err = CoolError::DatabaseTyped(DbErrorInfo {
detail: "could not serialize access due to read/write dependencies".to_owned(),
sqlstate: Some("XX999".to_owned()),
constraint: None,
});
assert!(
is_retriable(&err),
"unknown sqlstate must fall through to detail-substring fallback",
);
}
#[test]
fn typed_variant_exposes_constraint_for_unique_violation() {
use cratestack_core::DbErrorInfo;
let err = CoolError::DatabaseTyped(DbErrorInfo {
detail: "duplicate key value violates unique constraint \"wallets_owner_key\""
.to_owned(),
sqlstate: Some("23505".to_owned()),
constraint: Some("wallets_owner_key".to_owned()),
});
assert_eq!(err.db_sqlstate(), Some("23505"));
assert_eq!(err.db_constraint(), Some("wallets_owner_key"));
assert_eq!(err.public_message(), "internal error");
}
}