Skip to main content

postgrest_parser/sql/
mutation.rs

1use crate::ast::{
2    ConflictAction, DeleteParams, InsertParams, InsertValues, OnConflict, ResolvedTable,
3    SelectItem, UpdateParams,
4};
5use crate::error::SqlError;
6use crate::sql::{QueryBuilder, QueryResult};
7
8impl QueryBuilder {
9    /// Builds an INSERT query with schema-qualified table name
10    pub fn build_insert(
11        &mut self,
12        resolved_table: &ResolvedTable,
13        params: &InsertParams,
14    ) -> Result<QueryResult, SqlError> {
15        if params.values.is_empty() {
16            return Err(SqlError::NoInsertValues);
17        }
18
19        self.tables.push(resolved_table.name.clone());
20
21        // INSERT INTO "schema"."table"
22        self.sql
23            .push_str(&format!("INSERT INTO {}", resolved_table.qualified_name()));
24
25        // Determine columns
26        let columns = if let Some(ref cols) = params.columns {
27            cols.clone()
28        } else {
29            params.values.get_columns()
30        };
31
32        // Column list
33        self.sql.push_str(" (");
34        for (i, col) in columns.iter().enumerate() {
35            if i > 0 {
36                self.sql.push_str(", ");
37            }
38            self.sql.push_str(&format!("\"{}\"", col));
39        }
40        self.sql.push(')');
41
42        // VALUES clause
43        self.build_values_clause(&params.values, &columns)?;
44
45        // ON CONFLICT clause
46        if let Some(ref on_conflict) = params.on_conflict {
47            self.build_on_conflict_clause(on_conflict)?;
48        }
49
50        // RETURNING clause
51        if let Some(ref returning) = params.returning {
52            self.build_returning_clause(returning)?;
53        }
54
55        Ok(QueryResult {
56            query: self.sql.clone(),
57            params: self.params.clone(),
58            tables: self.tables.clone(),
59        })
60    }
61
62    /// Builds an UPDATE query with schema-qualified table name and safety validation
63    pub fn build_update(
64        &mut self,
65        resolved_table: &ResolvedTable,
66        params: &UpdateParams,
67    ) -> Result<QueryResult, SqlError> {
68        // Safety validation
69        self.validate_update_safety(params)?;
70
71        if params.set_values.is_empty() {
72            return Err(SqlError::NoUpdateSet);
73        }
74
75        self.tables.push(resolved_table.name.clone());
76
77        // UPDATE "schema"."table"
78        self.sql
79            .push_str(&format!("UPDATE {}", resolved_table.qualified_name()));
80
81        // SET clause
82        self.build_set_clause(&params.set_values)?;
83
84        // WHERE clause
85        if !params.filters.is_empty() {
86            self.build_where_clause(&params.filters)?;
87        }
88
89        // ORDER BY clause
90        if !params.order.is_empty() {
91            self.build_order_clause(&params.order)?;
92        }
93
94        // LIMIT clause
95        if let Some(limit) = params.limit {
96            self.build_limit_clause(limit)?;
97        }
98
99        // RETURNING clause
100        if let Some(ref returning) = params.returning {
101            self.build_returning_clause(returning)?;
102        }
103
104        Ok(QueryResult {
105            query: self.sql.clone(),
106            params: self.params.clone(),
107            tables: self.tables.clone(),
108        })
109    }
110
111    /// Builds a DELETE query with schema-qualified table name and safety validation
112    pub fn build_delete(
113        &mut self,
114        resolved_table: &ResolvedTable,
115        params: &DeleteParams,
116    ) -> Result<QueryResult, SqlError> {
117        // Safety validation
118        self.validate_delete_safety(params)?;
119
120        self.tables.push(resolved_table.name.clone());
121
122        // DELETE FROM "schema"."table"
123        self.sql
124            .push_str(&format!("DELETE FROM {}", resolved_table.qualified_name()));
125
126        // WHERE clause
127        if !params.filters.is_empty() {
128            self.build_where_clause(&params.filters)?;
129        }
130
131        // ORDER BY clause
132        if !params.order.is_empty() {
133            self.build_order_clause(&params.order)?;
134        }
135
136        // LIMIT clause
137        if let Some(limit) = params.limit {
138            self.build_limit_clause(limit)?;
139        }
140
141        // RETURNING clause
142        if let Some(ref returning) = params.returning {
143            self.build_returning_clause(returning)?;
144        }
145
146        Ok(QueryResult {
147            query: self.sql.clone(),
148            params: self.params.clone(),
149            tables: self.tables.clone(),
150        })
151    }
152
153    fn build_values_clause(
154        &mut self,
155        values: &InsertValues,
156        columns: &[String],
157    ) -> Result<(), SqlError> {
158        self.sql.push_str(" VALUES ");
159
160        match values {
161            InsertValues::Single(map) => {
162                self.sql.push('(');
163                for (i, col) in columns.iter().enumerate() {
164                    if i > 0 {
165                        self.sql.push_str(", ");
166                    }
167                    let value = map.get(col).unwrap_or(&serde_json::Value::Null);
168                    let param = self.add_param(value.clone());
169                    self.sql.push_str(&param);
170                }
171                self.sql.push(')');
172            }
173            InsertValues::Bulk(rows) => {
174                for (row_idx, row) in rows.iter().enumerate() {
175                    if row_idx > 0 {
176                        self.sql.push_str(", ");
177                    }
178                    self.sql.push('(');
179                    for (i, col) in columns.iter().enumerate() {
180                        if i > 0 {
181                            self.sql.push_str(", ");
182                        }
183                        let value = row.get(col).unwrap_or(&serde_json::Value::Null);
184                        let param = self.add_param(value.clone());
185                        self.sql.push_str(&param);
186                    }
187                    self.sql.push(')');
188                }
189            }
190        }
191
192        Ok(())
193    }
194
195    fn build_on_conflict_clause(&mut self, on_conflict: &OnConflict) -> Result<(), SqlError> {
196        self.sql.push_str(" ON CONFLICT (");
197        for (i, col) in on_conflict.columns.iter().enumerate() {
198            if i > 0 {
199                self.sql.push_str(", ");
200            }
201            self.sql.push_str(&format!("\"{}\"", col));
202        }
203        self.sql.push(')');
204
205        // Add WHERE clause for partial unique index
206        if let Some(ref where_conditions) = on_conflict.where_clause {
207            self.sql.push_str(" WHERE ");
208            for (i, condition) in where_conditions.iter().enumerate() {
209                if i > 0 {
210                    self.sql.push_str(" AND ");
211                }
212                let condition_sql = self.build_filter(condition)?;
213                self.sql.push_str(&condition_sql);
214            }
215        }
216
217        match on_conflict.action {
218            ConflictAction::DoNothing => {
219                self.sql.push_str(" DO NOTHING");
220            }
221            ConflictAction::DoUpdate => {
222                self.sql.push_str(" DO UPDATE SET ");
223
224                // Determine which columns to update
225                let columns_to_update = if let Some(ref update_cols) = on_conflict.update_columns {
226                    // Use specified columns
227                    update_cols.clone()
228                } else {
229                    // Default: update all columns (same as conflict columns for now)
230                    on_conflict.columns.clone()
231                };
232
233                // Update specified columns
234                let mut first = true;
235                for col in columns_to_update.iter() {
236                    if !first {
237                        self.sql.push_str(", ");
238                    }
239                    self.sql
240                        .push_str(&format!("\"{}\" = EXCLUDED.\"{}\"", col, col));
241                    first = false;
242                }
243            }
244        }
245
246        Ok(())
247    }
248
249    fn build_set_clause(
250        &mut self,
251        set_values: &std::collections::HashMap<String, serde_json::Value>,
252    ) -> Result<(), SqlError> {
253        self.sql.push_str(" SET ");
254
255        let mut sorted_keys: Vec<&String> = set_values.keys().collect();
256        sorted_keys.sort(); // Sort for deterministic output
257
258        for (i, key) in sorted_keys.iter().enumerate() {
259            if i > 0 {
260                self.sql.push_str(", ");
261            }
262            let value = set_values.get(*key).unwrap();
263            let param = self.add_param(value.clone());
264            self.sql.push_str(&format!("\"{}\" = {}", key, param));
265        }
266
267        Ok(())
268    }
269
270    fn build_returning_clause(&mut self, items: &[SelectItem]) -> Result<(), SqlError> {
271        self.sql.push_str(" RETURNING ");
272
273        for (i, item) in items.iter().enumerate() {
274            if i > 0 {
275                self.sql.push_str(", ");
276            }
277
278            // Relations are not supported in RETURNING for now
279            if matches!(item.item_type, crate::ast::ItemType::Relation) {
280                return Err(SqlError::FailedToBuildSelectClause);
281            }
282
283            self.sql.push_str(&format!("\"{}\"", item.name));
284            if let Some(ref alias) = item.alias {
285                self.sql.push_str(&format!(" AS \"{}\"", alias));
286            }
287        }
288
289        Ok(())
290    }
291
292    fn build_limit_clause(&mut self, limit: u64) -> Result<(), SqlError> {
293        let param = self.add_param(serde_json::Value::Number(limit.into()));
294        self.sql.push_str(&format!(" LIMIT {}", param));
295        Ok(())
296    }
297
298    fn validate_update_safety(&self, params: &UpdateParams) -> Result<(), SqlError> {
299        if params.filters.is_empty() {
300            return Err(SqlError::UnsafeUpdate);
301        }
302
303        if params.limit.is_some() && params.order.is_empty() {
304            return Err(SqlError::LimitWithoutOrder);
305        }
306
307        Ok(())
308    }
309
310    fn validate_delete_safety(&self, params: &DeleteParams) -> Result<(), SqlError> {
311        if params.filters.is_empty() {
312            return Err(SqlError::UnsafeDelete);
313        }
314
315        if params.limit.is_some() && params.order.is_empty() {
316            return Err(SqlError::LimitWithoutOrder);
317        }
318
319        Ok(())
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::ast::{Field, LogicCondition};
327    use serde_json::json;
328    use std::collections::HashMap;
329
330    #[test]
331    fn test_build_insert_single() {
332        let mut builder = QueryBuilder::new();
333        let table = ResolvedTable::new("public", "users");
334
335        let mut values = HashMap::new();
336        values.insert("name".to_string(), json!("Alice"));
337        values.insert("age".to_string(), json!(30));
338
339        let params = InsertParams::new(InsertValues::Single(values));
340        let result = builder.build_insert(&table, &params).unwrap();
341
342        assert!(result.query.contains("INSERT INTO \"public\".\"users\""));
343        assert!(result.query.contains("\"age\""));
344        assert!(result.query.contains("\"name\""));
345        assert_eq!(result.params.len(), 2);
346    }
347
348    #[test]
349    fn test_build_insert_bulk() {
350        let mut builder = QueryBuilder::new();
351        let table = ResolvedTable::new("public", "users");
352
353        let mut row1 = HashMap::new();
354        row1.insert("name".to_string(), json!("Alice"));
355        let mut row2 = HashMap::new();
356        row2.insert("name".to_string(), json!("Bob"));
357
358        let params = InsertParams::new(InsertValues::Bulk(vec![row1, row2]));
359        let result = builder.build_insert(&table, &params).unwrap();
360
361        assert!(result.query.contains("VALUES"));
362        assert_eq!(result.params.len(), 2);
363    }
364
365    #[test]
366    fn test_build_insert_with_on_conflict() {
367        let mut builder = QueryBuilder::new();
368        let table = ResolvedTable::new("auth", "users");
369
370        let mut values = HashMap::new();
371        values.insert("email".to_string(), json!("alice@example.com"));
372
373        let conflict = OnConflict::do_update(vec!["email".to_string()]);
374        let params = InsertParams::new(InsertValues::Single(values)).with_on_conflict(conflict);
375
376        let result = builder.build_insert(&table, &params).unwrap();
377
378        assert!(result.query.contains("ON CONFLICT"));
379        assert!(result.query.contains("DO UPDATE"));
380        assert!(result.query.contains("EXCLUDED"));
381    }
382
383    #[test]
384    fn test_build_update_with_filters() {
385        let mut builder = QueryBuilder::new();
386        let table = ResolvedTable::new("public", "users");
387
388        let mut set_values = HashMap::new();
389        set_values.insert("status".to_string(), json!("active"));
390
391        let filter = LogicCondition::Filter(crate::ast::Filter::new(
392            Field::new("id"),
393            crate::ast::FilterOperator::Eq,
394            crate::ast::FilterValue::Single("123".to_string()),
395        ));
396
397        let params = UpdateParams::new(set_values).with_filters(vec![filter]);
398        let result = builder.build_update(&table, &params).unwrap();
399
400        assert!(result.query.contains("UPDATE \"public\".\"users\""));
401        assert!(result.query.contains("SET"));
402        assert!(result.query.contains("WHERE"));
403    }
404
405    #[test]
406    fn test_build_update_without_filters_fails() {
407        let mut builder = QueryBuilder::new();
408        let table = ResolvedTable::new("public", "users");
409
410        let mut set_values = HashMap::new();
411        set_values.insert("status".to_string(), json!("active"));
412
413        let params = UpdateParams::new(set_values);
414        let result = builder.build_update(&table, &params);
415
416        assert!(result.is_err());
417        assert!(matches!(result.unwrap_err(), SqlError::UnsafeUpdate));
418    }
419
420    #[test]
421    fn test_build_delete_with_filters() {
422        let mut builder = QueryBuilder::new();
423        let table = ResolvedTable::new("public", "users");
424
425        let filter = LogicCondition::Filter(crate::ast::Filter::new(
426            Field::new("status"),
427            crate::ast::FilterOperator::Eq,
428            crate::ast::FilterValue::Single("deleted".to_string()),
429        ));
430
431        let params = DeleteParams::new().with_filters(vec![filter]);
432        let result = builder.build_delete(&table, &params).unwrap();
433
434        assert!(result.query.contains("DELETE FROM \"public\".\"users\""));
435        assert!(result.query.contains("WHERE"));
436    }
437
438    #[test]
439    fn test_build_delete_without_filters_fails() {
440        let mut builder = QueryBuilder::new();
441        let table = ResolvedTable::new("public", "users");
442
443        let params = DeleteParams::new();
444        let result = builder.build_delete(&table, &params);
445
446        assert!(result.is_err());
447        assert!(matches!(result.unwrap_err(), SqlError::UnsafeDelete));
448    }
449
450    #[test]
451    fn test_update_limit_without_order_fails() {
452        let mut builder = QueryBuilder::new();
453        let table = ResolvedTable::new("public", "users");
454
455        let mut set_values = HashMap::new();
456        set_values.insert("status".to_string(), json!("active"));
457
458        let filter = LogicCondition::Filter(crate::ast::Filter::new(
459            Field::new("id"),
460            crate::ast::FilterOperator::Eq,
461            crate::ast::FilterValue::Single("123".to_string()),
462        ));
463
464        let params = UpdateParams::new(set_values)
465            .with_filters(vec![filter])
466            .with_limit(10);
467
468        let result = builder.build_update(&table, &params);
469        assert!(result.is_err());
470        assert!(matches!(result.unwrap_err(), SqlError::LimitWithoutOrder));
471    }
472}