Skip to main content

cratestack_sqlx/query/
batch.rs

1//! Batch primitives — `batch_get`, `batch_create`, `batch_update`,
2//! `batch_delete`, `batch_upsert`.
3//!
4//! Wire shape is the tRPC-style envelope from `cratestack-core`: every
5//! request returns `Vec<BatchItemResult<M>>` where each item carries an
6//! independent `Ok(M)` or `Err(BatchItemError)`. The outer `Result` is
7//! reserved for whole-batch infrastructure failures (size cap exceeded,
8//! duplicate-input rejection, DB connection lost).
9//!
10//! Transactional model: one outer `BEGIN`, with each mutating item running
11//! in a nested `SAVEPOINT`. Successful items commit together when the outer
12//! transaction commits; per-item failures rollback to their savepoint, so
13//! failed items leave no row, no audit row, no event outbox entry. The
14//! two non-mutating ops (`batch_get`) and the single-statement op
15//! (`batch_delete`) don't need savepoints — the WHERE clause already
16//! filters out policy-denied / missing rows, and we walk the returned set
17//! to produce the per-item envelope.
18//!
19//! Sizing: every request is capped at `BATCH_MAX_ITEMS` (1000) at the
20//! outer guard. Duplicate-input keys are loud-failed at the same guard.
21
22use std::collections::HashMap;
23use std::hash::Hash;
24
25use crate::sqlx;
26// `Acquire::begin` is what gives us a nested transaction (= SAVEPOINT) on
27// `&mut Transaction`. Without it in scope, `.begin()` resolves to the
28// inherent `Transaction::begin` constructor and rustc rightly complains.
29use sqlx_core::acquire::Acquire as _;
30
31use cratestack_core::{
32    AuditOperation, BATCH_MAX_ITEMS, BatchResponse, CoolContext, CoolError, ModelEventKind,
33    find_duplicate_position,
34};
35
36use crate::{
37    CreateModelInput, ModelDescriptor, ModelPrimaryKey, SqlValue, SqlxRuntime, UpdateModelInput,
38    UpsertModelInput,
39    audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
40    descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
41};
42
43use super::support::{
44    apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
45    push_bind_value,
46};
47
48// ───── outer guards ─────────────────────────────────────────────────────────
49
50fn validate_batch_size(len: usize) -> Result<(), CoolError> {
51    if len > BATCH_MAX_ITEMS {
52        return Err(CoolError::Validation(format!(
53            "batch size {len} exceeds maximum of {BATCH_MAX_ITEMS}",
54        )));
55    }
56    Ok(())
57}
58
59fn reject_duplicate_pks<K: Eq + Hash + Clone>(keys: &[K]) -> Result<(), CoolError> {
60    if let Some((first, dup)) = find_duplicate_position(keys.iter().cloned()) {
61        return Err(CoolError::Validation(format!(
62            "duplicate primary key in batch at positions {first} and {dup}",
63        )));
64    }
65    Ok(())
66}
67
68fn reject_duplicate_sql_values(values: &[SqlValue]) -> Result<(), CoolError> {
69    if let Some((first, dup)) = cratestack_sql::find_duplicate_sql_value(values) {
70        return Err(CoolError::Validation(format!(
71            "duplicate primary key in batch at positions {first} and {dup}",
72        )));
73    }
74    Ok(())
75}
76
77// ───── BatchGet ─────────────────────────────────────────────────────────────
78
79#[derive(Debug, Clone)]
80pub struct BatchGet<'a, M: 'static, PK: 'static> {
81    pub(crate) runtime: &'a SqlxRuntime,
82    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
83    pub(crate) ids: Vec<PK>,
84}
85
86impl<'a, M: 'static, PK: 'static> BatchGet<'a, M, PK> {
87    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
88    where
89        for<'r> M:
90            Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + ModelPrimaryKey<PK>,
91        PK: Clone
92            + Eq
93            + Hash
94            + Send
95            + sqlx::Type<sqlx::Postgres>
96            + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
97    {
98        validate_batch_size(self.ids.len())?;
99        reject_duplicate_pks(&self.ids)?;
100        if self.ids.is_empty() {
101            return Ok(BatchResponse::from_results(vec![]));
102        }
103
104        // Single SELECT with IN-list + read policy + soft-delete filter.
105        let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
106        query.push(self.descriptor.select_projection());
107        query.push(" FROM ").push(self.descriptor.table_name);
108        query.push(" WHERE ");
109        if let Some(col) = self.descriptor.soft_delete_column {
110            query.push(col).push(" IS NULL AND ");
111        }
112        query.push(self.descriptor.primary_key).push(" IN (");
113        for (index, id) in self.ids.iter().enumerate() {
114            if index > 0 {
115                query.push(", ");
116            }
117            query.push_bind(id.clone());
118        }
119        query.push(") AND ");
120        push_action_policy_query(
121            &mut query,
122            self.descriptor.read_allow_policies,
123            self.descriptor.read_deny_policies,
124            ctx,
125        );
126
127        let rows: Vec<M> = query
128            .build_query_as::<M>()
129            .fetch_all(self.runtime.pool())
130            .await
131            .map_err(|error| CoolError::Database(error.to_string()))?;
132
133        // Walk-and-match: pair each input PK back to its row, or NotFound
134        // when the read policy / soft-delete filter excluded it.
135        let mut by_pk: HashMap<PK, M> =
136            rows.into_iter().map(|m| (m.primary_key(), m)).collect();
137        let per_item: Vec<Result<M, CoolError>> = self
138            .ids
139            .into_iter()
140            .map(|id| {
141                by_pk
142                    .remove(&id)
143                    .ok_or_else(|| CoolError::NotFound("no row matched".to_owned()))
144            })
145            .collect();
146
147        Ok(BatchResponse::from_results(per_item))
148    }
149}
150
151// ───── BatchDelete ──────────────────────────────────────────────────────────
152
153#[derive(Debug, Clone)]
154pub struct BatchDelete<'a, M: 'static, PK: 'static> {
155    pub(crate) runtime: &'a SqlxRuntime,
156    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
157    pub(crate) ids: Vec<PK>,
158}
159
160impl<'a, M: 'static, PK: 'static> BatchDelete<'a, M, PK> {
161    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
162    where
163        for<'r> M: Send
164            + Unpin
165            + sqlx::FromRow<'r, sqlx::postgres::PgRow>
166            + ModelPrimaryKey<PK>
167            + serde::Serialize,
168        PK: Clone
169            + Eq
170            + Hash
171            + Send
172            + sqlx::Type<sqlx::Postgres>
173            + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
174    {
175        validate_batch_size(self.ids.len())?;
176        reject_duplicate_pks(&self.ids)?;
177        if self.ids.is_empty() {
178            return Ok(BatchResponse::from_results(vec![]));
179        }
180
181        let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
182        let audit_enabled = self.descriptor.audit_enabled;
183
184        let mut tx = self
185            .runtime
186            .pool()
187            .begin()
188            .await
189            .map_err(|error| CoolError::Database(error.to_string()))?;
190        if emits_event {
191            ensure_event_outbox_table(&mut *tx).await?;
192        }
193        if audit_enabled {
194            ensure_audit_table(self.runtime.pool()).await?;
195        }
196
197        // Build the DELETE-or-soft-delete statement with the policy
198        // predicate baked into the WHERE.
199        let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
200        match self.descriptor.soft_delete_column {
201            Some(col) => {
202                query.push("UPDATE ").push(self.descriptor.table_name);
203                query.push(" SET ").push(col).push(" = NOW()");
204                if let Some(version_col) = self.descriptor.version_column {
205                    query
206                        .push(", ")
207                        .push(version_col)
208                        .push(" = ")
209                        .push(version_col)
210                        .push(" + 1");
211                }
212                query.push(" WHERE ").push(col).push(" IS NULL AND ");
213            }
214            None => {
215                query.push("DELETE FROM ").push(self.descriptor.table_name);
216                query.push(" WHERE ");
217            }
218        }
219        query.push(self.descriptor.primary_key).push(" IN (");
220        for (index, id) in self.ids.iter().enumerate() {
221            if index > 0 {
222                query.push(", ");
223            }
224            query.push_bind(id.clone());
225        }
226        query.push(") AND ");
227        push_action_policy_query(
228            &mut query,
229            self.descriptor.delete_allow_policies,
230            self.descriptor.delete_deny_policies,
231            ctx,
232        );
233        query
234            .push(" RETURNING ")
235            .push(self.descriptor.select_projection());
236
237        let deleted: Vec<M> = query
238            .build_query_as::<M>()
239            .fetch_all(&mut *tx)
240            .await
241            .map_err(|error| CoolError::Database(error.to_string()))?;
242
243        // Fan-out one audit + one outbox entry per actually-deleted row.
244        // The RETURNING row IS the "before" snapshot — DELETE/soft-delete
245        // returns the pre-mutation state.
246        for record in &deleted {
247            if emits_event {
248                enqueue_event_outbox(
249                    &mut *tx,
250                    self.descriptor.schema_name,
251                    ModelEventKind::Deleted,
252                    record,
253                )
254                .await?;
255            }
256            if audit_enabled {
257                let before = serde_json::to_value(record).ok();
258                let event = build_audit_event(
259                    self.descriptor,
260                    AuditOperation::Delete,
261                    before,
262                    None,
263                    ctx,
264                );
265                enqueue_audit_event(&mut *tx, &event).await?;
266            }
267        }
268
269        tx.commit()
270            .await
271            .map_err(|error| CoolError::Database(error.to_string()))?;
272
273        if emits_event {
274            let _ = self.runtime.drain_event_outbox().await;
275        }
276
277        // Walk-and-match: any input id whose row isn't in `deleted` failed
278        // the WHERE clause (already tombstoned, policy denied, or never
279        // existed). All three collapse to NotFound on the wire.
280        let mut by_pk: HashMap<PK, M> =
281            deleted.into_iter().map(|m| (m.primary_key(), m)).collect();
282        let per_item: Vec<Result<M, CoolError>> = self
283            .ids
284            .into_iter()
285            .map(|id| {
286                by_pk
287                    .remove(&id)
288                    .ok_or_else(|| CoolError::NotFound("no row matched".to_owned()))
289            })
290            .collect();
291
292        Ok(BatchResponse::from_results(per_item))
293    }
294}
295
296// ───── BatchCreate ──────────────────────────────────────────────────────────
297
298#[derive(Debug, Clone)]
299pub struct BatchCreate<'a, M: 'static, PK: 'static, I> {
300    pub(crate) runtime: &'a SqlxRuntime,
301    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
302    pub(crate) inputs: Vec<I>,
303}
304
305impl<'a, M: 'static, PK: 'static, I> BatchCreate<'a, M, PK, I>
306where
307    I: CreateModelInput<M> + Send,
308{
309    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
310    where
311        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
312    {
313        validate_batch_size(self.inputs.len())?;
314        // No PK dedup here — `CreateModelInput` doesn't expose the PK
315        // generically (and server-generated PKs make duplicates impossible).
316        // Client-supplied PK collisions trip the DB uniqueness constraint
317        // and surface per-item as `CoolError::Database`. The right primitive
318        // for idempotent client-PK ingestion is `.batch_upsert(...)`.
319        if self.inputs.is_empty() {
320            return Ok(BatchResponse::from_results(vec![]));
321        }
322
323        let emits_event = self.descriptor.emits(ModelEventKind::Created);
324        let audit_enabled = self.descriptor.audit_enabled;
325
326        let mut tx = self
327            .runtime
328            .pool()
329            .begin()
330            .await
331            .map_err(|error| CoolError::Database(error.to_string()))?;
332        if emits_event {
333            ensure_event_outbox_table(&mut *tx).await?;
334        }
335        if audit_enabled {
336            ensure_audit_table(self.runtime.pool()).await?;
337        }
338
339        let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.inputs.len());
340        for input in self.inputs {
341            let outcome = run_create_item(
342                &mut tx,
343                self.runtime.pool(),
344                self.descriptor,
345                input,
346                ctx,
347                emits_event,
348                audit_enabled,
349            )
350            .await?;
351            per_item.push(outcome);
352        }
353
354        tx.commit()
355            .await
356            .map_err(|error| CoolError::Database(error.to_string()))?;
357
358        if emits_event {
359            let _ = self.runtime.drain_event_outbox().await;
360        }
361
362        Ok(BatchResponse::from_results(per_item))
363    }
364}
365
366async fn run_create_item<'tx, M, PK, I>(
367    outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
368    policy_pool: &sqlx::PgPool,
369    descriptor: &'static ModelDescriptor<M, PK>,
370    input: I,
371    ctx: &CoolContext,
372    emits_event: bool,
373    audit_enabled: bool,
374) -> Result<Result<M, CoolError>, CoolError>
375where
376    I: CreateModelInput<M>,
377    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
378{
379    let mut item_tx = outer
380        .begin()
381        .await
382        .map_err(|error| CoolError::Database(error.to_string()))?;
383
384    // All per-item failures funnel through this inner closure so the
385    // savepoint commit/rollback decision is centralized below.
386    let inner: Result<M, CoolError> = async {
387        input.validate()?;
388        let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
389        if let Some(version_col) = descriptor.version_column
390            && find_column_value(&values, version_col).is_none()
391        {
392            values.push(crate::SqlColumnValue {
393                column: version_col,
394                value: crate::SqlValue::Int(0),
395            });
396        }
397        if values.is_empty() {
398            return Err(CoolError::Validation(
399                "create input must contain at least one column".to_owned(),
400            ));
401        }
402        if !evaluate_create_policies(
403            policy_pool,
404            descriptor.create_allow_policies,
405            descriptor.create_deny_policies,
406            &values,
407            ctx,
408        )
409        .await?
410        {
411            return Err(CoolError::Forbidden(
412                "create policy denied this operation".to_owned(),
413            ));
414        }
415
416        let record = insert_one_into_savepoint::<M, PK>(&mut item_tx, descriptor, &values).await?;
417
418        if emits_event {
419            enqueue_event_outbox(
420                &mut *item_tx,
421                descriptor.schema_name,
422                ModelEventKind::Created,
423                &record,
424            )
425            .await?;
426        }
427        if audit_enabled {
428            let after = serde_json::to_value(&record).ok();
429            let event =
430                build_audit_event(descriptor, AuditOperation::Create, None, after, ctx);
431            enqueue_audit_event(&mut *item_tx, &event).await?;
432        }
433        Ok(record)
434    }
435    .await;
436
437    match inner {
438        Ok(record) => {
439            item_tx
440                .commit()
441                .await
442                .map_err(|error| CoolError::Database(error.to_string()))?;
443            Ok(Ok(record))
444        }
445        Err(item_err) => {
446            // ROLLBACK TO SAVEPOINT brings the outer tx back to its
447            // pre-savepoint state. If THAT fails, the outer tx is dead and
448            // we propagate as the outer `Result::Err` — no point trying to
449            // continue.
450            item_tx
451                .rollback()
452                .await
453                .map_err(|error| CoolError::Database(error.to_string()))?;
454            Ok(Err(item_err))
455        }
456    }
457}
458
459async fn insert_one_into_savepoint<'tx, M, PK>(
460    executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
461    descriptor: &'static ModelDescriptor<M, PK>,
462    values: &[crate::SqlColumnValue],
463) -> Result<M, CoolError>
464where
465    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
466{
467    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
468    query.push(descriptor.table_name).push(" (");
469    for (index, value) in values.iter().enumerate() {
470        if index > 0 {
471            query.push(", ");
472        }
473        query.push(value.column);
474    }
475    query.push(") VALUES (");
476    for (index, value) in values.iter().enumerate() {
477        if index > 0 {
478            query.push(", ");
479        }
480        push_bind_value(&mut query, &value.value);
481    }
482    query
483        .push(") RETURNING ")
484        .push(descriptor.select_projection());
485
486    query
487        .build_query_as::<M>()
488        .fetch_one(&mut **executor)
489        .await
490        .map_err(|error| classify_insert_error(error))
491}
492
493/// Map a sqlx error from a per-item INSERT into the right `CoolError`
494/// variant. Unique-constraint violations become `Conflict` so the envelope
495/// surfaces the right code; everything else stays `Database`.
496fn classify_insert_error(error: sqlx::Error) -> CoolError {
497    if let sqlx::Error::Database(db_err) = &error
498        && let Some(code) = db_err.code()
499        && code == "23505"
500    {
501        return CoolError::Conflict(db_err.message().to_owned());
502    }
503    CoolError::Database(error.to_string())
504}
505
506// ───── BatchUpdate ──────────────────────────────────────────────────────────
507
508/// One per-item update: `(id, patch, optional expected version)`.
509pub type BatchUpdateItem<PK, I> = (PK, I, Option<i64>);
510
511#[derive(Debug, Clone)]
512pub struct BatchUpdate<'a, M: 'static, PK: 'static, I> {
513    pub(crate) runtime: &'a SqlxRuntime,
514    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
515    pub(crate) items: Vec<BatchUpdateItem<PK, I>>,
516}
517
518impl<'a, M: 'static, PK: 'static, I> BatchUpdate<'a, M, PK, I>
519where
520    I: UpdateModelInput<M> + Send,
521{
522    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
523    where
524        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
525        PK: Clone
526            + Eq
527            + Hash
528            + Send
529            + sqlx::Type<sqlx::Postgres>
530            + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
531    {
532        validate_batch_size(self.items.len())?;
533        let ids: Vec<PK> = self.items.iter().map(|(id, _, _)| id.clone()).collect();
534        reject_duplicate_pks(&ids)?;
535        if self.items.is_empty() {
536            return Ok(BatchResponse::from_results(vec![]));
537        }
538
539        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
540        let audit_enabled = self.descriptor.audit_enabled;
541
542        let mut tx = self
543            .runtime
544            .pool()
545            .begin()
546            .await
547            .map_err(|error| CoolError::Database(error.to_string()))?;
548        if emits_event {
549            ensure_event_outbox_table(&mut *tx).await?;
550        }
551        if audit_enabled {
552            ensure_audit_table(self.runtime.pool()).await?;
553        }
554
555        let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.items.len());
556        for (id, input, if_match) in self.items {
557            let outcome = run_update_item(
558                &mut tx,
559                self.descriptor,
560                id,
561                input,
562                if_match,
563                ctx,
564                emits_event,
565                audit_enabled,
566            )
567            .await?;
568            per_item.push(outcome);
569        }
570
571        tx.commit()
572            .await
573            .map_err(|error| CoolError::Database(error.to_string()))?;
574
575        if emits_event {
576            let _ = self.runtime.drain_event_outbox().await;
577        }
578
579        Ok(BatchResponse::from_results(per_item))
580    }
581}
582
583async fn run_update_item<'tx, M, PK, I>(
584    outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
585    descriptor: &'static ModelDescriptor<M, PK>,
586    id: PK,
587    input: I,
588    if_match: Option<i64>,
589    ctx: &CoolContext,
590    emits_event: bool,
591    audit_enabled: bool,
592) -> Result<Result<M, CoolError>, CoolError>
593where
594    I: UpdateModelInput<M>,
595    PK: Clone + Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
596    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
597{
598    let mut item_tx = outer
599        .begin()
600        .await
601        .map_err(|error| CoolError::Database(error.to_string()))?;
602
603    let inner: Result<M, CoolError> = async {
604        if descriptor.version_column.is_some() && if_match.is_none() {
605            return Err(CoolError::PreconditionFailed(
606                "If-Match required for versioned model".to_owned(),
607            ));
608        }
609        input.validate()?;
610        let values = input.sql_values();
611        if values.is_empty() {
612            return Err(CoolError::Validation(
613                "update input must contain at least one changed column".to_owned(),
614            ));
615        }
616
617        // Capture before-snapshot under FOR UPDATE for clean audit timing.
618        let before = if audit_enabled {
619            fetch_for_audit(&mut *item_tx, descriptor, id.clone()).await?
620        } else {
621            None
622        };
623
624        let record = update_one_in_savepoint(
625            &mut item_tx,
626            descriptor,
627            id,
628            &values,
629            ctx,
630            if_match,
631        )
632        .await?;
633
634        if emits_event {
635            enqueue_event_outbox(
636                &mut *item_tx,
637                descriptor.schema_name,
638                ModelEventKind::Updated,
639                &record,
640            )
641            .await?;
642        }
643        if audit_enabled {
644            let before_snapshot = before.as_ref().and_then(|m| serde_json::to_value(m).ok());
645            let after = serde_json::to_value(&record).ok();
646            let event = build_audit_event(
647                descriptor,
648                AuditOperation::Update,
649                before_snapshot,
650                after,
651                ctx,
652            );
653            enqueue_audit_event(&mut *item_tx, &event).await?;
654        }
655        Ok(record)
656    }
657    .await;
658
659    match inner {
660        Ok(record) => {
661            item_tx
662                .commit()
663                .await
664                .map_err(|error| CoolError::Database(error.to_string()))?;
665            Ok(Ok(record))
666        }
667        Err(item_err) => {
668            item_tx
669                .rollback()
670                .await
671                .map_err(|error| CoolError::Database(error.to_string()))?;
672            Ok(Err(item_err))
673        }
674    }
675}
676
677async fn update_one_in_savepoint<'tx, M, PK>(
678    executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
679    descriptor: &'static ModelDescriptor<M, PK>,
680    id: PK,
681    values: &[crate::SqlColumnValue],
682    ctx: &CoolContext,
683    if_match: Option<i64>,
684) -> Result<M, CoolError>
685where
686    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
687    PK: Clone + Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
688{
689    let version_column = descriptor.version_column;
690    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
691    query.push(descriptor.table_name).push(" SET ");
692    for (index, value) in values.iter().enumerate() {
693        if index > 0 {
694            query.push(", ");
695        }
696        query.push(value.column).push(" = ");
697        push_bind_value(&mut query, &value.value);
698    }
699    if let Some(version_col) = version_column {
700        query
701            .push(", ")
702            .push(version_col)
703            .push(" = ")
704            .push(version_col)
705            .push(" + 1");
706    }
707    query
708        .push(" WHERE ")
709        .push(descriptor.primary_key)
710        .push(" = ");
711    query.push_bind(id);
712    if let (Some(version_col), Some(expected)) = (version_column, if_match) {
713        query.push(" AND ").push(version_col).push(" = ");
714        query.push_bind(expected);
715    }
716    query.push(" AND ");
717    push_action_policy_query(
718        &mut query,
719        descriptor.update_allow_policies,
720        descriptor.update_deny_policies,
721        ctx,
722    );
723    query
724        .push(" RETURNING ")
725        .push(descriptor.select_projection());
726
727    let outcome = query
728        .build_query_as::<M>()
729        .fetch_optional(&mut **executor)
730        .await
731        .map_err(|error| CoolError::Database(error.to_string()))?;
732    match outcome {
733        Some(record) => Ok(record),
734        None => {
735            // Could be: row missing, policy denied, version mismatch, soft-
736            // deleted. Probing to discriminate adds round-trips; we report
737            // Forbidden for batches (matches single-update behavior) when
738            // there's no `if_match`, and PreconditionFailed when there is.
739            // Either way the caller's recovery is the same: refetch & retry.
740            if if_match.is_some() {
741                Err(CoolError::PreconditionFailed(
742                    "version mismatch or row missing".to_owned(),
743                ))
744            } else {
745                Err(CoolError::Forbidden(
746                    "update policy denied or row missing".to_owned(),
747                ))
748            }
749        }
750    }
751}
752
753// ───── BatchUpsert ──────────────────────────────────────────────────────────
754
755#[derive(Debug, Clone)]
756pub struct BatchUpsert<'a, M: 'static, PK: 'static, I> {
757    pub(crate) runtime: &'a SqlxRuntime,
758    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
759    pub(crate) inputs: Vec<I>,
760}
761
762impl<'a, M: 'static, PK: 'static, I> BatchUpsert<'a, M, PK, I>
763where
764    I: UpsertModelInput<M>,
765{
766    pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
767    where
768        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
769        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
770    {
771        validate_batch_size(self.inputs.len())?;
772        // Upsert dedup runs on the per-input primary key — this is what
773        // keeps two callers from both producing batches with the same key
774        // and ending up with surprising "second write wins" semantics.
775        let pks: Vec<SqlValue> = self
776            .inputs
777            .iter()
778            .map(UpsertModelInput::primary_key_value)
779            .collect();
780        reject_duplicate_sql_values(&pks)?;
781        if self.inputs.is_empty() {
782            return Ok(BatchResponse::from_results(vec![]));
783        }
784
785        let emits_created = self.descriptor.emits(ModelEventKind::Created);
786        let emits_updated = self.descriptor.emits(ModelEventKind::Updated);
787        let audit_enabled = self.descriptor.audit_enabled;
788
789        let mut tx = self
790            .runtime
791            .pool()
792            .begin()
793            .await
794            .map_err(|error| CoolError::Database(error.to_string()))?;
795        if emits_created || emits_updated {
796            ensure_event_outbox_table(&mut *tx).await?;
797        }
798        if audit_enabled {
799            ensure_audit_table(self.runtime.pool()).await?;
800        }
801
802        let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.inputs.len());
803        for input in self.inputs {
804            let outcome = run_upsert_item(
805                &mut tx,
806                self.runtime.pool(),
807                self.descriptor,
808                input,
809                ctx,
810                emits_created,
811                emits_updated,
812                audit_enabled,
813            )
814            .await?;
815            per_item.push(outcome);
816        }
817
818        tx.commit()
819            .await
820            .map_err(|error| CoolError::Database(error.to_string()))?;
821
822        if emits_created || emits_updated {
823            let _ = self.runtime.drain_event_outbox().await;
824        }
825
826        Ok(BatchResponse::from_results(per_item))
827    }
828}
829
830#[allow(clippy::too_many_arguments)]
831async fn run_upsert_item<'tx, M, PK, I>(
832    outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
833    policy_pool: &sqlx::PgPool,
834    descriptor: &'static ModelDescriptor<M, PK>,
835    input: I,
836    ctx: &CoolContext,
837    emits_created: bool,
838    emits_updated: bool,
839    audit_enabled: bool,
840) -> Result<Result<M, CoolError>, CoolError>
841where
842    I: UpsertModelInput<M>,
843    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
844    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
845{
846    let mut item_tx = outer
847        .begin()
848        .await
849        .map_err(|error| CoolError::Database(error.to_string()))?;
850
851    let inner: Result<M, CoolError> = async {
852        input.validate()?;
853        let mut insert_values =
854            apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
855        if let Some(version_col) = descriptor.version_column
856            && find_column_value(&insert_values, version_col).is_none()
857        {
858            insert_values.push(crate::SqlColumnValue {
859                column: version_col,
860                value: crate::SqlValue::Int(0),
861            });
862        }
863        if insert_values.is_empty() {
864            return Err(CoolError::Validation(
865                "upsert input must contain at least one column".to_owned(),
866            ));
867        }
868        if !evaluate_create_policies(
869            policy_pool,
870            descriptor.create_allow_policies,
871            descriptor.create_deny_policies,
872            &insert_values,
873            ctx,
874        )
875        .await?
876        {
877            return Err(CoolError::Forbidden(
878                "create policy denied this upsert".to_owned(),
879            ));
880        }
881
882        let pk_value = input.primary_key_value();
883        // Probe under FOR UPDATE so the audit before-snapshot is consistent
884        // with the row state at the moment of the upsert.
885        let before_record =
886            select_for_update_by_pk_value(&mut item_tx, descriptor, &pk_value).await?;
887        let inserted = before_record.is_none();
888
889        if !inserted
890            && !row_passes_update_policy(policy_pool, descriptor, &pk_value, ctx).await?
891        {
892            return Err(CoolError::Forbidden(
893                "update policy denied this upsert".to_owned(),
894            ));
895        }
896
897        let before_snapshot = if !inserted && audit_enabled {
898            before_record
899                .as_ref()
900                .and_then(|m| serde_json::to_value(m).ok())
901        } else {
902            None
903        };
904
905        let record =
906            upsert_one_in_savepoint::<M, PK>(&mut item_tx, descriptor, &insert_values).await?;
907
908        let event_kind = if inserted {
909            ModelEventKind::Created
910        } else {
911            ModelEventKind::Updated
912        };
913        let audit_op = if inserted {
914            AuditOperation::Create
915        } else {
916            AuditOperation::Update
917        };
918        let emits_this_event = if inserted { emits_created } else { emits_updated };
919
920        if emits_this_event {
921            enqueue_event_outbox(&mut *item_tx, descriptor.schema_name, event_kind, &record)
922                .await?;
923        }
924        if audit_enabled {
925            let after = serde_json::to_value(&record).ok();
926            let event = build_audit_event(descriptor, audit_op, before_snapshot, after, ctx);
927            enqueue_audit_event(&mut *item_tx, &event).await?;
928        }
929
930        Ok(record)
931    }
932    .await;
933
934    match inner {
935        Ok(record) => {
936            item_tx
937                .commit()
938                .await
939                .map_err(|error| CoolError::Database(error.to_string()))?;
940            Ok(Ok(record))
941        }
942        Err(item_err) => {
943            item_tx
944                .rollback()
945                .await
946                .map_err(|error| CoolError::Database(error.to_string()))?;
947            Ok(Err(item_err))
948        }
949    }
950}
951
952async fn select_for_update_by_pk_value<'tx, M, PK>(
953    executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
954    descriptor: &'static ModelDescriptor<M, PK>,
955    pk_value: &SqlValue,
956) -> Result<Option<M>, CoolError>
957where
958    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
959{
960    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
961    query.push(descriptor.select_projection());
962    query.push(" FROM ").push(descriptor.table_name);
963    query
964        .push(" WHERE ")
965        .push(descriptor.primary_key)
966        .push(" = ");
967    push_bind_value(&mut query, pk_value);
968    if let Some(col) = descriptor.soft_delete_column {
969        query.push(" AND ").push(col).push(" IS NULL");
970    }
971    query.push(" FOR UPDATE");
972
973    query
974        .build_query_as::<M>()
975        .fetch_optional(&mut **executor)
976        .await
977        .map_err(|error| CoolError::Database(error.to_string()))
978}
979
980async fn row_passes_update_policy<M, PK>(
981    policy_pool: &sqlx::PgPool,
982    descriptor: &'static ModelDescriptor<M, PK>,
983    pk_value: &SqlValue,
984    ctx: &CoolContext,
985) -> Result<bool, CoolError> {
986    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT 1 FROM ");
987    query.push(descriptor.table_name);
988    query
989        .push(" WHERE ")
990        .push(descriptor.primary_key)
991        .push(" = ");
992    push_bind_value(&mut query, pk_value);
993    query.push(" AND ");
994    push_action_policy_query(
995        &mut query,
996        descriptor.update_allow_policies,
997        descriptor.update_deny_policies,
998        ctx,
999    );
1000
1001    let row: Option<(i32,)> = query
1002        .build_query_as::<(i32,)>()
1003        .fetch_optional(policy_pool)
1004        .await
1005        .map_err(|error| CoolError::Database(error.to_string()))?;
1006    Ok(row.is_some())
1007}
1008
1009async fn upsert_one_in_savepoint<'tx, M, PK>(
1010    executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
1011    descriptor: &'static ModelDescriptor<M, PK>,
1012    insert_values: &[crate::SqlColumnValue],
1013) -> Result<M, CoolError>
1014where
1015    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
1016{
1017    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
1018    query.push(descriptor.table_name).push(" (");
1019    for (index, value) in insert_values.iter().enumerate() {
1020        if index > 0 {
1021            query.push(", ");
1022        }
1023        query.push(value.column);
1024    }
1025    query.push(") VALUES (");
1026    for (index, value) in insert_values.iter().enumerate() {
1027        if index > 0 {
1028            query.push(", ");
1029        }
1030        push_bind_value(&mut query, &value.value);
1031    }
1032    query
1033        .push(") ON CONFLICT (")
1034        .push(descriptor.primary_key)
1035        .push(") DO UPDATE SET ");
1036
1037    if descriptor.upsert_update_columns.is_empty() {
1038        query.push(descriptor.primary_key);
1039        query.push(" = EXCLUDED.").push(descriptor.primary_key);
1040    } else {
1041        for (index, column) in descriptor.upsert_update_columns.iter().enumerate() {
1042            if index > 0 {
1043                query.push(", ");
1044            }
1045            query.push(*column).push(" = EXCLUDED.").push(*column);
1046        }
1047    }
1048    if let Some(version_col) = descriptor.version_column {
1049        query
1050            .push(", ")
1051            .push(version_col)
1052            .push(" = ")
1053            .push(descriptor.table_name)
1054            .push(".")
1055            .push(version_col)
1056            .push(" + 1");
1057    }
1058
1059    query
1060        .push(" RETURNING ")
1061        .push(descriptor.select_projection());
1062
1063    query
1064        .build_query_as::<M>()
1065        .fetch_one(&mut **executor)
1066        .await
1067        .map_err(|error| CoolError::Database(error.to_string()))
1068}