use std::{
cell::RefCell,
sync::atomic::{AtomicU32, AtomicU64, Ordering},
};
use rustrails_support::{database, runtime};
use sea_orm::{ConnectionTrait, DatabaseConnection};
use crate::base::{Record, RecordError};
type TransactionCallback = Box<dyn FnOnce() + Send>;
#[derive(Default)]
struct TransactionLevel {
savepoint_name: Option<String>,
after_commit: Vec<TransactionCallback>,
after_rollback: Vec<TransactionCallback>,
}
impl TransactionLevel {
fn outermost() -> Self {
Self::default()
}
fn nested(savepoint_name: String) -> Self {
Self {
savepoint_name: Some(savepoint_name),
..Self::default()
}
}
fn absorb(&mut self, nested: Self) {
self.after_commit.extend(nested.after_commit);
self.after_rollback.extend(nested.after_rollback);
}
}
enum FinalizeAction {
Commit,
Rollback,
}
enum TransactionBoundary {
Outermost,
Nested(String),
}
thread_local! {
static OPEN_TRANSACTION_COUNT: AtomicU32 = const { AtomicU32::new(0) };
static SAVEPOINT_SEQUENCE: AtomicU32 = const { AtomicU32::new(0) };
static CURRENT_TRANSACTION_ID: RefCell<Option<String>> = const { RefCell::new(None) };
static TRANSACTION_LEVELS: RefCell<Vec<TransactionLevel>> = const { RefCell::new(Vec::new()) };
}
static NEXT_TRANSACTION_ID: AtomicU64 = AtomicU64::new(1);
#[must_use]
pub fn open_transactions() -> u32 {
OPEN_TRANSACTION_COUNT.with(|count| count.load(Ordering::SeqCst))
}
#[must_use]
pub fn transaction_open() -> bool {
open_transactions() > 0
}
#[must_use]
pub fn current_transaction_id() -> Option<String> {
CURRENT_TRANSACTION_ID.with(|current| current.borrow().clone())
}
pub fn after_commit<F>(callback: F)
where
F: FnOnce() + Send + 'static,
{
let mut callback = Some(Box::new(callback) as TransactionCallback);
let registered = TRANSACTION_LEVELS.with(|levels| {
let mut levels = levels.borrow_mut();
if let Some(level) = levels.last_mut() {
level
.after_commit
.push(callback.take().expect("after_commit callback should exist"));
true
} else {
false
}
});
if !registered {
callback.expect("after_commit callback should exist outside transactions")();
}
}
pub fn after_rollback<F>(callback: F)
where
F: FnOnce() + Send + 'static,
{
let mut callback = Some(Box::new(callback) as TransactionCallback);
TRANSACTION_LEVELS.with(|levels| {
let mut levels = levels.borrow_mut();
if let Some(level) = levels.last_mut() {
level.after_rollback.push(
callback
.take()
.expect("after_rollback callback should exist"),
);
}
});
}
pub async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
where
F: FnOnce(&DatabaseConnection) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
T: Send,
{
begin_transaction_scope(db).await?;
match f(db).await {
Ok(value) => {
commit_transaction_scope(db).await?;
Ok(value)
}
Err(error) => {
rollback_transaction_scope(db).await?;
Err(error)
}
}
}
pub fn transaction_sync<F, Fut, T>(f: F) -> Result<T, RecordError>
where
F: FnOnce(&DatabaseConnection) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
T: Send,
{
database::with_db(|db| runtime::block_on(transaction(db, f)))
}
async fn begin_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
if transaction_open() {
let savepoint_name = next_savepoint_name();
execute_transaction_control(db, &format!("SAVEPOINT {savepoint_name}")).await?;
OPEN_TRANSACTION_COUNT.with(|count| {
count.fetch_add(1, Ordering::SeqCst);
});
TRANSACTION_LEVELS.with(|levels| {
levels
.borrow_mut()
.push(TransactionLevel::nested(savepoint_name));
});
} else {
let transaction_id = next_transaction_id();
execute_transaction_control(db, "BEGIN").await?;
OPEN_TRANSACTION_COUNT.with(|count| count.store(1, Ordering::SeqCst));
SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
CURRENT_TRANSACTION_ID.with(|current| {
current.replace(Some(transaction_id));
});
TRANSACTION_LEVELS.with(|levels| {
levels.borrow_mut().push(TransactionLevel::outermost());
});
}
Ok(())
}
async fn commit_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
match current_transaction_boundary() {
Some(TransactionBoundary::Outermost) => {
execute_or_reset_state(db, "COMMIT").await?;
let callbacks = finish_outermost_transaction(FinalizeAction::Commit);
run_callbacks(callbacks);
Ok(())
}
Some(TransactionBoundary::Nested(savepoint_name)) => {
execute_or_reset_state(db, &format!("RELEASE SAVEPOINT {savepoint_name}")).await?;
merge_nested_transaction_into_parent();
Ok(())
}
None => Ok(()),
}
}
async fn rollback_transaction_scope(db: &DatabaseConnection) -> Result<(), RecordError> {
match current_transaction_boundary() {
Some(TransactionBoundary::Outermost) => {
execute_or_reset_state(db, "ROLLBACK").await?;
let callbacks = finish_outermost_transaction(FinalizeAction::Rollback);
run_callbacks(callbacks);
Ok(())
}
Some(TransactionBoundary::Nested(savepoint_name)) => {
execute_or_reset_state(db, &format!("ROLLBACK TO SAVEPOINT {savepoint_name}")).await?;
let callbacks = rollback_nested_transaction();
run_callbacks(callbacks);
Ok(())
}
None => Ok(()),
}
}
async fn execute_or_reset_state(db: &DatabaseConnection, sql: &str) -> Result<(), RecordError> {
if let Err(error) = execute_transaction_control(db, sql).await {
reset_transaction_state();
Err(error)
} else {
Ok(())
}
}
async fn execute_transaction_control(
db: &DatabaseConnection,
sql: &str,
) -> Result<(), RecordError> {
db.execute_unprepared(sql).await?;
Ok(())
}
fn current_transaction_boundary() -> Option<TransactionBoundary> {
TRANSACTION_LEVELS.with(|levels| {
let levels = levels.borrow();
levels.last().map(|level| match &level.savepoint_name {
Some(savepoint_name) => TransactionBoundary::Nested(savepoint_name.clone()),
None => TransactionBoundary::Outermost,
})
})
}
fn merge_nested_transaction_into_parent() {
TRANSACTION_LEVELS.with(|levels| {
let mut levels = levels.borrow_mut();
let nested = levels
.pop()
.expect("nested transaction state should exist during commit");
let parent = levels
.last_mut()
.expect("parent transaction state should exist during nested commit");
parent.absorb(nested);
});
OPEN_TRANSACTION_COUNT.with(|count| {
count.fetch_sub(1, Ordering::SeqCst);
});
}
fn rollback_nested_transaction() -> Vec<TransactionCallback> {
let callbacks = TRANSACTION_LEVELS.with(|levels| {
let mut levels = levels.borrow_mut();
levels
.pop()
.expect("nested transaction state should exist during rollback")
.after_rollback
});
OPEN_TRANSACTION_COUNT.with(|count| {
count.fetch_sub(1, Ordering::SeqCst);
});
callbacks
}
fn finish_outermost_transaction(action: FinalizeAction) -> Vec<TransactionCallback> {
OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
CURRENT_TRANSACTION_ID.with(|current| {
current.replace(None);
});
TRANSACTION_LEVELS.with(|levels| {
let mut levels = levels.borrow_mut();
let outermost = levels
.pop()
.expect("outermost transaction state should exist during finalization");
levels.clear();
match action {
FinalizeAction::Commit => outermost.after_commit,
FinalizeAction::Rollback => outermost.after_rollback,
}
})
}
fn next_transaction_id() -> String {
format!("tx-{}", NEXT_TRANSACTION_ID.fetch_add(1, Ordering::Relaxed))
}
fn next_savepoint_name() -> String {
SAVEPOINT_SEQUENCE.with(|sequence| {
let next = sequence.fetch_add(1, Ordering::SeqCst) + 1;
format!("sp_{next}")
})
}
fn reset_transaction_state() {
OPEN_TRANSACTION_COUNT.with(|count| count.store(0, Ordering::SeqCst));
SAVEPOINT_SEQUENCE.with(|sequence| sequence.store(0, Ordering::SeqCst));
CURRENT_TRANSACTION_ID.with(|current| {
current.replace(None);
});
TRANSACTION_LEVELS.with(|levels| levels.borrow_mut().clear());
}
fn run_callbacks(callbacks: Vec<TransactionCallback>) {
for callback in callbacks {
callback();
}
}
pub trait Transactional: Record {
async fn transaction<F, Fut, T>(db: &DatabaseConnection, f: F) -> Result<T, RecordError>
where
F: FnOnce(&DatabaseConnection) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, RecordError>> + Send,
T: Send,
{
crate::transactions::transaction(db, f).await
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering as AtomicOrdering},
},
};
use sea_orm::{ConnectionTrait, Schema};
use serde_json::{Value, json};
use super::{
Transactional, after_commit, after_rollback, current_transaction_id, open_transactions,
transaction, transaction_open, transaction_sync,
};
use crate::{
Record, RecordError,
base::test_support::{TestUser, setup_db, test_user},
persistence::AsyncPersistence,
querying::AsyncQuerying,
};
use rustrails_support::{database, runtime};
fn run_sync_transaction_test(test: impl FnOnce() + Send + 'static) {
std::thread::spawn(move || {
let _rt = runtime::init_runtime();
database::establish("sqlite::memory:")
.expect("sqlite in-memory connection should succeed");
runtime::block_on(async {
let db = database::db();
let schema = Schema::new(db.get_database_backend());
db.execute(&schema.create_table_from_entity(test_user::Entity))
.await
.expect("test_users table should be created");
});
test();
})
.join()
.unwrap();
}
fn user_attrs(name: &str, email: &str) -> HashMap<String, Value> {
HashMap::from([
("name".to_owned(), json!(name)),
("email".to_owned(), json!(email)),
])
}
impl Transactional for TestUser {}
#[tokio::test]
async fn transaction_commits_on_success() {
let db = setup_db().await;
transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
Ok(())
}
})
.await
.expect("transaction should commit");
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
}
#[tokio::test]
async fn transaction_rolls_back_on_error() {
let db = setup_db().await;
let error = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
}
})
.await
.expect_err("transaction should fail");
assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
}
#[tokio::test]
async fn transaction_returns_closure_value() {
let db = setup_db().await;
let id = transaction(&db, |txn| {
let txn = txn.clone();
async move {
let user = TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
user.id().ok_or(RecordError::NotSaved)
}
})
.await
.expect("transaction should return a value");
assert_eq!(id, 1);
}
#[tokio::test]
async fn transactional_trait_delegates_to_helper() {
let db = setup_db().await;
TestUser::transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Bob")),
("email".to_owned(), json!("bob@example.com")),
]),
&txn,
)
.await?;
Ok(())
}
})
.await
.expect("trait helper should commit");
let user = TestUser::find(1, &db).await.expect("user should exist");
assert_eq!(user.name, "Bob");
}
#[tokio::test]
async fn rollback_preserves_rows_outside_failed_transaction() {
let db = setup_db().await;
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let _: Result<(), RecordError> = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Bob")),
("email".to_owned(), json!("bob@example.com")),
]),
&txn,
)
.await?;
Err(RecordError::Invalid("rollback".to_owned()))
}
})
.await;
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
}
#[tokio::test]
async fn transaction_commits_multiple_writes() {
let db = setup_db().await;
transaction(&db, |txn| {
let txn = txn.clone();
async move {
for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
TestUser::create(
HashMap::from([
("name".to_owned(), json!(name)),
("email".to_owned(), json!(email)),
]),
&txn,
)
.await?;
}
Ok(())
}
})
.await
.expect("multi-write transaction should commit");
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
}
#[tokio::test]
async fn transaction_rolls_back_multiple_writes() {
let db = setup_db().await;
let error = transaction(&db, |txn| {
let txn = txn.clone();
async move {
for (name, email) in [("Alice", "alice@example.com"), ("Bob", "bob@example.com")] {
TestUser::create(
HashMap::from([
("name".to_owned(), json!(name)),
("email".to_owned(), json!(email)),
]),
&txn,
)
.await?;
}
Err::<(), RecordError>(RecordError::Invalid("rollback all writes".to_owned()))
}
})
.await
.expect_err("multi-write transaction should fail");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback all writes"));
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
}
#[tokio::test]
async fn transaction_exposes_uncommitted_writes_inside_closure() {
let db = setup_db().await;
let visible_count = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
TestUser::count(&txn).await
}
})
.await
.expect("transaction should return the in-transaction count");
assert_eq!(visible_count, 1);
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
}
#[tokio::test]
async fn transaction_rolls_back_writes_visible_inside_failed_closure() {
let db = setup_db().await;
let visible_count = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
let count = TestUser::count(&txn).await?;
Err::<u64, RecordError>(RecordError::Invalid(format!(
"count before rollback: {count}"
)))
}
})
.await
.expect_err("transaction should fail");
assert!(
matches!(visible_count, RecordError::Invalid(message) if message == "count before rollback: 1")
);
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
}
#[tokio::test]
async fn transaction_can_read_seeded_rows_and_insert_more() {
let db = setup_db().await;
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let counts = transaction(&db, |txn| {
let txn = txn.clone();
async move {
let before = TestUser::count(&txn).await?;
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Bob")),
("email".to_owned(), json!("bob@example.com")),
]),
&txn,
)
.await?;
let after = TestUser::count(&txn).await?;
Ok((before, after))
}
})
.await
.expect("transaction should commit");
assert_eq!(counts, (1, 2));
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
}
#[tokio::test]
#[ignore = "Nested savepoint-backed transactions are supported on the same connection"]
async fn nested_transaction_on_same_connection_returns_database_error() {
let db = setup_db().await;
let error = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Outer")),
("email".to_owned(), json!("outer@example.com")),
]),
&txn,
)
.await?;
let nested = transaction(&txn, |inner_txn| {
let inner_txn = inner_txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Inner")),
("email".to_owned(), json!("inner@example.com")),
]),
&inner_txn,
)
.await?;
Ok(())
}
})
.await;
assert!(nested.is_err());
nested
}
})
.await
.expect_err("nested transaction should fail on the same connection");
assert!(matches!(error, RecordError::Database(_)));
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
}
#[tokio::test]
async fn transaction_rollback_preserves_seeded_rows_on_multiwrite_failure() {
let db = setup_db().await;
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let _ = transaction(&db, |txn| {
let txn = txn.clone();
async move {
for (name, email) in [("Bob", "bob@example.com"), ("Carol", "carol@example.com")] {
TestUser::create(
HashMap::from([
("name".to_owned(), json!(name)),
("email".to_owned(), json!(email)),
]),
&txn,
)
.await?;
}
Err::<(), RecordError>(RecordError::NotSaved)
}
})
.await;
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 1);
assert_eq!(
TestUser::find(1, &db)
.await
.expect("seed row should still exist")
.name,
"Alice"
);
}
#[tokio::test]
async fn manual_rollback_via_not_saved_error_rolls_back() {
let db = setup_db().await;
let error = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
Err::<(), RecordError>(RecordError::NotSaved)
}
})
.await
.expect_err("manual rollback should bubble the original error");
assert!(matches!(error, RecordError::NotSaved));
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 0);
}
#[tokio::test]
async fn transactional_trait_can_return_tuple_values() {
let db = setup_db().await;
let result = TestUser::transaction(&db, |txn| {
let txn = txn.clone();
async move {
let user = TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
Ok((
user.id().expect("id should be assigned"),
TestUser::count(&txn).await?,
))
}
})
.await
.expect("trait helper should return tuple values");
assert_eq!(result, (1, 1));
}
#[tokio::test]
async fn transaction_without_writes_can_return_existing_count() {
let db = setup_db().await;
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let count = transaction(&db, |txn| {
let txn = txn.clone();
async move { TestUser::count(&txn).await }
})
.await
.expect("read-only transaction should commit");
assert_eq!(count, 1);
}
#[tokio::test]
async fn transaction_commits_updates_to_existing_rows() {
let db = setup_db().await;
let user = TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let id = user.id().expect("seed row should have an id");
transaction(&db, |txn| {
let txn = txn.clone();
async move {
let mut user = TestUser::find(id, &txn).await?;
user.update_attributes(
HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
&txn,
)
.await?;
Ok(())
}
})
.await
.expect("update transaction should commit");
let reloaded = TestUser::find(id, &db)
.await
.expect("updated row should load after commit");
assert_eq!(reloaded.name, "Updated Alice");
assert_eq!(reloaded.email, "alice@example.com");
}
#[tokio::test]
async fn transaction_rolls_back_updates_to_existing_rows() {
let db = setup_db().await;
let user = TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("seed insert should succeed");
let id = user.id().expect("seed row should have an id");
let error = transaction(&db, |txn| {
let txn = txn.clone();
async move {
let mut user = TestUser::find(id, &txn).await?;
user.update_attributes(
HashMap::from([("name".to_owned(), json!("Updated Alice"))]),
&txn,
)
.await?;
Err::<(), RecordError>(RecordError::Invalid("rollback update".to_owned()))
}
})
.await
.expect_err("update transaction should fail");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback update"));
let reloaded = TestUser::find(id, &db)
.await
.expect("seed row should still load after rollback");
assert_eq!(reloaded.name, "Alice");
assert_eq!(reloaded.email, "alice@example.com");
}
#[tokio::test]
async fn nested_transaction_commits_with_savepoint_release() {
let db = setup_db().await;
let (outer_id, inner_id) = transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
let outer_id = current_transaction_id().expect("outer transaction id should exist");
let inner_id = transaction(&txn, |inner_txn| {
let inner_txn = inner_txn.clone();
async move {
assert_eq!(open_transactions(), 2);
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Ok::<String, RecordError>(
current_transaction_id()
.expect("nested transaction should reuse the outer id"),
)
}
})
.await?;
assert_eq!(TestUser::count(&txn).await?, 2);
Ok((outer_id, inner_id))
}
})
.await
.expect("nested transaction should commit");
assert_eq!(outer_id, inner_id);
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
}
#[test]
fn transaction_sync_commits_on_success() {
run_sync_transaction_test(|| {
transaction_sync(|txn| {
let txn = txn.clone();
async move {
TestUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&txn,
)
.await?;
Ok(())
}
})
.expect("transaction should commit");
let count = runtime::block_on(async {
let db = database::db();
TestUser::count(&db).await.expect("count should succeed")
});
assert_eq!(count, 1);
});
}
#[tokio::test]
async fn nested_transaction_rollback_to_savepoint_preserves_outer_changes() {
let db = setup_db().await;
transaction(&db, |txn| {
let txn = txn.clone();
async move {
TestUser::create(user_attrs("Outer", "outer@example.com"), &txn).await?;
let error = transaction(&txn, |inner_txn| {
let inner_txn = inner_txn.clone();
async move {
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
}
})
.await
.expect_err("inner transaction should roll back to its savepoint");
assert!(
matches!(error, RecordError::Invalid(message) if message == "rollback inner")
);
assert_eq!(open_transactions(), 1);
assert_eq!(TestUser::count(&txn).await?, 1);
TestUser::create(user_attrs("AfterInner", "after@example.com"), &txn).await?;
Ok(())
}
})
.await
.expect("outer transaction should still commit");
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 2);
assert_eq!(
TestUser::find(1, &db)
.await
.expect("outer row should persist")
.name,
"Outer"
);
assert_eq!(
TestUser::find(2, &db)
.await
.expect("post-rollback outer row should persist")
.name,
"AfterInner"
);
}
#[tokio::test]
async fn after_commit_fires_after_outermost_commit_only() {
let db = setup_db().await;
let calls = Arc::new(AtomicUsize::new(0));
let transaction_calls = Arc::clone(&calls);
transaction(&db, move |txn| {
let txn = txn.clone();
let calls = Arc::clone(&transaction_calls);
let outer_calls = Arc::clone(&calls);
async move {
after_commit(move || {
outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
let nested_calls = Arc::clone(&calls);
transaction(&txn, move |inner_txn| {
let inner_txn = inner_txn.clone();
let calls = Arc::clone(&nested_calls);
let inner_calls = Arc::clone(&calls);
async move {
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
after_commit(move || {
inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
Ok(())
}
})
.await?;
assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
Ok(())
}
})
.await
.expect("outer transaction should commit");
assert_eq!(calls.load(AtomicOrdering::SeqCst), 2);
}
#[tokio::test]
async fn after_commit_callbacks_fire_in_registration_order_across_nested_transactions() {
let db = setup_db().await;
let events = Arc::new(Mutex::new(Vec::new()));
let transaction_events = Arc::clone(&events);
transaction(&db, move |txn| {
let txn = txn.clone();
let events = Arc::clone(&transaction_events);
let outer_events = Arc::clone(&events);
async move {
after_commit(move || outer_events.lock().unwrap().push("outer-1".to_owned()));
let nested_events = Arc::clone(&events);
transaction(&txn, move |inner_txn| {
let inner_txn = inner_txn.clone();
let inner_events = Arc::clone(&nested_events);
async move {
after_commit(move || {
inner_events.lock().unwrap().push("inner-1".to_owned());
});
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Ok(())
}
})
.await?;
let trailing_events = Arc::clone(&events);
after_commit(move || {
trailing_events.lock().unwrap().push("outer-2".to_owned());
});
Ok(())
}
})
.await
.expect("callback ordering transaction should commit");
assert_eq!(
*events.lock().unwrap(),
vec![
"outer-1".to_owned(),
"inner-1".to_owned(),
"outer-2".to_owned()
]
);
}
#[tokio::test]
async fn after_commit_callbacks_do_not_fire_when_outer_transaction_rolls_back() {
let db = setup_db().await;
let calls = Arc::new(AtomicUsize::new(0));
let transaction_calls = Arc::clone(&calls);
let error = transaction(&db, move |txn| {
let txn = txn.clone();
let calls = Arc::clone(&transaction_calls);
let outer_calls = Arc::clone(&calls);
async move {
after_commit(move || {
outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
let nested_calls = Arc::clone(&calls);
transaction(&txn, move |inner_txn| {
let inner_txn = inner_txn.clone();
let calls = Arc::clone(&nested_calls);
let inner_calls = Arc::clone(&calls);
async move {
after_commit(move || {
inner_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Ok(())
}
})
.await?;
Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
}
})
.await
.expect_err("outer transaction should roll back");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
assert_eq!(calls.load(AtomicOrdering::SeqCst), 0);
}
#[tokio::test]
async fn after_commit_callbacks_are_cleared_after_commit() {
let db = setup_db().await;
let calls = Arc::new(AtomicUsize::new(0));
transaction(&db, |_| {
let calls = Arc::clone(&calls);
async move {
after_commit(move || {
calls.fetch_add(1, AtomicOrdering::SeqCst);
});
Ok(())
}
})
.await
.expect("first transaction should commit");
transaction(&db, |_| async move { Ok(()) })
.await
.expect("second transaction should commit");
assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
}
#[tokio::test]
async fn after_rollback_fires_on_outermost_rollback() {
let db = setup_db().await;
let calls = Arc::new(AtomicUsize::new(0));
let callback_calls = Arc::clone(&calls);
let error = transaction(&db, move |_txn| {
let outer_calls = Arc::clone(&callback_calls);
async move {
after_rollback(move || {
outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
}
})
.await
.expect_err("transaction should roll back");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
}
#[tokio::test]
async fn nested_rollback_fires_only_inner_after_rollback_callbacks() {
let db = setup_db().await;
let calls = Arc::new(Mutex::new(Vec::new()));
let transaction_calls = Arc::clone(&calls);
transaction(&db, move |txn| {
let txn = txn.clone();
let calls = Arc::clone(&transaction_calls);
let outer_calls = Arc::clone(&calls);
async move {
after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
let nested_calls = Arc::clone(&calls);
let error = transaction(&txn, move |inner_txn| {
let inner_txn = inner_txn.clone();
let calls = Arc::clone(&nested_calls);
let inner_calls = Arc::clone(&calls);
async move {
after_rollback(move || {
inner_calls.lock().unwrap().push("inner".to_owned())
});
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
}
})
.await
.expect_err("inner transaction should roll back");
assert!(
matches!(error, RecordError::Invalid(message) if message == "rollback inner")
);
assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
Ok(())
}
})
.await
.expect("outer transaction should commit");
assert_eq!(*calls.lock().unwrap(), vec!["inner".to_owned()]);
}
#[tokio::test]
async fn nested_successful_after_rollback_callbacks_fire_if_outer_transaction_rolls_back() {
let db = setup_db().await;
let calls = Arc::new(Mutex::new(Vec::new()));
let transaction_calls = Arc::clone(&calls);
let error = transaction(&db, move |txn| {
let txn = txn.clone();
let calls = Arc::clone(&transaction_calls);
let outer_calls = Arc::clone(&calls);
async move {
after_rollback(move || outer_calls.lock().unwrap().push("outer".to_owned()));
let nested_calls = Arc::clone(&calls);
transaction(&txn, move |inner_txn| {
let inner_txn = inner_txn.clone();
let calls = Arc::clone(&nested_calls);
let inner_calls = Arc::clone(&calls);
async move {
after_rollback(move || {
inner_calls.lock().unwrap().push("inner".to_owned())
});
TestUser::create(user_attrs("Inner", "inner@example.com"), &inner_txn)
.await?;
Ok(())
}
})
.await?;
Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
}
})
.await
.expect_err("outer transaction should roll back");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
assert_eq!(
*calls.lock().unwrap(),
vec!["outer".to_owned(), "inner".to_owned()]
);
}
#[tokio::test]
async fn after_rollback_callbacks_are_cleared_after_rollback() {
let db = setup_db().await;
let calls = Arc::new(AtomicUsize::new(0));
let callback_calls = Arc::clone(&calls);
let _ = transaction(&db, move |_txn| {
let outer_calls = Arc::clone(&callback_calls);
async move {
after_rollback(move || {
outer_calls.fetch_add(1, AtomicOrdering::SeqCst);
});
Err::<(), RecordError>(RecordError::Invalid("rollback once".to_owned()))
}
})
.await;
transaction(&db, |_| async move { Ok(()) })
.await
.expect("later transaction should commit");
assert_eq!(calls.load(AtomicOrdering::SeqCst), 1);
}
#[tokio::test]
async fn open_transactions_starts_closed() {
assert_eq!(open_transactions(), 0);
assert!(!transaction_open());
}
#[tokio::test]
async fn open_transactions_tracks_nested_depth() {
let db = setup_db().await;
assert_eq!(open_transactions(), 0);
transaction(&db, |txn| {
let txn = txn.clone();
async move {
assert_eq!(open_transactions(), 1);
transaction(&txn, |_| async move {
assert_eq!(open_transactions(), 2);
Ok(())
})
.await?;
assert_eq!(open_transactions(), 1);
Ok(())
}
})
.await
.expect("nested transaction should commit");
assert_eq!(open_transactions(), 0);
}
#[tokio::test]
async fn transaction_open_reflects_current_state() {
let db = setup_db().await;
assert!(!transaction_open());
transaction(&db, |txn| {
let txn = txn.clone();
async move {
assert!(transaction_open());
transaction(&txn, |_| async move {
assert!(transaction_open());
Ok(())
})
.await?;
assert!(transaction_open());
Ok(())
}
})
.await
.expect("transaction should commit");
assert!(!transaction_open());
}
#[tokio::test]
async fn current_transaction_id_is_none_outside_transactions() {
assert_eq!(current_transaction_id(), None);
let db = setup_db().await;
transaction(&db, |_| async move { Ok(()) })
.await
.expect("transaction should commit");
assert_eq!(current_transaction_id(), None);
}
#[tokio::test]
async fn current_transaction_id_is_stable_across_nested_transactions() {
let db = setup_db().await;
let (outer_id, inner_id, after_inner_id) = transaction(&db, |txn| {
let txn = txn.clone();
async move {
let outer_id = current_transaction_id().expect("outer transaction id should exist");
let inner_id = transaction(&txn, |_| async move {
Ok::<String, RecordError>(
current_transaction_id().expect("nested transaction id should exist"),
)
})
.await?;
let after_inner_id =
current_transaction_id().expect("outer transaction id should still exist");
Ok((outer_id, inner_id, after_inner_id))
}
})
.await
.expect("transaction should commit");
assert_eq!(outer_id, inner_id);
assert_eq!(outer_id, after_inner_id);
}
#[tokio::test]
async fn current_transaction_id_changes_between_outer_transactions() {
let db = setup_db().await;
let first = transaction(&db, |_| async move {
Ok::<String, RecordError>(
current_transaction_id().expect("transaction id should exist"),
)
})
.await
.expect("first transaction should commit");
let second = transaction(&db, |_| async move {
Ok::<String, RecordError>(
current_transaction_id().expect("transaction id should exist"),
)
})
.await
.expect("second transaction should commit");
assert_ne!(first, second);
}
#[tokio::test]
async fn transaction_state_clears_after_outer_rollback() {
let db = setup_db().await;
let error = transaction(&db, |_| async move {
assert_eq!(open_transactions(), 1);
assert!(transaction_open());
assert!(current_transaction_id().is_some());
Err::<(), RecordError>(RecordError::Invalid("rollback outer".to_owned()))
})
.await
.expect_err("transaction should roll back");
assert!(matches!(error, RecordError::Invalid(message) if message == "rollback outer"));
assert_eq!(open_transactions(), 0);
assert!(!transaction_open());
assert_eq!(current_transaction_id(), None);
}
#[tokio::test]
async fn nested_rollback_restores_outer_transaction_state() {
let db = setup_db().await;
transaction(&db, |txn| {
let txn = txn.clone();
async move {
let outer_id = current_transaction_id().expect("outer transaction id should exist");
let _ = transaction(&txn, |_| async move {
assert_eq!(open_transactions(), 2);
Err::<(), RecordError>(RecordError::Invalid("rollback inner".to_owned()))
})
.await
.expect_err("inner transaction should roll back");
assert_eq!(open_transactions(), 1);
assert!(transaction_open());
assert_eq!(
current_transaction_id().expect("outer transaction id should remain"),
outer_id
);
Ok(())
}
})
.await
.expect("outer transaction should commit");
assert_eq!(open_transactions(), 0);
assert_eq!(current_transaction_id(), None);
}
}