use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
use crate::audit::{build_audit_event, enqueue_audit_event, ensure_audit_table};
use crate::descriptor::{enqueue_event_outbox, ensure_event_outbox_table};
use crate::query::support::{apply_create_defaults, evaluate_create_policies, find_column_value};
use crate::{ConflictTarget, ModelDescriptor, SqlValue, UpsertModelInput, sqlx};
use super::upsert_sql::{
row_passes_update_policy, select_for_update_by_conflict_target, upsert_returning_record,
};
pub(super) async fn run_upsert_in_tx<'tx, M, PK, I>(
tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
policy_pool: &sqlx::PgPool,
descriptor: &'static ModelDescriptor<M, PK>,
input: I,
conflict_target: ConflictTarget,
ctx: &CoolContext,
) -> Result<(M, bool), CoolError>
where
I: UpsertModelInput<M>,
for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
{
input.validate()?;
let mut insert_values =
apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
if let Some(version_col) = descriptor.version_column
&& find_column_value(&insert_values, version_col).is_none()
{
insert_values.push(crate::SqlColumnValue {
column: version_col,
value: crate::SqlValue::Int(0),
});
}
if insert_values.is_empty() {
return Err(CoolError::Validation(
"upsert input must contain at least one column".to_owned(),
));
}
let pk_value = input.primary_key_value();
let conflict_columns: Vec<(&'static str, SqlValue)> = match conflict_target {
ConflictTarget::PrimaryKey => vec![(descriptor.primary_key, pk_value)],
ConflictTarget::Columns(cols) => {
let mut out = Vec::with_capacity(cols.len());
for col in cols {
let value = find_column_value(&insert_values, col).cloned().ok_or_else(|| {
CoolError::Validation(format!(
"upsert on_conflict references column `{col}` which is not present in the input",
))
})?;
out.push((*col, value));
}
out
}
};
if !evaluate_create_policies(
policy_pool,
descriptor.create_allow_policies,
descriptor.create_deny_policies,
&insert_values,
ctx,
)
.await?
{
return Err(CoolError::Forbidden(
"create policy denied this upsert".to_owned(),
));
}
let emits_created = descriptor.emits(ModelEventKind::Created);
let emits_updated = descriptor.emits(ModelEventKind::Updated);
let audit_enabled = descriptor.audit_enabled;
if emits_created || emits_updated {
ensure_event_outbox_table(&mut **tx).await?;
}
if audit_enabled {
ensure_audit_table(policy_pool).await?;
}
let before_record =
select_for_update_by_conflict_target(&mut **tx, descriptor, &conflict_columns).await?;
let inserted = before_record.is_none();
if !inserted
&& !row_passes_update_policy(policy_pool, descriptor, &conflict_columns, ctx).await?
{
return Err(CoolError::Forbidden(
"update policy denied this upsert".to_owned(),
));
}
let before_snapshot = if !inserted && audit_enabled {
before_record
.as_ref()
.and_then(|m| serde_json::to_value(m).ok())
} else {
None
};
let record =
upsert_returning_record(&mut **tx, descriptor, &insert_values, conflict_target).await?;
let event_kind = if inserted {
ModelEventKind::Created
} else {
ModelEventKind::Updated
};
let audit_op = if inserted {
AuditOperation::Create
} else {
AuditOperation::Update
};
let emits_event = if inserted {
emits_created
} else {
emits_updated
};
if emits_event {
enqueue_event_outbox(&mut **tx, descriptor.schema_name, event_kind, &record).await?;
}
if audit_enabled {
let after = serde_json::to_value(&record).ok();
let event = build_audit_event(descriptor, audit_op, before_snapshot, after, ctx);
enqueue_audit_event(&mut **tx, &event).await?;
}
Ok((record, emits_event))
}