Skip to main content

cratestack_sqlx/query/batch/
update.rs

1//! `batch_update` driver — opens the outer tx, fans the items out to
2//! [`super::update_item::run_update_item`] (one savepoint per item),
3//! commits, drains the outbox.
4
5use std::hash::Hash;
6
7use cratestack_core::{BatchResponse, CoolContext, CoolError, ModelEventKind};
8
9use crate::audit::ensure_audit_table;
10use crate::descriptor::ensure_event_outbox_table;
11use crate::{ModelDescriptor, SqlxRuntime, UpdateModelInput, sqlx};
12
13use super::update_item::run_update_item;
14use super::validate::{reject_duplicate_pks, validate_batch_size};
15
16/// One per-item update: `(id, patch, optional expected version)`.
17pub type BatchUpdateItem<PK, I> = (PK, I, Option<i64>);
18
19#[derive(Debug, Clone)]
20pub struct BatchUpdate<'a, M: 'static, PK: 'static, I> {
21    pub(crate) runtime: &'a SqlxRuntime,
22    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
23    pub(crate) items: Vec<BatchUpdateItem<PK, I>>,
24}
25
26impl<'a, M: 'static, PK: 'static, I> BatchUpdate<'a, M, PK, I>
27where
28    I: UpdateModelInput<M> + Send,
29{
30    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
31    where
32        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
33        PK: Clone
34            + Eq
35            + Hash
36            + Send
37            + sqlx::Type<sqlx::Postgres>
38            + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
39    {
40        validate_batch_size(self.items.len())?;
41        let ids: Vec<PK> = self.items.iter().map(|(id, _, _)| id.clone()).collect();
42        reject_duplicate_pks(&ids)?;
43        if self.items.is_empty() {
44            return Ok(BatchResponse::from_results(vec![]));
45        }
46
47        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
48        let audit_enabled = self.descriptor.audit_enabled;
49
50        let mut tx = self
51            .runtime
52            .pool()
53            .begin()
54            .await
55            .map_err(|error| CoolError::Database(error.to_string()))?;
56        if emits_event {
57            ensure_event_outbox_table(&mut *tx).await?;
58        }
59        if audit_enabled {
60            ensure_audit_table(self.runtime.pool()).await?;
61        }
62
63        let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.items.len());
64        for (id, input, if_match) in self.items {
65            let outcome = run_update_item(
66                &mut tx,
67                self.descriptor,
68                id,
69                input,
70                if_match,
71                ctx,
72                emits_event,
73                audit_enabled,
74            )
75            .await?;
76            per_item.push(outcome);
77        }
78
79        tx.commit()
80            .await
81            .map_err(|error| CoolError::Database(error.to_string()))?;
82
83        if emits_event {
84            let _ = self.runtime.drain_event_outbox().await;
85        }
86
87        Ok(BatchResponse::from_results(per_item))
88    }
89}