Skip to main content

cratestack_sqlx/query/write/
update_many.rs

1//! Bulk UPDATE-by-predicate: emit one statement that mutates every
2//! row the filter matches AND the update policy admits, in one
3//! round-trip.
4//!
5//! Differences from per-row `.update(id).set(input)`:
6//!   * No `if_match` slot — bulk updates aren't an optimistic-locking
7//!     idiom. `@version` is auto-incremented for every matched row;
8//!     the caller does NOT supply an expected version.
9//!   * Requires at least one filter — predicate-less bulk updates
10//!     should be raw SQL so the intent is obvious at review.
11
12use cratestack_core::{BatchSummary, CoolContext, CoolError};
13
14use crate::{FilterExpr, ModelDescriptor, SqlxRuntime, UpdateModelInput, sqlx};
15
16use super::preview::render_update_many_preview_sql;
17use super::update_many_exec::run_update_many_in_tx;
18
19#[derive(Debug, Clone)]
20pub struct UpdateMany<'a, M: 'static, PK: 'static> {
21    pub(crate) runtime: &'a SqlxRuntime,
22    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
23    pub(crate) filters: Vec<FilterExpr>,
24}
25
26impl<'a, M: 'static, PK: 'static> UpdateMany<'a, M, PK> {
27    pub fn where_(mut self, filter: crate::Filter) -> Self {
28        self.filters.push(FilterExpr::from(filter));
29        self
30    }
31
32    pub fn where_expr(mut self, filter: FilterExpr) -> Self {
33        self.filters.push(filter);
34        self
35    }
36
37    pub fn where_any(mut self, filters: impl IntoIterator<Item = FilterExpr>) -> Self {
38        self.filters.push(FilterExpr::any(filters));
39        self
40    }
41
42    /// Conditionally append a filter — `None` is a no-op.
43    pub fn where_optional<F>(mut self, filter: Option<F>) -> Self
44    where
45        F: Into<FilterExpr>,
46    {
47        if let Some(filter) = filter {
48            self.filters.push(filter.into());
49        }
50        self
51    }
52
53    /// Supply the patch values. Returns a builder ready to `.run(ctx)`.
54    pub fn set<I>(self, input: I) -> UpdateManySet<'a, M, PK, I> {
55        UpdateManySet {
56            runtime: self.runtime,
57            descriptor: self.descriptor,
58            filters: self.filters,
59            input,
60        }
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct UpdateManySet<'a, M: 'static, PK: 'static, I> {
66    pub(crate) runtime: &'a SqlxRuntime,
67    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
68    pub(crate) filters: Vec<FilterExpr>,
69    pub(crate) input: I,
70}
71
72impl<'a, M: 'static, PK: 'static, I> UpdateManySet<'a, M, PK, I>
73where
74    I: UpdateModelInput<M>,
75{
76    pub fn preview_sql(&self) -> String {
77        let values = self.input.sql_values();
78        let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
79        render_update_many_preview_sql(
80            self.descriptor.table_name,
81            self.descriptor.soft_delete_column.is_some(),
82            self.descriptor.version_column,
83            &columns,
84            &self.descriptor.select_projection(),
85        )
86    }
87
88    /// Returns `BatchSummary { total, ok, err }` where
89    /// `total = ok = rows actually updated` and `err = 0`.
90    /// Statement-level failures surface as the outer `Err`.
91    pub async fn run(self, ctx: &CoolContext) -> Result<BatchSummary, CoolError>
92    where
93        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
94    {
95        let runtime = self.runtime;
96        let descriptor = self.descriptor;
97        let mut tx = runtime
98            .pool()
99            .begin()
100            .await
101            .map_err(|error| CoolError::Database(error.to_string()))?;
102        let (summary, emits_event) = run_update_many_in_tx(
103            &mut tx,
104            runtime.pool(),
105            descriptor,
106            &self.filters,
107            self.input,
108            ctx,
109        )
110        .await?;
111        tx.commit()
112            .await
113            .map_err(|error| CoolError::Database(error.to_string()))?;
114        if emits_event {
115            let _ = runtime.drain_event_outbox().await;
116        }
117        Ok(summary)
118    }
119
120    /// Run inside a caller-supplied transaction. Audit + outbox
121    /// writes land in `tx`; caller commits.
122    pub async fn run_in_tx<'tx>(
123        self,
124        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
125        ctx: &CoolContext,
126    ) -> Result<BatchSummary, CoolError>
127    where
128        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
129    {
130        let (summary, _) = run_update_many_in_tx(
131            tx,
132            self.runtime.pool(),
133            self.descriptor,
134            &self.filters,
135            self.input,
136            ctx,
137        )
138        .await?;
139        Ok(summary)
140    }
141}