1use crate::sqlx;
2
3use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
4
5use crate::{
6 CreateModelInput, ModelDescriptor, SqlValue, SqlxRuntime, UpdateModelInput, UpsertModelInput,
7 audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
8 descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
9};
10
11use super::support::{
12 apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
13 push_bind_value,
14};
15
16pub fn render_update_preview_sql(
19 table_name: &str,
20 primary_key: &str,
21 version_column: Option<&str>,
22 columns: &[&str],
23 select_projection: &str,
24) -> String {
25 let assignments = columns
26 .iter()
27 .enumerate()
28 .map(|(index, column)| format!("{column} = ${}", index + 1))
29 .collect::<Vec<_>>()
30 .join(", ");
31
32 match version_column {
33 Some(version_col) => format!(
34 "UPDATE {} SET {}, {} = {} + 1 WHERE {} = ${} AND {} = ${} RETURNING {}",
35 table_name,
36 assignments,
37 version_col,
38 version_col,
39 primary_key,
40 columns.len() + 1,
41 version_col,
42 columns.len() + 2,
43 select_projection,
44 ),
45 None => format!(
46 "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
47 table_name,
48 assignments,
49 primary_key,
50 columns.len() + 1,
51 select_projection,
52 ),
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct CreateRecord<'a, M: 'static, PK: 'static, I> {
58 pub(crate) runtime: &'a SqlxRuntime,
59 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
60 pub(crate) input: I,
61}
62
63impl<'a, M: 'static, PK: 'static, I> CreateRecord<'a, M, PK, I>
64where
65 I: CreateModelInput<M>,
66{
67 pub fn preview_sql(&self) -> String {
68 let values = self.input.sql_values();
69 let placeholders = (1..=values.len())
70 .map(|index| format!("${index}"))
71 .collect::<Vec<_>>()
72 .join(", ");
73 let columns = values
74 .iter()
75 .map(|value| value.column)
76 .collect::<Vec<_>>()
77 .join(", ");
78
79 format!(
80 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
81 self.descriptor.table_name,
82 columns,
83 placeholders,
84 self.descriptor.select_projection(),
85 )
86 }
87
88 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
89 where
90 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
91 {
92 let emits_event = self.descriptor.emits(ModelEventKind::Created);
93 let audit_enabled = self.descriptor.audit_enabled;
94 let needs_tx = emits_event || audit_enabled;
95 let record = if needs_tx {
96 let mut tx = self
97 .runtime
98 .pool()
99 .begin()
100 .await
101 .map_err(|error| CoolError::Database(error.to_string()))?;
102 if emits_event {
103 ensure_event_outbox_table(&mut *tx).await?;
104 }
105 if audit_enabled {
106 ensure_audit_table(self.runtime.pool()).await?;
107 }
108 let record = create_record_with_executor(
109 &mut *tx,
110 self.runtime.pool(),
111 self.descriptor,
112 self.input,
113 ctx,
114 )
115 .await?;
116 if emits_event {
117 enqueue_event_outbox(
118 &mut *tx,
119 self.descriptor.schema_name,
120 ModelEventKind::Created,
121 &record,
122 )
123 .await?;
124 }
125 if audit_enabled {
126 let after = serde_json::to_value(&record).ok();
127 let event =
128 build_audit_event(self.descriptor, AuditOperation::Create, None, after, ctx);
129 enqueue_audit_event(&mut *tx, &event).await?;
130 }
131 tx.commit()
132 .await
133 .map_err(|error| CoolError::Database(error.to_string()))?;
134 record
135 } else {
136 create_record_with_executor(
137 self.runtime.pool(),
138 self.runtime.pool(),
139 self.descriptor,
140 self.input,
141 ctx,
142 )
143 .await?
144 };
145
146 if emits_event {
147 let _ = self.runtime.drain_event_outbox().await;
148 }
149
150 Ok(record)
151 }
152}
153
154#[derive(Debug, Clone)]
155pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
156 pub(crate) runtime: &'a SqlxRuntime,
157 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
158 pub(crate) id: PK,
159}
160
161impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
162 pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
163 UpdateRecordSet {
164 runtime: self.runtime,
165 descriptor: self.descriptor,
166 id: self.id,
167 input,
168 if_match: None,
169 }
170 }
171}
172
173#[derive(Debug, Clone)]
174pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
175 pub(crate) runtime: &'a SqlxRuntime,
176 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
177 pub(crate) id: PK,
178 pub(crate) input: I,
179 pub(crate) if_match: Option<i64>,
180}
181
182impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
183where
184 I: UpdateModelInput<M>,
185{
186 pub fn if_match(mut self, expected: i64) -> Self {
190 self.if_match = Some(expected);
191 self
192 }
193
194 pub fn preview_sql(&self) -> String {
195 let values = self.input.sql_values();
196 let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
197 render_update_preview_sql(
198 self.descriptor.table_name,
199 self.descriptor.primary_key,
200 self.descriptor.version_column,
201 &columns,
202 &self.descriptor.select_projection(),
203 )
204 }
205
206 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
207 where
208 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
209 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
210 {
211 if self.descriptor.version_column.is_some() && self.if_match.is_none() {
212 return Err(CoolError::PreconditionFailed(
213 "If-Match header required for versioned model".to_owned(),
214 ));
215 }
216 let emits_event = self.descriptor.emits(ModelEventKind::Updated);
217 let audit_enabled = self.descriptor.audit_enabled;
218 let needs_tx = emits_event || audit_enabled;
219 let record = if needs_tx {
220 let mut tx = self
221 .runtime
222 .pool()
223 .begin()
224 .await
225 .map_err(|error| CoolError::Database(error.to_string()))?;
226 if emits_event {
227 ensure_event_outbox_table(&mut *tx).await?;
228 }
229 if audit_enabled {
230 ensure_audit_table(self.runtime.pool()).await?;
231 }
232 let before_record = if audit_enabled {
235 fetch_for_audit(&mut *tx, self.descriptor, self.id.clone()).await?
236 } else {
237 None
238 };
239 let before_snapshot = before_record
240 .as_ref()
241 .and_then(|m| serde_json::to_value(m).ok());
242 let record = update_record_with_executor(
243 &mut *tx,
244 self.runtime.pool(),
245 self.descriptor,
246 self.id,
247 self.input,
248 ctx,
249 self.if_match,
250 )
251 .await?;
252 if emits_event {
253 enqueue_event_outbox(
254 &mut *tx,
255 self.descriptor.schema_name,
256 ModelEventKind::Updated,
257 &record,
258 )
259 .await?;
260 }
261 if audit_enabled {
262 let after = serde_json::to_value(&record).ok();
263 let event = build_audit_event(
264 self.descriptor,
265 AuditOperation::Update,
266 before_snapshot,
267 after,
268 ctx,
269 );
270 enqueue_audit_event(&mut *tx, &event).await?;
271 }
272 tx.commit()
273 .await
274 .map_err(|error| CoolError::Database(error.to_string()))?;
275 record
276 } else {
277 update_record_with_executor(
278 self.runtime.pool(),
279 self.runtime.pool(),
280 self.descriptor,
281 self.id,
282 self.input,
283 ctx,
284 self.if_match,
285 )
286 .await?
287 };
288
289 if emits_event {
290 let _ = self.runtime.drain_event_outbox().await;
291 }
292
293 Ok(record)
294 }
295}
296
297#[derive(Debug, Clone)]
298pub struct DeleteRecord<'a, M: 'static, PK: 'static> {
299 pub(crate) runtime: &'a SqlxRuntime,
300 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
301 pub(crate) id: PK,
302}
303
304impl<'a, M: 'static, PK: 'static> DeleteRecord<'a, M, PK> {
305 pub fn preview_sql(&self) -> String {
306 format!(
307 "DELETE FROM {} WHERE {} = $1 RETURNING {}",
308 self.descriptor.table_name,
309 self.descriptor.primary_key,
310 self.descriptor.select_projection(),
311 )
312 }
313
314 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
315 where
316 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
317 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
318 {
319 let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
320 let audit_enabled = self.descriptor.audit_enabled;
321 let needs_tx = emits_event || audit_enabled;
322 let record = if needs_tx {
323 let mut tx = self
324 .runtime
325 .pool()
326 .begin()
327 .await
328 .map_err(|error| CoolError::Database(error.to_string()))?;
329 if emits_event {
330 ensure_event_outbox_table(&mut *tx).await?;
331 }
332 if audit_enabled {
333 ensure_audit_table(self.runtime.pool()).await?;
334 }
335
336 let record = delete_returning_record(&mut *tx, self.descriptor, self.id, ctx).await?;
337 if emits_event {
338 enqueue_event_outbox(
339 &mut *tx,
340 self.descriptor.schema_name,
341 ModelEventKind::Deleted,
342 &record,
343 )
344 .await?;
345 }
346 if audit_enabled {
347 let before = serde_json::to_value(&record).ok();
350 let event =
351 build_audit_event(self.descriptor, AuditOperation::Delete, before, None, ctx);
352 enqueue_audit_event(&mut *tx, &event).await?;
353 }
354 tx.commit()
355 .await
356 .map_err(|error| CoolError::Database(error.to_string()))?;
357 record
358 } else {
359 delete_returning_record(self.runtime.pool(), self.descriptor, self.id, ctx).await?
360 };
361
362 if emits_event {
363 let _ = self.runtime.drain_event_outbox().await;
364 }
365
366 Ok(record)
367 }
368}
369
370pub async fn create_record_with_executor<'e, E, M, PK, I>(
371 executor: E,
372 policy_pool: &sqlx::PgPool,
373 descriptor: &'static ModelDescriptor<M, PK>,
374 input: I,
375 ctx: &CoolContext,
376) -> Result<M, CoolError>
377where
378 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
379 I: CreateModelInput<M>,
380 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
381{
382 input.validate()?;
383 let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
384 if let Some(version_col) = descriptor.version_column
392 && find_column_value(&values, version_col).is_none()
393 {
394 values.push(crate::SqlColumnValue {
395 column: version_col,
396 value: crate::SqlValue::Int(0),
397 });
398 }
399 if values.is_empty() {
400 return Err(CoolError::Validation(
401 "create input must contain at least one column".to_owned(),
402 ));
403 }
404 if !evaluate_create_policies(
405 policy_pool,
406 descriptor.create_allow_policies,
407 descriptor.create_deny_policies,
408 &values,
409 ctx,
410 )
411 .await?
412 {
413 return Err(CoolError::Forbidden(
414 "create policy denied this operation".to_owned(),
415 ));
416 }
417
418 insert_returning_record(executor, descriptor, &values).await
419}
420
421pub async fn update_record_with_executor<'e, E, M, PK, I>(
422 executor: E,
423 policy_pool: &sqlx::PgPool,
424 descriptor: &'static ModelDescriptor<M, PK>,
425 id: PK,
426 input: I,
427 ctx: &CoolContext,
428 if_match: Option<i64>,
429) -> Result<M, CoolError>
430where
431 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
432 I: UpdateModelInput<M>,
433 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
434 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
435{
436 input.validate()?;
437 let values = input.sql_values();
438 if values.is_empty() {
439 return Err(CoolError::Validation(
440 "update input must contain at least one changed column".to_owned(),
441 ));
442 }
443
444 update_returning_record(
445 executor,
446 policy_pool,
447 descriptor,
448 id,
449 &values,
450 ctx,
451 if_match,
452 )
453 .await
454}
455
456async fn insert_returning_record<'e, E, M, PK>(
457 executor: E,
458 descriptor: &'static ModelDescriptor<M, PK>,
459 values: &[crate::SqlColumnValue],
460) -> Result<M, CoolError>
461where
462 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
463 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
464{
465 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
466 query.push(descriptor.table_name).push(" (");
467 for (index, value) in values.iter().enumerate() {
468 if index > 0 {
469 query.push(", ");
470 }
471 query.push(value.column);
472 }
473 query.push(") VALUES (");
474 for (index, value) in values.iter().enumerate() {
475 if index > 0 {
476 query.push(", ");
477 }
478 push_bind_value(&mut query, &value.value);
479 }
480 query
481 .push(") RETURNING ")
482 .push(descriptor.select_projection());
483
484 query
485 .build_query_as::<M>()
486 .fetch_one(executor)
487 .await
488 .map_err(|error| CoolError::Database(error.to_string()))
489}
490
491async fn update_returning_record<'e, E, M, PK>(
492 executor: E,
493 policy_pool: &sqlx::PgPool,
494 descriptor: &'static ModelDescriptor<M, PK>,
495 id: PK,
496 values: &[crate::SqlColumnValue],
497 ctx: &CoolContext,
498 if_match: Option<i64>,
499) -> Result<M, CoolError>
500where
501 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
502 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
503 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
504{
505 let version_column = descriptor.version_column;
506 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
507 query.push(descriptor.table_name).push(" SET ");
508 for (index, value) in values.iter().enumerate() {
509 if index > 0 {
510 query.push(", ");
511 }
512 query.push(value.column).push(" = ");
513 push_bind_value(&mut query, &value.value);
514 }
515 if let Some(version_col) = version_column {
516 query
517 .push(", ")
518 .push(version_col)
519 .push(" = ")
520 .push(version_col)
521 .push(" + 1");
522 }
523 query
524 .push(" WHERE ")
525 .push(descriptor.primary_key)
526 .push(" = ");
527 let id_for_probe = id.clone();
528 query.push_bind(id);
529 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
530 query.push(" AND ").push(version_col).push(" = ");
531 query.push_bind(expected);
532 }
533 query.push(" AND ");
534 push_action_policy_query(
535 &mut query,
536 descriptor.update_allow_policies,
537 descriptor.update_deny_policies,
538 ctx,
539 );
540 query
541 .push(" RETURNING ")
542 .push(descriptor.select_projection());
543
544 let outcome = query
545 .build_query_as::<M>()
546 .fetch_optional(executor)
547 .await
548 .map_err(|error| CoolError::Database(error.to_string()))?;
549 match outcome {
550 Some(record) => Ok(record),
551 None => {
552 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
558 if let Some(current) =
559 probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
560 .await?
561 {
562 if current != expected {
563 return Err(CoolError::PreconditionFailed(format!(
564 "version mismatch: expected {expected}, found {current}",
565 )));
566 }
567 }
568 }
569 Err(CoolError::Forbidden(
570 "update policy denied this operation".to_owned(),
571 ))
572 }
573 }
574}
575
576async fn probe_current_version<M, PK>(
581 policy_pool: &sqlx::PgPool,
582 descriptor: &'static ModelDescriptor<M, PK>,
583 id: PK,
584 version_col: &'static str,
585 ctx: &CoolContext,
586) -> Result<Option<i64>, CoolError>
587where
588 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
589{
590 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
591 query.push(version_col);
592 query.push(" FROM ").push(descriptor.table_name);
593 query
594 .push(" WHERE ")
595 .push(descriptor.primary_key)
596 .push(" = ");
597 query.push_bind(id);
598 query.push(" AND ");
599 push_action_policy_query(
600 &mut query,
601 descriptor.read_allow_policies,
602 descriptor.read_deny_policies,
603 ctx,
604 );
605
606 let row: Option<(i64,)> = query
607 .build_query_as::<(i64,)>()
608 .fetch_optional(policy_pool)
609 .await
610 .map_err(|error| CoolError::Database(error.to_string()))?;
611 Ok(row.map(|(v,)| v))
612}
613
614async fn delete_returning_record<'e, E, M, PK>(
615 executor: E,
616 descriptor: &'static ModelDescriptor<M, PK>,
617 id: PK,
618 ctx: &CoolContext,
619) -> Result<M, CoolError>
620where
621 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
622 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
623 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
624{
625 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
626 match descriptor.soft_delete_column {
627 Some(col) => {
628 query.push("UPDATE ").push(descriptor.table_name);
631 query.push(" SET ").push(col).push(" = NOW()");
632 if let Some(version_col) = descriptor.version_column {
633 query
634 .push(", ")
635 .push(version_col)
636 .push(" = ")
637 .push(version_col)
638 .push(" + 1");
639 }
640 query.push(" WHERE ");
641 query.push(col).push(" IS NULL AND ");
642 query.push(descriptor.primary_key).push(" = ");
643 query.push_bind(id);
644 }
645 None => {
646 query.push("DELETE FROM ").push(descriptor.table_name);
647 query.push(" WHERE ");
648 query.push(descriptor.primary_key).push(" = ");
649 query.push_bind(id);
650 }
651 }
652 query.push(" AND ");
653 push_action_policy_query(
654 &mut query,
655 descriptor.delete_allow_policies,
656 descriptor.delete_deny_policies,
657 ctx,
658 );
659 query
660 .push(" RETURNING ")
661 .push(descriptor.select_projection());
662
663 query
664 .build_query_as::<M>()
665 .fetch_optional(executor)
666 .await
667 .map_err(|error| CoolError::Database(error.to_string()))?
668 .ok_or_else(|| CoolError::Forbidden("delete policy denied this operation".to_owned()))
669}
670
671#[derive(Debug, Clone)]
688pub struct UpsertRecord<'a, M: 'static, PK: 'static, I> {
689 pub(crate) runtime: &'a SqlxRuntime,
690 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
691 pub(crate) input: I,
692}
693
694impl<'a, M: 'static, PK: 'static, I> UpsertRecord<'a, M, PK, I>
695where
696 I: UpsertModelInput<M>,
697{
698 pub fn preview_sql(&self) -> String {
703 let values = self.input.sql_values();
704 let placeholders = (1..=values.len())
705 .map(|index| format!("${index}"))
706 .collect::<Vec<_>>()
707 .join(", ");
708 let columns = values
709 .iter()
710 .map(|value| value.column)
711 .collect::<Vec<_>>()
712 .join(", ");
713 let update_assignments = self
714 .descriptor
715 .upsert_update_columns
716 .iter()
717 .map(|column| format!("{column} = EXCLUDED.{column}"))
718 .collect::<Vec<_>>()
719 .join(", ");
720 let version_bump = match self.descriptor.version_column {
721 Some(col) => format!(", {col} = {table}.{col} + 1", table = self.descriptor.table_name, col = col),
722 None => String::new(),
723 };
724
725 format!(
726 "INSERT INTO {table} ({columns}) VALUES ({placeholders}) \
727 ON CONFLICT ({pk}) DO UPDATE SET {update_assignments}{version_bump} \
728 RETURNING {projection}",
729 table = self.descriptor.table_name,
730 pk = self.descriptor.primary_key,
731 projection = self.descriptor.select_projection(),
732 )
733 }
734
735 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
736 where
737 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
738 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
739 {
740 self.input.validate()?;
741
742 let mut insert_values =
746 apply_create_defaults(self.input.sql_values(), self.descriptor.create_defaults, ctx)?;
747 if let Some(version_col) = self.descriptor.version_column
748 && find_column_value(&insert_values, version_col).is_none()
749 {
750 insert_values.push(crate::SqlColumnValue {
751 column: version_col,
752 value: crate::SqlValue::Int(0),
753 });
754 }
755 if insert_values.is_empty() {
756 return Err(CoolError::Validation(
757 "upsert input must contain at least one column".to_owned(),
758 ));
759 }
760
761 if !evaluate_create_policies(
766 self.runtime.pool(),
767 self.descriptor.create_allow_policies,
768 self.descriptor.create_deny_policies,
769 &insert_values,
770 ctx,
771 )
772 .await?
773 {
774 return Err(CoolError::Forbidden(
775 "create policy denied this upsert".to_owned(),
776 ));
777 }
778
779 let pk_value = self.input.primary_key_value();
780 let emits_created = self.descriptor.emits(ModelEventKind::Created);
781 let emits_updated = self.descriptor.emits(ModelEventKind::Updated);
782 let audit_enabled = self.descriptor.audit_enabled;
783
784 let mut tx = self
785 .runtime
786 .pool()
787 .begin()
788 .await
789 .map_err(|error| CoolError::Database(error.to_string()))?;
790
791 if emits_created || emits_updated {
792 ensure_event_outbox_table(&mut *tx).await?;
793 }
794 if audit_enabled {
795 ensure_audit_table(self.runtime.pool()).await?;
796 }
797
798 let before_record =
803 select_for_update_by_pk_value(&mut *tx, self.descriptor, &pk_value).await?;
804 let inserted = before_record.is_none();
805
806 if !inserted
812 && !row_passes_update_policy(
813 self.runtime.pool(),
814 self.descriptor,
815 &pk_value,
816 ctx,
817 )
818 .await?
819 {
820 return Err(CoolError::Forbidden(
821 "update policy denied this upsert".to_owned(),
822 ));
823 }
824
825 let before_snapshot = if !inserted && audit_enabled {
826 before_record
827 .as_ref()
828 .and_then(|m| serde_json::to_value(m).ok())
829 } else {
830 None
831 };
832
833 let record = upsert_returning_record(&mut *tx, self.descriptor, &insert_values).await?;
834
835 let event_kind = if inserted {
839 ModelEventKind::Created
840 } else {
841 ModelEventKind::Updated
842 };
843 let audit_op = if inserted {
844 AuditOperation::Create
845 } else {
846 AuditOperation::Update
847 };
848 let emits_event = if inserted { emits_created } else { emits_updated };
849
850 if emits_event {
851 enqueue_event_outbox(
852 &mut *tx,
853 self.descriptor.schema_name,
854 event_kind,
855 &record,
856 )
857 .await?;
858 }
859 if audit_enabled {
860 let after = serde_json::to_value(&record).ok();
861 let event = build_audit_event(self.descriptor, audit_op, before_snapshot, after, ctx);
862 enqueue_audit_event(&mut *tx, &event).await?;
863 }
864
865 tx.commit()
866 .await
867 .map_err(|error| CoolError::Database(error.to_string()))?;
868
869 if emits_event {
870 let _ = self.runtime.drain_event_outbox().await;
871 }
872
873 Ok(record)
874 }
875}
876
877async fn select_for_update_by_pk_value<'e, E, M, PK>(
882 executor: E,
883 descriptor: &'static ModelDescriptor<M, PK>,
884 pk_value: &SqlValue,
885) -> Result<Option<M>, CoolError>
886where
887 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
888 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
889{
890 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
891 query.push(descriptor.select_projection());
892 query.push(" FROM ").push(descriptor.table_name);
893 query
894 .push(" WHERE ")
895 .push(descriptor.primary_key)
896 .push(" = ");
897 push_bind_value(&mut query, pk_value);
898 if let Some(col) = descriptor.soft_delete_column {
902 query.push(" AND ").push(col).push(" IS NULL");
903 }
904 query.push(" FOR UPDATE");
905
906 query
907 .build_query_as::<M>()
908 .fetch_optional(executor)
909 .await
910 .map_err(|error| CoolError::Database(error.to_string()))
911}
912
913async fn row_passes_update_policy<M, PK>(
918 policy_pool: &sqlx::PgPool,
919 descriptor: &'static ModelDescriptor<M, PK>,
920 pk_value: &SqlValue,
921 ctx: &CoolContext,
922) -> Result<bool, CoolError> {
923 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT 1 FROM ");
924 query.push(descriptor.table_name);
925 query
926 .push(" WHERE ")
927 .push(descriptor.primary_key)
928 .push(" = ");
929 push_bind_value(&mut query, pk_value);
930 query.push(" AND ");
931 push_action_policy_query(
932 &mut query,
933 descriptor.update_allow_policies,
934 descriptor.update_deny_policies,
935 ctx,
936 );
937
938 let row: Option<(i32,)> = query
939 .build_query_as::<(i32,)>()
940 .fetch_optional(policy_pool)
941 .await
942 .map_err(|error| CoolError::Database(error.to_string()))?;
943 Ok(row.is_some())
944}
945
946async fn upsert_returning_record<'e, E, M, PK>(
951 executor: E,
952 descriptor: &'static ModelDescriptor<M, PK>,
953 insert_values: &[crate::SqlColumnValue],
954) -> Result<M, CoolError>
955where
956 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
957 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
958{
959 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
960 query.push(descriptor.table_name).push(" (");
961 for (index, value) in insert_values.iter().enumerate() {
962 if index > 0 {
963 query.push(", ");
964 }
965 query.push(value.column);
966 }
967 query.push(") VALUES (");
968 for (index, value) in insert_values.iter().enumerate() {
969 if index > 0 {
970 query.push(", ");
971 }
972 push_bind_value(&mut query, &value.value);
973 }
974 query.push(") ON CONFLICT (").push(descriptor.primary_key).push(") DO UPDATE SET ");
975
976 if descriptor.upsert_update_columns.is_empty() {
982 query.push(descriptor.primary_key);
983 query.push(" = EXCLUDED.").push(descriptor.primary_key);
984 } else {
985 for (index, column) in descriptor.upsert_update_columns.iter().enumerate() {
986 if index > 0 {
987 query.push(", ");
988 }
989 query.push(*column).push(" = EXCLUDED.").push(*column);
990 }
991 }
992 if let Some(version_col) = descriptor.version_column {
993 query
994 .push(", ")
995 .push(version_col)
996 .push(" = ")
997 .push(descriptor.table_name)
998 .push(".")
999 .push(version_col)
1000 .push(" + 1");
1001 }
1002
1003 query
1004 .push(" RETURNING ")
1005 .push(descriptor.select_projection());
1006
1007 query
1008 .build_query_as::<M>()
1009 .fetch_one(executor)
1010 .await
1011 .map_err(|error| CoolError::Database(error.to_string()))
1012}