1use std::collections::HashMap;
23use std::hash::Hash;
24
25use crate::sqlx;
26use sqlx_core::acquire::Acquire as _;
30
31use cratestack_core::{
32 AuditOperation, BATCH_MAX_ITEMS, BatchResponse, CoolContext, CoolError, ModelEventKind,
33 find_duplicate_position,
34};
35
36use crate::{
37 CreateModelInput, ModelDescriptor, ModelPrimaryKey, SqlValue, SqlxRuntime, UpdateModelInput,
38 UpsertModelInput,
39 audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
40 descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
41};
42
43use super::support::{
44 apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
45 push_bind_value,
46};
47
48fn validate_batch_size(len: usize) -> Result<(), CoolError> {
51 if len > BATCH_MAX_ITEMS {
52 return Err(CoolError::Validation(format!(
53 "batch size {len} exceeds maximum of {BATCH_MAX_ITEMS}",
54 )));
55 }
56 Ok(())
57}
58
59fn reject_duplicate_pks<K: Eq + Hash + Clone>(keys: &[K]) -> Result<(), CoolError> {
60 if let Some((first, dup)) = find_duplicate_position(keys.iter().cloned()) {
61 return Err(CoolError::Validation(format!(
62 "duplicate primary key in batch at positions {first} and {dup}",
63 )));
64 }
65 Ok(())
66}
67
68fn reject_duplicate_sql_values(values: &[SqlValue]) -> Result<(), CoolError> {
69 if let Some((first, dup)) = cratestack_sql::find_duplicate_sql_value(values) {
70 return Err(CoolError::Validation(format!(
71 "duplicate primary key in batch at positions {first} and {dup}",
72 )));
73 }
74 Ok(())
75}
76
77#[derive(Debug, Clone)]
80pub struct BatchGet<'a, M: 'static, PK: 'static> {
81 pub(crate) runtime: &'a SqlxRuntime,
82 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
83 pub(crate) ids: Vec<PK>,
84}
85
86impl<'a, M: 'static, PK: 'static> BatchGet<'a, M, PK> {
87 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
88 where
89 for<'r> M:
90 Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + ModelPrimaryKey<PK>,
91 PK: Clone
92 + Eq
93 + Hash
94 + Send
95 + sqlx::Type<sqlx::Postgres>
96 + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
97 {
98 validate_batch_size(self.ids.len())?;
99 reject_duplicate_pks(&self.ids)?;
100 if self.ids.is_empty() {
101 return Ok(BatchResponse::from_results(vec![]));
102 }
103
104 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
106 query.push(self.descriptor.select_projection());
107 query.push(" FROM ").push(self.descriptor.table_name);
108 query.push(" WHERE ");
109 if let Some(col) = self.descriptor.soft_delete_column {
110 query.push(col).push(" IS NULL AND ");
111 }
112 query.push(self.descriptor.primary_key).push(" IN (");
113 for (index, id) in self.ids.iter().enumerate() {
114 if index > 0 {
115 query.push(", ");
116 }
117 query.push_bind(id.clone());
118 }
119 query.push(") AND ");
120 push_action_policy_query(
121 &mut query,
122 self.descriptor.read_allow_policies,
123 self.descriptor.read_deny_policies,
124 ctx,
125 );
126
127 let rows: Vec<M> = query
128 .build_query_as::<M>()
129 .fetch_all(self.runtime.pool())
130 .await
131 .map_err(|error| CoolError::Database(error.to_string()))?;
132
133 let mut by_pk: HashMap<PK, M> =
136 rows.into_iter().map(|m| (m.primary_key(), m)).collect();
137 let per_item: Vec<Result<M, CoolError>> = self
138 .ids
139 .into_iter()
140 .map(|id| {
141 by_pk
142 .remove(&id)
143 .ok_or_else(|| CoolError::NotFound("no row matched".to_owned()))
144 })
145 .collect();
146
147 Ok(BatchResponse::from_results(per_item))
148 }
149}
150
151#[derive(Debug, Clone)]
154pub struct BatchDelete<'a, M: 'static, PK: 'static> {
155 pub(crate) runtime: &'a SqlxRuntime,
156 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
157 pub(crate) ids: Vec<PK>,
158}
159
160impl<'a, M: 'static, PK: 'static> BatchDelete<'a, M, PK> {
161 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
162 where
163 for<'r> M: Send
164 + Unpin
165 + sqlx::FromRow<'r, sqlx::postgres::PgRow>
166 + ModelPrimaryKey<PK>
167 + serde::Serialize,
168 PK: Clone
169 + Eq
170 + Hash
171 + Send
172 + sqlx::Type<sqlx::Postgres>
173 + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
174 {
175 validate_batch_size(self.ids.len())?;
176 reject_duplicate_pks(&self.ids)?;
177 if self.ids.is_empty() {
178 return Ok(BatchResponse::from_results(vec![]));
179 }
180
181 let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
182 let audit_enabled = self.descriptor.audit_enabled;
183
184 let mut tx = self
185 .runtime
186 .pool()
187 .begin()
188 .await
189 .map_err(|error| CoolError::Database(error.to_string()))?;
190 if emits_event {
191 ensure_event_outbox_table(&mut *tx).await?;
192 }
193 if audit_enabled {
194 ensure_audit_table(self.runtime.pool()).await?;
195 }
196
197 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
200 match self.descriptor.soft_delete_column {
201 Some(col) => {
202 query.push("UPDATE ").push(self.descriptor.table_name);
203 query.push(" SET ").push(col).push(" = NOW()");
204 if let Some(version_col) = self.descriptor.version_column {
205 query
206 .push(", ")
207 .push(version_col)
208 .push(" = ")
209 .push(version_col)
210 .push(" + 1");
211 }
212 query.push(" WHERE ").push(col).push(" IS NULL AND ");
213 }
214 None => {
215 query.push("DELETE FROM ").push(self.descriptor.table_name);
216 query.push(" WHERE ");
217 }
218 }
219 query.push(self.descriptor.primary_key).push(" IN (");
220 for (index, id) in self.ids.iter().enumerate() {
221 if index > 0 {
222 query.push(", ");
223 }
224 query.push_bind(id.clone());
225 }
226 query.push(") AND ");
227 push_action_policy_query(
228 &mut query,
229 self.descriptor.delete_allow_policies,
230 self.descriptor.delete_deny_policies,
231 ctx,
232 );
233 query
234 .push(" RETURNING ")
235 .push(self.descriptor.select_projection());
236
237 let deleted: Vec<M> = query
238 .build_query_as::<M>()
239 .fetch_all(&mut *tx)
240 .await
241 .map_err(|error| CoolError::Database(error.to_string()))?;
242
243 for record in &deleted {
247 if emits_event {
248 enqueue_event_outbox(
249 &mut *tx,
250 self.descriptor.schema_name,
251 ModelEventKind::Deleted,
252 record,
253 )
254 .await?;
255 }
256 if audit_enabled {
257 let before = serde_json::to_value(record).ok();
258 let event = build_audit_event(
259 self.descriptor,
260 AuditOperation::Delete,
261 before,
262 None,
263 ctx,
264 );
265 enqueue_audit_event(&mut *tx, &event).await?;
266 }
267 }
268
269 tx.commit()
270 .await
271 .map_err(|error| CoolError::Database(error.to_string()))?;
272
273 if emits_event {
274 let _ = self.runtime.drain_event_outbox().await;
275 }
276
277 let mut by_pk: HashMap<PK, M> =
281 deleted.into_iter().map(|m| (m.primary_key(), m)).collect();
282 let per_item: Vec<Result<M, CoolError>> = self
283 .ids
284 .into_iter()
285 .map(|id| {
286 by_pk
287 .remove(&id)
288 .ok_or_else(|| CoolError::NotFound("no row matched".to_owned()))
289 })
290 .collect();
291
292 Ok(BatchResponse::from_results(per_item))
293 }
294}
295
296#[derive(Debug, Clone)]
299pub struct BatchCreate<'a, M: 'static, PK: 'static, I> {
300 pub(crate) runtime: &'a SqlxRuntime,
301 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
302 pub(crate) inputs: Vec<I>,
303}
304
305impl<'a, M: 'static, PK: 'static, I> BatchCreate<'a, M, PK, I>
306where
307 I: CreateModelInput<M> + Send,
308{
309 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
310 where
311 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
312 {
313 validate_batch_size(self.inputs.len())?;
314 if self.inputs.is_empty() {
320 return Ok(BatchResponse::from_results(vec![]));
321 }
322
323 let emits_event = self.descriptor.emits(ModelEventKind::Created);
324 let audit_enabled = self.descriptor.audit_enabled;
325
326 let mut tx = self
327 .runtime
328 .pool()
329 .begin()
330 .await
331 .map_err(|error| CoolError::Database(error.to_string()))?;
332 if emits_event {
333 ensure_event_outbox_table(&mut *tx).await?;
334 }
335 if audit_enabled {
336 ensure_audit_table(self.runtime.pool()).await?;
337 }
338
339 let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.inputs.len());
340 for input in self.inputs {
341 let outcome = run_create_item(
342 &mut tx,
343 self.runtime.pool(),
344 self.descriptor,
345 input,
346 ctx,
347 emits_event,
348 audit_enabled,
349 )
350 .await?;
351 per_item.push(outcome);
352 }
353
354 tx.commit()
355 .await
356 .map_err(|error| CoolError::Database(error.to_string()))?;
357
358 if emits_event {
359 let _ = self.runtime.drain_event_outbox().await;
360 }
361
362 Ok(BatchResponse::from_results(per_item))
363 }
364}
365
366async fn run_create_item<'tx, M, PK, I>(
367 outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
368 policy_pool: &sqlx::PgPool,
369 descriptor: &'static ModelDescriptor<M, PK>,
370 input: I,
371 ctx: &CoolContext,
372 emits_event: bool,
373 audit_enabled: bool,
374) -> Result<Result<M, CoolError>, CoolError>
375where
376 I: CreateModelInput<M>,
377 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
378{
379 let mut item_tx = outer
380 .begin()
381 .await
382 .map_err(|error| CoolError::Database(error.to_string()))?;
383
384 let inner: Result<M, CoolError> = async {
387 input.validate()?;
388 let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
389 if let Some(version_col) = descriptor.version_column
390 && find_column_value(&values, version_col).is_none()
391 {
392 values.push(crate::SqlColumnValue {
393 column: version_col,
394 value: crate::SqlValue::Int(0),
395 });
396 }
397 if values.is_empty() {
398 return Err(CoolError::Validation(
399 "create input must contain at least one column".to_owned(),
400 ));
401 }
402 if !evaluate_create_policies(
403 policy_pool,
404 descriptor.create_allow_policies,
405 descriptor.create_deny_policies,
406 &values,
407 ctx,
408 )
409 .await?
410 {
411 return Err(CoolError::Forbidden(
412 "create policy denied this operation".to_owned(),
413 ));
414 }
415
416 let record = insert_one_into_savepoint::<M, PK>(&mut item_tx, descriptor, &values).await?;
417
418 if emits_event {
419 enqueue_event_outbox(
420 &mut *item_tx,
421 descriptor.schema_name,
422 ModelEventKind::Created,
423 &record,
424 )
425 .await?;
426 }
427 if audit_enabled {
428 let after = serde_json::to_value(&record).ok();
429 let event =
430 build_audit_event(descriptor, AuditOperation::Create, None, after, ctx);
431 enqueue_audit_event(&mut *item_tx, &event).await?;
432 }
433 Ok(record)
434 }
435 .await;
436
437 match inner {
438 Ok(record) => {
439 item_tx
440 .commit()
441 .await
442 .map_err(|error| CoolError::Database(error.to_string()))?;
443 Ok(Ok(record))
444 }
445 Err(item_err) => {
446 item_tx
451 .rollback()
452 .await
453 .map_err(|error| CoolError::Database(error.to_string()))?;
454 Ok(Err(item_err))
455 }
456 }
457}
458
459async fn insert_one_into_savepoint<'tx, M, PK>(
460 executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
461 descriptor: &'static ModelDescriptor<M, PK>,
462 values: &[crate::SqlColumnValue],
463) -> Result<M, CoolError>
464where
465 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
466{
467 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
468 query.push(descriptor.table_name).push(" (");
469 for (index, value) in values.iter().enumerate() {
470 if index > 0 {
471 query.push(", ");
472 }
473 query.push(value.column);
474 }
475 query.push(") VALUES (");
476 for (index, value) in values.iter().enumerate() {
477 if index > 0 {
478 query.push(", ");
479 }
480 push_bind_value(&mut query, &value.value);
481 }
482 query
483 .push(") RETURNING ")
484 .push(descriptor.select_projection());
485
486 query
487 .build_query_as::<M>()
488 .fetch_one(&mut **executor)
489 .await
490 .map_err(|error| classify_insert_error(error))
491}
492
493fn classify_insert_error(error: sqlx::Error) -> CoolError {
497 if let sqlx::Error::Database(db_err) = &error
498 && let Some(code) = db_err.code()
499 && code == "23505"
500 {
501 return CoolError::Conflict(db_err.message().to_owned());
502 }
503 CoolError::Database(error.to_string())
504}
505
506pub type BatchUpdateItem<PK, I> = (PK, I, Option<i64>);
510
511#[derive(Debug, Clone)]
512pub struct BatchUpdate<'a, M: 'static, PK: 'static, I> {
513 pub(crate) runtime: &'a SqlxRuntime,
514 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
515 pub(crate) items: Vec<BatchUpdateItem<PK, I>>,
516}
517
518impl<'a, M: 'static, PK: 'static, I> BatchUpdate<'a, M, PK, I>
519where
520 I: UpdateModelInput<M> + Send,
521{
522 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
523 where
524 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
525 PK: Clone
526 + Eq
527 + Hash
528 + Send
529 + sqlx::Type<sqlx::Postgres>
530 + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
531 {
532 validate_batch_size(self.items.len())?;
533 let ids: Vec<PK> = self.items.iter().map(|(id, _, _)| id.clone()).collect();
534 reject_duplicate_pks(&ids)?;
535 if self.items.is_empty() {
536 return Ok(BatchResponse::from_results(vec![]));
537 }
538
539 let emits_event = self.descriptor.emits(ModelEventKind::Updated);
540 let audit_enabled = self.descriptor.audit_enabled;
541
542 let mut tx = self
543 .runtime
544 .pool()
545 .begin()
546 .await
547 .map_err(|error| CoolError::Database(error.to_string()))?;
548 if emits_event {
549 ensure_event_outbox_table(&mut *tx).await?;
550 }
551 if audit_enabled {
552 ensure_audit_table(self.runtime.pool()).await?;
553 }
554
555 let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.items.len());
556 for (id, input, if_match) in self.items {
557 let outcome = run_update_item(
558 &mut tx,
559 self.descriptor,
560 id,
561 input,
562 if_match,
563 ctx,
564 emits_event,
565 audit_enabled,
566 )
567 .await?;
568 per_item.push(outcome);
569 }
570
571 tx.commit()
572 .await
573 .map_err(|error| CoolError::Database(error.to_string()))?;
574
575 if emits_event {
576 let _ = self.runtime.drain_event_outbox().await;
577 }
578
579 Ok(BatchResponse::from_results(per_item))
580 }
581}
582
583async fn run_update_item<'tx, M, PK, I>(
584 outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
585 descriptor: &'static ModelDescriptor<M, PK>,
586 id: PK,
587 input: I,
588 if_match: Option<i64>,
589 ctx: &CoolContext,
590 emits_event: bool,
591 audit_enabled: bool,
592) -> Result<Result<M, CoolError>, CoolError>
593where
594 I: UpdateModelInput<M>,
595 PK: Clone + Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
596 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
597{
598 let mut item_tx = outer
599 .begin()
600 .await
601 .map_err(|error| CoolError::Database(error.to_string()))?;
602
603 let inner: Result<M, CoolError> = async {
604 if descriptor.version_column.is_some() && if_match.is_none() {
605 return Err(CoolError::PreconditionFailed(
606 "If-Match required for versioned model".to_owned(),
607 ));
608 }
609 input.validate()?;
610 let values = input.sql_values();
611 if values.is_empty() {
612 return Err(CoolError::Validation(
613 "update input must contain at least one changed column".to_owned(),
614 ));
615 }
616
617 let before = if audit_enabled {
619 fetch_for_audit(&mut *item_tx, descriptor, id.clone()).await?
620 } else {
621 None
622 };
623
624 let record = update_one_in_savepoint(
625 &mut item_tx,
626 descriptor,
627 id,
628 &values,
629 ctx,
630 if_match,
631 )
632 .await?;
633
634 if emits_event {
635 enqueue_event_outbox(
636 &mut *item_tx,
637 descriptor.schema_name,
638 ModelEventKind::Updated,
639 &record,
640 )
641 .await?;
642 }
643 if audit_enabled {
644 let before_snapshot = before.as_ref().and_then(|m| serde_json::to_value(m).ok());
645 let after = serde_json::to_value(&record).ok();
646 let event = build_audit_event(
647 descriptor,
648 AuditOperation::Update,
649 before_snapshot,
650 after,
651 ctx,
652 );
653 enqueue_audit_event(&mut *item_tx, &event).await?;
654 }
655 Ok(record)
656 }
657 .await;
658
659 match inner {
660 Ok(record) => {
661 item_tx
662 .commit()
663 .await
664 .map_err(|error| CoolError::Database(error.to_string()))?;
665 Ok(Ok(record))
666 }
667 Err(item_err) => {
668 item_tx
669 .rollback()
670 .await
671 .map_err(|error| CoolError::Database(error.to_string()))?;
672 Ok(Err(item_err))
673 }
674 }
675}
676
677async fn update_one_in_savepoint<'tx, M, PK>(
678 executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
679 descriptor: &'static ModelDescriptor<M, PK>,
680 id: PK,
681 values: &[crate::SqlColumnValue],
682 ctx: &CoolContext,
683 if_match: Option<i64>,
684) -> Result<M, CoolError>
685where
686 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
687 PK: Clone + Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
688{
689 let version_column = descriptor.version_column;
690 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
691 query.push(descriptor.table_name).push(" SET ");
692 for (index, value) in values.iter().enumerate() {
693 if index > 0 {
694 query.push(", ");
695 }
696 query.push(value.column).push(" = ");
697 push_bind_value(&mut query, &value.value);
698 }
699 if let Some(version_col) = version_column {
700 query
701 .push(", ")
702 .push(version_col)
703 .push(" = ")
704 .push(version_col)
705 .push(" + 1");
706 }
707 query
708 .push(" WHERE ")
709 .push(descriptor.primary_key)
710 .push(" = ");
711 query.push_bind(id);
712 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
713 query.push(" AND ").push(version_col).push(" = ");
714 query.push_bind(expected);
715 }
716 query.push(" AND ");
717 push_action_policy_query(
718 &mut query,
719 descriptor.update_allow_policies,
720 descriptor.update_deny_policies,
721 ctx,
722 );
723 query
724 .push(" RETURNING ")
725 .push(descriptor.select_projection());
726
727 let outcome = query
728 .build_query_as::<M>()
729 .fetch_optional(&mut **executor)
730 .await
731 .map_err(|error| CoolError::Database(error.to_string()))?;
732 match outcome {
733 Some(record) => Ok(record),
734 None => {
735 if if_match.is_some() {
741 Err(CoolError::PreconditionFailed(
742 "version mismatch or row missing".to_owned(),
743 ))
744 } else {
745 Err(CoolError::Forbidden(
746 "update policy denied or row missing".to_owned(),
747 ))
748 }
749 }
750 }
751}
752
753#[derive(Debug, Clone)]
756pub struct BatchUpsert<'a, M: 'static, PK: 'static, I> {
757 pub(crate) runtime: &'a SqlxRuntime,
758 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
759 pub(crate) inputs: Vec<I>,
760}
761
762impl<'a, M: 'static, PK: 'static, I> BatchUpsert<'a, M, PK, I>
763where
764 I: UpsertModelInput<M>,
765{
766 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
767 where
768 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
769 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
770 {
771 validate_batch_size(self.inputs.len())?;
772 let pks: Vec<SqlValue> = self
776 .inputs
777 .iter()
778 .map(UpsertModelInput::primary_key_value)
779 .collect();
780 reject_duplicate_sql_values(&pks)?;
781 if self.inputs.is_empty() {
782 return Ok(BatchResponse::from_results(vec![]));
783 }
784
785 let emits_created = self.descriptor.emits(ModelEventKind::Created);
786 let emits_updated = self.descriptor.emits(ModelEventKind::Updated);
787 let audit_enabled = self.descriptor.audit_enabled;
788
789 let mut tx = self
790 .runtime
791 .pool()
792 .begin()
793 .await
794 .map_err(|error| CoolError::Database(error.to_string()))?;
795 if emits_created || emits_updated {
796 ensure_event_outbox_table(&mut *tx).await?;
797 }
798 if audit_enabled {
799 ensure_audit_table(self.runtime.pool()).await?;
800 }
801
802 let mut per_item: Vec<Result<M, CoolError>> = Vec::with_capacity(self.inputs.len());
803 for input in self.inputs {
804 let outcome = run_upsert_item(
805 &mut tx,
806 self.runtime.pool(),
807 self.descriptor,
808 input,
809 ctx,
810 emits_created,
811 emits_updated,
812 audit_enabled,
813 )
814 .await?;
815 per_item.push(outcome);
816 }
817
818 tx.commit()
819 .await
820 .map_err(|error| CoolError::Database(error.to_string()))?;
821
822 if emits_created || emits_updated {
823 let _ = self.runtime.drain_event_outbox().await;
824 }
825
826 Ok(BatchResponse::from_results(per_item))
827 }
828}
829
830#[allow(clippy::too_many_arguments)]
831async fn run_upsert_item<'tx, M, PK, I>(
832 outer: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
833 policy_pool: &sqlx::PgPool,
834 descriptor: &'static ModelDescriptor<M, PK>,
835 input: I,
836 ctx: &CoolContext,
837 emits_created: bool,
838 emits_updated: bool,
839 audit_enabled: bool,
840) -> Result<Result<M, CoolError>, CoolError>
841where
842 I: UpsertModelInput<M>,
843 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
844 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
845{
846 let mut item_tx = outer
847 .begin()
848 .await
849 .map_err(|error| CoolError::Database(error.to_string()))?;
850
851 let inner: Result<M, CoolError> = async {
852 input.validate()?;
853 let mut insert_values =
854 apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
855 if let Some(version_col) = descriptor.version_column
856 && find_column_value(&insert_values, version_col).is_none()
857 {
858 insert_values.push(crate::SqlColumnValue {
859 column: version_col,
860 value: crate::SqlValue::Int(0),
861 });
862 }
863 if insert_values.is_empty() {
864 return Err(CoolError::Validation(
865 "upsert input must contain at least one column".to_owned(),
866 ));
867 }
868 if !evaluate_create_policies(
869 policy_pool,
870 descriptor.create_allow_policies,
871 descriptor.create_deny_policies,
872 &insert_values,
873 ctx,
874 )
875 .await?
876 {
877 return Err(CoolError::Forbidden(
878 "create policy denied this upsert".to_owned(),
879 ));
880 }
881
882 let pk_value = input.primary_key_value();
883 let before_record =
886 select_for_update_by_pk_value(&mut item_tx, descriptor, &pk_value).await?;
887 let inserted = before_record.is_none();
888
889 if !inserted
890 && !row_passes_update_policy(policy_pool, descriptor, &pk_value, ctx).await?
891 {
892 return Err(CoolError::Forbidden(
893 "update policy denied this upsert".to_owned(),
894 ));
895 }
896
897 let before_snapshot = if !inserted && audit_enabled {
898 before_record
899 .as_ref()
900 .and_then(|m| serde_json::to_value(m).ok())
901 } else {
902 None
903 };
904
905 let record =
906 upsert_one_in_savepoint::<M, PK>(&mut item_tx, descriptor, &insert_values).await?;
907
908 let event_kind = if inserted {
909 ModelEventKind::Created
910 } else {
911 ModelEventKind::Updated
912 };
913 let audit_op = if inserted {
914 AuditOperation::Create
915 } else {
916 AuditOperation::Update
917 };
918 let emits_this_event = if inserted { emits_created } else { emits_updated };
919
920 if emits_this_event {
921 enqueue_event_outbox(&mut *item_tx, descriptor.schema_name, event_kind, &record)
922 .await?;
923 }
924 if audit_enabled {
925 let after = serde_json::to_value(&record).ok();
926 let event = build_audit_event(descriptor, audit_op, before_snapshot, after, ctx);
927 enqueue_audit_event(&mut *item_tx, &event).await?;
928 }
929
930 Ok(record)
931 }
932 .await;
933
934 match inner {
935 Ok(record) => {
936 item_tx
937 .commit()
938 .await
939 .map_err(|error| CoolError::Database(error.to_string()))?;
940 Ok(Ok(record))
941 }
942 Err(item_err) => {
943 item_tx
944 .rollback()
945 .await
946 .map_err(|error| CoolError::Database(error.to_string()))?;
947 Ok(Err(item_err))
948 }
949 }
950}
951
952async fn select_for_update_by_pk_value<'tx, M, PK>(
953 executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
954 descriptor: &'static ModelDescriptor<M, PK>,
955 pk_value: &SqlValue,
956) -> Result<Option<M>, CoolError>
957where
958 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
959{
960 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
961 query.push(descriptor.select_projection());
962 query.push(" FROM ").push(descriptor.table_name);
963 query
964 .push(" WHERE ")
965 .push(descriptor.primary_key)
966 .push(" = ");
967 push_bind_value(&mut query, pk_value);
968 if let Some(col) = descriptor.soft_delete_column {
969 query.push(" AND ").push(col).push(" IS NULL");
970 }
971 query.push(" FOR UPDATE");
972
973 query
974 .build_query_as::<M>()
975 .fetch_optional(&mut **executor)
976 .await
977 .map_err(|error| CoolError::Database(error.to_string()))
978}
979
980async fn row_passes_update_policy<M, PK>(
981 policy_pool: &sqlx::PgPool,
982 descriptor: &'static ModelDescriptor<M, PK>,
983 pk_value: &SqlValue,
984 ctx: &CoolContext,
985) -> Result<bool, CoolError> {
986 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT 1 FROM ");
987 query.push(descriptor.table_name);
988 query
989 .push(" WHERE ")
990 .push(descriptor.primary_key)
991 .push(" = ");
992 push_bind_value(&mut query, pk_value);
993 query.push(" AND ");
994 push_action_policy_query(
995 &mut query,
996 descriptor.update_allow_policies,
997 descriptor.update_deny_policies,
998 ctx,
999 );
1000
1001 let row: Option<(i32,)> = query
1002 .build_query_as::<(i32,)>()
1003 .fetch_optional(policy_pool)
1004 .await
1005 .map_err(|error| CoolError::Database(error.to_string()))?;
1006 Ok(row.is_some())
1007}
1008
1009async fn upsert_one_in_savepoint<'tx, M, PK>(
1010 executor: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
1011 descriptor: &'static ModelDescriptor<M, PK>,
1012 insert_values: &[crate::SqlColumnValue],
1013) -> Result<M, CoolError>
1014where
1015 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
1016{
1017 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
1018 query.push(descriptor.table_name).push(" (");
1019 for (index, value) in insert_values.iter().enumerate() {
1020 if index > 0 {
1021 query.push(", ");
1022 }
1023 query.push(value.column);
1024 }
1025 query.push(") VALUES (");
1026 for (index, value) in insert_values.iter().enumerate() {
1027 if index > 0 {
1028 query.push(", ");
1029 }
1030 push_bind_value(&mut query, &value.value);
1031 }
1032 query
1033 .push(") ON CONFLICT (")
1034 .push(descriptor.primary_key)
1035 .push(") DO UPDATE SET ");
1036
1037 if descriptor.upsert_update_columns.is_empty() {
1038 query.push(descriptor.primary_key);
1039 query.push(" = EXCLUDED.").push(descriptor.primary_key);
1040 } else {
1041 for (index, column) in descriptor.upsert_update_columns.iter().enumerate() {
1042 if index > 0 {
1043 query.push(", ");
1044 }
1045 query.push(*column).push(" = EXCLUDED.").push(*column);
1046 }
1047 }
1048 if let Some(version_col) = descriptor.version_column {
1049 query
1050 .push(", ")
1051 .push(version_col)
1052 .push(" = ")
1053 .push(descriptor.table_name)
1054 .push(".")
1055 .push(version_col)
1056 .push(" + 1");
1057 }
1058
1059 query
1060 .push(" RETURNING ")
1061 .push(descriptor.select_projection());
1062
1063 query
1064 .build_query_as::<M>()
1065 .fetch_one(&mut **executor)
1066 .await
1067 .map_err(|error| CoolError::Database(error.to_string()))
1068}