use sea_orm::{
ActiveModelBehavior, ColumnTrait, ConnectionTrait, DatabaseConnection, EntityTrait,
IntoActiveModel, Iterable, Statement,
};
use serde::{Serialize, de::DeserializeOwned};
use crate::{
base::{Record, RecordError, RecordState},
persistence::AsyncPersistence,
querying::AsyncQuerying,
relation::json_to_sea_value,
};
#[allow(dead_code)]
pub(crate) trait OptimisticLocking: Record {
fn lock_version(&self) -> i64;
fn increment_lock_version(&mut self);
fn locking_column() -> &'static str {
"lock_version"
}
fn lock_optimistically() -> bool {
true
}
fn locking_enabled() -> bool {
Self::lock_optimistically()
}
fn check_lock_version(&self, expected: i64) -> Result<(), StaleObjectError> {
if self.lock_version() == expected {
Ok(())
} else {
Err(StaleObjectError)
}
}
async fn current_lock_version(&self, db: &DatabaseConnection) -> Result<i64, RecordError>
where
Self: Sized + AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let id = self.id().ok_or(RecordError::NotSaved)?;
let fresh = Self::find(id, db).await?;
Ok(fresh.lock_version())
}
async fn save_with_optimistic_lock(
&mut self,
db: &DatabaseConnection,
) -> Result<(), RecordError>
where
Self: Sized + AsyncPersistence + AsyncQuerying + Serialize + DeserializeOwned,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
<Self::Entity as EntityTrait>::Model:
IntoActiveModel<<Self::Entity as EntityTrait>::ActiveModel>,
<Self::Entity as EntityTrait>::ActiveModel: ActiveModelBehavior + Send,
{
if self.destroyed() {
return Err(RecordError::NotSaved);
}
if !Self::locking_enabled() || self.new_record() {
return AsyncPersistence::save(self, db).await;
}
let id = self.id().ok_or(RecordError::NotSaved)?;
let expected = self.lock_version();
let current = self.current_lock_version(db).await?;
if current != expected {
return Err(StaleObjectError.into());
}
let assignments =
serialized_assignments(self, &[Self::primary_key_name(), Self::locking_column()])?;
let sql = build_update_sql::<Self>(&assignments);
let mut values = assignments
.into_iter()
.map(|(_, value)| value)
.collect::<Vec<_>>();
values.push(id.into());
values.push(expected.into());
let result = db
.execute_raw(Statement::from_sql_and_values(
db.get_database_backend(),
sql,
values,
))
.await?;
if result.rows_affected() == 0 {
return Err(StaleObjectError.into());
}
self.increment_lock_version();
self.set_record_state(RecordState::Persisted);
Ok(())
}
async fn destroy_with_optimistic_lock(
&mut self,
db: &DatabaseConnection,
) -> Result<(), RecordError>
where
Self: Sized + AsyncPersistence,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
if self.destroyed() {
return Err(RecordError::NotSaved);
}
if !Self::locking_enabled() {
return AsyncPersistence::destroy(self, db).await;
}
let id = self.id().ok_or(RecordError::NotSaved)?;
let result = db
.execute_raw(Statement::from_sql_and_values(
db.get_database_backend(),
format!(
"DELETE FROM {table} WHERE {primary_key} = ? AND {locking_column} = ?",
table = Self::table_name(),
primary_key = Self::primary_key_name(),
locking_column = Self::locking_column(),
),
[id.into(), self.lock_version().into()],
))
.await?;
if result.rows_affected() == 0 {
return Err(StaleObjectError.into());
}
self.set_record_state(RecordState::Destroyed);
Ok(())
}
}
#[allow(dead_code)]
fn serialized_assignments<T: Serialize>(
record: &T,
excluded_columns: &[&str],
) -> Result<Vec<(String, sea_orm::Value)>, RecordError> {
let json =
serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
let object = json
.as_object()
.ok_or_else(|| RecordError::Invalid("record must serialize to a JSON object".to_owned()))?;
let mut assignments = Vec::new();
for (column, value) in object {
if excluded_columns
.iter()
.any(|excluded| excluded == &column.as_str())
{
continue;
}
assignments.push((column.clone(), json_to_sea_value(value)?));
}
Ok(assignments)
}
#[allow(dead_code)]
fn build_update_sql<T: OptimisticLocking>(assignments: &[(String, sea_orm::Value)]) -> String {
let mut sql = format!("UPDATE {} SET ", T::table_name());
if assignments.is_empty() {
sql.push_str(&format!(
"{locking_column} = {locking_column} + 1",
locking_column = T::locking_column(),
));
} else {
for (index, (column, _)) in assignments.iter().enumerate() {
if index > 0 {
sql.push_str(", ");
}
sql.push_str(column);
sql.push_str(" = ?");
}
sql.push_str(&format!(
", {locking_column} = {locking_column} + 1",
locking_column = T::locking_column(),
));
}
sql.push_str(&format!(
" WHERE {primary_key} = ? AND {locking_column} = ?",
primary_key = T::primary_key_name(),
locking_column = T::locking_column(),
));
sql
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[error("stale object error: record was updated by another process")]
pub struct StaleObjectError;
impl From<StaleObjectError> for RecordError {
fn from(_: StaleObjectError) -> Self {
RecordError::StaleObject
}
}
#[cfg(test)]
mod tests {
use sea_orm::{
ActiveValue::{NotSet, Set},
ConnectionTrait, Database, DatabaseConnection, EntityTrait, Statement,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use super::{OptimisticLocking, StaleObjectError};
use crate::{
base::{Record, RecordError, RecordState},
persistence::AsyncPersistence,
querying::AsyncQuerying,
};
mod versioned_user {
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "versioned_users")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub name: String,
pub email: String,
pub lock_version: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
fn default_record_state() -> RecordState {
RecordState::New
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
struct VersionedUser {
id: Option<i64>,
name: String,
email: String,
lock_version: i64,
#[serde(skip, default = "default_record_state")]
state: RecordState,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
struct UnlockedVersionedUser {
id: Option<i64>,
name: String,
email: String,
lock_version: i64,
#[serde(skip, default = "default_record_state")]
state: RecordState,
}
macro_rules! impl_versioned_record {
($name:ident, $lock_optimistically:expr) => {
impl Record for $name {
type Entity = versioned_user::Entity;
fn table_name() -> &'static str {
"versioned_users"
}
fn id(&self) -> Option<i64> {
self.id
}
fn record_state(&self) -> RecordState {
self.state
}
fn set_record_state(&mut self, state: RecordState) {
self.state = state;
}
fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
Self {
id: Some(i64::from(model.id)),
name: model.name,
email: model.email,
lock_version: model.lock_version,
state: RecordState::Persisted,
}
}
fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
versioned_user::ActiveModel {
id: match self.id.and_then(|value| i32::try_from(value).ok()) {
Some(value) => Set(value),
None => NotSet,
},
name: Set(self.name.clone()),
email: Set(self.email.clone()),
lock_version: Set(self.lock_version),
}
}
}
impl AsyncPersistence for $name {}
impl AsyncQuerying for $name {}
impl OptimisticLocking for $name {
fn lock_version(&self) -> i64 {
self.lock_version
}
fn increment_lock_version(&mut self) {
self.lock_version += 1;
}
fn lock_optimistically() -> bool {
$lock_optimistically
}
}
};
}
impl_versioned_record!(VersionedUser, true);
impl_versioned_record!(UnlockedVersionedUser, false);
async fn setup_db() -> DatabaseConnection {
let db = Database::connect("sqlite::memory:")
.await
.expect("in-memory sqlite connection should succeed");
let schema = sea_orm::Schema::new(db.get_database_backend());
db.execute(&schema.create_table_from_entity(versioned_user::Entity))
.await
.expect("versioned_users table should be created");
db
}
async fn create_user(db: &DatabaseConnection, name: &str, email: &str) -> VersionedUser {
VersionedUser::create(
HashMap::from([
("name".to_owned(), json!(name)),
("email".to_owned(), json!(email)),
]),
db,
)
.await
.expect("create should succeed")
}
async fn db_lock_version(db: &DatabaseConnection, id: i64) -> i64 {
db.query_one_raw(Statement::from_sql_and_values(
db.get_database_backend(),
"SELECT lock_version FROM versioned_users WHERE id = ?".to_owned(),
[id.into()],
))
.await
.expect("lock-version query should succeed")
.expect("row should exist")
.try_get("", "lock_version")
.expect("lock_version should be readable")
}
#[test]
fn locking_enabled_defaults_to_true() {
assert!(VersionedUser::locking_enabled());
assert!(VersionedUser::lock_optimistically());
}
#[test]
fn lock_optimistically_can_disable_locking() {
assert!(!UnlockedVersionedUser::lock_optimistically());
assert!(!UnlockedVersionedUser::locking_enabled());
}
#[test]
fn stale_object_error_display_is_descriptive() {
assert_eq!(
StaleObjectError.to_string(),
"stale object error: record was updated by another process"
);
}
#[test]
fn stale_object_error_converts_to_record_error() {
let error: RecordError = StaleObjectError.into();
assert!(matches!(error, RecordError::StaleObject));
}
#[test]
fn check_lock_version_accepts_matching_version() {
let user = VersionedUser {
lock_version: 3,
..Default::default()
};
user.check_lock_version(3)
.expect("matching version should succeed");
}
#[test]
fn check_lock_version_rejects_stale_version() {
let user = VersionedUser {
lock_version: 4,
..Default::default()
};
let error = user
.check_lock_version(3)
.expect_err("mismatched version should fail");
assert_eq!(error, StaleObjectError);
}
#[test]
fn increment_lock_version_updates_counter() {
let mut user = VersionedUser::default();
user.increment_lock_version();
user.increment_lock_version();
assert_eq!(user.lock_version(), 2);
}
#[tokio::test]
async fn current_lock_version_reads_database_value() {
let db = setup_db().await;
let user = create_user(&db, "Alice", "alice@example.com").await;
assert_eq!(user.current_lock_version(&db).await.unwrap(), 0);
}
#[tokio::test]
async fn save_with_optimistic_lock_inserts_new_record_at_version_zero() {
let db = setup_db().await;
let mut user = VersionedUser {
name: "Alice".to_owned(),
email: "alice@example.com".to_owned(),
..Default::default()
};
user.save_with_optimistic_lock(&db)
.await
.expect("insert should succeed");
assert!(user.persisted());
assert_eq!(user.lock_version(), 0);
assert_eq!(db_lock_version(&db, user.id().unwrap()).await, 0);
}
#[tokio::test]
async fn save_with_optimistic_lock_updates_row_and_bumps_version() {
let db = setup_db().await;
let mut user = create_user(&db, "Alice", "alice@example.com").await;
user.name = "Alicia".to_owned();
user.save_with_optimistic_lock(&db)
.await
.expect("optimistic update should succeed");
let reloaded = VersionedUser::find(user.id().unwrap(), &db)
.await
.expect("reloaded row should exist");
assert_eq!(user.lock_version(), 1);
assert_eq!(reloaded.name, "Alicia");
assert_eq!(reloaded.lock_version(), 1);
}
#[tokio::test]
async fn save_with_optimistic_lock_rejects_stale_version() {
let db = setup_db().await;
let mut fresh = create_user(&db, "Alice", "alice@example.com").await;
let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
.await
.expect("second copy should exist");
fresh.name = "Alicia".to_owned();
fresh
.save_with_optimistic_lock(&db)
.await
.expect("first update should succeed");
stale.name = "Outdated".to_owned();
let error = stale
.save_with_optimistic_lock(&db)
.await
.expect_err("stale update should fail");
assert!(matches!(error, RecordError::StaleObject));
assert_eq!(stale.lock_version(), 0);
assert_eq!(db_lock_version(&db, fresh.id().unwrap()).await, 1);
}
#[tokio::test]
async fn save_with_optimistic_lock_detects_external_version_bump() {
let db = setup_db().await;
let mut user = create_user(&db, "Alice", "alice@example.com").await;
db.execute_unprepared("UPDATE versioned_users SET lock_version = 2 WHERE id = 1")
.await
.expect("direct version bump should succeed");
let error = user
.save_with_optimistic_lock(&db)
.await
.expect_err("external version bump should cause stale error");
assert!(matches!(error, RecordError::StaleObject));
assert_eq!(user.lock_version(), 0);
}
#[tokio::test]
async fn save_with_optimistic_lock_returns_not_saved_for_destroyed_records() {
let db = setup_db().await;
let mut user = create_user(&db, "Alice", "alice@example.com").await;
user.set_record_state(RecordState::Destroyed);
let error = user
.save_with_optimistic_lock(&db)
.await
.expect_err("destroyed records cannot be saved");
assert!(matches!(error, RecordError::NotSaved));
}
#[tokio::test]
async fn destroy_with_optimistic_lock_deletes_matching_version() {
let db = setup_db().await;
let mut user = create_user(&db, "Alice", "alice@example.com").await;
user.destroy_with_optimistic_lock(&db)
.await
.expect("destroy should succeed");
assert!(user.destroyed());
assert_eq!(VersionedUser::count(&db).await.unwrap(), 0);
}
#[tokio::test]
async fn destroy_with_optimistic_lock_rejects_stale_record() {
let db = setup_db().await;
let fresh = create_user(&db, "Alice", "alice@example.com").await;
let mut stale = VersionedUser::find(fresh.id().unwrap(), &db)
.await
.expect("stale copy should exist");
db.execute_unprepared("UPDATE versioned_users SET lock_version = 1 WHERE id = 1")
.await
.expect("direct version bump should succeed");
let error = stale
.destroy_with_optimistic_lock(&db)
.await
.expect_err("stale destroy should fail");
assert!(matches!(error, RecordError::StaleObject));
assert!(stale.persisted());
assert_eq!(VersionedUser::count(&db).await.unwrap(), 1);
}
#[tokio::test]
async fn destroy_with_optimistic_lock_returns_not_saved_without_id() {
let db = setup_db().await;
let mut user = VersionedUser::default();
let error = user
.destroy_with_optimistic_lock(&db)
.await
.expect_err("new records cannot be destroyed optimistically");
assert!(matches!(error, RecordError::NotSaved));
}
#[tokio::test]
async fn save_with_optimistic_lock_uses_regular_save_when_disabled() {
let db = setup_db().await;
let mut user = UnlockedVersionedUser {
name: "Alice".to_owned(),
email: "alice@example.com".to_owned(),
..Default::default()
};
user.save_with_optimistic_lock(&db)
.await
.expect("save should delegate when locking is disabled");
user.name = "Alicia".to_owned();
user.save_with_optimistic_lock(&db)
.await
.expect("update should delegate when locking is disabled");
let reloaded = UnlockedVersionedUser::find(user.id().unwrap(), &db)
.await
.expect("row should still reload");
assert_eq!(reloaded.name, "Alicia");
assert_eq!(reloaded.lock_version(), 0);
}
#[tokio::test]
async fn destroy_with_optimistic_lock_uses_regular_destroy_when_disabled() {
let db = setup_db().await;
let mut user = UnlockedVersionedUser::create(
HashMap::from([
("name".to_owned(), json!("Alice")),
("email".to_owned(), json!("alice@example.com")),
]),
&db,
)
.await
.expect("create should succeed");
user.destroy_with_optimistic_lock(&db)
.await
.expect("destroy should delegate when locking is disabled");
assert!(user.destroyed());
assert_eq!(UnlockedVersionedUser::count(&db).await.unwrap(), 0);
}
#[tokio::test]
async fn multiple_successive_optimistic_saves_keep_advancing_version() {
let db = setup_db().await;
let mut user = create_user(&db, "Alice", "alice@example.com").await;
user.name = "Alicia".to_owned();
user.save_with_optimistic_lock(&db).await.unwrap();
user.email = "alicia@example.com".to_owned();
user.save_with_optimistic_lock(&db).await.unwrap();
assert_eq!(user.lock_version(), 2);
let reloaded = VersionedUser::find(user.id().unwrap(), &db).await.unwrap();
assert_eq!(reloaded.name, "Alicia");
assert_eq!(reloaded.email, "alicia@example.com");
assert_eq!(reloaded.lock_version(), 2);
}
}