cratestack_sqlx/query/write/
update_exec.rs1use cratestack_core::{CoolContext, CoolError};
7
8use crate::query::support::{push_action_policy_query, push_bind_value};
9use crate::{ModelDescriptor, UpdateModelInput, sqlx};
10
11pub async fn update_record_with_executor<'e, E, M, PK, I>(
12 executor: E,
13 policy_pool: &sqlx::PgPool,
14 descriptor: &'static ModelDescriptor<M, PK>,
15 id: PK,
16 input: I,
17 ctx: &CoolContext,
18 if_match: Option<i64>,
19) -> Result<M, CoolError>
20where
21 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
22 I: UpdateModelInput<M>,
23 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
24 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
25{
26 input.validate()?;
27 let values = input.sql_values();
28 if values.is_empty() {
29 return Err(CoolError::Validation(
30 "update input must contain at least one changed column".to_owned(),
31 ));
32 }
33
34 update_returning_record(
35 executor,
36 policy_pool,
37 descriptor,
38 id,
39 &values,
40 ctx,
41 if_match,
42 )
43 .await
44}
45
46#[allow(clippy::too_many_arguments)]
47async fn update_returning_record<'e, E, M, PK>(
48 executor: E,
49 policy_pool: &sqlx::PgPool,
50 descriptor: &'static ModelDescriptor<M, PK>,
51 id: PK,
52 values: &[crate::SqlColumnValue],
53 ctx: &CoolContext,
54 if_match: Option<i64>,
55) -> Result<M, CoolError>
56where
57 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
58 for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
59 PK: Send + Clone + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
60{
61 let version_column = descriptor.version_column;
62 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
63 query.push(descriptor.table_name).push(" SET ");
64 for (index, value) in values.iter().enumerate() {
65 if index > 0 {
66 query.push(", ");
67 }
68 query.push(value.column).push(" = ");
69 push_bind_value(&mut query, &value.value);
70 }
71 if let Some(version_col) = version_column {
72 query
73 .push(", ")
74 .push(version_col)
75 .push(" = ")
76 .push(version_col)
77 .push(" + 1");
78 }
79 query
80 .push(" WHERE ")
81 .push(descriptor.primary_key)
82 .push(" = ");
83 let id_for_probe = id.clone();
84 query.push_bind(id);
85 if let (Some(version_col), Some(expected)) = (version_column, if_match) {
86 query.push(" AND ").push(version_col).push(" = ");
87 query.push_bind(expected);
88 }
89 query.push(" AND ");
90 push_action_policy_query(
91 &mut query,
92 descriptor.update_allow_policies,
93 descriptor.update_deny_policies,
94 ctx,
95 );
96 query
97 .push(" RETURNING ")
98 .push(descriptor.select_projection());
99
100 let outcome = query
101 .build_query_as::<M>()
102 .fetch_optional(executor)
103 .await
104 .map_err(|error| CoolError::Database(error.to_string()))?;
105 match outcome {
106 Some(record) => Ok(record),
107 None => {
108 if let (Some(version_col), Some(expected)) = (version_column, if_match)
114 && let Some(current) =
115 probe_current_version(policy_pool, descriptor, id_for_probe, version_col, ctx)
116 .await?
117 && current != expected
118 {
119 return Err(CoolError::PreconditionFailed(format!(
120 "version mismatch: expected {expected}, found {current}",
121 )));
122 }
123 Err(CoolError::Forbidden(
124 "update policy denied this operation".to_owned(),
125 ))
126 }
127 }
128}
129
130async fn probe_current_version<M, PK>(
134 policy_pool: &sqlx::PgPool,
135 descriptor: &'static ModelDescriptor<M, PK>,
136 id: PK,
137 version_col: &'static str,
138 ctx: &CoolContext,
139) -> Result<Option<i64>, CoolError>
140where
141 PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
142{
143 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
144 query.push(version_col);
145 query.push(" FROM ").push(descriptor.table_name);
146 query
147 .push(" WHERE ")
148 .push(descriptor.primary_key)
149 .push(" = ");
150 query.push_bind(id);
151 query.push(" AND ");
152 push_action_policy_query(
153 &mut query,
154 descriptor.read_allow_policies,
155 descriptor.read_deny_policies,
156 ctx,
157 );
158
159 let row: Option<(i64,)> = query
160 .build_query_as::<(i64,)>()
161 .fetch_optional(policy_pool)
162 .await
163 .map_err(|error| CoolError::Database(error.to_string()))?;
164 Ok(row.map(|(v,)| v))
165}