Skip to main content

fraiseql_db/
where_sql_generator.rs

1//! WHERE clause to SQL string generator for fraiseql-wire.
2//!
3//! Converts FraiseQL's WHERE clause AST to SQL predicates that can be used
4//! with fraiseql-wire's `where_sql()` method.
5
6use fraiseql_error::{FraiseQLError, Result};
7use serde_json::Value;
8
9use crate::{WhereClause, WhereOperator};
10
11/// Maximum allowed byte length for a string value embedded in a raw SQL query.
12///
13/// Applies to SQL fragments assembled via string escaping (e.g. LIKE patterns,
14/// JSON path keys). Regular parameterized query paths are unaffected.
15/// 64 KiB is generous for any realistic filter value while blocking DoS inputs.
16const MAX_SQL_VALUE_BYTES: usize = 65_536;
17
18/// Generates SQL WHERE clause strings from AST.
19///
20/// # Note on continued existence
21///
22/// This generator embeds values as escaped string literals rather than using
23/// bind parameters.  It is intentionally retained for the **FraiseQL Wire
24/// Adapter** (`fraiseql_wire_adapter`), which constructs raw SQL strings for
25/// the wire protocol — a context where parameterized queries are not available.
26///
27/// **Do not use this in new production code.**  All other query paths must use
28/// [`GenericWhereGenerator`](crate::GenericWhereGenerator) which produces
29/// parameterized SQL (`$1`, `?`, etc.) and is safe by design.
30#[doc(hidden)]
31pub struct WhereSqlGenerator;
32
33impl WhereSqlGenerator {
34    /// Convert WHERE clause AST to SQL string.
35    ///
36    /// # Example
37    ///
38    /// ```rust,no_run
39    /// // fraiseql-db can be used directly or via `fraiseql_core::db` (re-export).
40    /// use fraiseql_db::{WhereClause, WhereOperator, where_sql_generator::WhereSqlGenerator};
41    /// use serde_json::json;
42    ///
43    /// let clause = WhereClause::Field {
44    ///     path: vec!["status".to_string()],
45    ///     operator: WhereOperator::Eq,
46    ///     value: json!("active"),
47    /// };
48    ///
49    /// let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
50    /// assert_eq!(sql, "data->>'status' = 'active'");
51    /// ```
52    ///
53    /// # Errors
54    ///
55    /// Returns `FraiseQLError::Validation` if the clause contains an unsupported
56    /// operator or an invalid value for the given operator.
57    pub fn to_sql(clause: &WhereClause) -> Result<String> {
58        match clause {
59            WhereClause::Field {
60                path,
61                operator,
62                value,
63            } => Self::generate_field_predicate(path, operator, value),
64            WhereClause::And(clauses) => {
65                if clauses.is_empty() {
66                    return Ok("TRUE".to_string());
67                }
68                let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
69                Ok(format!("({})", parts?.join(" AND ")))
70            },
71            WhereClause::Or(clauses) => {
72                if clauses.is_empty() {
73                    return Ok("FALSE".to_string());
74                }
75                let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
76                Ok(format!("({})", parts?.join(" OR ")))
77            },
78            WhereClause::Not(clause) => {
79                let inner = Self::to_sql(clause)?;
80                Ok(format!("NOT ({})", inner))
81            },
82            WhereClause::NativeField {
83                column,
84                operator,
85                value,
86                ..
87            } => {
88                // Wire adapter: use native column name directly with escaped literal value.
89                // Cast suffix is omitted — wire protocol assembles raw SQL without bind params.
90                let escaped_col = Self::escape_sql_string(column)?;
91                let col_expr = format!("\"{escaped_col}\"");
92                let sql_op = Self::operator_to_sql(operator)?;
93                let val_sql = Self::value_to_sql(value, operator)?;
94                Ok(format!("{col_expr} {sql_op} {val_sql}"))
95            },
96        }
97    }
98
99    fn generate_field_predicate(
100        path: &[String],
101        operator: &WhereOperator,
102        value: &Value,
103    ) -> Result<String> {
104        let json_path = Self::build_json_path(path)?;
105        let sql = if operator == &WhereOperator::IsNull {
106            let is_null = value.as_bool().unwrap_or(true);
107            if is_null {
108                format!("{json_path} IS NULL")
109            } else {
110                format!("{json_path} IS NOT NULL")
111            }
112        } else {
113            let sql_op = Self::operator_to_sql(operator)?;
114            let sql_value = Self::value_to_sql(value, operator)?;
115            format!("{json_path} {sql_op} {sql_value}")
116        };
117        Ok(sql)
118    }
119
120    fn build_json_path(path: &[String]) -> Result<String> {
121        if path.is_empty() {
122            return Ok("data".to_string());
123        }
124
125        if path.len() == 1 {
126            // Simple path: data->>'field'
127            // SECURITY: Escape field name to prevent SQL injection
128            let escaped = Self::escape_sql_string(&path[0])?;
129            Ok(format!("data->>'{}'", escaped))
130        } else {
131            // Nested path: data#>'{a,b,c}'->>'d'
132            // SECURITY: Escape all field names to prevent SQL injection
133            let nested = &path[..path.len() - 1];
134            let last = &path[path.len() - 1];
135
136            // Escape all nested components
137            let escaped_nested: Vec<String> =
138                nested.iter().map(|n| Self::escape_sql_string(n)).collect::<Result<Vec<_>>>()?;
139            let nested_path = escaped_nested.join(",");
140            let escaped_last = Self::escape_sql_string(last)?;
141            Ok(format!("data#>'{{{}}}'->>'{}'", nested_path, escaped_last))
142        }
143    }
144
145    fn operator_to_sql(operator: &WhereOperator) -> Result<&'static str> {
146        Ok(match operator {
147            // Comparison
148            WhereOperator::Eq => "=",
149            WhereOperator::Neq => "!=",
150            WhereOperator::Gt => ">",
151            WhereOperator::Gte => ">=",
152            WhereOperator::Lt => "<",
153            WhereOperator::Lte => "<=",
154
155            // Containment
156            WhereOperator::In => "= ANY",
157            WhereOperator::Nin => "!= ALL",
158
159            // String operations
160            WhereOperator::Contains => "LIKE",
161            WhereOperator::Icontains => "ILIKE",
162            WhereOperator::Startswith => "LIKE",
163            WhereOperator::Istartswith => "ILIKE",
164            WhereOperator::Endswith => "LIKE",
165            WhereOperator::Iendswith => "ILIKE",
166            WhereOperator::Like => "LIKE",
167            WhereOperator::Ilike => "ILIKE",
168            WhereOperator::Nlike => "NOT LIKE",
169            WhereOperator::Nilike => "NOT ILIKE",
170            WhereOperator::Regex => "~",
171            WhereOperator::Iregex => "~*",
172            WhereOperator::Nregex => "!~",
173            WhereOperator::Niregex => "!~*",
174
175            // Array operations
176            WhereOperator::ArrayContains => "@>",
177            WhereOperator::ArrayContainedBy => "<@",
178            WhereOperator::ArrayOverlaps => "&&",
179
180            // These operators require special handling
181            WhereOperator::IsNull => {
182                return Err(FraiseQLError::Internal {
183                    message: "IsNull should be handled separately".to_string(),
184                    source:  None,
185                });
186            },
187            WhereOperator::LenEq
188            | WhereOperator::LenGt
189            | WhereOperator::LenLt
190            | WhereOperator::LenGte
191            | WhereOperator::LenLte
192            | WhereOperator::LenNeq => {
193                return Err(FraiseQLError::Internal {
194                    message: format!(
195                        "Array length operators not yet supported in fraiseql-wire: {operator:?}"
196                    ),
197                    source:  None,
198                });
199            },
200
201            // Vector operations not supported
202            WhereOperator::L2Distance
203            | WhereOperator::CosineDistance
204            | WhereOperator::L1Distance
205            | WhereOperator::HammingDistance
206            | WhereOperator::InnerProduct
207            | WhereOperator::JaccardDistance => {
208                return Err(FraiseQLError::Internal {
209                    message: format!(
210                        "Vector operations not supported in fraiseql-wire: {operator:?}"
211                    ),
212                    source:  None,
213                });
214            },
215
216            // Full-text search operators not supported yet
217            WhereOperator::Matches
218            | WhereOperator::PlainQuery
219            | WhereOperator::PhraseQuery
220            | WhereOperator::WebsearchQuery => {
221                return Err(FraiseQLError::Internal {
222                    message: format!(
223                        "Full-text search operators not yet supported in fraiseql-wire: {operator:?}"
224                    ),
225                    source:  None,
226                });
227            },
228
229            // Network operators not supported yet
230            WhereOperator::IsIPv4
231            | WhereOperator::IsIPv6
232            | WhereOperator::IsPrivate
233            | WhereOperator::IsLoopback
234            | WhereOperator::IsMulticast
235            | WhereOperator::IsLinkLocal
236            | WhereOperator::IsDocumentation
237            | WhereOperator::IsCarrierGrade
238            | WhereOperator::InSubnet
239            | WhereOperator::ContainsSubnet
240            | WhereOperator::ContainsIP
241            | WhereOperator::Overlaps
242            | WhereOperator::StrictlyContains
243            | WhereOperator::AncestorOf
244            | WhereOperator::DescendantOf
245            | WhereOperator::MatchesLquery
246            | WhereOperator::MatchesLtxtquery
247            | WhereOperator::MatchesAnyLquery
248            | WhereOperator::DepthEq
249            | WhereOperator::DepthNeq
250            | WhereOperator::DepthGt
251            | WhereOperator::DepthGte
252            | WhereOperator::DepthLt
253            | WhereOperator::DepthLte
254            | WhereOperator::Lca
255            | WhereOperator::DescendantOfId
256            | WhereOperator::AncestorOfId
257            | WhereOperator::Extended(_) => {
258                return Err(FraiseQLError::Internal {
259                    message: format!(
260                        "Advanced operators not yet supported in fraiseql-wire: {operator:?}"
261                    ),
262                    source:  None,
263                });
264            },
265        })
266    }
267
268    fn value_to_sql(value: &Value, operator: &WhereOperator) -> Result<String> {
269        match (value, operator) {
270            (Value::Null, _) => Ok("NULL".to_string()),
271            (Value::Bool(b), _) => Ok(b.to_string()),
272            (Value::Number(n), _) => Ok(n.to_string()),
273
274            // String operators with wildcards
275            (Value::String(s), WhereOperator::Contains | WhereOperator::Icontains) => {
276                Ok(format!("'%{}%'", Self::escape_sql_string(s)?))
277            },
278            (Value::String(s), WhereOperator::Startswith | WhereOperator::Istartswith) => {
279                Ok(format!("'{}%'", Self::escape_sql_string(s)?))
280            },
281            (Value::String(s), WhereOperator::Endswith | WhereOperator::Iendswith) => {
282                Ok(format!("'%{}'", Self::escape_sql_string(s)?))
283            },
284
285            // Regular strings
286            (Value::String(s), _) => Ok(format!("'{}'", Self::escape_sql_string(s)?)),
287
288            // Arrays (for IN operator)
289            (Value::Array(arr), WhereOperator::In | WhereOperator::Nin) => {
290                let values: Result<Vec<_>> =
291                    arr.iter().map(|v| Self::value_to_sql(v, &WhereOperator::Eq)).collect();
292                Ok(format!("ARRAY[{}]", values?.join(", ")))
293            },
294
295            // Array operations
296            (
297                Value::Array(_),
298                WhereOperator::ArrayContains
299                | WhereOperator::ArrayContainedBy
300                | WhereOperator::ArrayOverlaps,
301            ) => {
302                // SECURITY: Serialize to JSON string and escape single quotes to prevent
303                // SQL injection. The serde_json serializer handles internal escaping, and
304                // we escape single quotes for the SQL string literal context.
305                let json_str =
306                    serde_json::to_string(value).map_err(|e| FraiseQLError::Internal {
307                        message: format!("Failed to serialize JSON for array operator: {e}"),
308                        source:  None,
309                    })?;
310                if json_str.len() > MAX_SQL_VALUE_BYTES {
311                    return Err(FraiseQLError::Validation {
312                        message: format!(
313                            "JSONB value exceeds maximum allowed size for SQL embedding \
314                             ({} bytes, limit is {} bytes)",
315                            json_str.len(),
316                            MAX_SQL_VALUE_BYTES
317                        ),
318                        path:    None,
319                    });
320                }
321                let escaped = json_str.replace('\'', "''");
322                Ok(format!("'{}'::jsonb", escaped))
323            },
324
325            _ => Err(FraiseQLError::Internal {
326                message: format!(
327                    "Unsupported value type for operator: {value:?} with {operator:?}"
328                ),
329                source:  None,
330            }),
331        }
332    }
333
334    fn escape_sql_string(s: &str) -> Result<String> {
335        if s.len() > MAX_SQL_VALUE_BYTES {
336            return Err(FraiseQLError::Validation {
337                message: format!(
338                    "String value exceeds maximum allowed size for SQL embedding \
339                     ({} bytes, limit is {} bytes)",
340                    s.len(),
341                    MAX_SQL_VALUE_BYTES
342                ),
343                path:    None,
344            });
345        }
346        Ok(s.replace('\'', "''"))
347    }
348}
349
350#[cfg(test)]
351mod tests;