use std::{future::Future, pin::Pin};
use sea_orm::{
ColumnTrait, ConnectionTrait, DatabaseBackend, DatabaseConnection, EntityTrait, Iterable,
QueryFilter, QuerySelect,
sea_query::{LockBehavior, LockType},
};
use crate::{
base::{Record, RecordError, RecordState},
querying::AsyncQuerying,
relation::resolve_column,
};
pub type LockFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, RecordError>> + Send + 'a>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LockOption {
#[default]
ForUpdate,
Nowait,
SkipLocked,
}
#[allow(dead_code)]
pub(crate) trait PessimisticLocking: Record {
#[allow(private_bounds)]
async fn lock(id: i64, db: &DatabaseConnection) -> Result<Self, RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
Self::lock_with_option(id, LockOption::ForUpdate, db).await
}
#[allow(private_bounds)]
async fn lock_with_option(
id: i64,
option: LockOption,
db: &DatabaseConnection,
) -> Result<Self, RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
if matches!(db.get_database_backend(), DatabaseBackend::Sqlite)
&& matches!(option, LockOption::Nowait)
{
return Err(RecordError::Invalid(
"SQLite does not support NOWAIT row locks".to_owned(),
));
}
let primary_key = resolve_column::<Self>(Self::primary_key_name())?;
let query = match option {
LockOption::ForUpdate => Self::Entity::find()
.filter(primary_key.eq(id))
.lock_exclusive(),
LockOption::Nowait => Self::Entity::find()
.filter(primary_key.eq(id))
.lock_with_behavior(LockType::Update, LockBehavior::Nowait),
LockOption::SkipLocked => Self::Entity::find()
.filter(primary_key.eq(id))
.lock_with_behavior(LockType::Update, LockBehavior::SkipLocked),
};
let model = query.one(db).await?.ok_or(RecordError::NotFound)?;
let mut record = Self::from_sea_model(model);
record.set_record_state(RecordState::Persisted);
Ok(record)
}
async fn lock_bang(&mut self, db: &DatabaseConnection) -> Result<(), RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.lock_bang_with_option(LockOption::ForUpdate, db).await
}
async fn lock_bang_with_option(
&mut self,
option: LockOption,
db: &DatabaseConnection,
) -> Result<(), RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let Some(id) = self.id() else {
return Ok(());
};
*self = Self::lock_with_option(id, option, db).await?;
Ok(())
}
async fn with_lock<F, T>(
&mut self,
db: &DatabaseConnection,
option: LockOption,
f: F,
) -> Result<T, RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
F: for<'a> FnOnce(&'a mut Self, &'a DatabaseConnection) -> LockFuture<'a, T> + Send,
T: Send,
{
let started_transaction = begin_lock_scope(db, option).await?;
let result = async {
self.lock_bang_with_option(option, db).await?;
f(self, db).await
}
.await;
if !started_transaction {
return result;
}
match result {
Ok(value) => {
db.execute_unprepared("COMMIT").await?;
Ok(value)
}
Err(error) => {
db.execute_unprepared("ROLLBACK").await?;
Err(error)
}
}
}
}
#[allow(dead_code)]
async fn begin_lock_scope(
db: &DatabaseConnection,
option: LockOption,
) -> Result<bool, RecordError> {
let begin_sql = match (db.get_database_backend(), option) {
(DatabaseBackend::Sqlite, LockOption::ForUpdate) => Some("BEGIN EXCLUSIVE"),
(DatabaseBackend::Sqlite, LockOption::SkipLocked) => Some("BEGIN"),
(DatabaseBackend::Sqlite, LockOption::Nowait) => {
return Err(RecordError::Invalid(
"SQLite does not support NOWAIT row locks".to_owned(),
));
}
(_, _) => Some("BEGIN"),
};
let Some(begin_sql) = begin_sql else {
return Ok(false);
};
match db.execute_unprepared(begin_sql).await {
Ok(_) => Ok(true),
Err(error) if transaction_already_open(&error) => Ok(false),
Err(error) => Err(error.into()),
}
}
#[allow(dead_code)]
fn transaction_already_open(error: &sea_orm::DbErr) -> bool {
let message = error.to_string().to_ascii_lowercase();
message.contains("within a transaction")
|| message.contains("already a transaction")
|| message.contains("transaction within a transaction")
|| message.contains("cannot start a transaction")
|| message.contains("transaction already in progress")
}
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
use serde_json::json;
use super::{LockOption, PessimisticLocking};
use crate::{
Record, RecordError,
base::test_support::{TestUser, seed_users, setup_db},
persistence::AsyncPersistence,
querying::AsyncQuerying,
transactions::transaction,
};
impl PessimisticLocking for TestUser {}
#[tokio::test]
async fn lock_returns_matching_record() {
let db = setup_db().await;
seed_users(&db).await;
let user = TestUser::lock(2, &db).await.expect("row should load");
assert_eq!(user.name, "Bob");
}
#[tokio::test]
async fn lock_returns_not_found_for_missing_row() {
let db = setup_db().await;
let error = TestUser::lock(404, &db)
.await
.expect_err("missing row should fail");
assert!(matches!(error, crate::RecordError::NotFound));
}
#[tokio::test]
async fn lock_marks_record_as_persisted() {
let db = setup_db().await;
seed_users(&db).await;
let user = TestUser::lock(1, &db).await.expect("row should load");
assert!(user.persisted());
}
#[tokio::test]
async fn lock_can_be_called_repeatedly() {
let db = setup_db().await;
seed_users(&db).await;
let first = TestUser::lock(1, &db)
.await
.expect("first lock should work");
let second = TestUser::lock(1, &db)
.await
.expect("second lock should work");
assert_eq!(first.name, second.name);
}
#[tokio::test]
async fn lock_does_not_change_row_count() {
let db = setup_db().await;
seed_users(&db).await;
let _ = TestUser::lock(3, &db).await.expect("row should load");
assert_eq!(TestUser::count(&db).await.expect("count should succeed"), 3);
}
#[tokio::test]
async fn lock_preserves_identifier_and_email() {
let db = setup_db().await;
seed_users(&db).await;
let user = TestUser::lock(2, &db).await.expect("row should load");
assert_eq!(user.id(), Some(2));
assert_eq!(user.email, "bob@example.com");
}
#[tokio::test]
async fn lock_with_option_skip_locked_degrades_on_sqlite() {
let db = setup_db().await;
seed_users(&db).await;
let user = TestUser::lock_with_option(1, LockOption::SkipLocked, &db)
.await
.expect("skip-locked should degrade to a plain lookup on sqlite");
assert_eq!(user.name, "Alice");
}
#[tokio::test]
async fn lock_with_option_nowait_returns_informative_error_on_sqlite() {
let db = setup_db().await;
seed_users(&db).await;
let error = TestUser::lock_with_option(1, LockOption::Nowait, &db)
.await
.expect_err("sqlite should not claim NOWAIT support");
assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
}
#[tokio::test]
async fn lock_with_option_returns_not_found_for_missing_row() {
let db = setup_db().await;
let error = TestUser::lock_with_option(99, LockOption::SkipLocked, &db)
.await
.expect_err("missing row should still be missing");
assert!(matches!(error, RecordError::NotFound));
}
#[tokio::test]
async fn lock_bang_reloads_latest_persisted_state() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::find(1, &db).await.expect("row should exist");
let mut other = TestUser::find(1, &db).await.expect("row should exist");
other.name = "Alicia".to_owned();
other.save(&db).await.expect("update should persist");
user.lock_bang(&db).await.expect("reload should succeed");
assert_eq!(user.name, "Alicia");
assert_eq!(user.email, "alice@example.com");
}
#[tokio::test]
async fn lock_bang_noops_for_new_records() {
let db = setup_db().await;
let mut user = TestUser::default();
user.lock_bang(&db).await.expect("new records should no-op");
assert!(user.new_record());
assert_eq!(user.id(), None);
}
#[tokio::test]
async fn with_lock_commits_changes_on_success() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::find(1, &db).await.expect("row should exist");
let updated_name = user
.with_lock(&db, LockOption::ForUpdate, |locked, txn| {
Box::pin(async move {
locked.name = "Locked Alice".to_owned();
locked.save(txn).await?;
Ok(locked.name.clone())
})
})
.await
.expect("with_lock should commit successful changes");
let reloaded = TestUser::find(1, &db)
.await
.expect("row should still exist");
assert_eq!(updated_name, "Locked Alice");
assert_eq!(reloaded.name, "Locked Alice");
}
#[tokio::test]
async fn with_lock_rolls_back_changes_on_error() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::find(1, &db).await.expect("row should exist");
let error = user
.with_lock(&db, LockOption::ForUpdate, |locked, txn| {
Box::pin(async move {
locked.name = "Should Roll Back".to_owned();
locked.save(txn).await?;
Err::<(), RecordError>(RecordError::Invalid("force rollback".to_owned()))
})
})
.await
.expect_err("error should trigger rollback");
assert!(matches!(error, RecordError::Invalid(message) if message == "force rollback"));
let reloaded = TestUser::find(1, &db)
.await
.expect("row should still exist");
assert_eq!(reloaded.name, "Alice");
}
#[tokio::test]
async fn with_lock_inside_existing_transaction_reuses_current_scope() {
let db = setup_db().await;
seed_users(&db).await;
transaction(&db, |txn| {
let txn = txn.clone();
Box::pin(async move {
let mut user = TestUser::find(2, &txn).await?;
user.with_lock(&txn, LockOption::ForUpdate, |locked, inner| {
Box::pin(async move {
locked.name = "Nested Bob".to_owned();
locked.save(inner).await?;
Ok(())
})
})
.await?;
Ok(())
})
})
.await
.expect("nested with_lock should succeed");
let reloaded = TestUser::find(2, &db)
.await
.expect("row should still exist");
assert_eq!(reloaded.name, "Nested Bob");
}
#[tokio::test]
async fn with_lock_skip_locked_still_executes_closure_on_sqlite() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::find(3, &db).await.expect("row should exist");
let result = user
.with_lock(&db, LockOption::SkipLocked, |locked, _| {
Box::pin(async move {
locked.name.push_str("-seen");
Ok(locked.name.clone())
})
})
.await
.expect("skip-locked should still yield the record on sqlite");
assert_eq!(result, "Carol-seen");
assert_eq!(user.name, "Carol-seen");
}
#[tokio::test]
async fn with_lock_nowait_returns_error_before_running_closure() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::find(1, &db).await.expect("row should exist");
let ran = Arc::new(AtomicBool::new(false));
let error = user
.with_lock(&db, LockOption::Nowait, {
let ran = Arc::clone(&ran);
move |_locked, _| {
ran.store(true, Ordering::SeqCst);
Box::pin(async { Ok(()) })
}
})
.await
.expect_err("sqlite NOWAIT should fail early");
assert!(matches!(error, RecordError::Invalid(message) if message.contains("NOWAIT")));
assert!(!ran.load(Ordering::SeqCst));
}
#[tokio::test]
async fn lock_matches_plain_find_for_same_row() {
let db = setup_db().await;
seed_users(&db).await;
let locked = TestUser::lock(3, &db).await.expect("row should lock");
let found = TestUser::find(3, &db).await.expect("row should find");
assert_eq!(locked, found);
}
#[tokio::test]
async fn lock_reads_latest_persisted_values_after_update() {
let db = setup_db().await;
seed_users(&db).await;
let mut user = TestUser::lock(2, &db).await.expect("row should lock");
user.update_attributes(HashMap::from([("name".to_owned(), json!("Bobby"))]), &db)
.await
.expect("update should succeed");
let refreshed = TestUser::lock(2, &db)
.await
.expect("updated row should lock");
assert_eq!(refreshed.name, "Bobby");
assert_eq!(refreshed.email, "bob@example.com");
assert!(refreshed.persisted());
}
}