Skip to main content

pgorm/
bulk.rs

1//! Bulk update and delete operations.
2//!
3//! This module provides [`SetExpr`], [`UpdateManyBuilder`], and [`DeleteManyBuilder`]
4//! for performing batch UPDATE and DELETE operations with type-safe conditions.
5//!
6//! # Example
7//! ```ignore
8//! use pgorm::prelude::*;
9//! use pgorm::SetExpr;
10//!
11//! // Bulk update
12//! let affected = pgorm::sql("users")
13//!     .update_many([
14//!         SetExpr::set("status", "inactive")?,
15//!     ])
16//!     .filter(Condition::lt("last_login", one_year_ago)?)
17//!     .execute(&client)
18//!     .await?;
19//!
20//! // Bulk delete
21//! let deleted = pgorm::sql("sessions")
22//!     .delete_many()
23//!     .filter(Condition::lt("expires_at", now)?)
24//!     .execute(&client)
25//!     .await?;
26//! ```
27
28use crate::builder::WhereExpr;
29use crate::client::{GenericClient, StreamingClient};
30use crate::error::{OrmError, OrmResult};
31use crate::ident::{Ident, IntoIdent};
32use crate::row::FromRow;
33use crate::sql::{FromRowStream, Sql};
34use std::sync::Arc;
35use tokio_postgres::types::ToSql;
36
37// ==================== SetExpr ====================
38
39/// A SET clause expression for bulk updates.
40///
41/// # Example
42/// ```ignore
43/// use pgorm::SetExpr;
44///
45/// // Simple value assignment: SET status = $1
46/// SetExpr::set("status", "inactive")?;
47///
48/// // Increment: SET view_count = view_count + 1
49/// SetExpr::increment("view_count", 1)?;
50///
51/// // Raw SQL expression: SET updated_at = NOW()
52/// SetExpr::raw("updated_at = NOW()");
53/// ```
54pub enum SetExpr {
55    /// `column = $n` (parameterized value)
56    Value {
57        column: Ident,
58        value: Arc<dyn ToSql + Send + Sync>,
59    },
60    /// `column = column + amount` (increment/decrement)
61    Increment { column: Ident, amount: i64 },
62    /// Raw SQL expression (escape hatch), e.g. `"updated_at = NOW()"`
63    Raw(String),
64}
65
66impl SetExpr {
67    /// Create a SET clause that assigns a parameterized value: `col = $n`
68    pub fn set<T: ToSql + Send + Sync + 'static>(
69        column: impl IntoIdent,
70        value: T,
71    ) -> OrmResult<Self> {
72        Ok(SetExpr::Value {
73            column: column.into_ident()?,
74            value: Arc::new(value),
75        })
76    }
77
78    /// Create a SET clause that increments a column: `col = col + amount`
79    ///
80    /// Supports negative values for decrement.
81    pub fn increment(column: impl IntoIdent, amount: i64) -> OrmResult<Self> {
82        Ok(SetExpr::Increment {
83            column: column.into_ident()?,
84            amount,
85        })
86    }
87
88    /// Create a SET clause with a raw SQL expression.
89    ///
90    /// The string should be a complete assignment expression, e.g. `"updated_at = NOW()"`.
91    ///
92    /// **Warning**: This bypasses SQL injection protection. Only use with trusted SQL.
93    pub fn raw(expr: impl Into<String>) -> Self {
94        SetExpr::Raw(expr.into())
95    }
96
97    fn append_to_sql(&self, sql: &mut Sql) {
98        match self {
99            SetExpr::Value { column, value } => {
100                sql.push_ident_ref(column);
101                sql.push(" = ");
102                sql.push_bind_value(value.clone());
103            }
104            SetExpr::Increment { column, amount } => {
105                sql.push_ident_ref(column);
106                sql.push(" = ");
107                sql.push_ident_ref(column);
108                if *amount >= 0 {
109                    let s = format!(" + {amount}");
110                    sql.push(&s);
111                } else {
112                    let s = format!(" - {}", amount.abs());
113                    sql.push(&s);
114                }
115            }
116            SetExpr::Raw(expr) => {
117                sql.push(expr);
118            }
119        }
120    }
121}
122
123// ==================== UpdateManyBuilder ====================
124
125/// Builder for bulk UPDATE operations.
126///
127/// Created via [`Sql::update_many`].
128///
129/// # Example
130/// ```ignore
131/// pgorm::sql("users")
132///     .update_many([
133///         SetExpr::set("status", "inactive")?,
134///     ])
135///     .filter(Condition::lt("last_login", one_year_ago)?)
136///     .execute(&client)
137///     .await?;
138/// ```
139#[must_use]
140pub struct UpdateManyBuilder {
141    pub(crate) table: Ident,
142    pub(crate) sets: Vec<SetExpr>,
143    pub(crate) where_clause: Option<WhereExpr>,
144    pub(crate) all_rows: bool,
145}
146
147impl UpdateManyBuilder {
148    /// Add a WHERE condition.
149    pub fn filter(mut self, condition: impl Into<WhereExpr>) -> Self {
150        let new_where = condition.into();
151        self.where_clause = Some(match self.where_clause.take() {
152            Some(existing) => existing.and_with(new_where),
153            None => new_where,
154        });
155        self
156    }
157
158    /// Explicitly allow updating all rows without a WHERE clause.
159    ///
160    /// Without this, executing without a `.filter()` returns an error.
161    pub fn all_rows(mut self) -> Self {
162        self.all_rows = true;
163        self
164    }
165
166    /// Build the SQL statement without executing it.
167    ///
168    /// Useful for inspecting the generated SQL.
169    pub fn build_sql(&self) -> OrmResult<Sql> {
170        if self.where_clause.is_none() && !self.all_rows {
171            return Err(OrmError::Validation(
172                "update_many requires a .filter() condition or .all_rows() to proceed. \
173                 This prevents accidental full-table updates."
174                    .to_string(),
175            ));
176        }
177
178        let mut sql = Sql::new("UPDATE ");
179        sql.push_ident_ref(&self.table);
180        sql.push(" SET ");
181
182        for (i, set) in self.sets.iter().enumerate() {
183            if i > 0 {
184                sql.push(", ");
185            }
186            set.append_to_sql(&mut sql);
187        }
188
189        if let Some(ref where_clause) = self.where_clause {
190            sql.push(" WHERE ");
191            where_clause.append_to_sql(&mut sql);
192        }
193
194        Ok(sql)
195    }
196
197    /// Execute the update, returning the number of affected rows.
198    pub async fn execute(self, conn: &impl GenericClient) -> OrmResult<u64> {
199        let sql = self.build_sql()?;
200        sql.execute(conn).await
201    }
202
203    /// Execute the update and return the affected rows.
204    ///
205    /// Appends `RETURNING *` to the query.
206    pub async fn returning<T: FromRow>(self, conn: &impl GenericClient) -> OrmResult<Vec<T>> {
207        let mut sql = self.build_sql()?;
208        sql.push(" RETURNING *");
209        sql.fetch_all_as(conn).await
210    }
211
212    /// Execute the update and return a stream of affected rows.
213    ///
214    /// Appends `RETURNING *` to the query.
215    pub async fn returning_stream<T: FromRow>(
216        self,
217        conn: &impl StreamingClient,
218    ) -> OrmResult<FromRowStream<T>> {
219        let mut sql = self.build_sql()?;
220        sql.push(" RETURNING *");
221        sql.stream_as(conn).await
222    }
223}
224
225// ==================== DeleteManyBuilder ====================
226
227/// Builder for bulk DELETE operations.
228///
229/// Created via [`Sql::delete_many`].
230///
231/// # Example
232/// ```ignore
233/// pgorm::sql("sessions")
234///     .delete_many()
235///     .filter(Condition::lt("expires_at", now)?)
236///     .execute(&client)
237///     .await?;
238/// ```
239#[must_use]
240pub struct DeleteManyBuilder {
241    pub(crate) table: Ident,
242    pub(crate) where_clause: Option<WhereExpr>,
243    pub(crate) all_rows: bool,
244}
245
246impl DeleteManyBuilder {
247    /// Add a WHERE condition.
248    pub fn filter(mut self, condition: impl Into<WhereExpr>) -> Self {
249        let new_where = condition.into();
250        self.where_clause = Some(match self.where_clause.take() {
251            Some(existing) => existing.and_with(new_where),
252            None => new_where,
253        });
254        self
255    }
256
257    /// Explicitly allow deleting all rows without a WHERE clause.
258    ///
259    /// Without this, executing without a `.filter()` returns an error.
260    pub fn all_rows(mut self) -> Self {
261        self.all_rows = true;
262        self
263    }
264
265    /// Build the SQL statement without executing it.
266    ///
267    /// Useful for inspecting the generated SQL.
268    pub fn build_sql(&self) -> OrmResult<Sql> {
269        if self.where_clause.is_none() && !self.all_rows {
270            return Err(OrmError::Validation(
271                "delete_many requires a .filter() condition or .all_rows() to proceed. \
272                 This prevents accidental full-table deletes."
273                    .to_string(),
274            ));
275        }
276
277        let mut sql = Sql::new("DELETE FROM ");
278        sql.push_ident_ref(&self.table);
279
280        if let Some(ref where_clause) = self.where_clause {
281            sql.push(" WHERE ");
282            where_clause.append_to_sql(&mut sql);
283        }
284
285        Ok(sql)
286    }
287
288    /// Execute the delete, returning the number of affected rows.
289    pub async fn execute(self, conn: &impl GenericClient) -> OrmResult<u64> {
290        let sql = self.build_sql()?;
291        sql.execute(conn).await
292    }
293
294    /// Execute the delete and return the deleted rows.
295    ///
296    /// Appends `RETURNING *` to the query.
297    pub async fn returning<T: FromRow>(self, conn: &impl GenericClient) -> OrmResult<Vec<T>> {
298        let mut sql = self.build_sql()?;
299        sql.push(" RETURNING *");
300        sql.fetch_all_as(conn).await
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::condition::Condition;
308
309    #[test]
310    fn update_many_basic_sql() {
311        let builder = UpdateManyBuilder {
312            table: Ident::parse("users").unwrap(),
313            sets: vec![SetExpr::set("status", "inactive").unwrap()],
314            where_clause: Some(WhereExpr::Atom(Condition::eq("active", true).unwrap())),
315            all_rows: false,
316        };
317        let sql = builder.build_sql().unwrap();
318        assert_eq!(
319            sql.to_sql(),
320            "UPDATE users SET status = $1 WHERE active = $2"
321        );
322        assert_eq!(sql.params_ref().len(), 2);
323    }
324
325    #[test]
326    fn update_many_multiple_sets() {
327        let builder = UpdateManyBuilder {
328            table: Ident::parse("orders").unwrap(),
329            sets: vec![
330                SetExpr::set("status", "shipped").unwrap(),
331                SetExpr::raw("shipped_at = NOW()"),
332            ],
333            where_clause: Some(WhereExpr::Atom(Condition::eq("id", 1_i64).unwrap())),
334            all_rows: false,
335        };
336        let sql = builder.build_sql().unwrap();
337        assert_eq!(
338            sql.to_sql(),
339            "UPDATE orders SET status = $1, shipped_at = NOW() WHERE id = $2"
340        );
341        assert_eq!(sql.params_ref().len(), 2);
342    }
343
344    #[test]
345    fn update_many_increment() {
346        let builder = UpdateManyBuilder {
347            table: Ident::parse("products").unwrap(),
348            sets: vec![SetExpr::increment("view_count", 1).unwrap()],
349            where_clause: Some(WhereExpr::Atom(Condition::eq("id", 42_i64).unwrap())),
350            all_rows: false,
351        };
352        let sql = builder.build_sql().unwrap();
353        assert_eq!(
354            sql.to_sql(),
355            "UPDATE products SET view_count = view_count + 1 WHERE id = $1"
356        );
357        assert_eq!(sql.params_ref().len(), 1);
358    }
359
360    #[test]
361    fn update_many_decrement() {
362        let builder = UpdateManyBuilder {
363            table: Ident::parse("products").unwrap(),
364            sets: vec![SetExpr::increment("stock", -5).unwrap()],
365            where_clause: Some(WhereExpr::Atom(Condition::eq("id", 1_i64).unwrap())),
366            all_rows: false,
367        };
368        let sql = builder.build_sql().unwrap();
369        assert_eq!(
370            sql.to_sql(),
371            "UPDATE products SET stock = stock - 5 WHERE id = $1"
372        );
373    }
374
375    #[test]
376    fn update_many_all_rows() {
377        let builder = UpdateManyBuilder {
378            table: Ident::parse("temp_data").unwrap(),
379            sets: vec![SetExpr::set("status", "archived").unwrap()],
380            where_clause: None,
381            all_rows: true,
382        };
383        let sql = builder.build_sql().unwrap();
384        assert_eq!(sql.to_sql(), "UPDATE temp_data SET status = $1");
385    }
386
387    #[test]
388    fn update_many_rejects_no_where() {
389        let builder = UpdateManyBuilder {
390            table: Ident::parse("users").unwrap(),
391            sets: vec![SetExpr::set("status", "x").unwrap()],
392            where_clause: None,
393            all_rows: false,
394        };
395        assert!(builder.build_sql().is_err());
396    }
397
398    #[test]
399    fn delete_many_basic_sql() {
400        let builder = DeleteManyBuilder {
401            table: Ident::parse("sessions").unwrap(),
402            where_clause: Some(WhereExpr::raw("expires_at < NOW()")),
403            all_rows: false,
404        };
405        let sql = builder.build_sql().unwrap();
406        assert_eq!(
407            sql.to_sql(),
408            "DELETE FROM sessions WHERE expires_at < NOW()"
409        );
410    }
411
412    #[test]
413    fn delete_many_with_condition() {
414        let builder = DeleteManyBuilder {
415            table: Ident::parse("audit_logs").unwrap(),
416            where_clause: Some(WhereExpr::And(vec![
417                WhereExpr::Atom(Condition::eq("level", "debug").unwrap()),
418                WhereExpr::Atom(Condition::eq("archived", true).unwrap()),
419            ])),
420            all_rows: false,
421        };
422        let sql = builder.build_sql().unwrap();
423        assert_eq!(
424            sql.to_sql(),
425            "DELETE FROM audit_logs WHERE (level = $1 AND archived = $2)"
426        );
427        assert_eq!(sql.params_ref().len(), 2);
428    }
429
430    #[test]
431    fn delete_many_all_rows() {
432        let builder = DeleteManyBuilder {
433            table: Ident::parse("temp_data").unwrap(),
434            where_clause: None,
435            all_rows: true,
436        };
437        let sql = builder.build_sql().unwrap();
438        assert_eq!(sql.to_sql(), "DELETE FROM temp_data");
439    }
440
441    #[test]
442    fn delete_many_rejects_no_where() {
443        let builder = DeleteManyBuilder {
444            table: Ident::parse("users").unwrap(),
445            where_clause: None,
446            all_rows: false,
447        };
448        assert!(builder.build_sql().is_err());
449    }
450
451    #[test]
452    fn update_many_via_sql_builder() {
453        let builder = crate::sql("users")
454            .update_many([SetExpr::set("status", "inactive").unwrap()])
455            .unwrap()
456            .filter(Condition::eq("active", true).unwrap());
457        let sql = builder.build_sql().unwrap();
458        assert_eq!(
459            sql.to_sql(),
460            "UPDATE users SET status = $1 WHERE active = $2"
461        );
462    }
463
464    #[test]
465    fn delete_many_via_sql_builder() {
466        let builder = crate::sql("sessions")
467            .delete_many()
468            .unwrap()
469            .filter(WhereExpr::raw("expires_at < NOW()"));
470        let sql = builder.build_sql().unwrap();
471        assert_eq!(
472            sql.to_sql(),
473            "DELETE FROM sessions WHERE expires_at < NOW()"
474        );
475    }
476
477    #[test]
478    fn update_many_filter_combines_with_and() {
479        let builder = crate::sql("orders")
480            .update_many([SetExpr::set("status", "archived").unwrap()])
481            .unwrap()
482            .filter(Condition::eq("status", "cancelled").unwrap())
483            .filter(Condition::eq("archived", false).unwrap());
484        let sql = builder.build_sql().unwrap();
485        assert_eq!(
486            sql.to_sql(),
487            "UPDATE orders SET status = $1 WHERE (status = $2 AND archived = $3)"
488        );
489    }
490
491    #[test]
492    fn delete_many_filter_combines_with_and() {
493        let builder = crate::sql("logs")
494            .delete_many()
495            .unwrap()
496            .filter(Condition::eq("level", "debug").unwrap())
497            .filter(Condition::eq("archived", true).unwrap());
498        let sql = builder.build_sql().unwrap();
499        assert_eq!(
500            sql.to_sql(),
501            "DELETE FROM logs WHERE (level = $1 AND archived = $2)"
502        );
503    }
504
505    #[test]
506    fn set_expr_validates_column_name() {
507        assert!(SetExpr::set("valid_column", "value").is_ok());
508        assert!(SetExpr::set("1invalid", "value").is_err());
509        assert!(SetExpr::set("has space", "value").is_err());
510        assert!(SetExpr::increment("valid_col", 1).is_ok());
511        assert!(SetExpr::increment("bad;col", 1).is_err());
512    }
513}