elif_orm/sql/
generation.rs

1//! Query Builder SQL generation
2//! This module implements secure SQL generation with proper parameterization
3//! and identifier escaping to prevent SQL injection attacks.
4
5use crate::error::ModelError;
6use crate::query::builder::QueryBuilder;
7use crate::query::types::*;
8use crate::security::{escape_identifier, validate_identifier, validate_parameter};
9use serde_json::Value;
10
11impl<M> QueryBuilder<M> {
12    /// Generate SQL from query with parameter placeholders and return parameters
13    /// This method includes basic identifier escaping for security
14    pub fn to_sql_with_params(&self) -> (String, Vec<String>) {
15        // Note: For backward compatibility, this returns the tuple directly
16        // In the future, consider making this return Result<(String, Vec<String>), ModelError>
17        match self.query_type {
18            QueryType::Select => self.build_select_sql(),
19            QueryType::Insert => self.build_insert_sql(),
20            QueryType::Update => self.build_update_sql(),
21            QueryType::Delete => self.build_delete_sql(),
22        }
23    }
24
25    /// Generate SQL with security validation - returns Result for proper error handling
26    pub fn to_sql_with_params_secure(&self) -> Result<(String, Vec<String>), ModelError> {
27        // Validate identifiers first
28        self.validate_query_security()?;
29        Ok(self.to_sql_with_params())
30    }
31
32    /// Validate query security before SQL generation
33    fn validate_query_security(&self) -> Result<(), ModelError> {
34        // Validate table identifiers
35        for table in &self.from_tables {
36            validate_identifier(table)?;
37        }
38
39        if let Some(ref table) = self.insert_table {
40            validate_identifier(table)?;
41        }
42
43        if let Some(ref table) = self.update_table {
44            validate_identifier(table)?;
45        }
46
47        if let Some(ref table) = self.delete_table {
48            validate_identifier(table)?;
49        }
50
51        // Validate select field identifiers
52        for field in &self.select_fields {
53            // Skip wildcard and function calls for now
54            if field != "*" && !field.contains('(') {
55                validate_identifier(field)?;
56            }
57        }
58
59        // Validate column identifiers in WHERE clauses
60        for condition in &self.where_conditions {
61            if condition.column != "RAW"
62                && condition.column != "EXISTS"
63                && condition.column != "NOT EXISTS"
64            {
65                validate_identifier(&condition.column)?;
66            }
67        }
68
69        // Validate JOIN table identifiers
70        for join in &self.joins {
71            validate_identifier(&join.table)?;
72        }
73
74        // Validate parameter values
75        for condition in &self.where_conditions {
76            if let Some(ref value) = condition.value {
77                if let Value::String(s) = value {
78                    validate_parameter(s)?;
79                }
80            }
81            for value in &condition.values {
82                if let Value::String(s) = value {
83                    validate_parameter(s)?;
84                }
85            }
86        }
87
88        Ok(())
89    }
90
91    /// Build SELECT SQL with parameters
92    fn build_select_sql(&self) -> (String, Vec<String>) {
93        let mut sql = String::new();
94        let mut params = Vec::new();
95        let mut param_counter = 1;
96
97        // SELECT clause
98        if self.distinct {
99            sql.push_str("SELECT DISTINCT ");
100        } else {
101            sql.push_str("SELECT ");
102        }
103
104        if self.select_fields.is_empty() {
105            sql.push('*');
106        } else {
107            let escaped_fields: Vec<String> = self
108                .select_fields
109                .iter()
110                .map(|field| {
111                    if field == "*" || field.contains('(') {
112                        // Keep wildcards and function calls as-is
113                        field.clone()
114                    } else {
115                        escape_identifier(field)
116                    }
117                })
118                .collect();
119            sql.push_str(&escaped_fields.join(", "));
120        }
121
122        // FROM clause
123        if !self.from_tables.is_empty() {
124            sql.push_str(" FROM ");
125            let escaped_tables: Vec<String> = self
126                .from_tables
127                .iter()
128                .map(|table| escape_identifier(table))
129                .collect();
130            sql.push_str(&escaped_tables.join(", "));
131        }
132
133        // JOIN clauses
134        for join in &self.joins {
135            sql.push(' ');
136            sql.push_str(&join.join_type.to_string());
137            sql.push(' ');
138            sql.push_str(&escape_identifier(&join.table));
139            sql.push_str(" ON ");
140            for (i, (left, right)) in join.on_conditions.iter().enumerate() {
141                if i > 0 {
142                    sql.push_str(" AND ");
143                }
144                sql.push_str(&format!(
145                    "{} = {}",
146                    escape_identifier(left),
147                    escape_identifier(right)
148                ));
149            }
150        }
151
152        self.build_where_clause(&mut sql, &mut params, &mut param_counter);
153        self.build_order_limit_clause(&mut sql);
154
155        (sql, params)
156    }
157
158    /// Build INSERT SQL with parameters
159    fn build_insert_sql(&self) -> (String, Vec<String>) {
160        let mut sql = String::new();
161        let mut params = Vec::new();
162        let mut param_counter = 1;
163
164        if let Some(table) = &self.insert_table {
165            sql.push_str("INSERT INTO ");
166            sql.push_str(&escape_identifier(table));
167
168            if !self.set_clauses.is_empty() {
169                sql.push_str(" (");
170                let columns: Vec<String> = self
171                    .set_clauses
172                    .iter()
173                    .map(|clause| escape_identifier(&clause.column))
174                    .collect();
175                sql.push_str(&columns.join(", "));
176                sql.push_str(") VALUES (");
177
178                for (i, clause) in self.set_clauses.iter().enumerate() {
179                    if i > 0 {
180                        sql.push_str(", ");
181                    }
182                    if let Some(ref value) = clause.value {
183                        sql.push_str(&format!("${}", param_counter));
184                        params.push(self.json_value_to_param_string(value));
185                        param_counter += 1;
186                    } else {
187                        sql.push_str("NULL");
188                    }
189                }
190                sql.push(')');
191            }
192        }
193
194        (sql, params)
195    }
196
197    /// Build UPDATE SQL with parameters
198    fn build_update_sql(&self) -> (String, Vec<String>) {
199        let mut sql = String::new();
200        let mut params = Vec::new();
201        let mut param_counter = 1;
202
203        if let Some(table) = &self.update_table {
204            sql.push_str("UPDATE ");
205            sql.push_str(&escape_identifier(table));
206
207            if !self.set_clauses.is_empty() {
208                sql.push_str(" SET ");
209                for (i, clause) in self.set_clauses.iter().enumerate() {
210                    if i > 0 {
211                        sql.push_str(", ");
212                    }
213                    sql.push_str(&escape_identifier(&clause.column));
214                    sql.push_str(" = ");
215                    if let Some(ref value) = clause.value {
216                        sql.push_str(&format!("${}", param_counter));
217                        params.push(self.json_value_to_param_string(value));
218                        param_counter += 1;
219                    } else {
220                        sql.push_str("NULL");
221                    }
222                }
223            }
224
225            self.build_where_clause(&mut sql, &mut params, &mut param_counter);
226        }
227
228        (sql, params)
229    }
230
231    /// Build DELETE SQL with parameters
232    fn build_delete_sql(&self) -> (String, Vec<String>) {
233        let mut sql = String::new();
234        let mut params = Vec::new();
235        let mut param_counter = 1;
236
237        if let Some(table) = &self.delete_table {
238            sql.push_str("DELETE FROM ");
239            sql.push_str(&escape_identifier(table));
240            self.build_where_clause(&mut sql, &mut params, &mut param_counter);
241        }
242
243        (sql, params)
244    }
245
246    /// Helper method to build WHERE clauses
247    fn build_where_clause(
248        &self,
249        sql: &mut String,
250        params: &mut Vec<String>,
251        param_counter: &mut i32,
252    ) {
253        if !self.where_conditions.is_empty() {
254            sql.push_str(" WHERE ");
255            for (i, condition) in self.where_conditions.iter().enumerate() {
256                if i > 0 {
257                    sql.push_str(" AND ");
258                }
259
260                if condition.column == "RAW"
261                    || condition.column == "EXISTS"
262                    || condition.column == "NOT EXISTS"
263                {
264                    // Don't escape special keywords
265                    sql.push_str(&condition.column);
266                } else {
267                    sql.push_str(&escape_identifier(&condition.column));
268                }
269                sql.push(' ');
270
271                match condition.operator {
272                    QueryOperator::In | QueryOperator::NotIn => {
273                        sql.push_str(&condition.operator.to_string());
274                        sql.push_str(" (");
275                        for (j, value) in condition.values.iter().enumerate() {
276                            if j > 0 {
277                                sql.push_str(", ");
278                            }
279                            sql.push_str(&format!("${}", param_counter));
280                            params.push(self.json_value_to_param_string(value));
281                            *param_counter += 1;
282                        }
283                        sql.push(')');
284                    }
285                    QueryOperator::Between => {
286                        sql.push_str(&condition.operator.to_string());
287                        sql.push_str(&format!(" ${} AND ${}", param_counter, *param_counter + 1));
288                        if condition.values.len() >= 2 {
289                            params.push(self.json_value_to_param_string(&condition.values[0]));
290                            params.push(self.json_value_to_param_string(&condition.values[1]));
291                        }
292                        *param_counter += 2;
293                    }
294                    QueryOperator::IsNull | QueryOperator::IsNotNull => {
295                        sql.push_str(&condition.operator.to_string());
296                    }
297                    _ => {
298                        sql.push_str(&condition.operator.to_string());
299                        if let Some(ref value) = condition.value {
300                            sql.push_str(&format!(" ${}", param_counter));
301                            params.push(self.json_value_to_param_string(value));
302                            *param_counter += 1;
303                        }
304                    }
305                }
306            }
307        }
308    }
309
310    /// Helper method to build ORDER BY and LIMIT clauses
311    fn build_order_limit_clause(&self, sql: &mut String) {
312        // ORDER BY clause
313        if !self.order_by.is_empty() {
314            sql.push_str(" ORDER BY ");
315            for (i, (column, direction)) in self.order_by.iter().enumerate() {
316                if i > 0 {
317                    sql.push_str(", ");
318                }
319                sql.push_str(&format!("{} {}", escape_identifier(column), direction));
320            }
321        }
322
323        // LIMIT clause
324        if let Some(limit) = self.limit_count {
325            sql.push_str(&format!(" LIMIT {}", limit));
326        }
327
328        // OFFSET clause
329        if let Some(offset) = self.offset_value {
330            sql.push_str(&format!(" OFFSET {}", offset));
331        }
332    }
333
334    /// Generate SQL with parameters (unsafe version for backward compatibility)
335    fn to_sql_with_params_unsafe(&self) -> String {
336        self.to_sql_with_params().0
337    }
338
339    /// Convert the query to SQL string (for backwards compatibility)
340    pub fn to_sql(&self) -> String {
341        match self.query_type {
342            QueryType::Select => self.build_select_sql_simple(),
343            _ => self.to_sql_with_params_unsafe(),
344        }
345    }
346
347    /// Build SELECT SQL without parameters (for testing and simple queries)
348    fn build_select_sql_simple(&self) -> String {
349        let mut sql = String::new();
350
351        // SELECT clause
352        if self.distinct {
353            sql.push_str("SELECT DISTINCT ");
354        } else {
355            sql.push_str("SELECT ");
356        }
357
358        if self.select_fields.is_empty() {
359            sql.push('*');
360        } else {
361            sql.push_str(&self.select_fields.join(", "));
362        }
363
364        // FROM clause (no escaping for backward compatibility)
365        if !self.from_tables.is_empty() {
366            sql.push_str(" FROM ");
367            sql.push_str(&self.from_tables.join(", "));
368        }
369
370        // JOIN clauses
371        for join in &self.joins {
372            sql.push_str(&format!(" {} {}", join.join_type, join.table));
373            if !join.on_conditions.is_empty() {
374                sql.push_str(" ON ");
375                let conditions: Vec<String> = join
376                    .on_conditions
377                    .iter()
378                    .map(|(left, right)| format!("{} = {}", left, right))
379                    .collect();
380                sql.push_str(&conditions.join(" AND "));
381            }
382        }
383
384        // WHERE clause
385        if !self.where_conditions.is_empty() {
386            sql.push_str(" WHERE ");
387            let conditions = self.build_where_conditions(&self.where_conditions);
388            sql.push_str(&conditions.join(" AND "));
389        }
390
391        // GROUP BY clause
392        if !self.group_by.is_empty() {
393            sql.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
394        }
395
396        // HAVING clause
397        if !self.having_conditions.is_empty() {
398            sql.push_str(" HAVING ");
399            let conditions = self.build_where_conditions(&self.having_conditions);
400            sql.push_str(&conditions.join(" AND "));
401        }
402
403        // ORDER BY clause
404        if !self.order_by.is_empty() {
405            sql.push_str(" ORDER BY ");
406            let order_clauses: Vec<String> = self
407                .order_by
408                .iter()
409                .map(|(column, direction)| format!("{} {}", column, direction))
410                .collect();
411            sql.push_str(&order_clauses.join(", "));
412        }
413
414        // LIMIT clause
415        if let Some(limit) = self.limit_count {
416            sql.push_str(&format!(" LIMIT {}", limit));
417        }
418
419        // OFFSET clause
420        if let Some(offset) = self.offset_value {
421            sql.push_str(&format!(" OFFSET {}", offset));
422        }
423
424        sql
425    }
426
427    /// Build WHERE condition strings
428    fn build_where_conditions(&self, conditions: &[WhereCondition]) -> Vec<String> {
429        conditions
430            .iter()
431            .map(|condition| {
432                // Handle special raw conditions
433                if condition.column == "RAW" {
434                    if let Some(Value::String(raw_sql)) = &condition.value {
435                        return raw_sql.clone();
436                    }
437                }
438
439                // Handle EXISTS and NOT EXISTS
440                if condition.column == "EXISTS" || condition.column == "NOT EXISTS" {
441                    if let Some(Value::String(subquery)) = &condition.value {
442                        return format!("{} {}", condition.column, subquery);
443                    }
444                }
445
446                match &condition.operator {
447                    QueryOperator::IsNull | QueryOperator::IsNotNull => {
448                        format!("{} {}", condition.column, condition.operator)
449                    }
450                    QueryOperator::In | QueryOperator::NotIn => {
451                        // Handle subqueries (stored in value field) vs regular IN lists (stored in values field)
452                        if let Some(Value::String(subquery)) = &condition.value {
453                            if subquery.starts_with('(') && subquery.ends_with(')') {
454                                // This is a subquery
455                                format!("{} {} {}", condition.column, condition.operator, subquery)
456                            } else {
457                                // Single value IN (unusual case)
458                                format!(
459                                    "{} {} ({})",
460                                    condition.column,
461                                    condition.operator,
462                                    self.format_value(condition.value.as_ref().unwrap())
463                                )
464                            }
465                        } else {
466                            // Regular IN with multiple values
467                            let values: Vec<String> = condition
468                                .values
469                                .iter()
470                                .map(|v| self.format_value(v))
471                                .collect();
472                            format!(
473                                "{} {} ({})",
474                                condition.column,
475                                condition.operator,
476                                values.join(", ")
477                            )
478                        }
479                    }
480                    QueryOperator::Between => {
481                        if condition.values.len() == 2 {
482                            format!(
483                                "{} BETWEEN {} AND {}",
484                                condition.column,
485                                self.format_value(&condition.values[0]),
486                                self.format_value(&condition.values[1])
487                            )
488                        } else {
489                            format!("{} = NULL", condition.column) // Invalid BETWEEN
490                        }
491                    }
492                    _ => {
493                        if let Some(value) = &condition.value {
494                            // Handle subquery values
495                            if let Value::String(val_str) = value {
496                                if val_str.starts_with('(') && val_str.ends_with(')') {
497                                    // This looks like a subquery
498                                    format!(
499                                        "{} {} {}",
500                                        condition.column, condition.operator, val_str
501                                    )
502                                } else {
503                                    format!(
504                                        "{} {} {}",
505                                        condition.column,
506                                        condition.operator,
507                                        self.format_value(value)
508                                    )
509                                }
510                            } else {
511                                format!(
512                                    "{} {} {}",
513                                    condition.column,
514                                    condition.operator,
515                                    self.format_value(value)
516                                )
517                            }
518                        } else {
519                            format!("{} = NULL", condition.column) // Fallback
520                        }
521                    }
522                }
523            })
524            .collect()
525    }
526
527    /// Format a value for SQL
528    pub(crate) fn format_value(&self, value: &Value) -> String {
529        match value {
530            Value::String(s) => format!("'{}'", s.replace('\'', "''")), // Escape single quotes
531            Value::Number(n) => n.to_string(),
532            Value::Bool(b) => b.to_string(),
533            Value::Null => "NULL".to_string(),
534            _ => "NULL".to_string(), // Arrays and objects not yet supported
535        }
536    }
537
538    /// Convert JSON value to parameter string for SQL parameters
539    fn json_value_to_param_string(&self, value: &Value) -> String {
540        match value {
541            Value::String(s) => s.clone(), // Extract the string without JSON quotes
542            Value::Number(n) => n.to_string(),
543            Value::Bool(b) => b.to_string(),
544            Value::Null => "NULL".to_string(),
545            _ => value.to_string(), // Fallback to JSON representation
546        }
547    }
548}