1use super::types::*;
4use crate::api_request::{
5 ApiRequest, Mutation, Payload, PreferResolution, QualifiedIdentifier,
6};
7use crate::error::{Error, Result};
8use crate::schema_cache::Table;
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
13pub enum MutatePlan {
14 Insert {
16 target: QualifiedIdentifier,
18 columns: Vec<CoercibleField>,
20 body: Option<bytes::Bytes>,
22 on_conflict: Option<(PreferResolution, Vec<String>)>,
24 where_clauses: Vec<CoercibleLogicTree>,
26 returning: Vec<String>,
28 pk_cols: Vec<String>,
30 apply_defaults: bool,
32 },
33 Update {
35 target: QualifiedIdentifier,
37 columns: Vec<CoercibleField>,
39 body: Option<bytes::Bytes>,
41 where_clauses: Vec<CoercibleLogicTree>,
43 returning: Vec<String>,
45 apply_defaults: bool,
47 },
48 Delete {
50 target: QualifiedIdentifier,
52 where_clauses: Vec<CoercibleLogicTree>,
54 returning: Vec<String>,
56 },
57}
58
59impl MutatePlan {
60 pub fn from_request(
62 request: &ApiRequest,
63 table: &Table,
64 mutation: &Mutation,
65 ) -> Result<Self> {
66 let qi = table.qualified_identifier();
67
68 match mutation {
69 Mutation::Create => Self::create_insert(request, table, qi),
70 Mutation::Update => Self::create_update(request, table, qi),
71 Mutation::Delete => Self::create_delete(request, table, qi),
72 Mutation::SingleUpsert => Self::create_upsert(request, table, qi),
73 }
74 }
75
76 fn create_insert(
78 request: &ApiRequest,
79 table: &Table,
80 qi: QualifiedIdentifier,
81 ) -> Result<Self> {
82 let columns = get_payload_columns(request, table)?;
83 let body = get_body_bytes(request)?;
84 let returning = get_returning_columns(request, table);
85 let apply_defaults = request.preferences.missing == crate::api_request::PreferMissing::ApplyDefaults;
86
87 let on_conflict = request.query_params.on_conflict.as_ref().map(|cols| {
88 let resolution = request
89 .preferences
90 .resolution
91 .clone()
92 .unwrap_or(PreferResolution::MergeDuplicates);
93 (resolution, cols.clone())
94 });
95
96 Ok(Self::Insert {
97 target: qi,
98 columns,
99 body,
100 on_conflict,
101 where_clauses: vec![],
102 returning,
103 pk_cols: table.pk_cols.clone(),
104 apply_defaults,
105 })
106 }
107
108 fn create_update(
110 request: &ApiRequest,
111 table: &Table,
112 qi: QualifiedIdentifier,
113 ) -> Result<Self> {
114 let columns = get_payload_columns(request, table)?;
115 let body = get_body_bytes(request)?;
116 let where_clauses = build_mutation_where(request, table)?;
117 let returning = get_returning_columns(request, table);
118 let apply_defaults = request.preferences.missing == crate::api_request::PreferMissing::ApplyDefaults;
119
120 Ok(Self::Update {
121 target: qi,
122 columns,
123 body,
124 where_clauses,
125 returning,
126 apply_defaults,
127 })
128 }
129
130 fn create_delete(
132 request: &ApiRequest,
133 table: &Table,
134 qi: QualifiedIdentifier,
135 ) -> Result<Self> {
136 let where_clauses = build_mutation_where(request, table)?;
137 let returning = get_returning_columns(request, table);
138
139 Ok(Self::Delete {
140 target: qi,
141 where_clauses,
142 returning,
143 })
144 }
145
146 fn create_upsert(
148 request: &ApiRequest,
149 table: &Table,
150 qi: QualifiedIdentifier,
151 ) -> Result<Self> {
152 let columns = get_payload_columns(request, table)?;
153 let body = get_body_bytes(request)?;
154 let returning = get_returning_columns(request, table);
155
156 let on_conflict = Some((
158 PreferResolution::MergeDuplicates,
159 table.pk_cols.clone(),
160 ));
161
162 Ok(Self::Insert {
163 target: qi,
164 columns,
165 body,
166 on_conflict,
167 where_clauses: vec![],
168 returning,
169 pk_cols: table.pk_cols.clone(),
170 apply_defaults: true,
171 })
172 }
173
174 pub fn target(&self) -> &QualifiedIdentifier {
176 match self {
177 Self::Insert { target, .. } => target,
178 Self::Update { target, .. } => target,
179 Self::Delete { target, .. } => target,
180 }
181 }
182
183 pub fn has_body(&self) -> bool {
185 match self {
186 Self::Insert { body, .. } => body.is_some(),
187 Self::Update { body, .. } => body.is_some(),
188 Self::Delete { .. } => false,
189 }
190 }
191}
192
193fn get_payload_columns(
195 request: &ApiRequest,
196 table: &Table,
197) -> Result<Vec<CoercibleField>> {
198 let keys = match &request.payload {
199 Some(Payload::ProcessedJson { keys, .. }) => keys,
200 Some(Payload::ProcessedUrlEncoded { keys, .. }) => keys,
201 _ => return Ok(vec![]),
202 };
203
204 let mut columns = Vec::new();
205
206 for key in keys {
207 let column = table
208 .get_column(key)
209 .ok_or_else(|| Error::UnknownColumn(key.clone()))?;
210
211 columns.push(CoercibleField::simple(key, &column.data_type));
212 }
213
214 Ok(columns)
215}
216
217fn get_body_bytes(request: &ApiRequest) -> Result<Option<bytes::Bytes>> {
219 match &request.payload {
220 Some(Payload::ProcessedJson { raw, .. }) => Ok(Some(raw.clone())),
221 Some(Payload::RawJson(raw)) => Ok(Some(raw.clone())),
222 Some(Payload::RawPayload(raw)) => Ok(Some(raw.clone())),
223 Some(Payload::ProcessedUrlEncoded { data, .. }) => {
224 let json = serde_json::to_vec(
226 &data.iter().cloned().collect::<std::collections::HashMap<_, _>>()
227 ).map_err(|e| Error::InvalidBody(e.to_string()))?;
228 Ok(Some(bytes::Bytes::from(json)))
229 }
230 None => Ok(None),
231 }
232}
233
234fn get_returning_columns(request: &ApiRequest, table: &Table) -> Vec<String> {
236 if request.preferences.representation.needs_body() {
237 table.column_names().map(|s| s.to_string()).collect()
238 } else {
239 table.pk_cols.clone()
241 }
242}
243
244fn build_mutation_where(
246 request: &ApiRequest,
247 table: &Table,
248) -> Result<Vec<CoercibleLogicTree>> {
249 let type_resolver = |name: &str| -> String {
250 table
251 .get_column(name)
252 .map(|c| c.data_type.clone())
253 .unwrap_or_else(|| "text".to_string())
254 };
255
256 let mut clauses = Vec::new();
257
258 for filter in &request.query_params.filters_root {
259 let pg_type = type_resolver(&filter.field.name);
260 clauses.push(CoercibleLogicTree::Stmt(CoercibleFilter::from_filter(
261 filter, &pg_type,
262 )));
263 }
264
265 Ok(clauses)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_mutate_plan_target() {
274 let qi = QualifiedIdentifier::new("public", "users");
275 let plan = MutatePlan::Delete {
276 target: qi.clone(),
277 where_clauses: vec![],
278 returning: vec!["id".into()],
279 };
280
281 assert_eq!(plan.target().name, "users");
282 }
283
284 #[test]
285 fn test_mutate_plan_has_body() {
286 let qi = QualifiedIdentifier::new("public", "users");
287
288 let insert = MutatePlan::Insert {
289 target: qi.clone(),
290 columns: vec![],
291 body: Some(bytes::Bytes::from("{}".as_bytes())),
292 on_conflict: None,
293 where_clauses: vec![],
294 returning: vec![],
295 pk_cols: vec![],
296 apply_defaults: true,
297 };
298 assert!(insert.has_body());
299
300 let delete = MutatePlan::Delete {
301 target: qi,
302 where_clauses: vec![],
303 returning: vec![],
304 };
305 assert!(!delete.has_body());
306 }
307}