Skip to main content

cratestack_sqlx/query/batch/
upsert.rs

1//! `batch_upsert` driver — dedupes inputs by PK (different shape from
2//! `batch_update` because `UpsertModelInput` exposes a PK getter), then
3//! fans out to [`super::upsert_item::run_upsert_item`].
4
5use cratestack_core::{BatchResponse, CoolContext, CoolError, ModelEventKind};
6
7use crate::audit::ensure_audit_table;
8use crate::descriptor::ensure_event_outbox_table;
9use crate::{ModelDescriptor, SqlValue, SqlxRuntime, UpsertModelInput, sqlx};
10
11use super::upsert_item::run_upsert_item;
12use super::validate::{reject_duplicate_sql_values, validate_batch_size};
13
14#[derive(Debug, Clone)]
15pub struct BatchUpsert<'a, M: 'static, PK: 'static, I> {
16    pub(crate) runtime: &'a SqlxRuntime,
17    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
18    pub(crate) inputs: Vec<I>,
19}
20
21impl<'a, M: 'static, PK: 'static, I> BatchUpsert<'a, M, PK, I>
22where
23    I: UpsertModelInput<M>,
24{
25    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
26    where
27        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
28        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
29    {
30        validate_batch_size(self.inputs.len())?;
31        // Upsert dedup runs on the per-input primary key — keeps two
32        // callers from both producing batches with the same key and
33        // ending up with surprising "second write wins" semantics.
34        let pks: Vec<SqlValue> = self
35            .inputs
36            .iter()
37            .map(UpsertModelInput::primary_key_value)
38            .collect();
39        reject_duplicate_sql_values(&pks)?;
40        if self.inputs.is_empty() {
41            return Ok(BatchResponse::from_results(vec![]));
42        }
43
44        let emits_created = self.descriptor.emits(ModelEventKind::Created);
45        let emits_updated = self.descriptor.emits(ModelEventKind::Updated);
46        let audit_enabled = self.descriptor.audit_enabled;
47
48        let mut tx = self
49            .runtime
50            .pool()
51            .begin()
52            .await
53            .map_err(|error| CoolError::Database(error.to_string()))?;
54        if emits_created || emits_updated {
55            ensure_event_outbox_table(&mut *tx).await?;
56        }
57        if audit_enabled {
58            ensure_audit_table(self.runtime.pool()).await?;
59        }
60
61        let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.inputs.len());
62        for input in self.inputs {
63            let outcome = run_upsert_item(
64                &mut tx,
65                self.runtime.pool(),
66                self.descriptor,
67                input,
68                ctx,
69                emits_created,
70                emits_updated,
71                audit_enabled,
72            )
73            .await?;
74            per_item.push(outcome);
75        }
76
77        tx.commit()
78            .await
79            .map_err(|error| CoolError::Database(error.to_string()))?;
80
81        if emits_created || emits_updated {
82            let _ = self.runtime.drain_event_outbox().await;
83        }
84
85        Ok(BatchResponse::from_results(per_item))
86    }
87}