1use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
2
3use crate::{
4 CreateModelInput, ModelDescriptor, SqlxRuntime, UpdateModelInput,
5 audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
6 descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
7};
8
9use super::support::{
10 apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
11 push_bind_value,
12};
13
14pub fn render_update_preview_sql(
17 table_name: &str,
18 primary_key: &str,
19 version_column: Option<&str>,
20 columns: &[&str],
21 select_projection: &str,
22) -> String {
23 let assignments = columns
24 .iter()
25 .enumerate()
26 .map(|(index, column)| format!("{column} = ${}", index + 1))
27 .collect::<Vec<_>>()
28 .join(", ");
29
30 match version_column {
31 Some(version_col) => format!(
32 "UPDATE {} SET {}, {} = {} + 1 WHERE {} = ${} AND {} = ${} RETURNING {}",
33 table_name,
34 assignments,
35 version_col,
36 version_col,
37 primary_key,
38 columns.len() + 1,
39 version_col,
40 columns.len() + 2,
41 select_projection,
42 ),
43 None => format!(
44 "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
45 table_name,
46 assignments,
47 primary_key,
48 columns.len() + 1,
49 select_projection,
50 ),
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct CreateRecord<'a, M: 'static, PK: 'static, I> {
56 pub(crate) runtime: &'a SqlxRuntime,
57 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
58 pub(crate) input: I,
59}
60
61impl<'a, M: 'static, PK: 'static, I> CreateRecord<'a, M, PK, I>
62where
63 I: CreateModelInput<M>,
64{
65 pub fn preview_sql(&self) -> String {
66 let values = self.input.sql_values();
67 let placeholders = (1..=values.len())
68 .map(|index| format!("${index}"))
69 .collect::<Vec<_>>()
70 .join(", ");
71 let columns = values
72 .iter()
73 .map(|value| value.column)
74 .collect::<Vec<_>>()
75 .join(", ");
76
77 format!(
78 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
79 self.descriptor.table_name,
80 columns,
81 placeholders,
82 self.descriptor.select_projection(),
83 )
84 }
85
86 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
87 where
88 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
89 {
90 let emits_event = self.descriptor.emits(ModelEventKind::Created);
91 let audit_enabled = self.descriptor.audit_enabled;
92 let needs_tx = emits_event || audit_enabled;
93 let record = if needs_tx {
94 let mut tx = self
95 .runtime
96 .pool()
97 .begin()
98 .await
99 .map_err(|error| CoolError::Database(error.to_string()))?;
100 if emits_event {
101 ensure_event_outbox_table(&mut *tx).await?;
102 }
103 if audit_enabled {
104 ensure_audit_table(self.runtime.pool()).await?;
105 }
106 let record = create_record_with_executor(
107 &mut *tx,
108 self.runtime.pool(),
109 self.descriptor,
110 self.input,
111 ctx,
112 )
113 .await?;
114 if emits_event {
115 enqueue_event_outbox(
116 &mut *tx,
117 self.descriptor.schema_name,
118 ModelEventKind::Created,
119 &record,
120 )
121 .await?;
122 }
123 if audit_enabled {
124 let after = serde_json::to_value(&record).ok();
125 let event =
126 build_audit_event(self.descriptor, AuditOperation::Create, None, after, ctx);
127 enqueue_audit_event(&mut *tx, &event).await?;
128 }
129 tx.commit()
130 .await
131 .map_err(|error| CoolError::Database(error.to_string()))?;
132 record
133 } else {
134 create_record_with_executor(
135 self.runtime.pool(),
136 self.runtime.pool(),
137 self.descriptor,
138 self.input,
139 ctx,
140 )
141 .await?
142 };
143
144 if emits_event {
145 let _ = self.runtime.drain_event_outbox().await;
146 }
147
148 Ok(record)
149 }
150}
151
152#[derive(Debug, Clone)]
153pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
154 pub(crate) runtime: &'a SqlxRuntime,
155 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
156 pub(crate) id: PK,
157}
158
159impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
160 pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
161 UpdateRecordSet {
162 runtime: self.runtime,
163 descriptor: self.descriptor,
164 id: self.id,
165 input,
166 if_match: None,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
172pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
173 pub(crate) runtime: &'a SqlxRuntime,
174 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
175 pub(crate) id: PK,
176 pub(crate) input: I,
177 pub(crate) if_match: Option<i64>,
178}
179
180impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
181where
182 I: UpdateModelInput<M>,
183{
184 pub fn if_match(mut self, expected: i64) -> Self {
188 self.if_match = Some(expected);
189 self
190 }
191
192 pub fn preview_sql(&self) -> String {
193 let values = self.input.sql_values();
194 let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
195 render_update_preview_sql(
196 self.descriptor.table_name,
197 self.descriptor.primary_key,
198 self.descriptor.version_column,
199 &columns,
200 &self.descriptor.select_projection(),
201 )
202 }
203
204 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
205 where
206 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
207 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
208 {
209 if self.descriptor.version_column.is_some() && self.if_match.is_none() {
210 return Err(CoolError::PreconditionFailed(
211 "If-Match header required for versioned model".to_owned(),
212 ));
213 }
214 let emits_event = self.descriptor.emits(ModelEventKind::Updated);
215 let audit_enabled = self.descriptor.audit_enabled;
216 let needs_tx = emits_event || audit_enabled;
217 let record = if needs_tx {
218 let mut tx = self
219 .runtime
220 .pool()
221 .begin()
222 .await
223 .map_err(|error| CoolError::Database(error.to_string()))?;
224 if emits_event {
225 ensure_event_outbox_table(&mut *tx).await?;
226 }
227 if audit_enabled {
228 ensure_audit_table(self.runtime.pool()).await?;
229 }
230 let before_record = if audit_enabled {
233 fetch_for_audit(&mut *tx, self.descriptor, self.id.clone()).await?
234 } else {
235 None
236 };
237 let before_snapshot = before_record
238 .as_ref()
239 .and_then(|m| serde_json::to_value(m).ok());
240 let record = update_record_with_executor(
241 &mut *tx,
242 self.runtime.pool(),
243 self.descriptor,
244 self.id,
245 self.input,
246 ctx,
247 self.if_match,
248 )
249 .await?;
250 if emits_event {
251 enqueue_event_outbox(
252 &mut *tx,
253 self.descriptor.schema_name,
254 ModelEventKind::Updated,
255 &record,
256 )
257 .await?;
258 }
259 if audit_enabled {
260 let after = serde_json::to_value(&record).ok();
261 let event = build_audit_event(
262 self.descriptor,
263 AuditOperation::Update,
264 before_snapshot,
265 after,
266 ctx,
267 );
268 enqueue_audit_event(&mut *tx, &event).await?;
269 }
270 tx.commit()
271 .await
272 .map_err(|error| CoolError::Database(error.to_string()))?;
273 record
274 } else {
275 update_record_with_executor(
276 self.runtime.pool(),
277 self.runtime.pool(),
278 self.descriptor,
279 self.id,
280 self.input,
281 ctx,
282 self.if_match,
283 )
284 .await?
285 };
286
287 if emits_event {
288 let _ = self.runtime.drain_event_outbox().await;
289 }
290
291 Ok(record)
292 }
293}
294
295#[derive(Debug, Clone)]
296pub struct DeleteRecord<'a, M: 'static, PK: 'static> {
297 pub(crate) runtime: &'a SqlxRuntime,
298 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
299 pub(crate) id: PK,
300}
301
302impl<'a, M: 'static, PK: 'static> DeleteRecord<'a, M, PK> {
303 pub fn preview_sql(&self) -> String {
304 format!(
305 "DELETE FROM {} WHERE {} = $1 RETURNING {}",
306 self.descriptor.table_name,
307 self.descriptor.primary_key,
308 self.descriptor.select_projection(),
309 )
310 }
311
312 pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
313 where
314 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
315 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
316 {
317 let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
318 let audit_enabled = self.descriptor.audit_enabled;
319 let needs_tx = emits_event || audit_enabled;
320 let record = if needs_tx {
321 let mut tx = self
322 .runtime
323 .pool()
324 .begin()
325 .await
326 .map_err(|error| CoolError::Database(error.to_string()))?;
327 if emits_event {
328 ensure_event_outbox_table(&mut *tx).await?;
329 }
330 if audit_enabled {
331 ensure_audit_table(self.runtime.pool()).await?;
332 }
333
334 let record = delete_returning_record(&mut *tx, self.descriptor, self.id, ctx).await?;
335 if emits_event {
336 enqueue_event_outbox(
337 &mut *tx,
338 self.descriptor.schema_name,
339 ModelEventKind::Deleted,
340 &record,
341 )
342 .await?;
343 }
344 if audit_enabled {
345 let before = serde_json::to_value(&record).ok();
348 let event =
349 build_audit_event(self.descriptor, AuditOperation::Delete, before, None, ctx);
350 enqueue_audit_event(&mut *tx, &event).await?;
351 }
352 tx.commit()
353 .await
354 .map_err(|error| CoolError::Database(error.to_string()))?;
355 record
356 } else {
357 delete_returning_record(self.runtime.pool(), self.descriptor, self.id, ctx).await?
358 };
359
360 if emits_event {
361 let _ = self.runtime.drain_event_outbox().await;
362 }
363
364 Ok(record)
365 }
366}
367
368pub async fn create_record_with_executor<'e, E, M, PK, I>(
369 executor: E,
370 policy_pool: &sqlx::PgPool,
371 descriptor: &'static ModelDescriptor<M, PK>,
372 input: I,
373 ctx: &CoolContext,
374) -> Result<M, CoolError>
375where
376 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
377 I: CreateModelInput<M>,
378 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
379{
380 input.validate()?;
381 let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
382 if let Some(version_col) = descriptor.version_column
390 && find_column_value(&values, version_col).is_none()
391 {
392 values.push(crate::SqlColumnValue {
393 column: version_col,
394 value: crate::SqlValue::Int(0),
395 });
396 }
397 if values.is_empty() {
398 return Err(CoolError::Validation(
399 "create input must contain at least one column".to_owned(),
400 ));
401 }
402 if !evaluate_create_policies(
403 policy_pool,
404 descriptor.create_allow_policies,
405 descriptor.create_deny_policies,
406 &values,
407 ctx,
408 )
409 .await?
410 {
411 return Err(CoolError::Forbidden(
412 "create policy denied this operation".to_owned(),
413 ));
414 }
415
416 insert_returning_record(executor, descriptor, &values).await
417}
418
419pub async fn update_record_with_executor<'e, E, M, PK, I>(
420 executor: E,
421 policy_pool: &sqlx::PgPool,
422 descriptor: &'static ModelDescriptor<M, PK>,
423 id: PK,
424 input: I,
425 ctx: &CoolContext,
426 if_match: Option<i64>,
427) -> Result<M, CoolError>
428where
429 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
430 I: UpdateModelInput<M>,
431 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
432 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
433{
434 input.validate()?;
435 let values = input.sql_values();
436 if values.is_empty() {
437 return Err(CoolError::Validation(
438 "update input must contain at least one changed column".to_owned(),
439 ));
440 }
441
442 update_returning_record(
443 executor,
444 policy_pool,
445 descriptor,
446 id,
447 &values,
448 ctx,
449 if_match,
450 )
451 .await
452}
453
454async fn insert_returning_record<'e, E, M, PK>(
455 executor: E,
456 descriptor: &'static ModelDescriptor<M, PK>,
457 values: &[crate::SqlColumnValue],
458) -> Result<M, CoolError>
459where
460 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
461 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
462{
463 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
464 query.push(descriptor.table_name).push(" (");
465 for (index, value) in values.iter().enumerate() {
466 if index > 0 {
467 query.push(", ");
468 }
469 query.push(value.column);
470 }
471 query.push(") VALUES (");
472 for (index, value) in values.iter().enumerate() {
473 if index > 0 {
474 query.push(", ");
475 }
476 push_bind_value(&mut query, &value.value);
477 }
478 query
479 .push(") RETURNING ")
480 .push(descriptor.select_projection());
481
482 query
483 .build_query_as::<M>()
484 .fetch_one(executor)
485 .await
486 .map_err(|error| CoolError::Database(error.to_string()))
487}
488
489async fn update_returning_record<'e, E, M, PK>(
490 executor: E,
491 policy_pool: &sqlx::PgPool,
492 descriptor: &'static ModelDescriptor<M, PK>,
493 id: PK,
494 values: &[crate::SqlColumnValue],
495 ctx: &CoolContext,
496 if_match: Option<i64>,
497) -> Result<M, CoolError>
498where
499 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
500 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
501 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
502{
503 let version_column = descriptor.version_column;
504 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
505 query.push(descriptor.table_name).push(" SET ");
506 for (index, value) in values.iter().enumerate() {
507 if index > 0 {
508 query.push(", ");
509 }
510 query.push(value.column).push(" = ");
511 push_bind_value(&mut query, &value.value);
512 }
513 if let Some(version_col) = version_column {
514 query
515 .push(", ")
516 .push(version_col)
517 .push(" = ")
518 .push(version_col)
519 .push(" + 1");
520 }
521 query
522 .push(" WHERE ")
523 .push(descriptor.primary_key)
524 .push(" = ");
525 let id_for_probe = id.clone();
526 query.push_bind(id);
527 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
528 query.push(" AND ").push(version_col).push(" = ");
529 query.push_bind(expected);
530 }
531 query.push(" AND ");
532 push_action_policy_query(
533 &mut query,
534 descriptor.update_allow_policies,
535 descriptor.update_deny_policies,
536 ctx,
537 );
538 query
539 .push(" RETURNING ")
540 .push(descriptor.select_projection());
541
542 let outcome = query
543 .build_query_as::<M>()
544 .fetch_optional(executor)
545 .await
546 .map_err(|error| CoolError::Database(error.to_string()))?;
547 match outcome {
548 Some(record) => Ok(record),
549 None => {
550 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
556 if let Some(current) =
557 probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
558 .await?
559 {
560 if current != expected {
561 return Err(CoolError::PreconditionFailed(format!(
562 "version mismatch: expected {expected}, found {current}",
563 )));
564 }
565 }
566 }
567 Err(CoolError::Forbidden(
568 "update policy denied this operation".to_owned(),
569 ))
570 }
571 }
572}
573
574async fn probe_current_version<M, PK>(
579 policy_pool: &sqlx::PgPool,
580 descriptor: &'static ModelDescriptor<M, PK>,
581 id: PK,
582 version_col: &'static str,
583 ctx: &CoolContext,
584) -> Result<Option<i64>, CoolError>
585where
586 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
587{
588 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
589 query.push(version_col);
590 query.push(" FROM ").push(descriptor.table_name);
591 query
592 .push(" WHERE ")
593 .push(descriptor.primary_key)
594 .push(" = ");
595 query.push_bind(id);
596 query.push(" AND ");
597 push_action_policy_query(
598 &mut query,
599 descriptor.read_allow_policies,
600 descriptor.read_deny_policies,
601 ctx,
602 );
603
604 let row: Option<(i64,)> = query
605 .build_query_as::<(i64,)>()
606 .fetch_optional(policy_pool)
607 .await
608 .map_err(|error| CoolError::Database(error.to_string()))?;
609 Ok(row.map(|(v,)| v))
610}
611
612async fn delete_returning_record<'e, E, M, PK>(
613 executor: E,
614 descriptor: &'static ModelDescriptor<M, PK>,
615 id: PK,
616 ctx: &CoolContext,
617) -> Result<M, CoolError>
618where
619 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
620 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
621 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
622{
623 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
624 match descriptor.soft_delete_column {
625 Some(col) => {
626 query.push("UPDATE ").push(descriptor.table_name);
629 query.push(" SET ").push(col).push(" = NOW()");
630 if let Some(version_col) = descriptor.version_column {
631 query
632 .push(", ")
633 .push(version_col)
634 .push(" = ")
635 .push(version_col)
636 .push(" + 1");
637 }
638 query.push(" WHERE ");
639 query.push(col).push(" IS NULL AND ");
640 query.push(descriptor.primary_key).push(" = ");
641 query.push_bind(id);
642 }
643 None => {
644 query.push("DELETE FROM ").push(descriptor.table_name);
645 query.push(" WHERE ");
646 query.push(descriptor.primary_key).push(" = ");
647 query.push_bind(id);
648 }
649 }
650 query.push(" AND ");
651 push_action_policy_query(
652 &mut query,
653 descriptor.delete_allow_policies,
654 descriptor.delete_deny_policies,
655 ctx,
656 );
657 query
658 .push(" RETURNING ")
659 .push(descriptor.select_projection());
660
661 query
662 .build_query_as::<M>()
663 .fetch_optional(executor)
664 .await
665 .map_err(|error| CoolError::Database(error.to_string()))?
666 .ok_or_else(|| CoolError::Forbidden("delete policy denied this operation".to_owned()))
667}