postgrest_parser/parser/
mutation.rs1use crate::ast::{ConflictAction, DeleteParams, InsertParams, OnConflict, UpdateParams};
2use crate::error::{Error, ParseError};
3use crate::parser::{
4 parse_json_body, parse_order, parse_select, validate_insert_body, validate_update_body,
5};
6use std::collections::HashMap;
7
8pub fn parse_insert_params(query_string: &str, body: &str) -> Result<InsertParams, Error> {
25 let json_value = parse_json_body(body)?;
27 let values = validate_insert_body(json_value)?;
28
29 let mut params = InsertParams::new(values);
30
31 let query_params = parse_query_params(query_string);
33
34 if let Some(select_str) = query_params.get("select") {
36 let returning = parse_select(select_str)?;
37 params = params.with_returning(returning);
38 } else if let Some(returning_str) = query_params.get("returning") {
39 let returning = parse_select(returning_str)?;
41 params = params.with_returning(returning);
42 }
43
44 if let Some(columns_str) = query_params.get("columns") {
46 let columns: Vec<String> = columns_str
47 .split(',')
48 .map(|s| s.trim().to_string())
49 .collect();
50 if !columns.is_empty() && !columns.iter().any(|c| c.is_empty()) {
51 params = params.with_columns(columns);
52 }
53 }
54
55 if let Some(on_conflict_str) = query_params.get("on_conflict") {
57 let on_conflict = parse_on_conflict(on_conflict_str)?;
58 params = params.with_on_conflict(on_conflict);
59 }
60
61 Ok(params)
62}
63
64pub fn parse_update_params(query_string: &str, body: &str) -> Result<UpdateParams, Error> {
81 let json_value = parse_json_body(body)?;
83 let set_values = validate_update_body(json_value)?;
84
85 let mut params = UpdateParams::new(set_values);
86
87 let query_params = parse_query_params(query_string);
89
90 let filters = crate::parse_params_from_pairs(
92 query_params
93 .iter()
94 .filter(|(k, _)| !is_reserved_key(k))
95 .map(|(k, v)| (k.clone(), v.clone()))
96 .collect(),
97 )?;
98
99 params = params.with_filters(filters.filters);
100
101 if let Some(order_str) = query_params.get("order") {
103 let order = parse_order(order_str)?;
104 params = params.with_order(order);
105 }
106
107 if let Some(limit_str) = query_params.get("limit") {
109 let limit = limit_str.parse::<u64>().map_err(|_| {
110 Error::Parse(ParseError::InvalidLimit(format!(
111 "Invalid limit value: {}",
112 limit_str
113 )))
114 })?;
115 params = params.with_limit(limit);
116 }
117
118 if let Some(select_str) = query_params.get("select") {
120 let returning = parse_select(select_str)?;
121 params = params.with_returning(returning);
122 } else if let Some(returning_str) = query_params.get("returning") {
123 let returning = parse_select(returning_str)?;
125 params = params.with_returning(returning);
126 }
127
128 Ok(params)
129}
130
131pub fn parse_delete_params(query_string: &str) -> Result<DeleteParams, Error> {
146 let mut params = DeleteParams::new();
147
148 let query_params = parse_query_params(query_string);
150
151 let filters = crate::parse_params_from_pairs(
153 query_params
154 .iter()
155 .filter(|(k, _)| !is_reserved_key(k))
156 .map(|(k, v)| (k.clone(), v.clone()))
157 .collect(),
158 )?;
159
160 params = params.with_filters(filters.filters);
161
162 if let Some(order_str) = query_params.get("order") {
164 let order = parse_order(order_str)?;
165 params = params.with_order(order);
166 }
167
168 if let Some(limit_str) = query_params.get("limit") {
170 let limit = limit_str.parse::<u64>().map_err(|_| {
171 Error::Parse(ParseError::InvalidLimit(format!(
172 "Invalid limit value: {}",
173 limit_str
174 )))
175 })?;
176 params = params.with_limit(limit);
177 }
178
179 if let Some(select_str) = query_params.get("select") {
181 let returning = parse_select(select_str)?;
182 params = params.with_returning(returning);
183 } else if let Some(returning_str) = query_params.get("returning") {
184 let returning = parse_select(returning_str)?;
186 params = params.with_returning(returning);
187 }
188
189 Ok(params)
190}
191
192fn parse_query_params(query_string: &str) -> HashMap<String, String> {
193 query_string
194 .split('&')
195 .filter_map(|pair| {
196 let parts: Vec<&str> = pair.splitn(2, '=').collect();
197 if parts.len() == 2 {
198 Some((parts[0].to_string(), parts[1].to_string()))
199 } else {
200 None
201 }
202 })
203 .collect()
204}
205
206fn is_reserved_key(key: &str) -> bool {
207 matches!(
208 key,
209 "select" | "order" | "limit" | "offset" | "on_conflict" | "columns" | "returning"
210 )
211}
212
213fn parse_on_conflict(spec: &str) -> Result<OnConflict, Error> {
214 let parts: Vec<&str> = spec.split('.').collect();
218
219 let (columns_str, action) = match parts.len() {
220 1 => (parts[0], ConflictAction::DoNothing), 2 => {
222 let action = match parts[1].to_lowercase().as_str() {
223 "do_nothing" => ConflictAction::DoNothing,
224 "do_update" => ConflictAction::DoUpdate,
225 _ => {
226 return Err(Error::Parse(ParseError::InvalidOnConflict(format!(
227 "Invalid conflict action: '{}'. Expected 'do_nothing' or 'do_update'",
228 parts[1]
229 ))))
230 }
231 };
232 (parts[0], action)
233 }
234 _ => {
235 return Err(Error::Parse(ParseError::InvalidOnConflict(format!(
236 "Invalid on_conflict format: '{}'",
237 spec
238 ))))
239 }
240 };
241
242 let columns: Vec<String> = columns_str
243 .split(',')
244 .map(|s| s.trim().to_string())
245 .filter(|s| !s.is_empty())
246 .collect();
247
248 if columns.is_empty() {
249 return Err(Error::Parse(ParseError::InvalidOnConflict(
250 "on_conflict must specify at least one column".to_string(),
251 )));
252 }
253
254 Ok(OnConflict {
255 columns,
256 action,
257 where_clause: None,
258 update_columns: None,
259 })
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_parse_insert_params_simple() {
268 let body = r#"{"name": "Alice", "age": 30}"#;
269 let result = parse_insert_params("", body);
270 assert!(result.is_ok());
271 let params = result.unwrap();
272 assert!(params.returning.is_none());
273 assert!(params.on_conflict.is_none());
274 }
275
276 #[test]
277 fn test_parse_insert_params_with_returning() {
278 let body = r#"{"name": "Alice"}"#;
279 let result = parse_insert_params("returning=id,created_at", body);
280 assert!(result.is_ok());
281 let params = result.unwrap();
282 assert!(params.returning.is_some());
283 assert_eq!(params.returning.unwrap().len(), 2);
284 }
285
286 #[test]
287 fn test_parse_insert_params_with_on_conflict() {
288 let body = r#"{"email": "alice@example.com"}"#;
289 let result = parse_insert_params("on_conflict=email", body);
290 assert!(result.is_ok());
291 let params = result.unwrap();
292 assert!(params.on_conflict.is_some());
293 let conflict = params.on_conflict.unwrap();
294 assert_eq!(conflict.columns, vec!["email"]);
295 assert_eq!(conflict.action, ConflictAction::DoNothing);
296 }
297
298 #[test]
299 fn test_parse_insert_params_with_columns() {
300 let body = r#"{"name": "Alice", "age": 30, "extra": "ignored"}"#;
301 let result = parse_insert_params("columns=name,age", body);
302 assert!(result.is_ok());
303 let params = result.unwrap();
304 assert!(params.columns.is_some());
305 assert_eq!(params.columns.unwrap(), vec!["name", "age"]);
306 }
307
308 #[test]
309 fn test_parse_update_params_simple() {
310 let body = r#"{"status": "active"}"#;
311 let result = parse_update_params("id=eq.123", body);
312 assert!(result.is_ok());
313 let params = result.unwrap();
314 assert!(params.has_filters());
315 assert_eq!(params.set_values.len(), 1);
316 }
317
318 #[test]
319 fn test_parse_update_params_with_limit() {
320 let body = r#"{"status": "active"}"#;
321 let result = parse_update_params("status=eq.pending&limit=10", body);
322 assert!(result.is_ok());
323 let params = result.unwrap();
324 assert_eq!(params.limit, Some(10));
325 }
326
327 #[test]
328 fn test_parse_update_params_with_order() {
329 let body = r#"{"status": "active"}"#;
330 let result = parse_update_params("status=eq.pending&order=created_at.desc", body);
331 assert!(result.is_ok());
332 let params = result.unwrap();
333 assert_eq!(params.order.len(), 1);
334 }
335
336 #[test]
337 fn test_parse_delete_params_simple() {
338 let result = parse_delete_params("id=eq.123");
339 assert!(result.is_ok());
340 let params = result.unwrap();
341 assert!(params.has_filters());
342 }
343
344 #[test]
345 fn test_parse_delete_params_with_returning() {
346 let result = parse_delete_params("status=eq.deleted&returning=*");
347 assert!(result.is_ok());
348 let params = result.unwrap();
349 assert!(params.returning.is_some());
350 }
351
352 #[test]
353 fn test_parse_on_conflict_do_nothing() {
354 let result = parse_on_conflict("email");
355 assert!(result.is_ok());
356 let conflict = result.unwrap();
357 assert_eq!(conflict.columns, vec!["email"]);
358 assert_eq!(conflict.action, ConflictAction::DoNothing);
359 }
360
361 #[test]
362 fn test_parse_on_conflict_do_update() {
363 let result = parse_on_conflict("email.do_update");
364 assert!(result.is_ok());
365 let conflict = result.unwrap();
366 assert_eq!(conflict.columns, vec!["email"]);
367 assert_eq!(conflict.action, ConflictAction::DoUpdate);
368 }
369
370 #[test]
371 fn test_parse_on_conflict_multiple_columns() {
372 let result = parse_on_conflict("email,username.do_nothing");
373 assert!(result.is_ok());
374 let conflict = result.unwrap();
375 assert_eq!(conflict.columns, vec!["email", "username"]);
376 assert_eq!(conflict.action, ConflictAction::DoNothing);
377 }
378
379 #[test]
380 fn test_parse_on_conflict_invalid_action() {
381 let result = parse_on_conflict("email.invalid_action");
382 assert!(result.is_err());
383 }
384
385 #[test]
386 fn test_parse_on_conflict_empty_columns() {
387 let result = parse_on_conflict("");
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_parse_query_params() {
393 let params = parse_query_params("id=eq.123&status=eq.active&limit=10");
394 assert_eq!(params.len(), 3);
395 assert_eq!(params.get("id"), Some(&"eq.123".to_string()));
396 assert_eq!(params.get("limit"), Some(&"10".to_string()));
397 }
398
399 #[test]
400 fn test_is_reserved_key() {
401 assert!(is_reserved_key("select"));
402 assert!(is_reserved_key("order"));
403 assert!(is_reserved_key("limit"));
404 assert!(is_reserved_key("returning"));
405 assert!(is_reserved_key("on_conflict"));
406 assert!(!is_reserved_key("id"));
407 assert!(!is_reserved_key("status"));
408 }
409}