Skip to main content

ferro_orm/
guarded.rs

1//! `GuardedUpdate<E>` — chainable builder for atomic conditional `UPDATE`
2//! statements. Compiles to exactly one `UPDATE … WHERE …` SQL statement.
3//!
4//! The database engine's per-statement atomicity (SQLite serial writer,
5//! Postgres `READ COMMITTED`) is the entire correctness mechanism — this
6//! builder adds the chainable surface and the rows-affected → `GuardedError`
7//! mapping on top.
8
9use sea_orm::sea_query::{Condition, IntoCondition, SimpleExpr};
10use sea_orm::{ConnectionTrait, EntityTrait, QueryFilter, Update, Value};
11
12use crate::GuardedError;
13
14pub struct GuardedUpdate<E: EntityTrait> {
15    entity: E,
16    filters: Condition,
17    sets: Vec<(E::Column, SimpleExpr)>,
18}
19
20impl<E: EntityTrait> GuardedUpdate<E> {
21    /// Start a new builder targeting `entity`.
22    pub fn new(entity: E) -> Self {
23        Self {
24            entity,
25            filters: Condition::all(), // AND-combiner per D-06
26            sets: Vec::new(),
27        }
28    }
29
30    /// Add a filter expression. Multiple `.filter(...)` calls AND-combine.
31    pub fn filter<F: IntoCondition>(mut self, f: F) -> Self {
32        self.filters = self.filters.add(f.into_condition());
33        self
34    }
35
36    /// Set a column to a value-derived expression.
37    pub fn set_expr(mut self, col: E::Column, expr: SimpleExpr) -> Self {
38        self.sets.push((col, expr));
39        self
40    }
41
42    /// Set a column to a literal value.
43    pub fn set_value(mut self, col: E::Column, value: Value) -> Self {
44        self.sets.push((col, SimpleExpr::Value(value)));
45        self
46    }
47
48    /// Execute the conditional UPDATE; succeed iff exactly one row matched.
49    ///
50    /// Returns `Err(GuardedError::NoRowsAffected)` on 0 rows (predicate
51    /// failure — the race-free "capacity exhausted" signal). Returns
52    /// `Err(GuardedError::TooManyRows { affected })` on `>1` rows
53    /// (filter is not unique-key-equivalent — index/uniqueness bug).
54    ///
55    /// Note on `TooManyRows`: this variant is preserved for documentation
56    /// and future-proofing. sea-orm's `UpdateMany::exec` returns
57    /// `rows_affected` unconditionally on success, so a filter matching
58    /// `>1` rows will mutate every matched row before this post-processor
59    /// surfaces the error. The variant is the right way to shout when a
60    /// supposed unique-key-equivalent filter turns out not to be — see
61    /// Pitfall 4 in 152-RESEARCH.md.
62    pub async fn exec_one<C: ConnectionTrait>(self, conn: &C) -> Result<(), GuardedError> {
63        match self.exec_raw(conn).await? {
64            0 => Err(GuardedError::NoRowsAffected),
65            1 => Ok(()),
66            n => Err(GuardedError::TooManyRows { affected: n }),
67        }
68    }
69
70    /// Execute the conditional UPDATE; tolerate 0 rows as a normal outcome.
71    ///
72    /// Returns `Ok(true)` on 1 row, `Ok(false)` on 0 rows. `>1` rows still
73    /// returns `Err(GuardedError::TooManyRows)` — the uniqueness contract
74    /// is the same.
75    pub async fn exec_at_most_one<C: ConnectionTrait>(
76        self,
77        conn: &C,
78    ) -> Result<bool, GuardedError> {
79        match self.exec_raw(conn).await? {
80            0 => Ok(false),
81            1 => Ok(true),
82            n => Err(GuardedError::TooManyRows { affected: n }),
83        }
84    }
85
86    async fn exec_raw<C: ConnectionTrait>(self, conn: &C) -> Result<u64, GuardedError> {
87        // Load-bearing — sea-orm's `Updater::is_noop()` short-circuits with
88        // `rows_affected: 0` when SET is empty, which would otherwise look
89        // like a predicate miss (Pitfall 1 in 152-RESEARCH.md).
90        if self.sets.is_empty() {
91            return Err(GuardedError::EmptyUpdate);
92        }
93
94        let mut stmt = Update::many(self.entity).filter(self.filters);
95        for (col, expr) in self.sets {
96            stmt = stmt.col_expr(col, expr);
97        }
98        let result = stmt.exec(conn).await?; // From<DbErr> via #[from] on Db variant.
99        Ok(result.rows_affected)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use sea_orm::sea_query::Expr;
107    use sea_orm::{
108        ColumnTrait, ConnectionTrait, Database, DatabaseBackend, EntityTrait, Schema, Set,
109        TransactionTrait,
110    };
111
112    mod counters {
113        use sea_orm::entity::prelude::*;
114
115        #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
116        #[sea_orm(table_name = "counters")]
117        pub struct Model {
118            #[sea_orm(primary_key)]
119            pub id: i32,
120            pub quantity: i32,
121            pub status: String,
122        }
123
124        #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
125        pub enum Relation {}
126
127        impl ActiveModelBehavior for ActiveModel {}
128    }
129
130    async fn fresh_db() -> sea_orm::DatabaseConnection {
131        let conn = Database::connect("sqlite::memory:")
132            .await
133            .expect("connect to in-memory sqlite");
134        let schema = Schema::new(DatabaseBackend::Sqlite);
135        let stmt = schema.create_table_from_entity(counters::Entity);
136        conn.execute(conn.get_database_backend().build(&stmt))
137            .await
138            .expect("create counters table");
139        conn
140    }
141
142    async fn insert_row(conn: &sea_orm::DatabaseConnection, id: i32, quantity: i32, status: &str) {
143        counters::Entity::insert(counters::ActiveModel {
144            id: Set(id),
145            quantity: Set(quantity),
146            status: Set(status.to_string()),
147        })
148        .exec(conn)
149        .await
150        .expect("insert counters row");
151    }
152
153    #[tokio::test]
154    async fn predicate_matches_one_row_succeeds() {
155        // T-16-1
156        let conn = fresh_db().await;
157        insert_row(&conn, 1, 5, "pending").await;
158
159        GuardedUpdate::new(counters::Entity)
160            .filter(counters::Column::Id.eq(1))
161            .filter(counters::Column::Quantity.gte(3))
162            .set_expr(
163                counters::Column::Quantity,
164                Expr::col(counters::Column::Quantity).sub(3),
165            )
166            .exec_one(&conn)
167            .await
168            .expect("guarded update should succeed");
169
170        let row = counters::Entity::find_by_id(1)
171            .one(&conn)
172            .await
173            .unwrap()
174            .expect("row exists");
175        assert_eq!(row.quantity, 2);
176    }
177
178    #[tokio::test]
179    async fn predicate_fails_zero_rows() {
180        // T-16-2
181        let conn = fresh_db().await;
182        insert_row(&conn, 1, 2, "pending").await;
183
184        // exec_one path
185        let err = GuardedUpdate::new(counters::Entity)
186            .filter(counters::Column::Id.eq(1))
187            .filter(counters::Column::Quantity.gte(5))
188            .set_expr(
189                counters::Column::Quantity,
190                Expr::col(counters::Column::Quantity).sub(5),
191            )
192            .exec_one(&conn)
193            .await
194            .expect_err("should fail predicate");
195        assert!(matches!(err, GuardedError::NoRowsAffected));
196
197        // exec_at_most_one path
198        let updated = GuardedUpdate::new(counters::Entity)
199            .filter(counters::Column::Id.eq(1))
200            .filter(counters::Column::Quantity.gte(5))
201            .set_expr(
202                counters::Column::Quantity,
203                Expr::col(counters::Column::Quantity).sub(5),
204            )
205            .exec_at_most_one(&conn)
206            .await
207            .expect("exec_at_most_one tolerates 0 rows");
208        assert!(!updated);
209
210        let row = counters::Entity::find_by_id(1)
211            .one(&conn)
212            .await
213            .unwrap()
214            .unwrap();
215        assert_eq!(row.quantity, 2);
216    }
217
218    #[tokio::test]
219    async fn predicate_matches_multiple_rows() {
220        // T-16-3
221        // Note: SeaORM's `Update::many` is multi-row by SQL semantics. The
222        // failing `exec_one` call below DOES mutate both rows before our
223        // post-processor surfaces TooManyRows. The test re-seeds for the
224        // exec_at_most_one check to verify the error variant alone.
225        // This is the correct behavior per D-13 and Pitfall 4 — surface
226        // the bug loudly rather than silently succeeding.
227        let conn = fresh_db().await;
228        insert_row(&conn, 1, 10, "pending").await;
229        insert_row(&conn, 2, 10, "pending").await;
230
231        let err = GuardedUpdate::new(counters::Entity)
232            .filter(counters::Column::Status.eq("pending"))
233            .set_expr(
234                counters::Column::Quantity,
235                Expr::col(counters::Column::Quantity).sub(1),
236            )
237            .exec_one(&conn)
238            .await
239            .expect_err("should fail with TooManyRows");
240        assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
241
242        insert_row(&conn, 3, 10, "shipped").await;
243        insert_row(&conn, 4, 10, "shipped").await;
244        let err = GuardedUpdate::new(counters::Entity)
245            .filter(counters::Column::Status.eq("shipped"))
246            .set_expr(
247                counters::Column::Quantity,
248                Expr::col(counters::Column::Quantity).sub(1),
249            )
250            .exec_at_most_one(&conn)
251            .await
252            .expect_err("exec_at_most_one should also fail with TooManyRows");
253        assert!(matches!(err, GuardedError::TooManyRows { affected: 2 }));
254    }
255
256    #[tokio::test]
257    async fn empty_update_no_sets() {
258        // T-16-4 — critical: this must error BEFORE any SQL fires (Pitfall 1).
259        let conn = fresh_db().await;
260        insert_row(&conn, 1, 5, "pending").await;
261
262        let err = GuardedUpdate::new(counters::Entity)
263            .filter(counters::Column::Id.eq(1))
264            .exec_one(&conn)
265            .await
266            .expect_err("empty builder must error");
267        assert!(matches!(err, GuardedError::EmptyUpdate));
268
269        let err = GuardedUpdate::new(counters::Entity)
270            .filter(counters::Column::Id.eq(1))
271            .exec_at_most_one(&conn)
272            .await
273            .expect_err("empty builder must error in exec_at_most_one too");
274        assert!(matches!(err, GuardedError::EmptyUpdate));
275
276        let row = counters::Entity::find_by_id(1)
277            .one(&conn)
278            .await
279            .unwrap()
280            .unwrap();
281        assert_eq!(row.quantity, 5);
282    }
283
284    #[tokio::test]
285    async fn multi_column_set_atomic() {
286        // T-16-5
287        let conn = fresh_db().await;
288        insert_row(&conn, 1, 5, "pending").await;
289
290        GuardedUpdate::new(counters::Entity)
291            .filter(counters::Column::Id.eq(1))
292            .filter(counters::Column::Status.eq("pending"))
293            .set_expr(
294                counters::Column::Quantity,
295                Expr::col(counters::Column::Quantity).sub(2),
296            )
297            .set_value(
298                counters::Column::Status,
299                Value::String(Some(Box::new("committed".to_string()))),
300            )
301            .exec_one(&conn)
302            .await
303            .expect("multi-column guarded update");
304
305        let row = counters::Entity::find_by_id(1)
306            .one(&conn)
307            .await
308            .unwrap()
309            .unwrap();
310        assert_eq!(row.quantity, 3);
311        assert_eq!(row.status, "committed");
312    }
313
314    #[tokio::test]
315    async fn transaction_rollback() {
316        // T-16-6 — exec inside &DatabaseTransaction; rollback rolls back.
317        let conn = fresh_db().await;
318        insert_row(&conn, 1, 5, "pending").await;
319
320        let txn = conn.begin().await.expect("begin transaction");
321
322        GuardedUpdate::new(counters::Entity)
323            .filter(counters::Column::Id.eq(1))
324            .set_expr(
325                counters::Column::Quantity,
326                Expr::col(counters::Column::Quantity).sub(2),
327            )
328            .exec_one(&txn)
329            .await
330            .expect("guarded update inside transaction");
331
332        let row_in_txn = counters::Entity::find_by_id(1)
333            .one(&txn)
334            .await
335            .unwrap()
336            .unwrap();
337        assert_eq!(row_in_txn.quantity, 3);
338
339        txn.rollback().await.expect("rollback");
340
341        let row_after = counters::Entity::find_by_id(1)
342            .one(&conn)
343            .await
344            .unwrap()
345            .unwrap();
346        assert_eq!(row_after.quantity, 5);
347    }
348
349    #[tokio::test]
350    async fn filter_and_combine() {
351        // T-16-7
352        let conn = fresh_db().await;
353        insert_row(&conn, 1, 5, "pending").await;
354        insert_row(&conn, 2, 5, "shipped").await;
355        insert_row(&conn, 3, 10, "pending").await;
356
357        GuardedUpdate::new(counters::Entity)
358            .filter(counters::Column::Status.eq("pending"))
359            .filter(counters::Column::Quantity.eq(5))
360            .set_value(
361                counters::Column::Status,
362                Value::String(Some(Box::new("matched".to_string()))),
363            )
364            .exec_one(&conn)
365            .await
366            .expect("AND-filter should match exactly row 1");
367
368        assert_eq!(
369            counters::Entity::find_by_id(1)
370                .one(&conn)
371                .await
372                .unwrap()
373                .unwrap()
374                .status,
375            "matched"
376        );
377        assert_eq!(
378            counters::Entity::find_by_id(2)
379                .one(&conn)
380                .await
381                .unwrap()
382                .unwrap()
383                .status,
384            "shipped"
385        );
386        assert_eq!(
387            counters::Entity::find_by_id(3)
388                .one(&conn)
389                .await
390                .unwrap()
391                .unwrap()
392                .status,
393            "pending"
394        );
395    }
396}