1use sea_orm::sea_query::{Condition, IntoCondition, SimpleExpr};
10use sea_orm::{ConnectionTrait, EntityTrait, QueryFilter, Update, Value};
11
12use crate::GuardedError;
13
14pub struct GuardedUpdate<E: EntityTrait> {
15 entity: E,
16 filters: Condition,
17 sets: Vec<(E::Column, SimpleExpr)>,
18}
19
20impl<E: EntityTrait> GuardedUpdate<E> {
21 pub fn new(entity: E) -> Self {
23 Self {
24 entity,
25 filters: Condition::all(), sets: Vec::new(),
27 }
28 }
29
30 pub fn filter<F: IntoCondition>(mut self, f: F) -> Self {
32 self.filters = self.filters.add(f.into_condition());
33 self
34 }
35
36 pub fn set_expr(mut self, col: E::Column, expr: SimpleExpr) -> Self {
38 self.sets.push((col, expr));
39 self
40 }
41
42 pub fn set_value(mut self, col: E::Column, value: Value) -> Self {
44 self.sets.push((col, SimpleExpr::Value(value)));
45 self
46 }
47
48 pub async fn exec_one<C: ConnectionTrait>(self, conn: &C) -> Result<(), GuardedError> {
63 match self.exec_raw(conn).await? {
64 0 => Err(GuardedError::NoRowsAffected),
65 1 => Ok(()),
66 n => Err(GuardedError::TooManyRows { affected: n }),
67 }
68 }
69
70 pub async fn exec_at_most_one<C: ConnectionTrait>(
76 self,
77 conn: &C,
78 ) -> Result<bool, GuardedError> {
79 match self.exec_raw(conn).await? {
80 0 => Ok(false),
81 1 => Ok(true),
82 n => Err(GuardedError::TooManyRows { affected: n }),
83 }
84 }
85
86 async fn exec_raw<C: ConnectionTrait>(self, conn: &C) -> Result<u64, GuardedError> {
87 if self.sets.is_empty() {
91 return Err(GuardedError::EmptyUpdate);
92 }
93
94 let mut stmt = Update::many(self.entity).filter(self.filters);
95 for (col, expr) in self.sets {
96 stmt = stmt.col_expr(col, expr);
97 }
98 let result = stmt.exec(conn).await?; Ok(result.rows_affected)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use sea_orm::sea_query::Expr;
107 use sea_orm::{
108 ColumnTrait, ConnectionTrait, Database, DatabaseBackend, EntityTrait, Schema, Set,
109 TransactionTrait,
110 };
111
112 mod counters {
113 use sea_orm::entity::prelude::*;
114
115 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
116 #[sea_orm(table_name = "counters")]
117 pub struct Model {
118 #[sea_orm(primary_key)]
119 pub id: i32,
120 pub quantity: i32,
121 pub status: String,
122 }
123
124 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
125 pub enum Relation {}
126
127 impl ActiveModelBehavior for ActiveModel {}
128 }
129
130 async fn fresh_db() -> sea_orm::DatabaseConnection {
131 let conn = Database::connect("sqlite::memory:")
132 .await
133 .expect("connect to in-memory sqlite");
134 let schema = Schema::new(DatabaseBackend::Sqlite);
135 let stmt = schema.create_table_from_entity(counters::Entity);
136 conn.execute(conn.get_database_backend().build(&stmt))
137 .await
138 .expect("create counters table");
139 conn
140 }
141
142 async fn insert_row(conn: &sea_orm::DatabaseConnection, id: i32, quantity: i32, status: &str) {
143 counters::Entity::insert(counters::ActiveModel {
144 id: Set(id),
145 quantity: Set(quantity),
146 status: Set(status.to_string()),
147 })
148 .exec(conn)
149 .await
150 .expect("insert counters row");
151 }
152
153 #[tokio::test]
154 async fn predicate_matches_one_row_succeeds() {
155 let conn = fresh_db().await;
157 insert_row(&conn, 1, 5, "pending").await;
158
159 GuardedUpdate::new(counters::Entity)
160 .filter(counters::Column::Id.eq(1))
161 .filter(counters::Column::Quantity.gte(3))
162 .set_expr(
163 counters::Column::Quantity,
164 Expr::col(counters::Column::Quantity).sub(3),
165 )
166 .exec_one(&conn)
167 .await
168 .expect("guarded update should succeed");
169
170 let row = counters::Entity::find_by_id(1)
171 .one(&conn)
172 .await
173 .unwrap()
174 .expect("row exists");
175 assert_eq!(row.quantity, 2);
176 }
177
178 #[tokio::test]
179 async fn predicate_fails_zero_rows() {
180 let conn = fresh_db().await;
182 insert_row(&conn, 1, 2, "pending").await;
183
184 let err = GuardedUpdate::new(counters::Entity)
186 .filter(counters::Column::Id.eq(1))
187 .filter(counters::Column::Quantity.gte(5))
188 .set_expr(
189 counters::Column::Quantity,
190 Expr::col(counters::Column::Quantity).sub(5),
191 )
192 .exec_one(&conn)
193 .await
194 .expect_err("should fail predicate");
195 assert!(matches!(err, GuardedError::NoRowsAffected));
196
197 let updated = GuardedUpdate::new(counters::Entity)
199 .filter(counters::Column::Id.eq(1))
200 .filter(counters::Column::Quantity.gte(5))
201 .set_expr(
202 counters::Column::Quantity,
203 Expr::col(counters::Column::Quantity).sub(5),
204 )
205 .exec_at_most_one(&conn)
206 .await
207 .expect("exec_at_most_one tolerates 0 rows");
208 assert!(!updated);
209
210 let row = counters::Entity::find_by_id(1)
211 .one(&conn)
212 .await
213 .unwrap()
214 .unwrap();
215 assert_eq!(row.quantity, 2);
216 }
217
218 #[tokio::test]
219 async fn predicate_matches_multiple_rows() {
220 let conn = fresh_db().await;
228 insert_row(&conn, 1, 10, "pending").await;
229 insert_row(&conn, 2, 10, "pending").await;
230
231 let err = GuardedUpdate::new(counters::Entity)
232 .filter(counters::Column::Status.eq("pending"))
233 .set_expr(
234 counters::Column::Quantity,
235 Expr::col(counters::Column::Quantity).sub(1),
236 )
237 .exec_one(&conn)
238 .await
239 .expect_err("should fail with TooManyRows");
240 assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
241
242 insert_row(&conn, 3, 10, "shipped").await;
243 insert_row(&conn, 4, 10, "shipped").await;
244 let err = GuardedUpdate::new(counters::Entity)
245 .filter(counters::Column::Status.eq("shipped"))
246 .set_expr(
247 counters::Column::Quantity,
248 Expr::col(counters::Column::Quantity).sub(1),
249 )
250 .exec_at_most_one(&conn)
251 .await
252 .expect_err("exec_at_most_one should also fail with TooManyRows");
253 assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
254 }
255
256 #[tokio::test]
257 async fn empty_update_no_sets() {
258 let conn = fresh_db().await;
260 insert_row(&conn, 1, 5, "pending").await;
261
262 let err = GuardedUpdate::new(counters::Entity)
263 .filter(counters::Column::Id.eq(1))
264 .exec_one(&conn)
265 .await
266 .expect_err("empty builder must error");
267 assert!(matches!(err, GuardedError::EmptyUpdate));
268
269 let err = GuardedUpdate::new(counters::Entity)
270 .filter(counters::Column::Id.eq(1))
271 .exec_at_most_one(&conn)
272 .await
273 .expect_err("empty builder must error in exec_at_most_one too");
274 assert!(matches!(err, GuardedError::EmptyUpdate));
275
276 let row = counters::Entity::find_by_id(1)
277 .one(&conn)
278 .await
279 .unwrap()
280 .unwrap();
281 assert_eq!(row.quantity, 5);
282 }
283
284 #[tokio::test]
285 async fn multi_column_set_atomic() {
286 let conn = fresh_db().await;
288 insert_row(&conn, 1, 5, "pending").await;
289
290 GuardedUpdate::new(counters::Entity)
291 .filter(counters::Column::Id.eq(1))
292 .filter(counters::Column::Status.eq("pending"))
293 .set_expr(
294 counters::Column::Quantity,
295 Expr::col(counters::Column::Quantity).sub(2),
296 )
297 .set_value(
298 counters::Column::Status,
299 Value::String(Some(Box::new("committed".to_string()))),
300 )
301 .exec_one(&conn)
302 .await
303 .expect("multi-column guarded update");
304
305 let row = counters::Entity::find_by_id(1)
306 .one(&conn)
307 .await
308 .unwrap()
309 .unwrap();
310 assert_eq!(row.quantity, 3);
311 assert_eq!(row.status, "committed");
312 }
313
314 #[tokio::test]
315 async fn transaction_rollback() {
316 let conn = fresh_db().await;
318 insert_row(&conn, 1, 5, "pending").await;
319
320 let txn = conn.begin().await.expect("begin transaction");
321
322 GuardedUpdate::new(counters::Entity)
323 .filter(counters::Column::Id.eq(1))
324 .set_expr(
325 counters::Column::Quantity,
326 Expr::col(counters::Column::Quantity).sub(2),
327 )
328 .exec_one(&txn)
329 .await
330 .expect("guarded update inside transaction");
331
332 let row_in_txn = counters::Entity::find_by_id(1)
333 .one(&txn)
334 .await
335 .unwrap()
336 .unwrap();
337 assert_eq!(row_in_txn.quantity, 3);
338
339 txn.rollback().await.expect("rollback");
340
341 let row_after = counters::Entity::find_by_id(1)
342 .one(&conn)
343 .await
344 .unwrap()
345 .unwrap();
346 assert_eq!(row_after.quantity, 5);
347 }
348
349 #[tokio::test]
350 async fn filter_and_combine() {
351 let conn = fresh_db().await;
353 insert_row(&conn, 1, 5, "pending").await;
354 insert_row(&conn, 2, 5, "shipped").await;
355 insert_row(&conn, 3, 10, "pending").await;
356
357 GuardedUpdate::new(counters::Entity)
358 .filter(counters::Column::Status.eq("pending"))
359 .filter(counters::Column::Quantity.eq(5))
360 .set_value(
361 counters::Column::Status,
362 Value::String(Some(Box::new("matched".to_string()))),
363 )
364 .exec_one(&conn)
365 .await
366 .expect("AND-filter should match exactly row 1");
367
368 assert_eq!(
369 counters::Entity::find_by_id(1)
370 .one(&conn)
371 .await
372 .unwrap()
373 .unwrap()
374 .status,
375 "matched"
376 );
377 assert_eq!(
378 counters::Entity::find_by_id(2)
379 .one(&conn)
380 .await
381 .unwrap()
382 .unwrap()
383 .status,
384 "shipped"
385 );
386 assert_eq!(
387 counters::Entity::find_by_id(3)
388 .one(&conn)
389 .await
390 .unwrap()
391 .unwrap()
392 .status,
393 "pending"
394 );
395 }
396}