elif_orm/query/
sql_generation.rs

1//! Query Builder SQL generation
2//! This module implements secure SQL generation with proper parameterization
3//! and identifier escaping to prevent SQL injection attacks.
4
5use serde_json::Value;
6use super::builder::QueryBuilder;
7use super::types::*;
8use crate::security::{escape_identifier, validate_identifier, validate_parameter};
9use crate::error::ModelError;
10
11impl<M> QueryBuilder<M> {
12    /// Generate SQL from query with parameter placeholders and return parameters
13    /// This method includes basic identifier escaping for security
14    pub fn to_sql_with_params(&self) -> (String, Vec<String>) {
15        // Note: For backward compatibility, this returns the tuple directly
16        // In the future, consider making this return Result<(String, Vec<String>), ModelError>
17        match self.query_type {
18            QueryType::Select => self.build_select_sql(),
19            QueryType::Insert => self.build_insert_sql(),
20            QueryType::Update => self.build_update_sql(),
21            QueryType::Delete => self.build_delete_sql(),
22        }
23    }
24
25    /// Generate SQL with security validation - returns Result for proper error handling
26    pub fn to_sql_with_params_secure(&self) -> Result<(String, Vec<String>), ModelError> {
27        // Validate identifiers first
28        self.validate_query_security()?;
29        Ok(self.to_sql_with_params())
30    }
31
32    /// Validate query security before SQL generation
33    fn validate_query_security(&self) -> Result<(), ModelError> {
34        // Validate table identifiers
35        for table in &self.from_tables {
36            validate_identifier(table)?;
37        }
38        
39        if let Some(ref table) = self.insert_table {
40            validate_identifier(table)?;
41        }
42        
43        if let Some(ref table) = self.update_table {
44            validate_identifier(table)?;
45        }
46        
47        if let Some(ref table) = self.delete_table {
48            validate_identifier(table)?;
49        }
50        
51        // Validate select field identifiers
52        for field in &self.select_fields {
53            // Skip wildcard and function calls for now
54            if field != "*" && !field.contains('(') {
55                validate_identifier(field)?;
56            }
57        }
58        
59        // Validate column identifiers in WHERE clauses
60        for condition in &self.where_conditions {
61            if condition.column != "RAW" && condition.column != "EXISTS" && condition.column != "NOT EXISTS" {
62                validate_identifier(&condition.column)?;
63            }
64        }
65        
66        // Validate JOIN table identifiers
67        for join in &self.joins {
68            validate_identifier(&join.table)?;
69        }
70        
71        // Validate parameter values
72        for condition in &self.where_conditions {
73            if let Some(ref value) = condition.value {
74                if let Value::String(s) = value {
75                    validate_parameter(s)?;
76                }
77            }
78            for value in &condition.values {
79                if let Value::String(s) = value {
80                    validate_parameter(s)?;
81                }
82            }
83        }
84        
85        Ok(())
86    }
87
88    /// Build SELECT SQL with parameters
89    fn build_select_sql(&self) -> (String, Vec<String>) {
90        let mut sql = String::new();
91        let mut params = Vec::new();
92        let mut param_counter = 1;
93
94        // SELECT clause
95        if self.distinct {
96            sql.push_str("SELECT DISTINCT ");
97        } else {
98            sql.push_str("SELECT ");
99        }
100        
101        if self.select_fields.is_empty() {
102            sql.push('*');
103        } else {
104            let escaped_fields: Vec<String> = self.select_fields.iter()
105                .map(|field| {
106                    if field == "*" || field.contains('(') {
107                        // Keep wildcards and function calls as-is
108                        field.clone()
109                    } else {
110                        escape_identifier(field)
111                    }
112                })
113                .collect();
114            sql.push_str(&escaped_fields.join(", "));
115        }
116
117        // FROM clause
118        if !self.from_tables.is_empty() {
119            sql.push_str(" FROM ");
120            let escaped_tables: Vec<String> = self.from_tables.iter()
121                .map(|table| escape_identifier(table))
122                .collect();
123            sql.push_str(&escaped_tables.join(", "));
124        }
125
126        // JOIN clauses
127        for join in &self.joins {
128            sql.push(' ');
129            sql.push_str(&join.join_type.to_string());
130            sql.push(' ');
131            sql.push_str(&escape_identifier(&join.table));
132            sql.push_str(" ON ");
133            for (i, (left, right)) in join.on_conditions.iter().enumerate() {
134                if i > 0 {
135                    sql.push_str(" AND ");
136                }
137                sql.push_str(&format!("{} = {}", escape_identifier(left), escape_identifier(right)));
138            }
139        }
140
141        self.build_where_clause(&mut sql, &mut params, &mut param_counter);
142        self.build_order_limit_clause(&mut sql);
143
144        (sql, params)
145    }
146
147    /// Build INSERT SQL with parameters
148    fn build_insert_sql(&self) -> (String, Vec<String>) {
149        let mut sql = String::new();
150        let mut params = Vec::new();
151        let mut param_counter = 1;
152
153        if let Some(table) = &self.insert_table {
154            sql.push_str("INSERT INTO ");
155            sql.push_str(&escape_identifier(table));
156            
157            if !self.set_clauses.is_empty() {
158                sql.push_str(" (");
159                let columns: Vec<String> = self.set_clauses.iter()
160                    .map(|clause| escape_identifier(&clause.column))
161                    .collect();
162                sql.push_str(&columns.join(", "));
163                sql.push_str(") VALUES (");
164                
165                for (i, clause) in self.set_clauses.iter().enumerate() {
166                    if i > 0 {
167                        sql.push_str(", ");
168                    }
169                    if let Some(ref value) = clause.value {
170                        sql.push_str(&format!("${}", param_counter));
171                        params.push(self.json_value_to_param_string(value));
172                        param_counter += 1;
173                    } else {
174                        sql.push_str("NULL");
175                    }
176                }
177                sql.push(')');
178            }
179        }
180
181        (sql, params)
182    }
183
184    /// Build UPDATE SQL with parameters
185    fn build_update_sql(&self) -> (String, Vec<String>) {
186        let mut sql = String::new();
187        let mut params = Vec::new();
188        let mut param_counter = 1;
189
190        if let Some(table) = &self.update_table {
191            sql.push_str("UPDATE ");
192            sql.push_str(&escape_identifier(table));
193            
194            if !self.set_clauses.is_empty() {
195                sql.push_str(" SET ");
196                for (i, clause) in self.set_clauses.iter().enumerate() {
197                    if i > 0 {
198                        sql.push_str(", ");
199                    }
200                    sql.push_str(&escape_identifier(&clause.column));
201                    sql.push_str(" = ");
202                    if let Some(ref value) = clause.value {
203                        sql.push_str(&format!("${}", param_counter));
204                        params.push(self.json_value_to_param_string(value));
205                        param_counter += 1;
206                    } else {
207                        sql.push_str("NULL");
208                    }
209                }
210            }
211
212            self.build_where_clause(&mut sql, &mut params, &mut param_counter);
213        }
214
215        (sql, params)
216    }
217
218    /// Build DELETE SQL with parameters
219    fn build_delete_sql(&self) -> (String, Vec<String>) {
220        let mut sql = String::new();
221        let mut params = Vec::new();
222        let mut param_counter = 1;
223
224        if let Some(table) = &self.delete_table {
225            sql.push_str("DELETE FROM ");
226            sql.push_str(&escape_identifier(table));
227            self.build_where_clause(&mut sql, &mut params, &mut param_counter);
228        }
229
230        (sql, params)
231    }
232
233    /// Helper method to build WHERE clauses
234    fn build_where_clause(&self, sql: &mut String, params: &mut Vec<String>, param_counter: &mut i32) {
235        if !self.where_conditions.is_empty() {
236            sql.push_str(" WHERE ");
237            for (i, condition) in self.where_conditions.iter().enumerate() {
238                if i > 0 {
239                    sql.push_str(" AND ");
240                }
241                
242                if condition.column == "RAW" || condition.column == "EXISTS" || condition.column == "NOT EXISTS" {
243                    // Don't escape special keywords
244                    sql.push_str(&condition.column);
245                } else {
246                    sql.push_str(&escape_identifier(&condition.column));
247                }
248                sql.push(' ');
249
250                match condition.operator {
251                    QueryOperator::In | QueryOperator::NotIn => {
252                        sql.push_str(&condition.operator.to_string());
253                        sql.push_str(" (");
254                        for (j, value) in condition.values.iter().enumerate() {
255                            if j > 0 {
256                                sql.push_str(", ");
257                            }
258                            sql.push_str(&format!("${}", param_counter));
259                            params.push(self.json_value_to_param_string(value));
260                            *param_counter += 1;
261                        }
262                        sql.push(')');
263                    }
264                    QueryOperator::Between => {
265                        sql.push_str(&condition.operator.to_string());
266                        sql.push_str(&format!(" ${} AND ${}", param_counter, *param_counter + 1));
267                        if condition.values.len() >= 2 {
268                            params.push(self.json_value_to_param_string(&condition.values[0]));
269                            params.push(self.json_value_to_param_string(&condition.values[1]));
270                        }
271                        *param_counter += 2;
272                    }
273                    QueryOperator::IsNull | QueryOperator::IsNotNull => {
274                        sql.push_str(&condition.operator.to_string());
275                    }
276                    _ => {
277                        sql.push_str(&condition.operator.to_string());
278                        if let Some(ref value) = condition.value {
279                            sql.push_str(&format!(" ${}", param_counter));
280                            params.push(self.json_value_to_param_string(value));
281                            *param_counter += 1;
282                        }
283                    }
284                }
285            }
286        }
287    }
288
289    /// Helper method to build ORDER BY and LIMIT clauses
290    fn build_order_limit_clause(&self, sql: &mut String) {
291        // ORDER BY clause
292        if !self.order_by.is_empty() {
293            sql.push_str(" ORDER BY ");
294            for (i, (column, direction)) in self.order_by.iter().enumerate() {
295                if i > 0 {
296                    sql.push_str(", ");
297                }
298                sql.push_str(&format!("{} {}", escape_identifier(column), direction));
299            }
300        }
301
302        // LIMIT clause
303        if let Some(limit) = self.limit_count {
304            sql.push_str(&format!(" LIMIT {}", limit));
305        }
306
307        // OFFSET clause
308        if let Some(offset) = self.offset_value {
309            sql.push_str(&format!(" OFFSET {}", offset));
310        }
311    }
312
313    /// Generate SQL with parameters (unsafe version for backward compatibility)
314    fn to_sql_with_params_unsafe(&self) -> String {
315        self.to_sql_with_params().0
316    }
317
318    /// Convert the query to SQL string (for backwards compatibility)
319    pub fn to_sql(&self) -> String {
320        match self.query_type {
321            QueryType::Select => self.build_select_sql_simple(),
322            _ => self.to_sql_with_params_unsafe(),
323        }
324    }
325
326    /// Build SELECT SQL without parameters (for testing and simple queries)
327    fn build_select_sql_simple(&self) -> String {
328        let mut sql = String::new();
329
330        // SELECT clause
331        if self.distinct {
332            sql.push_str("SELECT DISTINCT ");
333        } else {
334            sql.push_str("SELECT ");
335        }
336
337        if self.select_fields.is_empty() {
338            sql.push('*');
339        } else {
340            sql.push_str(&self.select_fields.join(", "));
341        }
342
343        // FROM clause (no escaping for backward compatibility)
344        if !self.from_tables.is_empty() {
345            sql.push_str(" FROM ");
346            sql.push_str(&self.from_tables.join(", "));
347        }
348
349        // JOIN clauses
350        for join in &self.joins {
351            sql.push_str(&format!(" {} {}", join.join_type, join.table));
352            if !join.on_conditions.is_empty() {
353                sql.push_str(" ON ");
354                let conditions: Vec<String> = join
355                    .on_conditions
356                    .iter()
357                    .map(|(left, right)| format!("{} = {}", left, right))
358                    .collect();
359                sql.push_str(&conditions.join(" AND "));
360            }
361        }
362
363        // WHERE clause
364        if !self.where_conditions.is_empty() {
365            sql.push_str(" WHERE ");
366            let conditions = self.build_where_conditions(&self.where_conditions);
367            sql.push_str(&conditions.join(" AND "));
368        }
369
370        // GROUP BY clause
371        if !self.group_by.is_empty() {
372            sql.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
373        }
374
375        // HAVING clause
376        if !self.having_conditions.is_empty() {
377            sql.push_str(" HAVING ");
378            let conditions = self.build_where_conditions(&self.having_conditions);
379            sql.push_str(&conditions.join(" AND "));
380        }
381
382        // ORDER BY clause
383        if !self.order_by.is_empty() {
384            sql.push_str(" ORDER BY ");
385            let order_clauses: Vec<String> = self
386                .order_by
387                .iter()
388                .map(|(column, direction)| format!("{} {}", column, direction))
389                .collect();
390            sql.push_str(&order_clauses.join(", "));
391        }
392
393        // LIMIT clause
394        if let Some(limit) = self.limit_count {
395            sql.push_str(&format!(" LIMIT {}", limit));
396        }
397
398        // OFFSET clause
399        if let Some(offset) = self.offset_value {
400            sql.push_str(&format!(" OFFSET {}", offset));
401        }
402
403        sql
404    }
405
406    /// Build WHERE condition strings
407    fn build_where_conditions(&self, conditions: &[WhereCondition]) -> Vec<String> {
408        conditions
409            .iter()
410            .map(|condition| {
411                // Handle special raw conditions
412                if condition.column == "RAW" {
413                    if let Some(Value::String(raw_sql)) = &condition.value {
414                        return raw_sql.clone();
415                    }
416                }
417                
418                // Handle EXISTS and NOT EXISTS
419                if condition.column == "EXISTS" || condition.column == "NOT EXISTS" {
420                    if let Some(Value::String(subquery)) = &condition.value {
421                        return format!("{} {}", condition.column, subquery);
422                    }
423                }
424                
425                match &condition.operator {
426                    QueryOperator::IsNull | QueryOperator::IsNotNull => {
427                        format!("{} {}", condition.column, condition.operator)
428                    }
429                    QueryOperator::In | QueryOperator::NotIn => {
430                        // Handle subqueries (stored in value field) vs regular IN lists (stored in values field)
431                        if let Some(Value::String(subquery)) = &condition.value {
432                            if subquery.starts_with('(') && subquery.ends_with(')') {
433                                // This is a subquery
434                                format!("{} {} {}", condition.column, condition.operator, subquery)
435                            } else {
436                                // Single value IN (unusual case)
437                                format!("{} {} ({})", condition.column, condition.operator, self.format_value(&condition.value.as_ref().unwrap()))
438                            }
439                        } else {
440                            // Regular IN with multiple values
441                            let values: Vec<String> = condition
442                                .values
443                                .iter()
444                                .map(|v| self.format_value(v))
445                                .collect();
446                            format!("{} {} ({})", condition.column, condition.operator, values.join(", "))
447                        }
448                    }
449                    QueryOperator::Between => {
450                        if condition.values.len() == 2 {
451                            format!(
452                                "{} BETWEEN {} AND {}",
453                                condition.column,
454                                self.format_value(&condition.values[0]),
455                                self.format_value(&condition.values[1])
456                            )
457                        } else {
458                            format!("{} = NULL", condition.column) // Invalid BETWEEN
459                        }
460                    }
461                    _ => {
462                        if let Some(value) = &condition.value {
463                            // Handle subquery values
464                            if let Value::String(val_str) = value {
465                                if val_str.starts_with('(') && val_str.ends_with(')') {
466                                    // This looks like a subquery
467                                    format!("{} {} {}", condition.column, condition.operator, val_str)
468                                } else {
469                                    format!("{} {} {}", condition.column, condition.operator, self.format_value(value))
470                                }
471                            } else {
472                                format!("{} {} {}", condition.column, condition.operator, self.format_value(value))
473                            }
474                        } else {
475                            format!("{} = NULL", condition.column) // Fallback
476                        }
477                    }
478                }
479            })
480            .collect()
481    }
482
483    /// Format a value for SQL
484    pub(crate) fn format_value(&self, value: &Value) -> String {
485        match value {
486            Value::String(s) => format!("'{}'", s.replace('\'', "''")), // Escape single quotes
487            Value::Number(n) => n.to_string(),
488            Value::Bool(b) => b.to_string(),
489            Value::Null => "NULL".to_string(),
490            _ => "NULL".to_string(), // Arrays and objects not yet supported
491        }
492    }
493
494    /// Convert JSON value to parameter string for SQL parameters
495    fn json_value_to_param_string(&self, value: &Value) -> String {
496        match value {
497            Value::String(s) => s.clone(), // Extract the string without JSON quotes
498            Value::Number(n) => n.to_string(),
499            Value::Bool(b) => b.to_string(),
500            Value::Null => "NULL".to_string(),
501            _ => value.to_string(), // Fallback to JSON representation
502        }
503    }
504}