Skip to main content

cratestack_sqlx/query/
write.rs

1use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
2
3use crate::{
4    CreateModelInput, ModelDescriptor, SqlxRuntime, UpdateModelInput,
5    audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit},
6    descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
7};
8
9use super::support::{
10    apply_create_defaults, evaluate_create_policies, find_column_value, push_action_policy_query,
11    push_bind_value,
12};
13
14/// Render the SQL string for an update. Pure helper, no I/O — separated
15/// so the version-aware branch can be unit-tested without a runtime.
16pub fn render_update_preview_sql(
17    table_name: &str,
18    primary_key: &str,
19    version_column: Option<&str>,
20    columns: &[&str],
21    select_projection: &str,
22) -> String {
23    let assignments = columns
24        .iter()
25        .enumerate()
26        .map(|(index, column)| format!("{column} = ${}", index + 1))
27        .collect::<Vec<_>>()
28        .join(", ");
29
30    match version_column {
31        Some(version_col) => format!(
32            "UPDATE {} SET {}, {} = {} + 1 WHERE {} = ${} AND {} = ${} RETURNING {}",
33            table_name,
34            assignments,
35            version_col,
36            version_col,
37            primary_key,
38            columns.len() + 1,
39            version_col,
40            columns.len() + 2,
41            select_projection,
42        ),
43        None => format!(
44            "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
45            table_name,
46            assignments,
47            primary_key,
48            columns.len() + 1,
49            select_projection,
50        ),
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct CreateRecord<'a, M: 'static, PK: 'static, I> {
56    pub(crate) runtime: &'a SqlxRuntime,
57    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
58    pub(crate) input: I,
59}
60
61impl<'a, M: 'static, PK: 'static, I> CreateRecord<'a, M, PK, I>
62where
63    I: CreateModelInput<M>,
64{
65    pub fn preview_sql(&self) -> String {
66        let values = self.input.sql_values();
67        let placeholders = (1..=values.len())
68            .map(|index| format!("${index}"))
69            .collect::<Vec<_>>()
70            .join(", ");
71        let columns = values
72            .iter()
73            .map(|value| value.column)
74            .collect::<Vec<_>>()
75            .join(", ");
76
77        format!(
78            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
79            self.descriptor.table_name,
80            columns,
81            placeholders,
82            self.descriptor.select_projection(),
83        )
84    }
85
86    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
87    where
88        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
89    {
90        let emits_event = self.descriptor.emits(ModelEventKind::Created);
91        let audit_enabled = self.descriptor.audit_enabled;
92        let needs_tx = emits_event || audit_enabled;
93        let record = if needs_tx {
94            let mut tx = self
95                .runtime
96                .pool()
97                .begin()
98                .await
99                .map_err(|error| CoolError::Database(error.to_string()))?;
100            if emits_event {
101                ensure_event_outbox_table(&mut *tx).await?;
102            }
103            if audit_enabled {
104                ensure_audit_table(self.runtime.pool()).await?;
105            }
106            let record = create_record_with_executor(
107                &mut *tx,
108                self.runtime.pool(),
109                self.descriptor,
110                self.input,
111                ctx,
112            )
113            .await?;
114            if emits_event {
115                enqueue_event_outbox(
116                    &mut *tx,
117                    self.descriptor.schema_name,
118                    ModelEventKind::Created,
119                    &record,
120                )
121                .await?;
122            }
123            if audit_enabled {
124                let after = serde_json::to_value(&record).ok();
125                let event =
126                    build_audit_event(self.descriptor, AuditOperation::Create, None, after, ctx);
127                enqueue_audit_event(&mut *tx, &event).await?;
128            }
129            tx.commit()
130                .await
131                .map_err(|error| CoolError::Database(error.to_string()))?;
132            record
133        } else {
134            create_record_with_executor(
135                self.runtime.pool(),
136                self.runtime.pool(),
137                self.descriptor,
138                self.input,
139                ctx,
140            )
141            .await?
142        };
143
144        if emits_event {
145            let _ = self.runtime.drain_event_outbox().await;
146        }
147
148        Ok(record)
149    }
150}
151
152#[derive(Debug, Clone)]
153pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
154    pub(crate) runtime: &'a SqlxRuntime,
155    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
156    pub(crate) id: PK,
157}
158
159impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
160    pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
161        UpdateRecordSet {
162            runtime: self.runtime,
163            descriptor: self.descriptor,
164            id: self.id,
165            input,
166            if_match: None,
167        }
168    }
169}
170
171#[derive(Debug, Clone)]
172pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
173    pub(crate) runtime: &'a SqlxRuntime,
174    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
175    pub(crate) id: PK,
176    pub(crate) input: I,
177    pub(crate) if_match: Option<i64>,
178}
179
180impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
181where
182    I: UpdateModelInput<M>,
183{
184    /// Attach an expected version for optimistic locking. The update will only
185    /// succeed if the row's current `@version` field matches `expected`.
186    /// Required on models that declare `@version`; ignored otherwise.
187    pub fn if_match(mut self, expected: i64) -> Self {
188        self.if_match = Some(expected);
189        self
190    }
191
192    pub fn preview_sql(&self) -> String {
193        let values = self.input.sql_values();
194        let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
195        render_update_preview_sql(
196            self.descriptor.table_name,
197            self.descriptor.primary_key,
198            self.descriptor.version_column,
199            &columns,
200            &self.descriptor.select_projection(),
201        )
202    }
203
204    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
205    where
206        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
207        PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
208    {
209        if self.descriptor.version_column.is_some() && self.if_match.is_none() {
210            return Err(CoolError::PreconditionFailed(
211                "If-Match header required for versioned model".to_owned(),
212            ));
213        }
214        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
215        let audit_enabled = self.descriptor.audit_enabled;
216        let needs_tx = emits_event || audit_enabled;
217        let record = if needs_tx {
218            let mut tx = self
219                .runtime
220                .pool()
221                .begin()
222                .await
223                .map_err(|error| CoolError::Database(error.to_string()))?;
224            if emits_event {
225                ensure_event_outbox_table(&mut *tx).await?;
226            }
227            if audit_enabled {
228                ensure_audit_table(self.runtime.pool()).await?;
229            }
230            // Capture the BEFORE snapshot under a row-level lock so concurrent
231            // mutations can't race the audit.
232            let before_record = if audit_enabled {
233                fetch_for_audit(&mut *tx, self.descriptor, self.id.clone()).await?
234            } else {
235                None
236            };
237            let before_snapshot = before_record
238                .as_ref()
239                .and_then(|m| serde_json::to_value(m).ok());
240            let record = update_record_with_executor(
241                &mut *tx,
242                self.runtime.pool(),
243                self.descriptor,
244                self.id,
245                self.input,
246                ctx,
247                self.if_match,
248            )
249            .await?;
250            if emits_event {
251                enqueue_event_outbox(
252                    &mut *tx,
253                    self.descriptor.schema_name,
254                    ModelEventKind::Updated,
255                    &record,
256                )
257                .await?;
258            }
259            if audit_enabled {
260                let after = serde_json::to_value(&record).ok();
261                let event = build_audit_event(
262                    self.descriptor,
263                    AuditOperation::Update,
264                    before_snapshot,
265                    after,
266                    ctx,
267                );
268                enqueue_audit_event(&mut *tx, &event).await?;
269            }
270            tx.commit()
271                .await
272                .map_err(|error| CoolError::Database(error.to_string()))?;
273            record
274        } else {
275            update_record_with_executor(
276                self.runtime.pool(),
277                self.runtime.pool(),
278                self.descriptor,
279                self.id,
280                self.input,
281                ctx,
282                self.if_match,
283            )
284            .await?
285        };
286
287        if emits_event {
288            let _ = self.runtime.drain_event_outbox().await;
289        }
290
291        Ok(record)
292    }
293}
294
295#[derive(Debug, Clone)]
296pub struct DeleteRecord<'a, M: 'static, PK: 'static> {
297    pub(crate) runtime: &'a SqlxRuntime,
298    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
299    pub(crate) id: PK,
300}
301
302impl<'a, M: 'static, PK: 'static> DeleteRecord<'a, M, PK> {
303    pub fn preview_sql(&self) -> String {
304        format!(
305            "DELETE FROM {} WHERE {} = $1 RETURNING {}",
306            self.descriptor.table_name,
307            self.descriptor.primary_key,
308            self.descriptor.select_projection(),
309        )
310    }
311
312    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
313    where
314        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
315        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
316    {
317        let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
318        let audit_enabled = self.descriptor.audit_enabled;
319        let needs_tx = emits_event || audit_enabled;
320        let record = if needs_tx {
321            let mut tx = self
322                .runtime
323                .pool()
324                .begin()
325                .await
326                .map_err(|error| CoolError::Database(error.to_string()))?;
327            if emits_event {
328                ensure_event_outbox_table(&mut *tx).await?;
329            }
330            if audit_enabled {
331                ensure_audit_table(self.runtime.pool()).await?;
332            }
333
334            let record = delete_returning_record(&mut *tx, self.descriptor, self.id, ctx).await?;
335            if emits_event {
336                enqueue_event_outbox(
337                    &mut *tx,
338                    self.descriptor.schema_name,
339                    ModelEventKind::Deleted,
340                    &record,
341                )
342                .await?;
343            }
344            if audit_enabled {
345                // DELETE ... RETURNING yields the row's pre-delete state, so
346                // it doubles as the audit `before` snapshot.
347                let before = serde_json::to_value(&record).ok();
348                let event =
349                    build_audit_event(self.descriptor, AuditOperation::Delete, before, None, ctx);
350                enqueue_audit_event(&mut *tx, &event).await?;
351            }
352            tx.commit()
353                .await
354                .map_err(|error| CoolError::Database(error.to_string()))?;
355            record
356        } else {
357            delete_returning_record(self.runtime.pool(), self.descriptor, self.id, ctx).await?
358        };
359
360        if emits_event {
361            let _ = self.runtime.drain_event_outbox().await;
362        }
363
364        Ok(record)
365    }
366}
367
368pub async fn create_record_with_executor<'e, E, M, PK, I>(
369    executor: E,
370    policy_pool: &sqlx::PgPool,
371    descriptor: &'static ModelDescriptor<M, PK>,
372    input: I,
373    ctx: &CoolContext,
374) -> Result<M, CoolError>
375where
376    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
377    I: CreateModelInput<M>,
378    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
379{
380    input.validate()?;
381    let mut values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
382    // Seed the optimistic-lock column server-side. `@version` is excluded
383    // from the generated Create input so clients can't pick the initial
384    // value, and the column has no SQL `DEFAULT`. If we didn't write it
385    // here, the INSERT would either skip the column (only valid when the
386    // DB-level default is set, which we don't require) or fail under
387    // `NOT NULL`. Done after `apply_create_defaults` so @default-driven
388    // overrides still win if a schema ever lands one.
389    if let Some(version_col) = descriptor.version_column
390        && find_column_value(&values, version_col).is_none()
391    {
392        values.push(crate::SqlColumnValue {
393            column: version_col,
394            value: crate::SqlValue::Int(0),
395        });
396    }
397    if values.is_empty() {
398        return Err(CoolError::Validation(
399            "create input must contain at least one column".to_owned(),
400        ));
401    }
402    if !evaluate_create_policies(
403        policy_pool,
404        descriptor.create_allow_policies,
405        descriptor.create_deny_policies,
406        &values,
407        ctx,
408    )
409    .await?
410    {
411        return Err(CoolError::Forbidden(
412            "create policy denied this operation".to_owned(),
413        ));
414    }
415
416    insert_returning_record(executor, descriptor, &values).await
417}
418
419pub async fn update_record_with_executor<'e, E, M, PK, I>(
420    executor: E,
421    policy_pool: &sqlx::PgPool,
422    descriptor: &'static ModelDescriptor<M, PK>,
423    id: PK,
424    input: I,
425    ctx: &CoolContext,
426    if_match: Option<i64>,
427) -> Result<M, CoolError>
428where
429    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
430    I: UpdateModelInput<M>,
431    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
432    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
433{
434    input.validate()?;
435    let values = input.sql_values();
436    if values.is_empty() {
437        return Err(CoolError::Validation(
438            "update input must contain at least one changed column".to_owned(),
439        ));
440    }
441
442    update_returning_record(
443        executor,
444        policy_pool,
445        descriptor,
446        id,
447        &values,
448        ctx,
449        if_match,
450    )
451    .await
452}
453
454async fn insert_returning_record<'e, E, M, PK>(
455    executor: E,
456    descriptor: &'static ModelDescriptor<M, PK>,
457    values: &[crate::SqlColumnValue],
458) -> Result<M, CoolError>
459where
460    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
461    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
462{
463    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
464    query.push(descriptor.table_name).push(" (");
465    for (index, value) in values.iter().enumerate() {
466        if index > 0 {
467            query.push(", ");
468        }
469        query.push(value.column);
470    }
471    query.push(") VALUES (");
472    for (index, value) in values.iter().enumerate() {
473        if index > 0 {
474            query.push(", ");
475        }
476        push_bind_value(&mut query, &value.value);
477    }
478    query
479        .push(") RETURNING ")
480        .push(descriptor.select_projection());
481
482    query
483        .build_query_as::<M>()
484        .fetch_one(executor)
485        .await
486        .map_err(|error| CoolError::Database(error.to_string()))
487}
488
489async fn update_returning_record<'e, E, M, PK>(
490    executor: E,
491    policy_pool: &sqlx::PgPool,
492    descriptor: &'static ModelDescriptor<M, PK>,
493    id: PK,
494    values: &[crate::SqlColumnValue],
495    ctx: &CoolContext,
496    if_match: Option<i64>,
497) -> Result<M, CoolError>
498where
499    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
500    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
501    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
502{
503    let version_column = descriptor.version_column;
504    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
505    query.push(descriptor.table_name).push(" SET ");
506    for (index, value) in values.iter().enumerate() {
507        if index > 0 {
508            query.push(", ");
509        }
510        query.push(value.column).push(" = ");
511        push_bind_value(&mut query, &value.value);
512    }
513    if let Some(version_col) = version_column {
514        query
515            .push(", ")
516            .push(version_col)
517            .push(" = ")
518            .push(version_col)
519            .push(" + 1");
520    }
521    query
522        .push(" WHERE ")
523        .push(descriptor.primary_key)
524        .push(" = ");
525    let id_for_probe = id.clone();
526    query.push_bind(id);
527    if let (Some(version_col), Some(expected)) = (version_column, if_match) {
528        query.push(" AND ").push(version_col).push(" = ");
529        query.push_bind(expected);
530    }
531    query.push(" AND ");
532    push_action_policy_query(
533        &mut query,
534        descriptor.update_allow_policies,
535        descriptor.update_deny_policies,
536        ctx,
537    );
538    query
539        .push(" RETURNING ")
540        .push(descriptor.select_projection());
541
542    let outcome = query
543        .build_query_as::<M>()
544        .fetch_optional(executor)
545        .await
546        .map_err(|error| CoolError::Database(error.to_string()))?;
547    match outcome {
548        Some(record) => Ok(record),
549        None => {
550            // No row matched. If this is a versioned update we want to
551            // distinguish "stale version" from a true policy denial. The
552            // probe applies the read policy: if the caller cannot see the
553            // row, we keep returning Forbidden so policy denials remain
554            // indistinguishable from missing rows.
555            if let (Some(version_col), Some(expected)) = (version_column, if_match) {
556                if let Some(current) =
557                    probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
558                        .await?
559                {
560                    if current != expected {
561                        return Err(CoolError::PreconditionFailed(format!(
562                            "version mismatch: expected {expected}, found {current}",
563                        )));
564                    }
565                }
566            }
567            Err(CoolError::Forbidden(
568                "update policy denied this operation".to_owned(),
569            ))
570        }
571    }
572}
573
574/// Read the current version of a row using the read policy. Returns `None` if
575/// the caller cannot see the row (so the outer code preserves the existing
576/// Forbidden-on-no-row semantics — readers can't tell a denied row from a
577/// missing one).
578async fn probe_current_version<M, PK>(
579    policy_pool: &sqlx::PgPool,
580    descriptor: &'static ModelDescriptor<M, PK>,
581    id: PK,
582    version_col: &'static str,
583    ctx: &CoolContext,
584) -> Result<Option<i64>, CoolError>
585where
586    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
587{
588    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
589    query.push(version_col);
590    query.push(" FROM ").push(descriptor.table_name);
591    query
592        .push(" WHERE ")
593        .push(descriptor.primary_key)
594        .push(" = ");
595    query.push_bind(id);
596    query.push(" AND ");
597    push_action_policy_query(
598        &mut query,
599        descriptor.read_allow_policies,
600        descriptor.read_deny_policies,
601        ctx,
602    );
603
604    let row: Option<(i64,)> = query
605        .build_query_as::<(i64,)>()
606        .fetch_optional(policy_pool)
607        .await
608        .map_err(|error| CoolError::Database(error.to_string()))?;
609    Ok(row.map(|(v,)| v))
610}
611
612async fn delete_returning_record<'e, E, M, PK>(
613    executor: E,
614    descriptor: &'static ModelDescriptor<M, PK>,
615    id: PK,
616    ctx: &CoolContext,
617) -> Result<M, CoolError>
618where
619    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
620    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
621    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
622{
623    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("");
624    match descriptor.soft_delete_column {
625        Some(col) => {
626            // Soft-delete: tombstone the row and bump version (if any) so
627            // optimistic-lock semantics on subsequent updates stay coherent.
628            query.push("UPDATE ").push(descriptor.table_name);
629            query.push(" SET ").push(col).push(" = NOW()");
630            if let Some(version_col) = descriptor.version_column {
631                query
632                    .push(", ")
633                    .push(version_col)
634                    .push(" = ")
635                    .push(version_col)
636                    .push(" + 1");
637            }
638            query.push(" WHERE ");
639            query.push(col).push(" IS NULL AND ");
640            query.push(descriptor.primary_key).push(" = ");
641            query.push_bind(id);
642        }
643        None => {
644            query.push("DELETE FROM ").push(descriptor.table_name);
645            query.push(" WHERE ");
646            query.push(descriptor.primary_key).push(" = ");
647            query.push_bind(id);
648        }
649    }
650    query.push(" AND ");
651    push_action_policy_query(
652        &mut query,
653        descriptor.delete_allow_policies,
654        descriptor.delete_deny_policies,
655        ctx,
656    );
657    query
658        .push(" RETURNING ")
659        .push(descriptor.select_projection());
660
661    query
662        .build_query_as::<M>()
663        .fetch_optional(executor)
664        .await
665        .map_err(|error| CoolError::Database(error.to_string()))?
666        .ok_or_else(|| CoolError::Forbidden("delete policy denied this operation".to_owned()))
667}