use sea_orm::sea_query::{Condition, IntoCondition, SimpleExpr};
use sea_orm::{ConnectionTrait, EntityTrait, QueryFilter, Update, Value};
use crate::GuardedError;
pub struct GuardedUpdate<E: EntityTrait> {
entity: E,
filters: Condition,
sets: Vec<(E::Column, SimpleExpr)>,
}
impl<E: EntityTrait> GuardedUpdate<E> {
pub fn new(entity: E) -> Self {
Self {
entity,
filters: Condition::all(), sets: Vec::new(),
}
}
pub fn filter<F: IntoCondition>(mut self, f: F) -> Self {
self.filters = self.filters.add(f.into_condition());
self
}
pub fn set_expr(mut self, col: E::Column, expr: SimpleExpr) -> Self {
self.sets.push((col, expr));
self
}
pub fn set_value(mut self, col: E::Column, value: Value) -> Self {
self.sets.push((col, SimpleExpr::Value(value)));
self
}
pub async fn exec_one<C: ConnectionTrait>(self, conn: &C) -> Result<(), GuardedError> {
match self.exec_raw(conn).await? {
0 => Err(GuardedError::NoRowsAffected),
1 => Ok(()),
n => Err(GuardedError::TooManyRows { affected: n }),
}
}
pub async fn exec_at_most_one<C: ConnectionTrait>(
self,
conn: &C,
) -> Result<bool, GuardedError> {
match self.exec_raw(conn).await? {
0 => Ok(false),
1 => Ok(true),
n => Err(GuardedError::TooManyRows { affected: n }),
}
}
async fn exec_raw<C: ConnectionTrait>(self, conn: &C) -> Result<u64, GuardedError> {
if self.sets.is_empty() {
return Err(GuardedError::EmptyUpdate);
}
let mut stmt = Update::many(self.entity).filter(self.filters);
for (col, expr) in self.sets {
stmt = stmt.col_expr(col, expr);
}
let result = stmt.exec(conn).await?; Ok(result.rows_affected)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sea_orm::sea_query::Expr;
use sea_orm::{
ColumnTrait, ConnectionTrait, Database, DatabaseBackend, EntityTrait, Schema, Set,
TransactionTrait,
};
mod counters {
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "counters")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub quantity: i32,
pub status: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
async fn fresh_db() -> sea_orm::DatabaseConnection {
let conn = Database::connect("sqlite::memory:")
.await
.expect("connect to in-memory sqlite");
let schema = Schema::new(DatabaseBackend::Sqlite);
let stmt = schema.create_table_from_entity(counters::Entity);
conn.execute(conn.get_database_backend().build(&stmt))
.await
.expect("create counters table");
conn
}
async fn insert_row(conn: &sea_orm::DatabaseConnection, id: i32, quantity: i32, status: &str) {
counters::Entity::insert(counters::ActiveModel {
id: Set(id),
quantity: Set(quantity),
status: Set(status.to_string()),
})
.exec(conn)
.await
.expect("insert counters row");
}
#[tokio::test]
async fn predicate_matches_one_row_succeeds() {
let conn = fresh_db().await;
insert_row(&conn, 1, 5, "pending").await;
GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.filter(counters::Column::Quantity.gte(3))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(3),
)
.exec_one(&conn)
.await
.expect("guarded update should succeed");
let row = counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.expect("row exists");
assert_eq!(row.quantity, 2);
}
#[tokio::test]
async fn predicate_fails_zero_rows() {
let conn = fresh_db().await;
insert_row(&conn, 1, 2, "pending").await;
let err = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.filter(counters::Column::Quantity.gte(5))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(5),
)
.exec_one(&conn)
.await
.expect_err("should fail predicate");
assert!(matches!(err, GuardedError::NoRowsAffected));
let updated = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.filter(counters::Column::Quantity.gte(5))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(5),
)
.exec_at_most_one(&conn)
.await
.expect("exec_at_most_one tolerates 0 rows");
assert!(!updated);
let row = counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.unwrap();
assert_eq!(row.quantity, 2);
}
#[tokio::test]
async fn predicate_matches_multiple_rows() {
let conn = fresh_db().await;
insert_row(&conn, 1, 10, "pending").await;
insert_row(&conn, 2, 10, "pending").await;
let err = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Status.eq("pending"))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(1),
)
.exec_one(&conn)
.await
.expect_err("should fail with TooManyRows");
assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
insert_row(&conn, 3, 10, "shipped").await;
insert_row(&conn, 4, 10, "shipped").await;
let err = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Status.eq("shipped"))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(1),
)
.exec_at_most_one(&conn)
.await
.expect_err("exec_at_most_one should also fail with TooManyRows");
assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
}
#[tokio::test]
async fn empty_update_no_sets() {
let conn = fresh_db().await;
insert_row(&conn, 1, 5, "pending").await;
let err = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.exec_one(&conn)
.await
.expect_err("empty builder must error");
assert!(matches!(err, GuardedError::EmptyUpdate));
let err = GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.exec_at_most_one(&conn)
.await
.expect_err("empty builder must error in exec_at_most_one too");
assert!(matches!(err, GuardedError::EmptyUpdate));
let row = counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.unwrap();
assert_eq!(row.quantity, 5);
}
#[tokio::test]
async fn multi_column_set_atomic() {
let conn = fresh_db().await;
insert_row(&conn, 1, 5, "pending").await;
GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.filter(counters::Column::Status.eq("pending"))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(2),
)
.set_value(
counters::Column::Status,
Value::String(Some(Box::new("committed".to_string()))),
)
.exec_one(&conn)
.await
.expect("multi-column guarded update");
let row = counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.unwrap();
assert_eq!(row.quantity, 3);
assert_eq!(row.status, "committed");
}
#[tokio::test]
async fn transaction_rollback() {
let conn = fresh_db().await;
insert_row(&conn, 1, 5, "pending").await;
let txn = conn.begin().await.expect("begin transaction");
GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Id.eq(1))
.set_expr(
counters::Column::Quantity,
Expr::col(counters::Column::Quantity).sub(2),
)
.exec_one(&txn)
.await
.expect("guarded update inside transaction");
let row_in_txn = counters::Entity::find_by_id(1)
.one(&txn)
.await
.unwrap()
.unwrap();
assert_eq!(row_in_txn.quantity, 3);
txn.rollback().await.expect("rollback");
let row_after = counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.unwrap();
assert_eq!(row_after.quantity, 5);
}
#[tokio::test]
async fn filter_and_combine() {
let conn = fresh_db().await;
insert_row(&conn, 1, 5, "pending").await;
insert_row(&conn, 2, 5, "shipped").await;
insert_row(&conn, 3, 10, "pending").await;
GuardedUpdate::new(counters::Entity)
.filter(counters::Column::Status.eq("pending"))
.filter(counters::Column::Quantity.eq(5))
.set_value(
counters::Column::Status,
Value::String(Some(Box::new("matched".to_string()))),
)
.exec_one(&conn)
.await
.expect("AND-filter should match exactly row 1");
assert_eq!(
counters::Entity::find_by_id(1)
.one(&conn)
.await
.unwrap()
.unwrap()
.status,
"matched"
);
assert_eq!(
counters::Entity::find_by_id(2)
.one(&conn)
.await
.unwrap()
.unwrap()
.status,
"shipped"
);
assert_eq!(
counters::Entity::find_by_id(3)
.one(&conn)
.await
.unwrap()
.unwrap()
.status,
"pending"
);
}
}