use cratestack_core::{CoolContext, CoolError};
use crate::{ConflictTarget, ModelDescriptor, SqlxRuntime, UpsertModelInput, sqlx};
use super::upsert_exec::run_upsert_in_tx;
#[derive(Debug, Clone)]
pub struct UpsertRecord<'a, M: 'static, PK: 'static, I> {
pub(crate) runtime: &'a SqlxRuntime,
pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
pub(crate) input: I,
pub(crate) conflict_target: ConflictTarget,
}
impl<'a, M: 'static, PK: 'static, I> UpsertRecord<'a, M, PK, I>
where
I: UpsertModelInput<M>,
{
pub fn on_conflict(mut self, target: ConflictTarget) -> Self {
self.conflict_target = target;
self
}
pub fn preview_sql(&self) -> String {
let values = self.input.sql_values();
let placeholders = (1..=values.len())
.map(|index| format!("${index}"))
.collect::<Vec<_>>()
.join(", ");
let columns = values
.iter()
.map(|value| value.column)
.collect::<Vec<_>>()
.join(", ");
let update_assignments = self
.descriptor
.upsert_update_columns
.iter()
.map(|column| format!("{column} = EXCLUDED.{column}"))
.collect::<Vec<_>>()
.join(", ");
let version_bump = match self.descriptor.version_column {
Some(col) => format!(
", {col} = {table}.{col} + 1",
table = self.descriptor.table_name,
col = col
),
None => String::new(),
};
let conflict_tuple = match self.conflict_target {
ConflictTarget::PrimaryKey => self.descriptor.primary_key.to_owned(),
ConflictTarget::Columns(cols) => cols.join(", "),
};
format!(
"INSERT INTO {table} ({columns}) VALUES ({placeholders}) \
ON CONFLICT ({conflict_tuple}) DO UPDATE SET {update_assignments}{version_bump} \
RETURNING {projection}",
table = self.descriptor.table_name,
projection = self.descriptor.select_projection(),
)
}
pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
where
for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
{
let runtime = self.runtime;
let mut tx = runtime
.pool()
.begin()
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
let (record, emits_event) = run_upsert_in_tx(
&mut tx,
runtime.pool(),
self.descriptor,
self.input,
self.conflict_target,
ctx,
)
.await?;
tx.commit()
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
if emits_event {
let _ = runtime.drain_event_outbox().await;
}
Ok(record)
}
pub async fn run_in_tx<'tx>(
self,
tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
ctx: &CoolContext,
) -> Result<M, CoolError>
where
for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
{
let (record, _) = run_upsert_in_tx(
tx,
self.runtime.pool(),
self.descriptor,
self.input,
self.conflict_target,
ctx,
)
.await?;
Ok(record)
}
}