cratestack_sqlx/query/write/
upsert.rs1use cratestack_core::{CoolContext, CoolError};
16
17use crate::{ConflictTarget, ModelDescriptor, SqlxRuntime, UpsertModelInput, sqlx};
18
19use super::upsert_exec::run_upsert_in_tx;
20
21#[derive(Debug, Clone)]
22pub struct UpsertRecord<'a, M: 'static, PK: 'static, I> {
23 pub(crate) runtime: &'a SqlxRuntime,
24 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
25 pub(crate) input: I,
26 pub(crate) conflict_target: ConflictTarget,
27}
28
29impl<'a, M: 'static, PK: 'static, I> UpsertRecord<'a, M, PK, I>
30where
31 I: UpsertModelInput<M>,
32{
33 pub fn on_conflict(mut self, target: ConflictTarget) -> Self {
38 self.conflict_target = target;
39 self
40 }
41
42 pub fn preview_sql(&self) -> String {
46 let values = self.input.sql_values();
47 let placeholders = (1..=values.len())
48 .map(|index| format!("${index}"))
49 .collect::<Vec<_>>()
50 .join(", ");
51 let columns = values
52 .iter()
53 .map(|value| value.column)
54 .collect::<Vec<_>>()
55 .join(", ");
56 let update_assignments = self
57 .descriptor
58 .upsert_update_columns
59 .iter()
60 .map(|column| format!("{column} = EXCLUDED.{column}"))
61 .collect::<Vec<_>>()
62 .join(", ");
63 let version_bump = match self.descriptor.version_column {
64 Some(col) => format!(
65 ", {col} = {table}.{col} + 1",
66 table = self.descriptor.table_name,
67 col = col
68 ),
69 None => String::new(),
70 };
71 let conflict_tuple = match self.conflict_target {
72 ConflictTarget::PrimaryKey => self.descriptor.primary_key.to_owned(),
73 ConflictTarget::Columns(cols) => cols.join(", "),
74 };
75
76 format!(
77 "INSERT INTO {table} ({columns}) VALUES ({placeholders}) \
78 ON CONFLICT ({conflict_tuple}) DO UPDATE SET {update_assignments}{version_bump} \
79 RETURNING {projection}",
80 table = self.descriptor.table_name,
81 projection = self.descriptor.select_projection(),
82 )
83 }
84
85 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
86 where
87 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
88 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
89 {
90 let runtime = self.runtime;
91 let mut tx = runtime
92 .pool()
93 .begin()
94 .await
95 .map_err(|error| CoolError::Database(error.to_string()))?;
96 let (record, emits_event) = run_upsert_in_tx(
97 &mut tx,
98 runtime.pool(),
99 self.descriptor,
100 self.input,
101 self.conflict_target,
102 ctx,
103 )
104 .await?;
105 tx.commit()
106 .await
107 .map_err(|error| CoolError::Database(error.to_string()))?;
108 if emits_event {
109 let _ = runtime.drain_event_outbox().await;
110 }
111 Ok(record)
112 }
113
114 pub async fn run_in_tx<'tx>(
119 self,
120 tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
121 ctx: &CoolContext,
122 ) -> Result<M, CoolError>
123 where
124 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
125 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
126 {
127 let (record, _) = run_upsert_in_tx(
128 tx,
129 self.runtime.pool(),
130 self.descriptor,
131 self.input,
132 self.conflict_target,
133 ctx,
134 )
135 .await?;
136 Ok(record)
137 }
138}