Skip to main content

cratestack_sqlx/query/write/
update_exec.rs

1//! Generic-over-Executor update helpers used by single-row UPDATE
2//! paths. Builds `UPDATE ... SET ... WHERE pk = $X [AND version = $Y]
3//! AND policy(...) RETURNING ...`, with version-mismatch detection via
4//! a read-policy probe.
5
6use cratestack_core::{CoolContext, CoolError};
7
8use crate::query::support::{push_action_policy_query, push_bind_value};
9use crate::{ModelDescriptor, UpdateModelInput, sqlx};
10
11pub async fn update_record_with_executor<'e, E, M, PK, I>(
12    executor: E,
13    policy_pool: &sqlx::PgPool,
14    descriptor: &'static ModelDescriptor<M, PK>,
15    id: PK,
16    input: I,
17    ctx: &CoolContext,
18    if_match: Option<i64>,
19) -> Result<M, CoolError>
20where
21    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
22    I: UpdateModelInput<M>,
23    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
24    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
25{
26    input.validate()?;
27    let values = input.sql_values();
28    if values.is_empty() {
29        return Err(CoolError::Validation(
30            "update input must contain at least one changed column".to_owned(),
31        ));
32    }
33
34    update_returning_record(
35        executor,
36        policy_pool,
37        descriptor,
38        id,
39        &values,
40        ctx,
41        if_match,
42    )
43    .await
44}
45
46#[allow(clippy::too_many_arguments)]
47async fn update_returning_record<'e, E, M, PK>(
48    executor: E,
49    policy_pool: &sqlx::PgPool,
50    descriptor: &'static ModelDescriptor<M, PK>,
51    id: PK,
52    values: &[crate::SqlColumnValue],
53    ctx: &CoolContext,
54    if_match: Option<i64>,
55) -> Result<M, CoolError>
56where
57    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
58    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
59    PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
60{
61    let version_column = descriptor.version_column;
62    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
63    query.push(descriptor.table_name).push(" SET ");
64    for (index, value) in values.iter().enumerate() {
65        if index > 0 {
66            query.push(", ");
67        }
68        query.push(value.column).push(" = ");
69        push_bind_value(&mut query, &value.value);
70    }
71    if let Some(version_col) = version_column {
72        query
73            .push(", ")
74            .push(version_col)
75            .push(" = ")
76            .push(version_col)
77            .push(" + 1");
78    }
79    query
80        .push(" WHERE ")
81        .push(descriptor.primary_key)
82        .push(" = ");
83    let id_for_probe = id.clone();
84    query.push_bind(id);
85    if let (Some(version_col), Some(expected)) = (version_column, if_match) {
86        query.push(" AND ").push(version_col).push(" = ");
87        query.push_bind(expected);
88    }
89    query.push(" AND ");
90    push_action_policy_query(
91        &mut query,
92        descriptor.update_allow_policies,
93        descriptor.update_deny_policies,
94        ctx,
95    );
96    query
97        .push(" RETURNING ")
98        .push(descriptor.select_projection());
99
100    let outcome = query
101        .build_query_as::<M>()
102        .fetch_optional(executor)
103        .await
104        .map_err(|error| CoolError::Database(error.to_string()))?;
105    match outcome {
106        Some(record) => Ok(record),
107        None => {
108            // If this is a versioned update, distinguish "stale
109            // version" from a true policy denial via the read-policy
110            // probe. If the caller can't see the row, we keep
111            // returning Forbidden so policy denials remain
112            // indistinguishable from missing rows.
113            if let (Some(version_col), Some(expected)) = (version_column, if_match)
114                && let Some(current) =
115                    probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
116                        .await?
117                && current != expected
118            {
119                return Err(CoolError::PreconditionFailed(format!(
120                    "version mismatch: expected {expected}, found {current}",
121                )));
122            }
123            Err(CoolError::Forbidden(
124                "update policy denied this operation".to_owned(),
125            ))
126        }
127    }
128}
129
130/// Read the current version of a row using the read policy. Returns
131/// `None` if the caller cannot see the row (so the outer code
132/// preserves the existing Forbidden-on-no-row semantics).
133async fn probe_current_version<M, PK>(
134    policy_pool: &sqlx::PgPool,
135    descriptor: &'static ModelDescriptor<M, PK>,
136    id: PK,
137    version_col: &'static str,
138    ctx: &CoolContext,
139) -> Result<Option<i64>, CoolError>
140where
141    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
142{
143    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
144    query.push(version_col);
145    query.push(" FROM ").push(descriptor.table_name);
146    query
147        .push(" WHERE ")
148        .push(descriptor.primary_key)
149        .push(" = ");
150    query.push_bind(id);
151    query.push(" AND ");
152    push_action_policy_query(
153        &mut query,
154        descriptor.read_allow_policies,
155        descriptor.read_deny_policies,
156        ctx,
157    );
158
159    let row: Option<(i64,)> = query
160        .build_query_as::<(i64,)>()
161        .fetch_optional(policy_pool)
162        .await
163        .map_err(|error| CoolError::Database(error.to_string()))?;
164    Ok(row.map(|(v,)| v))
165}