1use crate::ast::{
2 ConflictAction, DeleteParams, InsertParams, InsertValues, OnConflict, ResolvedTable,
3 SelectItem, UpdateParams,
4};
5use crate::error::SqlError;
6use crate::sql::{QueryBuilder, QueryResult};
7
8impl QueryBuilder {
9 pub fn build_insert(
11 &mut self,
12 resolved_table: &ResolvedTable,
13 params: &InsertParams,
14 ) -> Result<QueryResult, SqlError> {
15 if params.values.is_empty() {
16 return Err(SqlError::NoInsertValues);
17 }
18
19 self.tables.push(resolved_table.name.clone());
20
21 self.sql
23 .push_str(&format!("INSERT INTO {}", resolved_table.qualified_name()));
24
25 let columns = if let Some(ref cols) = params.columns {
27 cols.clone()
28 } else {
29 params.values.get_columns()
30 };
31
32 self.sql.push_str(" (");
34 for (i, col) in columns.iter().enumerate() {
35 if i > 0 {
36 self.sql.push_str(", ");
37 }
38 self.sql.push_str(&format!("\"{}\"", col));
39 }
40 self.sql.push(')');
41
42 self.build_values_clause(¶ms.values, &columns)?;
44
45 if let Some(ref on_conflict) = params.on_conflict {
47 self.build_on_conflict_clause(on_conflict)?;
48 }
49
50 if let Some(ref returning) = params.returning {
52 self.build_returning_clause(returning)?;
53 }
54
55 Ok(QueryResult {
56 query: self.sql.clone(),
57 params: self.params.clone(),
58 tables: self.tables.clone(),
59 })
60 }
61
62 pub fn build_update(
64 &mut self,
65 resolved_table: &ResolvedTable,
66 params: &UpdateParams,
67 ) -> Result<QueryResult, SqlError> {
68 self.validate_update_safety(params)?;
70
71 if params.set_values.is_empty() {
72 return Err(SqlError::NoUpdateSet);
73 }
74
75 self.tables.push(resolved_table.name.clone());
76
77 self.sql
79 .push_str(&format!("UPDATE {}", resolved_table.qualified_name()));
80
81 self.build_set_clause(¶ms.set_values)?;
83
84 if !params.filters.is_empty() {
86 self.build_where_clause(¶ms.filters)?;
87 }
88
89 if !params.order.is_empty() {
91 self.build_order_clause(¶ms.order)?;
92 }
93
94 if let Some(limit) = params.limit {
96 self.build_limit_clause(limit)?;
97 }
98
99 if let Some(ref returning) = params.returning {
101 self.build_returning_clause(returning)?;
102 }
103
104 Ok(QueryResult {
105 query: self.sql.clone(),
106 params: self.params.clone(),
107 tables: self.tables.clone(),
108 })
109 }
110
111 pub fn build_delete(
113 &mut self,
114 resolved_table: &ResolvedTable,
115 params: &DeleteParams,
116 ) -> Result<QueryResult, SqlError> {
117 self.validate_delete_safety(params)?;
119
120 self.tables.push(resolved_table.name.clone());
121
122 self.sql
124 .push_str(&format!("DELETE FROM {}", resolved_table.qualified_name()));
125
126 if !params.filters.is_empty() {
128 self.build_where_clause(¶ms.filters)?;
129 }
130
131 if !params.order.is_empty() {
133 self.build_order_clause(¶ms.order)?;
134 }
135
136 if let Some(limit) = params.limit {
138 self.build_limit_clause(limit)?;
139 }
140
141 if let Some(ref returning) = params.returning {
143 self.build_returning_clause(returning)?;
144 }
145
146 Ok(QueryResult {
147 query: self.sql.clone(),
148 params: self.params.clone(),
149 tables: self.tables.clone(),
150 })
151 }
152
153 fn build_values_clause(
154 &mut self,
155 values: &InsertValues,
156 columns: &[String],
157 ) -> Result<(), SqlError> {
158 self.sql.push_str(" VALUES ");
159
160 match values {
161 InsertValues::Single(map) => {
162 self.sql.push('(');
163 for (i, col) in columns.iter().enumerate() {
164 if i > 0 {
165 self.sql.push_str(", ");
166 }
167 let value = map.get(col).unwrap_or(&serde_json::Value::Null);
168 let param = self.add_param(value.clone());
169 self.sql.push_str(¶m);
170 }
171 self.sql.push(')');
172 }
173 InsertValues::Bulk(rows) => {
174 for (row_idx, row) in rows.iter().enumerate() {
175 if row_idx > 0 {
176 self.sql.push_str(", ");
177 }
178 self.sql.push('(');
179 for (i, col) in columns.iter().enumerate() {
180 if i > 0 {
181 self.sql.push_str(", ");
182 }
183 let value = row.get(col).unwrap_or(&serde_json::Value::Null);
184 let param = self.add_param(value.clone());
185 self.sql.push_str(¶m);
186 }
187 self.sql.push(')');
188 }
189 }
190 }
191
192 Ok(())
193 }
194
195 fn build_on_conflict_clause(&mut self, on_conflict: &OnConflict) -> Result<(), SqlError> {
196 self.sql.push_str(" ON CONFLICT (");
197 for (i, col) in on_conflict.columns.iter().enumerate() {
198 if i > 0 {
199 self.sql.push_str(", ");
200 }
201 self.sql.push_str(&format!("\"{}\"", col));
202 }
203 self.sql.push(')');
204
205 if let Some(ref where_conditions) = on_conflict.where_clause {
207 self.sql.push_str(" WHERE ");
208 for (i, condition) in where_conditions.iter().enumerate() {
209 if i > 0 {
210 self.sql.push_str(" AND ");
211 }
212 let condition_sql = self.build_filter(condition)?;
213 self.sql.push_str(&condition_sql);
214 }
215 }
216
217 match on_conflict.action {
218 ConflictAction::DoNothing => {
219 self.sql.push_str(" DO NOTHING");
220 }
221 ConflictAction::DoUpdate => {
222 self.sql.push_str(" DO UPDATE SET ");
223
224 let columns_to_update = if let Some(ref update_cols) = on_conflict.update_columns {
226 update_cols.clone()
228 } else {
229 on_conflict.columns.clone()
231 };
232
233 let mut first = true;
235 for col in columns_to_update.iter() {
236 if !first {
237 self.sql.push_str(", ");
238 }
239 self.sql
240 .push_str(&format!("\"{}\" = EXCLUDED.\"{}\"", col, col));
241 first = false;
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 fn build_set_clause(
250 &mut self,
251 set_values: &std::collections::HashMap<String, serde_json::Value>,
252 ) -> Result<(), SqlError> {
253 self.sql.push_str(" SET ");
254
255 let mut sorted_keys: Vec<&String> = set_values.keys().collect();
256 sorted_keys.sort(); for (i, key) in sorted_keys.iter().enumerate() {
259 if i > 0 {
260 self.sql.push_str(", ");
261 }
262 let value = set_values.get(*key).unwrap();
263 let param = self.add_param(value.clone());
264 self.sql.push_str(&format!("\"{}\" = {}", key, param));
265 }
266
267 Ok(())
268 }
269
270 fn build_returning_clause(&mut self, items: &[SelectItem]) -> Result<(), SqlError> {
271 self.sql.push_str(" RETURNING ");
272
273 for (i, item) in items.iter().enumerate() {
274 if i > 0 {
275 self.sql.push_str(", ");
276 }
277
278 if matches!(item.item_type, crate::ast::ItemType::Relation) {
280 return Err(SqlError::FailedToBuildSelectClause);
281 }
282
283 self.sql.push_str(&format!("\"{}\"", item.name));
284 if let Some(ref alias) = item.alias {
285 self.sql.push_str(&format!(" AS \"{}\"", alias));
286 }
287 }
288
289 Ok(())
290 }
291
292 fn build_limit_clause(&mut self, limit: u64) -> Result<(), SqlError> {
293 let param = self.add_param(serde_json::Value::Number(limit.into()));
294 self.sql.push_str(&format!(" LIMIT {}", param));
295 Ok(())
296 }
297
298 fn validate_update_safety(&self, params: &UpdateParams) -> Result<(), SqlError> {
299 if params.filters.is_empty() {
300 return Err(SqlError::UnsafeUpdate);
301 }
302
303 if params.limit.is_some() && params.order.is_empty() {
304 return Err(SqlError::LimitWithoutOrder);
305 }
306
307 Ok(())
308 }
309
310 fn validate_delete_safety(&self, params: &DeleteParams) -> Result<(), SqlError> {
311 if params.filters.is_empty() {
312 return Err(SqlError::UnsafeDelete);
313 }
314
315 if params.limit.is_some() && params.order.is_empty() {
316 return Err(SqlError::LimitWithoutOrder);
317 }
318
319 Ok(())
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::ast::{Field, LogicCondition};
327 use serde_json::json;
328 use std::collections::HashMap;
329
330 #[test]
331 fn test_build_insert_single() {
332 let mut builder = QueryBuilder::new();
333 let table = ResolvedTable::new("public", "users");
334
335 let mut values = HashMap::new();
336 values.insert("name".to_string(), json!("Alice"));
337 values.insert("age".to_string(), json!(30));
338
339 let params = InsertParams::new(InsertValues::Single(values));
340 let result = builder.build_insert(&table, ¶ms).unwrap();
341
342 assert!(result.query.contains("INSERT INTO \"public\".\"users\""));
343 assert!(result.query.contains("\"age\""));
344 assert!(result.query.contains("\"name\""));
345 assert_eq!(result.params.len(), 2);
346 }
347
348 #[test]
349 fn test_build_insert_bulk() {
350 let mut builder = QueryBuilder::new();
351 let table = ResolvedTable::new("public", "users");
352
353 let mut row1 = HashMap::new();
354 row1.insert("name".to_string(), json!("Alice"));
355 let mut row2 = HashMap::new();
356 row2.insert("name".to_string(), json!("Bob"));
357
358 let params = InsertParams::new(InsertValues::Bulk(vec![row1, row2]));
359 let result = builder.build_insert(&table, ¶ms).unwrap();
360
361 assert!(result.query.contains("VALUES"));
362 assert_eq!(result.params.len(), 2);
363 }
364
365 #[test]
366 fn test_build_insert_with_on_conflict() {
367 let mut builder = QueryBuilder::new();
368 let table = ResolvedTable::new("auth", "users");
369
370 let mut values = HashMap::new();
371 values.insert("email".to_string(), json!("alice@example.com"));
372
373 let conflict = OnConflict::do_update(vec!["email".to_string()]);
374 let params = InsertParams::new(InsertValues::Single(values)).with_on_conflict(conflict);
375
376 let result = builder.build_insert(&table, ¶ms).unwrap();
377
378 assert!(result.query.contains("ON CONFLICT"));
379 assert!(result.query.contains("DO UPDATE"));
380 assert!(result.query.contains("EXCLUDED"));
381 }
382
383 #[test]
384 fn test_build_update_with_filters() {
385 let mut builder = QueryBuilder::new();
386 let table = ResolvedTable::new("public", "users");
387
388 let mut set_values = HashMap::new();
389 set_values.insert("status".to_string(), json!("active"));
390
391 let filter = LogicCondition::Filter(crate::ast::Filter::new(
392 Field::new("id"),
393 crate::ast::FilterOperator::Eq,
394 crate::ast::FilterValue::Single("123".to_string()),
395 ));
396
397 let params = UpdateParams::new(set_values).with_filters(vec![filter]);
398 let result = builder.build_update(&table, ¶ms).unwrap();
399
400 assert!(result.query.contains("UPDATE \"public\".\"users\""));
401 assert!(result.query.contains("SET"));
402 assert!(result.query.contains("WHERE"));
403 }
404
405 #[test]
406 fn test_build_update_without_filters_fails() {
407 let mut builder = QueryBuilder::new();
408 let table = ResolvedTable::new("public", "users");
409
410 let mut set_values = HashMap::new();
411 set_values.insert("status".to_string(), json!("active"));
412
413 let params = UpdateParams::new(set_values);
414 let result = builder.build_update(&table, ¶ms);
415
416 assert!(result.is_err());
417 assert!(matches!(result.unwrap_err(), SqlError::UnsafeUpdate));
418 }
419
420 #[test]
421 fn test_build_delete_with_filters() {
422 let mut builder = QueryBuilder::new();
423 let table = ResolvedTable::new("public", "users");
424
425 let filter = LogicCondition::Filter(crate::ast::Filter::new(
426 Field::new("status"),
427 crate::ast::FilterOperator::Eq,
428 crate::ast::FilterValue::Single("deleted".to_string()),
429 ));
430
431 let params = DeleteParams::new().with_filters(vec![filter]);
432 let result = builder.build_delete(&table, ¶ms).unwrap();
433
434 assert!(result.query.contains("DELETE FROM \"public\".\"users\""));
435 assert!(result.query.contains("WHERE"));
436 }
437
438 #[test]
439 fn test_build_delete_without_filters_fails() {
440 let mut builder = QueryBuilder::new();
441 let table = ResolvedTable::new("public", "users");
442
443 let params = DeleteParams::new();
444 let result = builder.build_delete(&table, ¶ms);
445
446 assert!(result.is_err());
447 assert!(matches!(result.unwrap_err(), SqlError::UnsafeDelete));
448 }
449
450 #[test]
451 fn test_update_limit_without_order_fails() {
452 let mut builder = QueryBuilder::new();
453 let table = ResolvedTable::new("public", "users");
454
455 let mut set_values = HashMap::new();
456 set_values.insert("status".to_string(), json!("active"));
457
458 let filter = LogicCondition::Filter(crate::ast::Filter::new(
459 Field::new("id"),
460 crate::ast::FilterOperator::Eq,
461 crate::ast::FilterValue::Single("123".to_string()),
462 ));
463
464 let params = UpdateParams::new(set_values)
465 .with_filters(vec![filter])
466 .with_limit(10);
467
468 let result = builder.build_update(&table, ¶ms);
469 assert!(result.is_err());
470 assert!(matches!(result.unwrap_err(), SqlError::LimitWithoutOrder));
471 }
472}