Skip to main content

fraiseql_core/db/
where_sql_generator.rs

1//! WHERE clause to SQL string generator for fraiseql-wire.
2//!
3//! Converts FraiseQL's WHERE clause AST to SQL predicates that can be used
4//! with fraiseql-wire's `where_sql()` method.
5
6use serde_json::Value;
7
8use crate::{
9    db::{WhereClause, WhereOperator},
10    error::{FraiseQLError, Result},
11};
12
13/// Generates SQL WHERE clause strings from AST.
14pub struct WhereSqlGenerator;
15
16impl WhereSqlGenerator {
17    /// Convert WHERE clause AST to SQL string.
18    ///
19    /// # Example
20    ///
21    /// ```rust
22    /// use fraiseql_core::db::{WhereClause, WhereOperator, where_sql_generator::WhereSqlGenerator};
23    /// use serde_json::json;
24    ///
25    /// let clause = WhereClause::Field {
26    ///     path: vec!["status".to_string()],
27    ///     operator: WhereOperator::Eq,
28    ///     value: json!("active"),
29    /// };
30    ///
31    /// let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
32    /// assert_eq!(sql, "data->>'status' = 'active'");
33    /// ```
34    pub fn to_sql(clause: &WhereClause) -> Result<String> {
35        match clause {
36            WhereClause::Field {
37                path,
38                operator,
39                value,
40            } => Self::generate_field_predicate(path, operator, value),
41            WhereClause::And(clauses) => {
42                if clauses.is_empty() {
43                    return Ok("TRUE".to_string());
44                }
45                let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
46                Ok(format!("({})", parts?.join(" AND ")))
47            },
48            WhereClause::Or(clauses) => {
49                if clauses.is_empty() {
50                    return Ok("FALSE".to_string());
51                }
52                let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
53                Ok(format!("({})", parts?.join(" OR ")))
54            },
55            WhereClause::Not(clause) => {
56                let inner = Self::to_sql(clause)?;
57                Ok(format!("NOT ({})", inner))
58            },
59        }
60    }
61
62    fn generate_field_predicate(
63        path: &[String],
64        operator: &WhereOperator,
65        value: &Value,
66    ) -> Result<String> {
67        let json_path = Self::build_json_path(path);
68        let sql = match operator {
69            // Null checks
70            WhereOperator::IsNull => {
71                let is_null = value.as_bool().unwrap_or(true);
72                if is_null {
73                    format!("{json_path} IS NULL")
74                } else {
75                    format!("{json_path} IS NOT NULL")
76                }
77            },
78            // All other operators
79            _ => {
80                let sql_op = Self::operator_to_sql(operator)?;
81                let sql_value = Self::value_to_sql(value, operator)?;
82                format!("{json_path} {sql_op} {sql_value}")
83            },
84        };
85        Ok(sql)
86    }
87
88    fn build_json_path(path: &[String]) -> String {
89        if path.is_empty() {
90            return "data".to_string();
91        }
92
93        if path.len() == 1 {
94            // Simple path: data->>'field'
95            // SECURITY: Escape field name to prevent SQL injection
96            let escaped = Self::escape_sql_string(&path[0]);
97            format!("data->>'{}'", escaped)
98        } else {
99            // Nested path: data#>'{a,b,c}'->>'d'
100            // SECURITY: Escape all field names to prevent SQL injection
101            let nested = &path[..path.len() - 1];
102            let last = &path[path.len() - 1];
103
104            // Escape all nested components
105            let escaped_nested: Vec<String> =
106                nested.iter().map(|n| Self::escape_sql_string(n)).collect();
107            let nested_path = escaped_nested.join(",");
108            let escaped_last = Self::escape_sql_string(last);
109            format!("data#>'{{{}}}'->>'{}'", nested_path, escaped_last)
110        }
111    }
112
113    fn operator_to_sql(operator: &WhereOperator) -> Result<&'static str> {
114        Ok(match operator {
115            // Comparison
116            WhereOperator::Eq => "=",
117            WhereOperator::Neq => "!=",
118            WhereOperator::Gt => ">",
119            WhereOperator::Gte => ">=",
120            WhereOperator::Lt => "<",
121            WhereOperator::Lte => "<=",
122
123            // Containment
124            WhereOperator::In => "= ANY",
125            WhereOperator::Nin => "!= ALL",
126
127            // String operations
128            WhereOperator::Contains => "LIKE",
129            WhereOperator::Icontains => "ILIKE",
130            WhereOperator::Startswith => "LIKE",
131            WhereOperator::Istartswith => "ILIKE",
132            WhereOperator::Endswith => "LIKE",
133            WhereOperator::Iendswith => "ILIKE",
134            WhereOperator::Like => "LIKE",
135            WhereOperator::Ilike => "ILIKE",
136
137            // Array operations
138            WhereOperator::ArrayContains => "@>",
139            WhereOperator::ArrayContainedBy => "<@",
140            WhereOperator::ArrayOverlaps => "&&",
141
142            // These operators require special handling
143            WhereOperator::IsNull => {
144                return Err(FraiseQLError::Internal {
145                    message: "IsNull should be handled separately".to_string(),
146                    source:  None,
147                });
148            },
149            WhereOperator::LenEq
150            | WhereOperator::LenGt
151            | WhereOperator::LenLt
152            | WhereOperator::LenGte
153            | WhereOperator::LenLte
154            | WhereOperator::LenNeq => {
155                return Err(FraiseQLError::Internal {
156                    message: format!(
157                        "Array length operators not yet supported in fraiseql-wire: {operator:?}"
158                    ),
159                    source:  None,
160                });
161            },
162
163            // Vector operations not supported
164            WhereOperator::L2Distance
165            | WhereOperator::CosineDistance
166            | WhereOperator::L1Distance
167            | WhereOperator::HammingDistance
168            | WhereOperator::InnerProduct
169            | WhereOperator::JaccardDistance => {
170                return Err(FraiseQLError::Internal {
171                    message: format!(
172                        "Vector operations not supported in fraiseql-wire: {operator:?}"
173                    ),
174                    source:  None,
175                });
176            },
177
178            // Full-text search operators not supported yet
179            WhereOperator::Matches
180            | WhereOperator::PlainQuery
181            | WhereOperator::PhraseQuery
182            | WhereOperator::WebsearchQuery => {
183                return Err(FraiseQLError::Internal {
184                    message: format!(
185                        "Full-text search operators not yet supported in fraiseql-wire: {operator:?}"
186                    ),
187                    source:  None,
188                });
189            },
190
191            // Network operators not supported yet
192            WhereOperator::IsIPv4
193            | WhereOperator::IsIPv6
194            | WhereOperator::IsPrivate
195            | WhereOperator::IsPublic
196            | WhereOperator::IsLoopback
197            | WhereOperator::InSubnet
198            | WhereOperator::ContainsSubnet
199            | WhereOperator::ContainsIP
200            | WhereOperator::Overlaps
201            | WhereOperator::StrictlyContains
202            | WhereOperator::AncestorOf
203            | WhereOperator::DescendantOf
204            | WhereOperator::MatchesLquery
205            | WhereOperator::MatchesLtxtquery
206            | WhereOperator::MatchesAnyLquery
207            | WhereOperator::DepthEq
208            | WhereOperator::DepthNeq
209            | WhereOperator::DepthGt
210            | WhereOperator::DepthGte
211            | WhereOperator::DepthLt
212            | WhereOperator::DepthLte
213            | WhereOperator::Lca
214            | WhereOperator::Extended(_) => {
215                return Err(FraiseQLError::Internal {
216                    message: format!(
217                        "Advanced operators not yet supported in fraiseql-wire: {operator:?}"
218                    ),
219                    source:  None,
220                });
221            },
222        })
223    }
224
225    fn value_to_sql(value: &Value, operator: &WhereOperator) -> Result<String> {
226        match (value, operator) {
227            (Value::Null, _) => Ok("NULL".to_string()),
228            (Value::Bool(b), _) => Ok(b.to_string()),
229            (Value::Number(n), _) => Ok(n.to_string()),
230
231            // String operators with wildcards
232            (Value::String(s), WhereOperator::Contains | WhereOperator::Icontains) => {
233                Ok(format!("'%{}%'", Self::escape_sql_string(s)))
234            },
235            (Value::String(s), WhereOperator::Startswith | WhereOperator::Istartswith) => {
236                Ok(format!("'{}%'", Self::escape_sql_string(s)))
237            },
238            (Value::String(s), WhereOperator::Endswith | WhereOperator::Iendswith) => {
239                Ok(format!("'%{}'", Self::escape_sql_string(s)))
240            },
241
242            // Regular strings
243            (Value::String(s), _) => Ok(format!("'{}'", Self::escape_sql_string(s))),
244
245            // Arrays (for IN operator)
246            (Value::Array(arr), WhereOperator::In | WhereOperator::Nin) => {
247                let values: Result<Vec<_>> =
248                    arr.iter().map(|v| Self::value_to_sql(v, &WhereOperator::Eq)).collect();
249                Ok(format!("ARRAY[{}]", values?.join(", ")))
250            },
251
252            // Array operations
253            (
254                Value::Array(_),
255                WhereOperator::ArrayContains
256                | WhereOperator::ArrayContainedBy
257                | WhereOperator::ArrayOverlaps,
258            ) => {
259                // SECURITY: Serialize to JSON string and escape single quotes to prevent
260                // SQL injection. The serde_json serializer handles internal escaping, and
261                // we escape single quotes for the SQL string literal context.
262                let json_str =
263                    serde_json::to_string(value).map_err(|e| FraiseQLError::Internal {
264                        message: format!("Failed to serialize JSON for array operator: {e}"),
265                        source:  None,
266                    })?;
267                let escaped = json_str.replace('\'', "''");
268                Ok(format!("'{}'::jsonb", escaped))
269            },
270
271            _ => Err(FraiseQLError::Internal {
272                message: format!(
273                    "Unsupported value type for operator: {value:?} with {operator:?}"
274                ),
275                source:  None,
276            }),
277        }
278    }
279
280    fn escape_sql_string(s: &str) -> String {
281        s.replace('\'', "''")
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use serde_json::json;
288
289    use super::*;
290
291    #[test]
292    fn test_simple_equality() {
293        let clause = WhereClause::Field {
294            path:     vec!["status".to_string()],
295            operator: WhereOperator::Eq,
296            value:    json!("active"),
297        };
298
299        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
300        assert_eq!(sql, "data->>'status' = 'active'");
301    }
302
303    #[test]
304    fn test_nested_path() {
305        let clause = WhereClause::Field {
306            path:     vec!["user".to_string(), "email".to_string()],
307            operator: WhereOperator::Eq,
308            value:    json!("test@example.com"),
309        };
310
311        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
312        assert_eq!(sql, "data#>'{user}'->>'email' = 'test@example.com'");
313    }
314
315    #[test]
316    fn test_icontains() {
317        let clause = WhereClause::Field {
318            path:     vec!["name".to_string()],
319            operator: WhereOperator::Icontains,
320            value:    json!("john"),
321        };
322
323        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
324        assert_eq!(sql, "data->>'name' ILIKE '%john%'");
325    }
326
327    #[test]
328    fn test_startswith() {
329        let clause = WhereClause::Field {
330            path:     vec!["email".to_string()],
331            operator: WhereOperator::Startswith,
332            value:    json!("admin"),
333        };
334
335        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
336        assert_eq!(sql, "data->>'email' LIKE 'admin%'");
337    }
338
339    #[test]
340    fn test_and_clause() {
341        let clause = WhereClause::And(vec![
342            WhereClause::Field {
343                path:     vec!["status".to_string()],
344                operator: WhereOperator::Eq,
345                value:    json!("active"),
346            },
347            WhereClause::Field {
348                path:     vec!["age".to_string()],
349                operator: WhereOperator::Gte,
350                value:    json!(18),
351            },
352        ]);
353
354        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
355        assert_eq!(sql, "(data->>'status' = 'active' AND data->>'age' >= 18)");
356    }
357
358    #[test]
359    fn test_or_clause() {
360        let clause = WhereClause::Or(vec![
361            WhereClause::Field {
362                path:     vec!["type".to_string()],
363                operator: WhereOperator::Eq,
364                value:    json!("admin"),
365            },
366            WhereClause::Field {
367                path:     vec!["type".to_string()],
368                operator: WhereOperator::Eq,
369                value:    json!("moderator"),
370            },
371        ]);
372
373        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
374        assert_eq!(sql, "(data->>'type' = 'admin' OR data->>'type' = 'moderator')");
375    }
376
377    #[test]
378    fn test_not_clause() {
379        let clause = WhereClause::Not(Box::new(WhereClause::Field {
380            path:     vec!["deleted".to_string()],
381            operator: WhereOperator::Eq,
382            value:    json!(true),
383        }));
384
385        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
386        assert_eq!(sql, "NOT (data->>'deleted' = true)");
387    }
388
389    #[test]
390    fn test_is_null() {
391        let clause = WhereClause::Field {
392            path:     vec!["deleted_at".to_string()],
393            operator: WhereOperator::IsNull,
394            value:    json!(true),
395        };
396
397        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
398        assert_eq!(sql, "data->>'deleted_at' IS NULL");
399    }
400
401    #[test]
402    fn test_is_not_null() {
403        let clause = WhereClause::Field {
404            path:     vec!["updated_at".to_string()],
405            operator: WhereOperator::IsNull,
406            value:    json!(false),
407        };
408
409        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
410        assert_eq!(sql, "data->>'updated_at' IS NOT NULL");
411    }
412
413    #[test]
414    fn test_in_operator() {
415        let clause = WhereClause::Field {
416            path:     vec!["status".to_string()],
417            operator: WhereOperator::In,
418            value:    json!(["active", "pending", "approved"]),
419        };
420
421        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
422        assert_eq!(sql, "data->>'status' = ANY ARRAY['active', 'pending', 'approved']");
423    }
424
425    #[test]
426    fn test_sql_injection_prevention() {
427        let clause = WhereClause::Field {
428            path:     vec!["name".to_string()],
429            operator: WhereOperator::Eq,
430            value:    json!("'; DROP TABLE users; --"),
431        };
432
433        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
434        assert_eq!(sql, "data->>'name' = '''; DROP TABLE users; --'");
435        // Single quotes are escaped to ''
436    }
437
438    #[test]
439    fn test_numeric_comparison() {
440        let clause = WhereClause::Field {
441            path:     vec!["price".to_string()],
442            operator: WhereOperator::Gt,
443            value:    json!(99.99),
444        };
445
446        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
447        assert_eq!(sql, "data->>'price' > 99.99");
448    }
449
450    #[test]
451    fn test_boolean_value() {
452        let clause = WhereClause::Field {
453            path:     vec!["published".to_string()],
454            operator: WhereOperator::Eq,
455            value:    json!(true),
456        };
457
458        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
459        assert_eq!(sql, "data->>'published' = true");
460    }
461
462    #[test]
463    fn test_empty_and_clause() {
464        let clause = WhereClause::And(vec![]);
465        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
466        assert_eq!(sql, "TRUE");
467    }
468
469    #[test]
470    fn test_empty_or_clause() {
471        let clause = WhereClause::Or(vec![]);
472        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
473        assert_eq!(sql, "FALSE");
474    }
475
476    #[test]
477    fn test_complex_nested_condition() {
478        let clause = WhereClause::And(vec![
479            WhereClause::Field {
480                path:     vec!["type".to_string()],
481                operator: WhereOperator::Eq,
482                value:    json!("article"),
483            },
484            WhereClause::Or(vec![
485                WhereClause::Field {
486                    path:     vec!["status".to_string()],
487                    operator: WhereOperator::Eq,
488                    value:    json!("published"),
489                },
490                WhereClause::And(vec![
491                    WhereClause::Field {
492                        path:     vec!["status".to_string()],
493                        operator: WhereOperator::Eq,
494                        value:    json!("draft"),
495                    },
496                    WhereClause::Field {
497                        path:     vec!["author".to_string(), "role".to_string()],
498                        operator: WhereOperator::Eq,
499                        value:    json!("admin"),
500                    },
501                ]),
502            ]),
503        ]);
504
505        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
506        assert_eq!(
507            sql,
508            "(data->>'type' = 'article' AND (data->>'status' = 'published' OR (data->>'status' = 'draft' AND data#>'{author}'->>'role' = 'admin')))"
509        );
510    }
511
512    #[test]
513    fn test_sql_injection_in_field_name_simple() {
514        // Test that malicious field names are escaped to prevent SQL injection
515        let clause = WhereClause::Field {
516            path:     vec!["name'; DROP TABLE users; --".to_string()],
517            operator: WhereOperator::Eq,
518            value:    json!("value"),
519        };
520
521        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
522        // Field name should be escaped with doubled single quotes
523        // Result: data->>'name''; DROP TABLE users; --' = 'value'
524        // The doubled '' prevents the quote from closing the string
525        assert!(sql.contains("''")); // Escaped quotes present
526        // The SQL structure should be: identifier->>'field' operator value
527        // With escaping, DROP TABLE becomes part of the field string, not executable
528        assert!(sql.contains("data->>'"));
529        assert!(sql.contains("= 'value'")); // Proper value comparison
530    }
531
532    #[test]
533    fn test_sql_injection_prevention_in_array_operator() {
534        // SECURITY: Ensure JSON injection in array operators is escaped
535        let clause = WhereClause::Field {
536            path:     vec!["tags".to_string()],
537            operator: WhereOperator::ArrayContains,
538            value:    json!(["normal", "'; DROP TABLE users; --"]),
539        };
540
541        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
542        // The JSON serializer will escape the inner quotes, and we escape SQL single quotes.
543        // The result should be a properly escaped JSONB literal, not executable SQL.
544        assert!(sql.contains("::jsonb"), "Must produce valid JSONB cast");
545        // Verify the value is inside a JSON string (double-quoted), not a raw SQL string.
546        // serde_json serializes this as: ["normal","'; DROP TABLE users; --"]
547        // After SQL escaping: [\"normal\",\"''; DROP TABLE users; --\"]
548        // The single quote inside the JSON value is doubled for SQL safety.
549        assert!(
550            sql.contains("''"),
551            "Single quotes inside JSON values must be doubled for SQL safety"
552        );
553    }
554
555    #[test]
556    fn test_sql_injection_in_nested_field_name() {
557        // Test that malicious nested field names are also escaped
558        let clause = WhereClause::Field {
559            path:     vec![
560                "user".to_string(),
561                "role'; DROP TABLE users; --".to_string(),
562            ],
563            operator: WhereOperator::Eq,
564            value:    json!("admin"),
565        };
566
567        let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
568        // Both simple and nested path components should be escaped
569        assert!(sql.contains("''")); // Escaped quotes present
570        assert!(sql.contains("data#>'{")); // Nested path syntax
571    }
572}