cratestack_sqlx/query/write/
update_exec.rs1use cratestack_core::{CoolContext, CoolError};
7
8use crate::query::support::{push_action_policy_query, push_bind_value};
9use crate::{ModelDescriptor, UpdateModelInput, sqlx};
10
11pub async fn update_record_with_executor<'e, E, M, PK, I>(
12 executor: E,
13 policy_pool: &sqlx::PgPool,
14 descriptor: &'static ModelDescriptor<M, PK>,
15 id: PK,
16 input: I,
17 ctx: &CoolContext,
18 if_match: Option<i64>,
19) -> Result<M, CoolError>
20where
21 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
22 I: UpdateModelInput<M>,
23 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
24 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
25{
26 input.validate()?;
27 let values = input.sql_values();
28 if values.is_empty() {
29 return Err(CoolError::Validation(
30 "update input must contain at least one changed column".to_owned(),
31 ));
32 }
33
34 update_returning_record(executor, policy_pool, descriptor, id, &values, ctx, if_match).await
35}
36
37#[allow(clippy::too_many_arguments)]
38async fn update_returning_record<'e, E, M, PK>(
39 executor: E,
40 policy_pool: &sqlx::PgPool,
41 descriptor: &'static ModelDescriptor<M, PK>,
42 id: PK,
43 values: &[crate::SqlColumnValue],
44 ctx: &CoolContext,
45 if_match: Option<i64>,
46) -> Result<M, CoolError>
47where
48 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
49 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
50 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
51{
52 let version_column = descriptor.version_column;
53 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
54 query.push(descriptor.table_name).push(" SET ");
55 for (index, value) in values.iter().enumerate() {
56 if index > 0 {
57 query.push(", ");
58 }
59 query.push(value.column).push(" = ");
60 push_bind_value(&mut query, &value.value);
61 }
62 if let Some(version_col) = version_column {
63 query
64 .push(", ")
65 .push(version_col)
66 .push(" = ")
67 .push(version_col)
68 .push(" + 1");
69 }
70 query
71 .push(" WHERE ")
72 .push(descriptor.primary_key)
73 .push(" = ");
74 let id_for_probe = id.clone();
75 query.push_bind(id);
76 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
77 query.push(" AND ").push(version_col).push(" = ");
78 query.push_bind(expected);
79 }
80 query.push(" AND ");
81 push_action_policy_query(
82 &mut query,
83 descriptor.update_allow_policies,
84 descriptor.update_deny_policies,
85 ctx,
86 );
87 query
88 .push(" RETURNING ")
89 .push(descriptor.select_projection());
90
91 let outcome = query
92 .build_query_as::<M>()
93 .fetch_optional(executor)
94 .await
95 .map_err(|error| CoolError::Database(error.to_string()))?;
96 match outcome {
97 Some(record) => Ok(record),
98 None => {
99 if let (Some(version_col), Some(expected)) = (version_column, if_match)
105 && let Some(current) =
106 probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
107 .await?
108 && current != expected
109 {
110 return Err(CoolError::PreconditionFailed(format!(
111 "version mismatch: expected {expected}, found {current}",
112 )));
113 }
114 Err(CoolError::Forbidden(
115 "update policy denied this operation".to_owned(),
116 ))
117 }
118 }
119}
120
121async fn probe_current_version<M, PK>(
125 policy_pool: &sqlx::PgPool,
126 descriptor: &'static ModelDescriptor<M, PK>,
127 id: PK,
128 version_col: &'static str,
129 ctx: &CoolContext,
130) -> Result<Option<i64>, CoolError>
131where
132 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
133{
134 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
135 query.push(version_col);
136 query.push(" FROM ").push(descriptor.table_name);
137 query
138 .push(" WHERE ")
139 .push(descriptor.primary_key)
140 .push(" = ");
141 query.push_bind(id);
142 query.push(" AND ");
143 push_action_policy_query(
144 &mut query,
145 descriptor.read_allow_policies,
146 descriptor.read_deny_policies,
147 ctx,
148 );
149
150 let row: Option<(i64,)> = query
151 .build_query_as::<(i64,)>()
152 .fetch_optional(policy_pool)
153 .await
154 .map_err(|error| CoolError::Database(error.to_string()))?;
155 Ok(row.map(|(v,)| v))
156}