Skip to main content

postgrest_parser/parser/
mutation.rs

1use crate::ast::{ConflictAction, DeleteParams, InsertParams, OnConflict, UpdateParams};
2use crate::error::{Error, ParseError};
3use crate::parser::{
4    parse_json_body, parse_order, parse_select, validate_insert_body, validate_update_body,
5};
6use std::collections::HashMap;
7
8/// Parses INSERT operation parameters from query string and body
9///
10/// # Arguments
11///
12/// * `query_string` - Query parameters (e.g., "returning=id&on_conflict=email")
13/// * `body` - JSON body with values to insert
14///
15/// # Examples
16///
17/// ```
18/// use postgrest_parser::parser::parse_insert_params;
19///
20/// let body = r#"{"name": "Alice", "age": 30}"#;
21/// let params = parse_insert_params("returning=id,created_at", body).unwrap();
22/// assert!(params.returning.is_some());
23/// ```
24pub fn parse_insert_params(query_string: &str, body: &str) -> Result<InsertParams, Error> {
25    // Parse and validate body
26    let json_value = parse_json_body(body)?;
27    let values = validate_insert_body(json_value)?;
28
29    let mut params = InsertParams::new(values);
30
31    // Parse query parameters
32    let query_params = parse_query_params(query_string);
33
34    // Parse returning clause (PostgREST uses 'select' parameter for mutations)
35    if let Some(select_str) = query_params.get("select") {
36        let returning = parse_select(select_str)?;
37        params = params.with_returning(returning);
38    } else if let Some(returning_str) = query_params.get("returning") {
39        // Also support 'returning' for backwards compatibility
40        let returning = parse_select(returning_str)?;
41        params = params.with_returning(returning);
42    }
43
44    // Parse columns specification
45    if let Some(columns_str) = query_params.get("columns") {
46        let columns: Vec<String> = columns_str
47            .split(',')
48            .map(|s| s.trim().to_string())
49            .collect();
50        if !columns.is_empty() && !columns.iter().any(|c| c.is_empty()) {
51            params = params.with_columns(columns);
52        }
53    }
54
55    // Parse on_conflict specification
56    if let Some(on_conflict_str) = query_params.get("on_conflict") {
57        let on_conflict = parse_on_conflict(on_conflict_str)?;
58        params = params.with_on_conflict(on_conflict);
59    }
60
61    Ok(params)
62}
63
64/// Parses UPDATE operation parameters from query string and body
65///
66/// # Arguments
67///
68/// * `query_string` - Query parameters with filters (e.g., "id=eq.123&returning=*")
69/// * `body` - JSON body with values to update
70///
71/// # Examples
72///
73/// ```
74/// use postgrest_parser::parser::parse_update_params;
75///
76/// let body = r#"{"status": "active"}"#;
77/// let params = parse_update_params("id=eq.123", body).unwrap();
78/// assert!(params.has_filters());
79/// ```
80pub fn parse_update_params(query_string: &str, body: &str) -> Result<UpdateParams, Error> {
81    // Parse and validate body
82    let json_value = parse_json_body(body)?;
83    let set_values = validate_update_body(json_value)?;
84
85    let mut params = UpdateParams::new(set_values);
86
87    // Parse query parameters
88    let query_params = parse_query_params(query_string);
89
90    // Parse filters (anything that's not a reserved key)
91    let filters = crate::parse_params_from_pairs(
92        query_params
93            .iter()
94            .filter(|(k, _)| !is_reserved_key(k))
95            .map(|(k, v)| (k.clone(), v.clone()))
96            .collect(),
97    )?;
98
99    params = params.with_filters(filters.filters);
100
101    // Parse order
102    if let Some(order_str) = query_params.get("order") {
103        let order = parse_order(order_str)?;
104        params = params.with_order(order);
105    }
106
107    // Parse limit
108    if let Some(limit_str) = query_params.get("limit") {
109        let limit = limit_str.parse::<u64>().map_err(|_| {
110            Error::Parse(ParseError::InvalidLimit(format!(
111                "Invalid limit value: {}",
112                limit_str
113            )))
114        })?;
115        params = params.with_limit(limit);
116    }
117
118    // Parse returning clause (PostgREST uses 'select' parameter for mutations)
119    if let Some(select_str) = query_params.get("select") {
120        let returning = parse_select(select_str)?;
121        params = params.with_returning(returning);
122    } else if let Some(returning_str) = query_params.get("returning") {
123        // Also support 'returning' for backwards compatibility
124        let returning = parse_select(returning_str)?;
125        params = params.with_returning(returning);
126    }
127
128    Ok(params)
129}
130
131/// Parses DELETE operation parameters from query string
132///
133/// # Arguments
134///
135/// * `query_string` - Query parameters with filters (e.g., "id=eq.123&returning=*")
136///
137/// # Examples
138///
139/// ```
140/// use postgrest_parser::parser::parse_delete_params;
141///
142/// let params = parse_delete_params("status=eq.deleted").unwrap();
143/// assert!(params.has_filters());
144/// ```
145pub fn parse_delete_params(query_string: &str) -> Result<DeleteParams, Error> {
146    let mut params = DeleteParams::new();
147
148    // Parse query parameters
149    let query_params = parse_query_params(query_string);
150
151    // Parse filters
152    let filters = crate::parse_params_from_pairs(
153        query_params
154            .iter()
155            .filter(|(k, _)| !is_reserved_key(k))
156            .map(|(k, v)| (k.clone(), v.clone()))
157            .collect(),
158    )?;
159
160    params = params.with_filters(filters.filters);
161
162    // Parse order
163    if let Some(order_str) = query_params.get("order") {
164        let order = parse_order(order_str)?;
165        params = params.with_order(order);
166    }
167
168    // Parse limit
169    if let Some(limit_str) = query_params.get("limit") {
170        let limit = limit_str.parse::<u64>().map_err(|_| {
171            Error::Parse(ParseError::InvalidLimit(format!(
172                "Invalid limit value: {}",
173                limit_str
174            )))
175        })?;
176        params = params.with_limit(limit);
177    }
178
179    // Parse returning clause (PostgREST uses 'select' parameter for mutations)
180    if let Some(select_str) = query_params.get("select") {
181        let returning = parse_select(select_str)?;
182        params = params.with_returning(returning);
183    } else if let Some(returning_str) = query_params.get("returning") {
184        // Also support 'returning' for backwards compatibility
185        let returning = parse_select(returning_str)?;
186        params = params.with_returning(returning);
187    }
188
189    Ok(params)
190}
191
192fn parse_query_params(query_string: &str) -> HashMap<String, String> {
193    query_string
194        .split('&')
195        .filter_map(|pair| {
196            let parts: Vec<&str> = pair.splitn(2, '=').collect();
197            if parts.len() == 2 {
198                Some((parts[0].to_string(), parts[1].to_string()))
199            } else {
200                None
201            }
202        })
203        .collect()
204}
205
206fn is_reserved_key(key: &str) -> bool {
207    matches!(
208        key,
209        "select" | "order" | "limit" | "offset" | "on_conflict" | "columns" | "returning"
210    )
211}
212
213fn parse_on_conflict(spec: &str) -> Result<OnConflict, Error> {
214    // Format: "column1,column2" or "column1,column2.action"
215    // where action can be "do_nothing" or "do_update" (default: do_nothing)
216
217    let parts: Vec<&str> = spec.split('.').collect();
218
219    let (columns_str, action) = match parts.len() {
220        1 => (parts[0], ConflictAction::DoNothing), // Default to DO NOTHING
221        2 => {
222            let action = match parts[1].to_lowercase().as_str() {
223                "do_nothing" => ConflictAction::DoNothing,
224                "do_update" => ConflictAction::DoUpdate,
225                _ => {
226                    return Err(Error::Parse(ParseError::InvalidOnConflict(format!(
227                        "Invalid conflict action: '{}'. Expected 'do_nothing' or 'do_update'",
228                        parts[1]
229                    ))))
230                }
231            };
232            (parts[0], action)
233        }
234        _ => {
235            return Err(Error::Parse(ParseError::InvalidOnConflict(format!(
236                "Invalid on_conflict format: '{}'",
237                spec
238            ))))
239        }
240    };
241
242    let columns: Vec<String> = columns_str
243        .split(',')
244        .map(|s| s.trim().to_string())
245        .filter(|s| !s.is_empty())
246        .collect();
247
248    if columns.is_empty() {
249        return Err(Error::Parse(ParseError::InvalidOnConflict(
250            "on_conflict must specify at least one column".to_string(),
251        )));
252    }
253
254    Ok(OnConflict {
255        columns,
256        action,
257        where_clause: None,
258        update_columns: None,
259    })
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_parse_insert_params_simple() {
268        let body = r#"{"name": "Alice", "age": 30}"#;
269        let result = parse_insert_params("", body);
270        assert!(result.is_ok());
271        let params = result.unwrap();
272        assert!(params.returning.is_none());
273        assert!(params.on_conflict.is_none());
274    }
275
276    #[test]
277    fn test_parse_insert_params_with_returning() {
278        let body = r#"{"name": "Alice"}"#;
279        let result = parse_insert_params("returning=id,created_at", body);
280        assert!(result.is_ok());
281        let params = result.unwrap();
282        assert!(params.returning.is_some());
283        assert_eq!(params.returning.unwrap().len(), 2);
284    }
285
286    #[test]
287    fn test_parse_insert_params_with_on_conflict() {
288        let body = r#"{"email": "alice@example.com"}"#;
289        let result = parse_insert_params("on_conflict=email", body);
290        assert!(result.is_ok());
291        let params = result.unwrap();
292        assert!(params.on_conflict.is_some());
293        let conflict = params.on_conflict.unwrap();
294        assert_eq!(conflict.columns, vec!["email"]);
295        assert_eq!(conflict.action, ConflictAction::DoNothing);
296    }
297
298    #[test]
299    fn test_parse_insert_params_with_columns() {
300        let body = r#"{"name": "Alice", "age": 30, "extra": "ignored"}"#;
301        let result = parse_insert_params("columns=name,age", body);
302        assert!(result.is_ok());
303        let params = result.unwrap();
304        assert!(params.columns.is_some());
305        assert_eq!(params.columns.unwrap(), vec!["name", "age"]);
306    }
307
308    #[test]
309    fn test_parse_update_params_simple() {
310        let body = r#"{"status": "active"}"#;
311        let result = parse_update_params("id=eq.123", body);
312        assert!(result.is_ok());
313        let params = result.unwrap();
314        assert!(params.has_filters());
315        assert_eq!(params.set_values.len(), 1);
316    }
317
318    #[test]
319    fn test_parse_update_params_with_limit() {
320        let body = r#"{"status": "active"}"#;
321        let result = parse_update_params("status=eq.pending&limit=10", body);
322        assert!(result.is_ok());
323        let params = result.unwrap();
324        assert_eq!(params.limit, Some(10));
325    }
326
327    #[test]
328    fn test_parse_update_params_with_order() {
329        let body = r#"{"status": "active"}"#;
330        let result = parse_update_params("status=eq.pending&order=created_at.desc", body);
331        assert!(result.is_ok());
332        let params = result.unwrap();
333        assert_eq!(params.order.len(), 1);
334    }
335
336    #[test]
337    fn test_parse_delete_params_simple() {
338        let result = parse_delete_params("id=eq.123");
339        assert!(result.is_ok());
340        let params = result.unwrap();
341        assert!(params.has_filters());
342    }
343
344    #[test]
345    fn test_parse_delete_params_with_returning() {
346        let result = parse_delete_params("status=eq.deleted&returning=*");
347        assert!(result.is_ok());
348        let params = result.unwrap();
349        assert!(params.returning.is_some());
350    }
351
352    #[test]
353    fn test_parse_on_conflict_do_nothing() {
354        let result = parse_on_conflict("email");
355        assert!(result.is_ok());
356        let conflict = result.unwrap();
357        assert_eq!(conflict.columns, vec!["email"]);
358        assert_eq!(conflict.action, ConflictAction::DoNothing);
359    }
360
361    #[test]
362    fn test_parse_on_conflict_do_update() {
363        let result = parse_on_conflict("email.do_update");
364        assert!(result.is_ok());
365        let conflict = result.unwrap();
366        assert_eq!(conflict.columns, vec!["email"]);
367        assert_eq!(conflict.action, ConflictAction::DoUpdate);
368    }
369
370    #[test]
371    fn test_parse_on_conflict_multiple_columns() {
372        let result = parse_on_conflict("email,username.do_nothing");
373        assert!(result.is_ok());
374        let conflict = result.unwrap();
375        assert_eq!(conflict.columns, vec!["email", "username"]);
376        assert_eq!(conflict.action, ConflictAction::DoNothing);
377    }
378
379    #[test]
380    fn test_parse_on_conflict_invalid_action() {
381        let result = parse_on_conflict("email.invalid_action");
382        assert!(result.is_err());
383    }
384
385    #[test]
386    fn test_parse_on_conflict_empty_columns() {
387        let result = parse_on_conflict("");
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_parse_query_params() {
393        let params = parse_query_params("id=eq.123&status=eq.active&limit=10");
394        assert_eq!(params.len(), 3);
395        assert_eq!(params.get("id"), Some(&"eq.123".to_string()));
396        assert_eq!(params.get("limit"), Some(&"10".to_string()));
397    }
398
399    #[test]
400    fn test_is_reserved_key() {
401        assert!(is_reserved_key("select"));
402        assert!(is_reserved_key("order"));
403        assert!(is_reserved_key("limit"));
404        assert!(is_reserved_key("returning"));
405        assert!(is_reserved_key("on_conflict"));
406        assert!(!is_reserved_key("id"));
407        assert!(!is_reserved_key("status"));
408    }
409}