Skip to main content

cratestack_sqlx/query/write/
upsert.rs

1//! `INSERT … ON CONFLICT (<pk>) DO UPDATE …`, but with the
2//! create/update distinction made *before* the SQL runs (via a
3//! `SELECT … FOR UPDATE` probe inside the same transaction) so we can:
4//!
5//!   * pick the right policy slot (both must allow at call time)
6//!   * emit the correct ModelEventKind (Created vs Updated)
7//!   * capture an audit `before` snapshot only on the update branch
8//!
9//! The upsert is always transactional regardless of whether the model
10//! emits events or has `@@audit`. One extra round-trip for the
11//! SELECT, in exchange for clean event/audit semantics. Upsert is not
12//! a hot read path — callers who need raw insert/update throughput
13//! should use `.create()` / `.update()` directly.
14
15use cratestack_core::{CoolContext, CoolError};
16
17use crate::{ConflictTarget, ModelDescriptor, SqlxRuntime, UpsertModelInput, sqlx};
18
19use super::upsert_exec::run_upsert_in_tx;
20
21#[derive(Debug, Clone)]
22pub struct UpsertRecord<'a, M: 'static, PK: 'static, I> {
23    pub(crate) runtime: &'a SqlxRuntime,
24    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
25    pub(crate) input: I,
26    pub(crate) conflict_target: ConflictTarget,
27}
28
29impl<'a, M: 'static, PK: 'static, I> UpsertRecord<'a, M, PK, I>
30where
31    I: UpsertModelInput<M>,
32{
33    /// Choose the conflict target. Defaults to the model's primary
34    /// key; pass [`ConflictTarget::Columns`] to upsert on a composite
35    /// unique key instead. The named columns must form a `UNIQUE`
36    /// constraint/index on the target table.
37    pub fn on_conflict(mut self, target: ConflictTarget) -> Self {
38        self.conflict_target = target;
39        self
40    }
41
42    /// Render an approximate SQL preview. The actual upsert wraps a
43    /// `SELECT … FOR UPDATE` around the `INSERT … ON CONFLICT`, but
44    /// this preview returns only the conflict-bearing statement.
45    pub fn preview_sql(&self) -> String {
46        let values = self.input.sql_values();
47        let placeholders = (1..=values.len())
48            .map(|index| format!("${index}"))
49            .collect::<Vec<_>>()
50            .join(", ");
51        let columns = values
52            .iter()
53            .map(|value| value.column)
54            .collect::<Vec<_>>()
55            .join(", ");
56        let update_assignments = self
57            .descriptor
58            .upsert_update_columns
59            .iter()
60            .map(|column| format!("{column} = EXCLUDED.{column}"))
61            .collect::<Vec<_>>()
62            .join(", ");
63        let version_bump = match self.descriptor.version_column {
64            Some(col) => format!(
65                ", {col} = {table}.{col} + 1",
66                table = self.descriptor.table_name,
67                col = col
68            ),
69            None => String::new(),
70        };
71        let conflict_tuple = match self.conflict_target {
72            ConflictTarget::PrimaryKey => self.descriptor.primary_key.to_owned(),
73            ConflictTarget::Columns(cols) => cols.join(", "),
74        };
75
76        format!(
77            "INSERT INTO {table} ({columns}) VALUES ({placeholders}) \
78             ON CONFLICT ({conflict_tuple}) DO UPDATE SET {update_assignments}{version_bump} \
79             RETURNING {projection}",
80            table = self.descriptor.table_name,
81            projection = self.descriptor.select_projection(),
82        )
83    }
84
85    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
86    where
87        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
88        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
89    {
90        let runtime = self.runtime;
91        let mut tx = runtime
92            .pool()
93            .begin()
94            .await
95            .map_err(|error| CoolError::Database(error.to_string()))?;
96        let (record, emits_event) = run_upsert_in_tx(
97            &mut tx,
98            runtime.pool(),
99            self.descriptor,
100            self.input,
101            self.conflict_target,
102            ctx,
103        )
104        .await?;
105        tx.commit()
106            .await
107            .map_err(|error| CoolError::Database(error.to_string()))?;
108        if emits_event {
109            let _ = runtime.drain_event_outbox().await;
110        }
111        Ok(record)
112    }
113
114    /// Like [`Self::run`] but participates in a caller-supplied
115    /// transaction. The conflict probe runs against `tx`, so the row
116    /// lock is held until the caller commits. The event outbox is not
117    /// drained here.
118    pub async fn run_in_tx<'tx>(
119        self,
120        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
121        ctx: &CoolContext,
122    ) -> Result<M, CoolError>
123    where
124        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
125        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
126    {
127        let (record, _) = run_upsert_in_tx(
128            tx,
129            self.runtime.pool(),
130            self.descriptor,
131            self.input,
132            self.conflict_target,
133            ctx,
134        )
135        .await?;
136        Ok(record)
137    }
138}