Skip to main content

cratestack_sqlx/query/write/
update.rs

1//! Single-row UPDATE with optional version locking, policy, audit + events.
2
3use cratestack_core::{AuditOperation, CoolContext, CoolError, ModelEventKind};
4
5use crate::audit::{build_audit_event, enqueue_audit_event, ensure_audit_table, fetch_for_audit};
6use crate::descriptor::{enqueue_event_outbox, ensure_event_outbox_table};
7use crate::{ModelDescriptor, SqlxRuntime, UpdateModelInput, sqlx};
8
9use super::preview::render_update_preview_sql;
10use super::update_exec::update_record_with_executor;
11
12#[derive(Debug, Clone)]
13pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
14    pub(crate) runtime: &'a SqlxRuntime,
15    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
16    pub(crate) id: PK,
17}
18
19impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
20    pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
21        UpdateRecordSet {
22            runtime: self.runtime,
23            descriptor: self.descriptor,
24            id: self.id,
25            input,
26            if_match: None,
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
33    pub(crate) runtime: &'a SqlxRuntime,
34    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
35    pub(crate) id: PK,
36    pub(crate) input: I,
37    pub(crate) if_match: Option<i64>,
38}
39
40impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
41where
42    I: UpdateModelInput<M>,
43{
44    /// Expected version for optimistic locking. Required on models
45    /// that declare `@version`; ignored otherwise.
46    pub fn if_match(mut self, expected: i64) -> Self {
47        self.if_match = Some(expected);
48        self
49    }
50
51    pub fn preview_sql(&self) -> String {
52        let values = self.input.sql_values();
53        let columns: Vec<&str> = values.iter().map(|v| v.column).collect();
54        render_update_preview_sql(
55            self.descriptor.table_name,
56            self.descriptor.primary_key,
57            self.descriptor.version_column,
58            &columns,
59            &self.descriptor.select_projection(),
60        )
61    }
62
63    pub async fn run_in_tx<'tx>(
64        self,
65        tx: &mut sqlx::Transaction<'tx, sqlx::Postgres>,
66        ctx: &CoolContext,
67    ) -> Result<M, CoolError>
68    where
69        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
70        PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
71    {
72        if self.descriptor.version_column.is_some() && self.if_match.is_none() {
73            return Err(CoolError::PreconditionFailed(
74                "If-Match header required for versioned model".to_owned(),
75            ));
76        }
77        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
78        let audit_enabled = self.descriptor.audit_enabled;
79        if emits_event {
80            ensure_event_outbox_table(&mut **tx).await?;
81        }
82        if audit_enabled {
83            ensure_audit_table(self.runtime.pool()).await?;
84        }
85        let before_record = if audit_enabled {
86            fetch_for_audit(&mut **tx, self.descriptor, self.id.clone()).await?
87        } else {
88            None
89        };
90        let before_snapshot = before_record
91            .as_ref()
92            .and_then(|m| serde_json::to_value(m).ok());
93        let record = update_record_with_executor(
94            &mut **tx,
95            self.runtime.pool(),
96            self.descriptor,
97            self.id,
98            self.input,
99            ctx,
100            self.if_match,
101        )
102        .await?;
103        if emits_event {
104            enqueue_event_outbox(
105                &mut **tx,
106                self.descriptor.schema_name,
107                ModelEventKind::Updated,
108                &record,
109            )
110            .await?;
111        }
112        if audit_enabled {
113            let after = serde_json::to_value(&record).ok();
114            let event = build_audit_event(
115                self.descriptor,
116                AuditOperation::Update,
117                before_snapshot,
118                after,
119                ctx,
120            );
121            enqueue_audit_event(&mut **tx, &event).await?;
122        }
123        Ok(record)
124    }
125
126    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
127    where
128        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
129        PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
130    {
131        super::update_run::run_update(
132            self.runtime,
133            self.descriptor,
134            self.id,
135            self.input,
136            self.if_match,
137            ctx,
138        )
139        .await
140    }
141}