Skip to main content

cratestack_sqlx/query/
write.rs

1use crate::sqlx;
2
3use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
4
5use crate::{
6    CreateModelInput, ModelDescriptor, SqlValue, SqlxRuntime, UpdateModelInput, UpsertModelInput,
7    audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
8    descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
9};
10
11use super::support::{
12    apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
13    push_bind_value,
14};
15
16/// Render the SQL string for an update. Pure helper, no I/O — separated
17/// so the version-aware branch can be unit-tested without a runtime.
18pub fn render_update_preview_sql(
19    table_name: &str,
20    primary_key: &str,
21    version_column: Option<&str>,
22    columns: &[&str],
23    select_projection: &str,
24) -> String {
25    let assignments = columns
26        .iter()
27        .enumerate()
28        .map(|(index, column)| format!("{column} = ${}", index + 1))
29        .collect::<Vec<_>>()
30        .join(", ");
31
32    match version_column {
33        Some(version_col) => format!(
34            "UPDATE {} SET {}, {} = {} + 1 WHERE {} = ${} AND {} = ${} RETURNING {}",
35            table_name,
36            assignments,
37            version_col,
38            version_col,
39            primary_key,
40            columns.len() + 1,
41            version_col,
42            columns.len() + 2,
43            select_projection,
44        ),
45        None => format!(
46            "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
47            table_name,
48            assignments,
49            primary_key,
50            columns.len() + 1,
51            select_projection,
52        ),
53    }
54}
55
56#[derive(Debug, Clone)]
57pub struct CreateRecord<'a, M: 'static, PK: 'static, I> {
58    pub(crate) runtime: &'a SqlxRuntime,
59    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
60    pub(crate) input: I,
61}
62
63impl<'a, M: 'static, PK: 'static, I> CreateRecord<'a, M, PK, I>
64where
65    I: CreateModelInput<M>,
66{
67    pub fn preview_sql(&self) -> String {
68        let values = self.input.sql_values();
69        let placeholders = (1..=values.len())
70            .map(|index| format!("${index}"))
71            .collect::<Vec<_>>()
72            .join(", ");
73        let columns = values
74            .iter()
75            .map(|value| value.column)
76            .collect::<Vec<_>>()
77            .join(", ");
78
79        format!(
80            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
81            self.descriptor.table_name,
82            columns,
83            placeholders,
84            self.descriptor.select_projection(),
85        )
86    }
87
88    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
89    where
90        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
91    {
92        let emits_event = self.descriptor.emits(ModelEventKind::Created);
93        let audit_enabled = self.descriptor.audit_enabled;
94        let needs_tx = emits_event || audit_enabled;
95        let record = if needs_tx {
96            let mut tx = self
97                .runtime
98                .pool()
99                .begin()
100                .await
101                .map_err(|error| CoolError::Database(error.to_string()))?;
102            if emits_event {
103                ensure_event_outbox_table(&mut *tx).await?;
104            }
105            if audit_enabled {
106                ensure_audit_table(self.runtime.pool()).await?;
107            }
108            let record = create_record_with_executor(
109                &mut *tx,
110                self.runtime.pool(),
111                self.descriptor,
112                self.input,
113                ctx,
114            )
115            .await?;
116            if emits_event {
117                enqueue_event_outbox(
118                    &mut *tx,
119                    self.descriptor.schema_name,
120                    ModelEventKind::Created,
121                    &record,
122                )
123                .await?;
124            }
125            if audit_enabled {
126                let after = serde_json::to_value(&record).ok();
127                let event =
128                    build_audit_event(self.descriptor, AuditOperation::Create, None, after, ctx);
129                enqueue_audit_event(&mut *tx, &event).await?;
130            }
131            tx.commit()
132                .await
133                .map_err(|error| CoolError::Database(error.to_string()))?;
134            record
135        } else {
136            create_record_with_executor(
137                self.runtime.pool(),
138                self.runtime.pool(),
139                self.descriptor,
140                self.input,
141                ctx,
142            )
143            .await?
144        };
145
146        if emits_event {
147            let _ = self.runtime.drain_event_outbox().await;
148        }
149
150        Ok(record)
151    }
152}
153
154#[derive(Debug, Clone)]
155pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
156    pub(crate) runtime: &'a SqlxRuntime,
157    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
158    pub(crate) id: PK,
159}
160
161impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
162    pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
163        UpdateRecordSet {
164            runtime: self.runtime,
165            descriptor: self.descriptor,
166            id: self.id,
167            input,
168            if_match: None,
169        }
170    }
171}
172
173#[derive(Debug, Clone)]
174pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
175    pub(crate) runtime: &'a SqlxRuntime,
176    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
177    pub(crate) id: PK,
178    pub(crate) input: I,
179    pub(crate) if_match: Option<i64>,
180}
181
182impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
183where
184    I: UpdateModelInput<M>,
185{
186    /// Attach an expected version for optimistic locking. The update will only
187    /// succeed if the row's current `@version` field matches `expected`.
188    /// Required on models that declare `@version`; ignored otherwise.
189    pub fn if_match(mut self, expected: i64) -> Self {
190        self.if_match = Some(expected);
191        self
192    }
193
194    pub fn preview_sql(&self) -> String {
195        let values = self.input.sql_values();
196        let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
197        render_update_preview_sql(
198            self.descriptor.table_name,
199            self.descriptor.primary_key,
200            self.descriptor.version_column,
201            &columns,
202            &self.descriptor.select_projection(),
203        )
204    }
205
206    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
207    where
208        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
209        PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
210    {
211        if self.descriptor.version_column.is_some() && self.if_match.is_none() {
212            return Err(CoolError::PreconditionFailed(
213                "If-Match header required for versioned model".to_owned(),
214            ));
215        }
216        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
217        let audit_enabled = self.descriptor.audit_enabled;
218        let needs_tx = emits_event || audit_enabled;
219        let record = if needs_tx {
220            let mut tx = self
221                .runtime
222                .pool()
223                .begin()
224                .await
225                .map_err(|error| CoolError::Database(error.to_string()))?;
226            if emits_event {
227                ensure_event_outbox_table(&mut *tx).await?;
228            }
229            if audit_enabled {
230                ensure_audit_table(self.runtime.pool()).await?;
231            }
232            // Capture the BEFORE snapshot under a row-level lock so concurrent
233            // mutations can't race the audit.
234            let before_record = if audit_enabled {
235                fetch_for_audit(&mut *tx, self.descriptor, self.id.clone()).await?
236            } else {
237                None
238            };
239            let before_snapshot = before_record
240                .as_ref()
241                .and_then(|m| serde_json::to_value(m).ok());
242            let record = update_record_with_executor(
243                &mut *tx,
244                self.runtime.pool(),
245                self.descriptor,
246                self.id,
247                self.input,
248                ctx,
249                self.if_match,
250            )
251            .await?;
252            if emits_event {
253                enqueue_event_outbox(
254                    &mut *tx,
255                    self.descriptor.schema_name,
256                    ModelEventKind::Updated,
257                    &record,
258                )
259                .await?;
260            }
261            if audit_enabled {
262                let after = serde_json::to_value(&record).ok();
263                let event = build_audit_event(
264                    self.descriptor,
265                    AuditOperation::Update,
266                    before_snapshot,
267                    after,
268                    ctx,
269                );
270                enqueue_audit_event(&mut *tx, &event).await?;
271            }
272            tx.commit()
273                .await
274                .map_err(|error| CoolError::Database(error.to_string()))?;
275            record
276        } else {
277            update_record_with_executor(
278                self.runtime.pool(),
279                self.runtime.pool(),
280                self.descriptor,
281                self.id,
282                self.input,
283                ctx,
284                self.if_match,
285            )
286            .await?
287        };
288
289        if emits_event {
290            let _ = self.runtime.drain_event_outbox().await;
291        }
292
293        Ok(record)
294    }
295}
296
297#[derive(Debug, Clone)]
298pub struct DeleteRecord<'a, M: 'static, PK: 'static> {
299    pub(crate) runtime: &'a SqlxRuntime,
300    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
301    pub(crate) id: PK,
302}
303
304impl<'a, M: 'static, PK: 'static> DeleteRecord<'a, M, PK> {
305    pub fn preview_sql(&self) -> String {
306        format!(
307            "DELETE FROM {} WHERE {} = $1 RETURNING {}",
308            self.descriptor.table_name,
309            self.descriptor.primary_key,
310            self.descriptor.select_projection(),
311        )
312    }
313
314    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
315    where
316        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
317        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
318    {
319        let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
320        let audit_enabled = self.descriptor.audit_enabled;
321        let needs_tx = emits_event || audit_enabled;
322        let record = if needs_tx {
323            let mut tx = self
324                .runtime
325                .pool()
326                .begin()
327                .await
328                .map_err(|error| CoolError::Database(error.to_string()))?;
329            if emits_event {
330                ensure_event_outbox_table(&mut *tx).await?;
331            }
332            if audit_enabled {
333                ensure_audit_table(self.runtime.pool()).await?;
334            }
335
336            let record = delete_returning_record(&mut *tx, self.descriptor, self.id, ctx).await?;
337            if emits_event {
338                enqueue_event_outbox(
339                    &mut *tx,
340                    self.descriptor.schema_name,
341                    ModelEventKind::Deleted,
342                    &record,
343                )
344                .await?;
345            }
346            if audit_enabled {
347                // DELETE ... RETURNING yields the row's pre-delete state, so
348                // it doubles as the audit `before` snapshot.
349                let before = serde_json::to_value(&record).ok();
350                let event =
351                    build_audit_event(self.descriptor, AuditOperation::Delete, before, None, ctx);
352                enqueue_audit_event(&mut *tx, &event).await?;
353            }
354            tx.commit()
355                .await
356                .map_err(|error| CoolError::Database(error.to_string()))?;
357            record
358        } else {
359            delete_returning_record(self.runtime.pool(), self.descriptor, self.id, ctx).await?
360        };
361
362        if emits_event {
363            let _ = self.runtime.drain_event_outbox().await;
364        }
365
366        Ok(record)
367    }
368}
369
370pub async fn create_record_with_executor<'e, E, M, PK, I>(
371    executor: E,
372    policy_pool: &sqlx::PgPool,
373    descriptor: &'static ModelDescriptor<M, PK>,
374    input: I,
375    ctx: &CoolContext,
376) -> Result<M, CoolError>
377where
378    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
379    I: CreateModelInput<M>,
380    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
381{
382    input.validate()?;
383    let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
384    // Seed the optimistic-lock column server-side. `@version` is excluded
385    // from the generated Create input so clients can't pick the initial
386    // value, and the column has no SQL `DEFAULT`. If we didn't write it
387    // here, the INSERT would either skip the column (only valid when the
388    // DB-level default is set, which we don't require) or fail under
389    // `NOT NULL`. Done after `apply_create_defaults` so @default-driven
390    // overrides still win if a schema ever lands one.
391    if let Some(version_col) = descriptor.version_column
392        && find_column_value(&values, version_col).is_none()
393    {
394        values.push(crate::SqlColumnValue {
395            column: version_col,
396            value: crate::SqlValue::Int(0),
397        });
398    }
399    if values.is_empty() {
400        return Err(CoolError::Validation(
401            "create input must contain at least one column".to_owned(),
402        ));
403    }
404    if !evaluate_create_policies(
405        policy_pool,
406        descriptor.create_allow_policies,
407        descriptor.create_deny_policies,
408        &values,
409        ctx,
410    )
411    .await?
412    {
413        return Err(CoolError::Forbidden(
414            "create policy denied this operation".to_owned(),
415        ));
416    }
417
418    insert_returning_record(executor, descriptor, &values).await
419}
420
421pub async fn update_record_with_executor<'e, E, M, PK, I>(
422    executor: E,
423    policy_pool: &sqlx::PgPool,
424    descriptor: &'static ModelDescriptor<M, PK>,
425    id: PK,
426    input: I,
427    ctx: &CoolContext,
428    if_match: Option<i64>,
429) -> Result<M, CoolError>
430where
431    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
432    I: UpdateModelInput<M>,
433    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
434    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
435{
436    input.validate()?;
437    let values = input.sql_values();
438    if values.is_empty() {
439        return Err(CoolError::Validation(
440            "update input must contain at least one changed column".to_owned(),
441        ));
442    }
443
444    update_returning_record(
445        executor,
446        policy_pool,
447        descriptor,
448        id,
449        &values,
450        ctx,
451        if_match,
452    )
453    .await
454}
455
456async fn insert_returning_record<'e, E, M, PK>(
457    executor: E,
458    descriptor: &'static ModelDescriptor<M, PK>,
459    values: &[crate::SqlColumnValue],
460) -> Result<M, CoolError>
461where
462    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
463    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
464{
465    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
466    query.push(descriptor.table_name).push(" (");
467    for (index, value) in values.iter().enumerate() {
468        if index > 0 {
469            query.push(", ");
470        }
471        query.push(value.column);
472    }
473    query.push(") VALUES (");
474    for (index, value) in values.iter().enumerate() {
475        if index > 0 {
476            query.push(", ");
477        }
478        push_bind_value(&mut query, &value.value);
479    }
480    query
481        .push(") RETURNING ")
482        .push(descriptor.select_projection());
483
484    query
485        .build_query_as::<M>()
486        .fetch_one(executor)
487        .await
488        .map_err(|error| CoolError::Database(error.to_string()))
489}
490
491async fn update_returning_record<'e, E, M, PK>(
492    executor: E,
493    policy_pool: &sqlx::PgPool,
494    descriptor: &'static ModelDescriptor<M, PK>,
495    id: PK,
496    values: &[crate::SqlColumnValue],
497    ctx: &CoolContext,
498    if_match: Option<i64>,
499) -> Result<M, CoolError>
500where
501    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
502    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
503    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
504{
505    let version_column = descriptor.version_column;
506    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
507    query.push(descriptor.table_name).push(" SET ");
508    for (index, value) in values.iter().enumerate() {
509        if index > 0 {
510            query.push(", ");
511        }
512        query.push(value.column).push(" = ");
513        push_bind_value(&mut query, &value.value);
514    }
515    if let Some(version_col) = version_column {
516        query
517            .push(", ")
518            .push(version_col)
519            .push(" = ")
520            .push(version_col)
521            .push(" + 1");
522    }
523    query
524        .push(" WHERE ")
525        .push(descriptor.primary_key)
526        .push(" = ");
527    let id_for_probe = id.clone();
528    query.push_bind(id);
529    if let (Some(version_col), Some(expected)) = (version_column, if_match) {
530        query.push(" AND ").push(version_col).push(" = ");
531        query.push_bind(expected);
532    }
533    query.push(" AND ");
534    push_action_policy_query(
535        &mut query,
536        descriptor.update_allow_policies,
537        descriptor.update_deny_policies,
538        ctx,
539    );
540    query
541        .push(" RETURNING ")
542        .push(descriptor.select_projection());
543
544    let outcome = query
545        .build_query_as::<M>()
546        .fetch_optional(executor)
547        .await
548        .map_err(|error| CoolError::Database(error.to_string()))?;
549    match outcome {
550        Some(record) => Ok(record),
551        None => {
552            // No row matched. If this is a versioned update we want to
553            // distinguish "stale version" from a true policy denial. The
554            // probe applies the read policy: if the caller cannot see the
555            // row, we keep returning Forbidden so policy denials remain
556            // indistinguishable from missing rows.
557            if let (Some(version_col), Some(expected)) = (version_column, if_match) {
558                if let Some(current) =
559                    probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
560                        .await?
561                {
562                    if current != expected {
563                        return Err(CoolError::PreconditionFailed(format!(
564                            "version mismatch: expected {expected}, found {current}",
565                        )));
566                    }
567                }
568            }
569            Err(CoolError::Forbidden(
570                "update policy denied this operation".to_owned(),
571            ))
572        }
573    }
574}
575
576/// Read the current version of a row using the read policy. Returns `None` if
577/// the caller cannot see the row (so the outer code preserves the existing
578/// Forbidden-on-no-row semantics — readers can't tell a denied row from a
579/// missing one).
580async fn probe_current_version<M, PK>(
581    policy_pool: &sqlx::PgPool,
582    descriptor: &'static ModelDescriptor<M, PK>,
583    id: PK,
584    version_col: &'static str,
585    ctx: &CoolContext,
586) -> Result<Option<i64>, CoolError>
587where
588    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
589{
590    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
591    query.push(version_col);
592    query.push(" FROM ").push(descriptor.table_name);
593    query
594        .push(" WHERE ")
595        .push(descriptor.primary_key)
596        .push(" = ");
597    query.push_bind(id);
598    query.push(" AND ");
599    push_action_policy_query(
600        &mut query,
601        descriptor.read_allow_policies,
602        descriptor.read_deny_policies,
603        ctx,
604    );
605
606    let row: Option<(i64,)> = query
607        .build_query_as::<(i64,)>()
608        .fetch_optional(policy_pool)
609        .await
610        .map_err(|error| CoolError::Database(error.to_string()))?;
611    Ok(row.map(|(v,)| v))
612}
613
614async fn delete_returning_record<'e, E, M, PK>(
615    executor: E,
616    descriptor: &'static ModelDescriptor<M, PK>,
617    id: PK,
618    ctx: &CoolContext,
619) -> Result<M, CoolError>
620where
621    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
622    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
623    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
624{
625    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
626    match descriptor.soft_delete_column {
627        Some(col) => {
628            // Soft-delete: tombstone the row and bump version (if any) so
629            // optimistic-lock semantics on subsequent updates stay coherent.
630            query.push("UPDATE ").push(descriptor.table_name);
631            query.push(" SET ").push(col).push(" = NOW()");
632            if let Some(version_col) = descriptor.version_column {
633                query
634                    .push(", ")
635                    .push(version_col)
636                    .push(" = ")
637                    .push(version_col)
638                    .push(" + 1");
639            }
640            query.push(" WHERE ");
641            query.push(col).push(" IS NULL AND ");
642            query.push(descriptor.primary_key).push(" = ");
643            query.push_bind(id);
644        }
645        None => {
646            query.push("DELETE FROM ").push(descriptor.table_name);
647            query.push(" WHERE ");
648            query.push(descriptor.primary_key).push(" = ");
649            query.push_bind(id);
650        }
651    }
652    query.push(" AND ");
653    push_action_policy_query(
654        &mut query,
655        descriptor.delete_allow_policies,
656        descriptor.delete_deny_policies,
657        ctx,
658    );
659    query
660        .push(" RETURNING ")
661        .push(descriptor.select_projection());
662
663    query
664        .build_query_as::<M>()
665        .fetch_optional(executor)
666        .await
667        .map_err(|error| CoolError::Database(error.to_string()))?
668        .ok_or_else(|| CoolError::Forbidden("delete policy denied this operation".to_owned()))
669}
670
671// ───── Upsert ──────────────────────────────────────────────────────────────
672//
673// `INSERT … ON CONFLICT (<pk>) DO UPDATE …`, but with the create/update
674// distinction made *before* the SQL runs (via a `SELECT … FOR UPDATE` probe
675// inside the same transaction) so we can:
676//
677//   * pick the right policy slot (both must allow at call time — see [docs])
678//   * emit the correct ModelEventKind (Created vs Updated)
679//   * capture an audit `before` snapshot only on the update branch
680//
681// The upsert is always transactional, regardless of whether the model emits
682// events or has `@@audit`. That's a deliberate cost: one extra round-trip
683// for the SELECT, in exchange for clean event/audit semantics. Upsert is
684// not a hot read path — callers who need raw insert/update throughput
685// should use `.create()` / `.update()` directly.
686
687#[derive(Debug, Clone)]
688pub struct UpsertRecord<'a, M: 'static, PK: 'static, I> {
689    pub(crate) runtime: &'a SqlxRuntime,
690    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
691    pub(crate) input: I,
692}
693
694impl<'a, M: 'static, PK: 'static, I> UpsertRecord<'a, M, PK, I>
695where
696    I: UpsertModelInput<M>,
697{
698    /// Render an approximate SQL preview. The actual upsert wraps a
699    /// `SELECT … FOR UPDATE` around the `INSERT … ON CONFLICT`, but this
700    /// preview returns only the conflict-bearing statement — sufficient
701    /// for migration tooling and the schema studio.
702    pub fn preview_sql(&self) -> String {
703        let values = self.input.sql_values();
704        let placeholders = (1..=values.len())
705            .map(|index| format!("${index}"))
706            .collect::<Vec<_>>()
707            .join(", ");
708        let columns = values
709            .iter()
710            .map(|value| value.column)
711            .collect::<Vec<_>>()
712            .join(", ");
713        let update_assignments = self
714            .descriptor
715            .upsert_update_columns
716            .iter()
717            .map(|column| format!("{column} = EXCLUDED.{column}"))
718            .collect::<Vec<_>>()
719            .join(", ");
720        let version_bump = match self.descriptor.version_column {
721            Some(col) => format!(", {col} = {table}.{col} + 1", table = self.descriptor.table_name, col = col),
722            None => String::new(),
723        };
724
725        format!(
726            "INSERT INTO {table} ({columns}) VALUES ({placeholders}) \
727             ON CONFLICT ({pk}) DO UPDATE SET {update_assignments}{version_bump} \
728             RETURNING {projection}",
729            table = self.descriptor.table_name,
730            pk = self.descriptor.primary_key,
731            projection = self.descriptor.select_projection(),
732        )
733    }
734
735    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
736    where
737        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
738        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
739    {
740        self.input.validate()?;
741
742        // Compose the full insert value set, including auth-derived defaults
743        // and the seeded `@version` column. Mirrors `create_record_with_executor`
744        // so insert-branch semantics stay identical to `.create()`.
745        let mut insert_values =
746            apply_create_defaults(self.input.sql_values(), self.descriptor.create_defaults, ctx)?;
747        if let Some(version_col) = self.descriptor.version_column
748            && find_column_value(&insert_values, version_col).is_none()
749        {
750            insert_values.push(crate::SqlColumnValue {
751                column: version_col,
752                value: crate::SqlValue::Int(0),
753            });
754        }
755        if insert_values.is_empty() {
756            return Err(CoolError::Validation(
757                "upsert input must contain at least one column".to_owned(),
758            ));
759        }
760
761        // Both create and update policies must allow the call. Stricter than
762        // "evaluate the path that runs," but it's the only choice we can make
763        // before knowing which branch will fire — pre-flighting a read just to
764        // pick the policy slot would leak row existence to the caller.
765        if !evaluate_create_policies(
766            self.runtime.pool(),
767            self.descriptor.create_allow_policies,
768            self.descriptor.create_deny_policies,
769            &insert_values,
770            ctx,
771        )
772        .await?
773        {
774            return Err(CoolError::Forbidden(
775                "create policy denied this upsert".to_owned(),
776            ));
777        }
778
779        let pk_value = self.input.primary_key_value();
780        let emits_created = self.descriptor.emits(ModelEventKind::Created);
781        let emits_updated = self.descriptor.emits(ModelEventKind::Updated);
782        let audit_enabled = self.descriptor.audit_enabled;
783
784        let mut tx = self
785            .runtime
786            .pool()
787            .begin()
788            .await
789            .map_err(|error| CoolError::Database(error.to_string()))?;
790
791        if emits_created || emits_updated {
792            ensure_event_outbox_table(&mut *tx).await?;
793        }
794        if audit_enabled {
795            ensure_audit_table(self.runtime.pool()).await?;
796        }
797
798        // Probe the conflict target under a row-level lock. If a row exists,
799        // this is the update branch and we capture the before-snapshot for
800        // audit; otherwise it's the insert branch. The lock serializes
801        // concurrent upserts on the same key, which is what callers expect.
802        let before_record =
803            select_for_update_by_pk_value(&mut *tx, self.descriptor, &pk_value).await?;
804        let inserted = before_record.is_none();
805
806        // For the update branch we additionally have to enforce the *update*
807        // policy. The insert branch already passed `create` above; for the
808        // update branch we evaluate the update policy against the live row
809        // (using its current column values, not the input — that's how
810        // ordinary `.update()` works) by re-running the policy SQL.
811        if !inserted
812            && !row_passes_update_policy(
813                self.runtime.pool(),
814                self.descriptor,
815                &pk_value,
816                ctx,
817            )
818            .await?
819        {
820            return Err(CoolError::Forbidden(
821                "update policy denied this upsert".to_owned(),
822            ));
823        }
824
825        let before_snapshot = if !inserted && audit_enabled {
826            before_record
827                .as_ref()
828                .and_then(|m| serde_json::to_value(m).ok())
829        } else {
830            None
831        };
832
833        let record = upsert_returning_record(&mut *tx, self.descriptor, &insert_values).await?;
834
835        // Event + audit fan-out, driven off whether the SELECT-FOR-UPDATE
836        // saw a row. We don't lean on `xmax = 0`: keeping the discriminator
837        // in the runtime (not the SQL) makes the rusqlite mirror trivial.
838        let event_kind = if inserted {
839            ModelEventKind::Created
840        } else {
841            ModelEventKind::Updated
842        };
843        let audit_op = if inserted {
844            AuditOperation::Create
845        } else {
846            AuditOperation::Update
847        };
848        let emits_event = if inserted { emits_created } else { emits_updated };
849
850        if emits_event {
851            enqueue_event_outbox(
852                &mut *tx,
853                self.descriptor.schema_name,
854                event_kind,
855                &record,
856            )
857            .await?;
858        }
859        if audit_enabled {
860            let after = serde_json::to_value(&record).ok();
861            let event = build_audit_event(self.descriptor, audit_op, before_snapshot, after, ctx);
862            enqueue_audit_event(&mut *tx, &event).await?;
863        }
864
865        tx.commit()
866            .await
867            .map_err(|error| CoolError::Database(error.to_string()))?;
868
869        if emits_event {
870            let _ = self.runtime.drain_event_outbox().await;
871        }
872
873        Ok(record)
874    }
875}
876
877/// Probe-with-lock: `SELECT projection FROM <table> WHERE <pk> = $1 FOR UPDATE`.
878/// Bypasses read policies — we need the raw row to drive insert/update
879/// branching and to capture the audit before-snapshot. Returns `None` when
880/// no row exists (the insert branch).
881async fn select_for_update_by_pk_value<'e, E, M, PK>(
882    executor: E,
883    descriptor: &'static ModelDescriptor<M, PK>,
884    pk_value: &SqlValue,
885) -> Result<Option<M>, CoolError>
886where
887    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
888    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
889{
890    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
891    query.push(descriptor.select_projection());
892    query.push(" FROM ").push(descriptor.table_name);
893    query
894        .push(" WHERE ")
895        .push(descriptor.primary_key)
896        .push(" = ");
897    push_bind_value(&mut query, pk_value);
898    // Soft-deleted rows act as "no row" for upsert purposes: the INSERT
899    // branch will then fail on the PK uniqueness constraint, which is the
900    // right outcome (refuse to silently revive a tombstone).
901    if let Some(col) = descriptor.soft_delete_column {
902        query.push(" AND ").push(col).push(" IS NULL");
903    }
904    query.push(" FOR UPDATE");
905
906    query
907        .build_query_as::<M>()
908        .fetch_optional(executor)
909        .await
910        .map_err(|error| CoolError::Database(error.to_string()))
911}
912
913/// Re-evaluate the update policy against an existing row, using the read
914/// pool so the policy predicates can resolve auth/tenancy. Returns `false`
915/// when the policy denies (or when the row is not visible to the caller,
916/// which we treat as denial — same semantics as ordinary `.update()`).
917async fn row_passes_update_policy<M, PK>(
918    policy_pool: &sqlx::PgPool,
919    descriptor: &'static ModelDescriptor<M, PK>,
920    pk_value: &SqlValue,
921    ctx: &CoolContext,
922) -> Result<bool, CoolError> {
923    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT 1 FROM ");
924    query.push(descriptor.table_name);
925    query
926        .push(" WHERE ")
927        .push(descriptor.primary_key)
928        .push(" = ");
929    push_bind_value(&mut query, pk_value);
930    query.push(" AND ");
931    push_action_policy_query(
932        &mut query,
933        descriptor.update_allow_policies,
934        descriptor.update_deny_policies,
935        ctx,
936    );
937
938    let row: Option<(i32,)> = query
939        .build_query_as::<(i32,)>()
940        .fetch_optional(policy_pool)
941        .await
942        .map_err(|error| CoolError::Database(error.to_string()))?;
943    Ok(row.is_some())
944}
945
946/// Render and execute the conflict-bearing INSERT. The DO UPDATE clause
947/// references only columns in `descriptor.upsert_update_columns` — PK,
948/// `@version`, `@readonly`, `@server_only`, and `@default(...)` columns are
949/// excluded by construction (see `generate_model_descriptor`).
950async fn upsert_returning_record<'e, E, M, PK>(
951    executor: E,
952    descriptor: &'static ModelDescriptor<M, PK>,
953    insert_values: &[crate::SqlColumnValue],
954) -> Result<M, CoolError>
955where
956    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
957    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
958{
959    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
960    query.push(descriptor.table_name).push(" (");
961    for (index, value) in insert_values.iter().enumerate() {
962        if index > 0 {
963            query.push(", ");
964        }
965        query.push(value.column);
966    }
967    query.push(") VALUES (");
968    for (index, value) in insert_values.iter().enumerate() {
969        if index > 0 {
970            query.push(", ");
971        }
972        push_bind_value(&mut query, &value.value);
973    }
974    query.push(") ON CONFLICT (").push(descriptor.primary_key).push(") DO UPDATE SET ");
975
976    // The DO UPDATE list. If there are no eligible columns to overwrite,
977    // fall back to "DO NOTHING"-equivalent semantics via a no-op assignment:
978    // touching the PK to itself. This keeps the RETURNING clause working
979    // (PG only RETURNs from rows the statement touched), which matters for
980    // round-trips that always want the current row back.
981    if descriptor.upsert_update_columns.is_empty() {
982        query.push(descriptor.primary_key);
983        query.push(" = EXCLUDED.").push(descriptor.primary_key);
984    } else {
985        for (index, column) in descriptor.upsert_update_columns.iter().enumerate() {
986            if index > 0 {
987                query.push(", ");
988            }
989            query.push(*column).push(" = EXCLUDED.").push(*column);
990        }
991    }
992    if let Some(version_col) = descriptor.version_column {
993        query
994            .push(", ")
995            .push(version_col)
996            .push(" = ")
997            .push(descriptor.table_name)
998            .push(".")
999            .push(version_col)
1000            .push(" + 1");
1001    }
1002
1003    query
1004        .push(" RETURNING ")
1005        .push(descriptor.select_projection());
1006
1007    query
1008        .build_query_as::<M>()
1009        .fetch_one(executor)
1010        .await
1011        .map_err(|error| CoolError::Database(error.to_string()))
1012}