Skip to main content

fraiseql_db/where_generator/
generic.rs

1//! Generic WHERE clause generator parameterised over a SQL dialect.
2
3use std::{collections::HashSet, sync::Arc};
4
5use fraiseql_error::{FraiseQLError, Result};
6
7use super::counter::ParamCounter;
8use crate::{
9    dialect::SqlDialect,
10    where_clause::{WhereClause, WhereOperator},
11};
12
13/// Escape LIKE metacharacters (`%`, `_`, `\`) in a user-supplied string so
14/// that it is treated as a literal substring inside a LIKE/ILIKE pattern.
15///
16/// Order matters: `\` is escaped first to avoid double-escaping.
17pub(crate) fn escape_like_literal(s: &str) -> String {
18    s.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
19}
20
21/// Maximum allowed length for user-supplied regex patterns.
22///
23/// PostgreSQL has no built-in regex timeout, so excessively long patterns
24/// or patterns with nested quantifiers can cause CPU exhaustion (ReDoS).
25const MAX_REGEX_PATTERN_LEN: usize = 1_000;
26
27/// Validate a user-supplied regex pattern for obvious ReDoS risks.
28///
29/// Rejects:
30/// - Patterns exceeding `MAX_REGEX_PATTERN_LEN` bytes
31/// - Patterns containing nested quantifiers (e.g., `(a+)+`, `(a*)*`, `(a+)*`)
32///
33/// This is not a full ReDoS detector but catches the most common attack vectors.
34fn validate_regex_pattern(pattern: &str) -> Result<()> {
35    if pattern.len() > MAX_REGEX_PATTERN_LEN {
36        return Err(FraiseQLError::Validation {
37            message: format!(
38                "Regex pattern exceeds maximum length of {MAX_REGEX_PATTERN_LEN} bytes"
39            ),
40            path:    None,
41        });
42    }
43
44    // Detect nested quantifiers: a quantifier (+, *, ?, {n}) immediately after
45    // a closing paren that itself follows a quantifier. Simplified heuristic:
46    // look for `)` followed by a quantifier, where the group contains a quantifier.
47    let bytes = pattern.as_bytes();
48    let mut depth: i32 = 0;
49    let mut group_has_quantifier = Vec::new(); // stack: does current group have a quantifier?
50
51    for (i, &b) in bytes.iter().enumerate() {
52        // Skip escaped characters
53        if i > 0 && bytes[i - 1] == b'\\' {
54            continue;
55        }
56        match b {
57            b'(' => {
58                depth += 1;
59                group_has_quantifier.push(false);
60            },
61            b')' => {
62                let had_quantifier = group_has_quantifier.pop().unwrap_or(false);
63                depth -= 1;
64                // Check if a quantifier follows this closing paren
65                if had_quantifier {
66                    let next = bytes.get(i + 1).copied();
67                    if matches!(next, Some(b'+' | b'*' | b'?' | b'{')) {
68                        return Err(FraiseQLError::Validation {
69                            message: "Regex pattern contains nested quantifiers (potential \
70                                      ReDoS). Simplify the pattern to avoid `(…+)+`, \
71                                      `(…*)*`, or similar constructs."
72                                .to_string(),
73                            path:    None,
74                        });
75                    }
76                }
77            },
78            b'+' | b'*' | b'?' => {
79                if let Some(flag) = group_has_quantifier.last_mut() {
80                    *flag = true;
81                }
82            },
83            b'{' if depth > 0 => {
84                if let Some(flag) = group_has_quantifier.last_mut() {
85                    *flag = true;
86                }
87            },
88            _ => {},
89        }
90    }
91
92    Ok(())
93}
94
95/// Generic WHERE clause SQL generator.
96///
97/// Replaces `PostgresWhereGenerator`, `MySqlWhereGenerator`,
98/// `SqliteWhereGenerator`, and `SqlServerWhereGenerator` — all dialect-specific
99/// primitives are delegated to `D: SqlDialect`.
100///
101/// # Interior mutability
102///
103/// The parameter counter uses `Cell<usize>` (via `ParamCounter`).  This is
104/// safe because:
105/// - `GenericWhereGenerator` is not `Sync` — no concurrent access is possible.
106/// - `generate()` resets the counter before every call.
107///
108/// # Example
109///
110/// ```rust
111/// use fraiseql_db::dialect::PostgresDialect;
112/// use fraiseql_db::where_generator::GenericWhereGenerator;
113/// use fraiseql_db::{WhereClause, WhereOperator};
114/// use serde_json::json;
115///
116/// let gen = GenericWhereGenerator::new(PostgresDialect);
117/// let clause = WhereClause::Field {
118///     path: vec!["email".to_string()],
119///     operator: WhereOperator::Eq,
120///     value: json!("alice@example.com"),
121/// };
122/// let (sql, params) = gen.generate(&clause).unwrap();
123/// assert_eq!(sql, "data->>'email' = $1");
124/// ```
125pub struct GenericWhereGenerator<D: SqlDialect> {
126    dialect:         D,
127    counter:         ParamCounter,
128    /// Optional indexed-column set (PostgreSQL optimisation: short-circuits JSONB
129    /// extraction when a generated column covers the path).
130    indexed_columns: Option<Arc<HashSet<String>>>,
131}
132
133impl<D: SqlDialect> GenericWhereGenerator<D> {
134    /// Create a new generator for the given dialect.
135    pub const fn new(dialect: D) -> Self {
136        Self {
137            dialect,
138            counter: ParamCounter::new(),
139            indexed_columns: None,
140        }
141    }
142
143    /// Attach an indexed-columns set (PostgreSQL optimisation).
144    ///
145    /// When a WHERE path matches a column name in this set, the generator
146    /// emits `"col_name" = $N` instead of `data->>'col_name' = $N`.
147    #[must_use]
148    pub fn with_indexed_columns(mut self, cols: Arc<HashSet<String>>) -> Self {
149        self.indexed_columns = Some(cols);
150        self
151    }
152
153    /// Generate SQL WHERE clause starting parameter numbering at 1.
154    ///
155    /// # Errors
156    ///
157    /// Returns `FraiseQLError::Validation` if the clause uses an operator
158    /// not supported by the dialect.
159    pub fn generate(&self, clause: &WhereClause) -> Result<(String, Vec<serde_json::Value>)> {
160        self.generate_with_param_offset(clause, 0)
161    }
162
163    /// Generate SQL WHERE clause with parameter numbering starting after `offset`.
164    ///
165    /// Use when the WHERE clause is appended to a query that already has bound
166    /// parameters (e.g. cursor values in relay pagination).
167    ///
168    /// # Errors
169    ///
170    /// Returns `FraiseQLError::Validation` if the clause uses an unsupported
171    /// operator.
172    pub fn generate_with_param_offset(
173        &self,
174        clause: &WhereClause,
175        offset: usize,
176    ) -> Result<(String, Vec<serde_json::Value>)> {
177        self.counter.reset_to(offset);
178        let mut params = Vec::new();
179        let sql = self.visit(clause, &mut params)?;
180        Ok((sql, params))
181    }
182
183    // ── Visitor ───────────────────────────────────────────────────────────────
184
185    fn visit(&self, clause: &WhereClause, params: &mut Vec<serde_json::Value>) -> Result<String> {
186        match clause {
187            WhereClause::And(clauses) => self.visit_and(clauses, params),
188            WhereClause::Or(clauses) => self.visit_or(clauses, params),
189            WhereClause::Not(inner) => Ok(format!("NOT ({})", self.visit(inner, params)?)),
190            WhereClause::Field {
191                path,
192                operator,
193                value,
194            } => self.visit_field(path, operator, value, params),
195            WhereClause::NativeField {
196                column,
197                pg_cast,
198                operator,
199                value,
200            } => self.visit_native_field(column, pg_cast, operator, value, params),
201        }
202    }
203
204    /// Generate SQL for a native-column condition.
205    ///
206    /// Emits `"column" = <cast>` where `<cast>` is a dialect-appropriate
207    /// expression (e.g. `$1::text::uuid` for PostgreSQL, `CAST(? AS CHAR)` for
208    /// MySQL) instead of the JSONB extraction path.
209    fn visit_native_field(
210        &self,
211        column: &str,
212        pg_cast: &str,
213        operator: &WhereOperator,
214        value: &serde_json::Value,
215        params: &mut Vec<serde_json::Value>,
216    ) -> Result<String> {
217        let col_expr = self.dialect.quote_identifier(column);
218        let p = self.push_param(params, value.clone());
219        let rhs = if pg_cast.is_empty() {
220            p
221        } else {
222            self.dialect.cast_native_param(&p, pg_cast)
223        };
224        match operator {
225            WhereOperator::Eq => Ok(format!("{col_expr} = {rhs}")),
226            WhereOperator::Neq => {
227                let neq = self.dialect.neq_operator();
228                Ok(format!("{col_expr} {neq} {rhs}"))
229            },
230            _ => Err(FraiseQLError::validation(format!(
231                "Operator {operator:?} is not supported for native column conditions"
232            ))),
233        }
234    }
235
236    fn visit_and(
237        &self,
238        clauses: &[WhereClause],
239        params: &mut Vec<serde_json::Value>,
240    ) -> Result<String> {
241        if clauses.is_empty() {
242            return Ok(self.dialect.always_true().to_string());
243        }
244        let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
245        Ok(format!("({})", parts?.join(" AND ")))
246    }
247
248    fn visit_or(
249        &self,
250        clauses: &[WhereClause],
251        params: &mut Vec<serde_json::Value>,
252    ) -> Result<String> {
253        if clauses.is_empty() {
254            return Ok(self.dialect.always_false().to_string());
255        }
256        let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
257        Ok(format!("({})", parts?.join(" OR ")))
258    }
259
260    // ── Field expression resolution ───────────────────────────────────────────
261
262    fn resolve_field_expr(&self, path: &[String]) -> String {
263        // PostgreSQL indexed-column optimisation.
264        if let Some(indexed) = &self.indexed_columns {
265            let col_name = path.join("__");
266            if indexed.contains(&col_name) {
267                return self.dialect.quote_identifier(&col_name);
268            }
269        }
270        self.dialect.json_extract_scalar("data", path)
271    }
272
273    // ── Push a parameter and return its placeholder ───────────────────────────
274
275    fn push_param(&self, params: &mut Vec<serde_json::Value>, v: serde_json::Value) -> String {
276        params.push(v);
277        self.dialect.placeholder(self.counter.next())
278    }
279
280    // ── Field visitor ─────────────────────────────────────────────────────────
281
282    fn visit_field(
283        &self,
284        path: &[String],
285        operator: &WhereOperator,
286        value: &serde_json::Value,
287        params: &mut Vec<serde_json::Value>,
288    ) -> Result<String> {
289        let field_expr = self.resolve_field_expr(path);
290
291        match operator {
292            // ── Comparison ────────────────────────────────────────────────────
293            WhereOperator::Eq => {
294                let p = self.push_param(params, value.clone());
295                if value.is_number() {
296                    let cast = self.dialect.cast_to_numeric(&field_expr);
297                    // Dialect-specific RHS cast: PostgreSQL uses (p::text)::numeric to
298                    // avoid wire-protocol type mismatch; other dialects pass p unchanged.
299                    let rhs = self.dialect.cast_param_numeric(&p);
300                    Ok(format!("{cast} = {rhs}"))
301                } else if value.is_boolean() {
302                    let cast = self.dialect.cast_to_boolean(&field_expr);
303                    Ok(format!("{cast} = {p}"))
304                } else {
305                    Ok(format!("{field_expr} = {p}"))
306                }
307            },
308            WhereOperator::Neq => {
309                let p = self.push_param(params, value.clone());
310                let neq = self.dialect.neq_operator();
311                if value.is_number() {
312                    let cast = self.dialect.cast_to_numeric(&field_expr);
313                    let rhs = self.dialect.cast_param_numeric(&p);
314                    Ok(format!("{cast} {neq} {rhs}"))
315                } else if value.is_boolean() {
316                    let cast = self.dialect.cast_to_boolean(&field_expr);
317                    Ok(format!("{cast} {neq} {p}"))
318                } else {
319                    Ok(format!("{field_expr} {neq} {p}"))
320                }
321            },
322            WhereOperator::Gt | WhereOperator::Gte | WhereOperator::Lt | WhereOperator::Lte => {
323                let op = match operator {
324                    WhereOperator::Gt => ">",
325                    WhereOperator::Gte => ">=",
326                    WhereOperator::Lt => "<",
327                    _ => "<=",
328                };
329                let cast = self.dialect.cast_to_numeric(&field_expr);
330                let p = self.push_param(params, value.clone());
331                let rhs = self.dialect.cast_param_numeric(&p);
332                Ok(format!("{cast} {op} {rhs}"))
333            },
334
335            // ── Containment ───────────────────────────────────────────────────
336            WhereOperator::In | WhereOperator::Nin => {
337                let arr = value.as_array().ok_or_else(|| {
338                    FraiseQLError::validation("IN operator requires an array value".to_string())
339                })?;
340                if arr.is_empty() {
341                    return Ok(if matches!(operator, WhereOperator::In) {
342                        self.dialect.always_false().to_string()
343                    } else {
344                        self.dialect.always_true().to_string()
345                    });
346                }
347                let placeholders: Vec<_> =
348                    arr.iter().map(|v| self.push_param(params, v.clone())).collect();
349                let in_list = placeholders.join(", ");
350                let sql = format!("{field_expr} IN ({in_list})");
351                Ok(if matches!(operator, WhereOperator::Nin) {
352                    format!("NOT ({sql})")
353                } else {
354                    sql
355                })
356            },
357
358            // ── NULL ──────────────────────────────────────────────────────────
359            WhereOperator::IsNull => {
360                let is_null = value.as_bool().unwrap_or(true);
361                let null_op = if is_null { "IS NULL" } else { "IS NOT NULL" };
362                Ok(format!("{field_expr} {null_op}"))
363            },
364
365            // ── String: LIKE family ───────────────────────────────────────────
366            WhereOperator::Contains => {
367                let val_str = self.require_str(value, "Contains")?;
368                let escaped = escape_like_literal(val_str);
369                let p = self.push_param(params, serde_json::Value::String(escaped));
370                let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
371                Ok(self.dialect.like_sql(&field_expr, &pattern))
372            },
373            WhereOperator::Icontains => {
374                let val_str = self.require_str(value, "Icontains")?;
375                let escaped = escape_like_literal(val_str);
376                let p = self.push_param(params, serde_json::Value::String(escaped));
377                let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
378                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
379            },
380            WhereOperator::Startswith => {
381                let val_str = self.require_str(value, "Startswith")?;
382                let escaped = escape_like_literal(val_str);
383                let p = self.push_param(params, serde_json::Value::String(escaped));
384                let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
385                Ok(self.dialect.like_sql(&field_expr, &pattern))
386            },
387            WhereOperator::Istartswith => {
388                let val_str = self.require_str(value, "Istartswith")?;
389                let escaped = escape_like_literal(val_str);
390                let p = self.push_param(params, serde_json::Value::String(escaped));
391                let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
392                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
393            },
394            WhereOperator::Endswith => {
395                let val_str = self.require_str(value, "Endswith")?;
396                let escaped = escape_like_literal(val_str);
397                let p = self.push_param(params, serde_json::Value::String(escaped));
398                let pattern = self.dialect.concat_sql(&["'%'", &p]);
399                Ok(self.dialect.like_sql(&field_expr, &pattern))
400            },
401            WhereOperator::Iendswith => {
402                let val_str = self.require_str(value, "Iendswith")?;
403                let escaped = escape_like_literal(val_str);
404                let p = self.push_param(params, serde_json::Value::String(escaped));
405                let pattern = self.dialect.concat_sql(&["'%'", &p]);
406                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
407            },
408            WhereOperator::Like => {
409                let p = self.push_param(params, value.clone());
410                Ok(self.dialect.like_sql(&field_expr, &p))
411            },
412            WhereOperator::Ilike => {
413                let p = self.push_param(params, value.clone());
414                Ok(self.dialect.ilike_sql(&field_expr, &p))
415            },
416            WhereOperator::Nlike => {
417                let p = self.push_param(params, value.clone());
418                Ok(format!("NOT ({})", self.dialect.like_sql(&field_expr, &p)))
419            },
420            WhereOperator::Nilike => {
421                let p = self.push_param(params, value.clone());
422                Ok(format!("NOT ({})", self.dialect.ilike_sql(&field_expr, &p)))
423            },
424
425            // ── String: Regex ─────────────────────────────────────────────────
426            WhereOperator::Regex => {
427                if let Some(s) = value.as_str() {
428                    validate_regex_pattern(s)?;
429                }
430                let p = self.push_param(params, value.clone());
431                self.dialect
432                    .regex_sql(&field_expr, &p, false, false)
433                    .map_err(|e| FraiseQLError::validation(e.to_string()))
434            },
435            WhereOperator::Iregex => {
436                if let Some(s) = value.as_str() {
437                    validate_regex_pattern(s)?;
438                }
439                let p = self.push_param(params, value.clone());
440                self.dialect
441                    .regex_sql(&field_expr, &p, true, false)
442                    .map_err(|e| FraiseQLError::validation(e.to_string()))
443            },
444            WhereOperator::Nregex => {
445                if let Some(s) = value.as_str() {
446                    validate_regex_pattern(s)?;
447                }
448                let p = self.push_param(params, value.clone());
449                self.dialect
450                    .regex_sql(&field_expr, &p, false, true)
451                    .map_err(|e| FraiseQLError::validation(e.to_string()))
452            },
453            WhereOperator::Niregex => {
454                if let Some(s) = value.as_str() {
455                    validate_regex_pattern(s)?;
456                }
457                let p = self.push_param(params, value.clone());
458                self.dialect
459                    .regex_sql(&field_expr, &p, true, true)
460                    .map_err(|e| FraiseQLError::validation(e.to_string()))
461            },
462
463            // ── Array: length ─────────────────────────────────────────────────
464            WhereOperator::LenEq
465            | WhereOperator::LenNeq
466            | WhereOperator::LenGt
467            | WhereOperator::LenGte
468            | WhereOperator::LenLt
469            | WhereOperator::LenLte => {
470                let op = match operator {
471                    WhereOperator::LenEq => "=",
472                    WhereOperator::LenNeq => self.dialect.neq_operator(),
473                    WhereOperator::LenGt => ">",
474                    WhereOperator::LenGte => ">=",
475                    WhereOperator::LenLt => "<",
476                    _ => "<=",
477                };
478                let len_expr = self.dialect.json_array_length(&field_expr);
479                let p = self.push_param(params, value.clone());
480                Ok(format!("{len_expr} {op} {p}"))
481            },
482
483            // ── Array: containment ────────────────────────────────────────────
484            WhereOperator::ArrayContains | WhereOperator::StrictlyContains => {
485                // Both @> (ArrayContains) and @> (StrictlyContains, a JSONB-level
486                // strict containment) are routed to array_contains_sql.
487                let p = self.push_param(params, value.clone());
488                self.dialect
489                    .array_contains_sql(&field_expr, &p)
490                    .map_err(|e| FraiseQLError::validation(e.to_string()))
491            },
492            WhereOperator::ArrayContainedBy => {
493                let p = self.push_param(params, value.clone());
494                self.dialect
495                    .array_contained_by_sql(&field_expr, &p)
496                    .map_err(|e| FraiseQLError::validation(e.to_string()))
497            },
498            WhereOperator::ArrayOverlaps => {
499                let p = self.push_param(params, value.clone());
500                self.dialect
501                    .array_overlaps_sql(&field_expr, &p)
502                    .map_err(|e| FraiseQLError::validation(e.to_string()))
503            },
504
505            // ── Full-text search ──────────────────────────────────────────────
506            WhereOperator::Matches => {
507                let p = self.push_param(params, value.clone());
508                self.dialect
509                    .fts_matches_sql(&field_expr, &p)
510                    .map_err(|e| FraiseQLError::validation(e.to_string()))
511            },
512            WhereOperator::PlainQuery => {
513                let p = self.push_param(params, value.clone());
514                self.dialect
515                    .fts_plain_query_sql(&field_expr, &p)
516                    .map_err(|e| FraiseQLError::validation(e.to_string()))
517            },
518            WhereOperator::PhraseQuery => {
519                let p = self.push_param(params, value.clone());
520                self.dialect
521                    .fts_phrase_query_sql(&field_expr, &p)
522                    .map_err(|e| FraiseQLError::validation(e.to_string()))
523            },
524            WhereOperator::WebsearchQuery => {
525                let p = self.push_param(params, value.clone());
526                self.dialect
527                    .fts_websearch_query_sql(&field_expr, &p)
528                    .map_err(|e| FraiseQLError::validation(e.to_string()))
529            },
530
531            // ── Vector (pgvector) ─────────────────────────────────────────────
532            WhereOperator::CosineDistance => {
533                let p = self.push_param(params, value.clone());
534                self.dialect
535                    .vector_distance_sql("<=>", &field_expr, &p)
536                    .map_err(|e| FraiseQLError::validation(e.to_string()))
537            },
538            WhereOperator::L2Distance => {
539                let p = self.push_param(params, value.clone());
540                self.dialect
541                    .vector_distance_sql("<->", &field_expr, &p)
542                    .map_err(|e| FraiseQLError::validation(e.to_string()))
543            },
544            WhereOperator::L1Distance => {
545                let p = self.push_param(params, value.clone());
546                self.dialect
547                    .vector_distance_sql("<+>", &field_expr, &p)
548                    .map_err(|e| FraiseQLError::validation(e.to_string()))
549            },
550            WhereOperator::HammingDistance => {
551                let p = self.push_param(params, value.clone());
552                self.dialect
553                    .vector_distance_sql("<~>", &field_expr, &p)
554                    .map_err(|e| FraiseQLError::validation(e.to_string()))
555            },
556            WhereOperator::InnerProduct => {
557                let p = self.push_param(params, value.clone());
558                self.dialect
559                    .vector_distance_sql("<#>", &field_expr, &p)
560                    .map_err(|e| FraiseQLError::validation(e.to_string()))
561            },
562            WhereOperator::JaccardDistance => {
563                let p = self.push_param(params, value.clone());
564                self.dialect
565                    .jaccard_distance_sql(&field_expr, &p)
566                    .map_err(|e| FraiseQLError::validation(e.to_string()))
567            },
568
569            // ── Network (INET/CIDR) ───────────────────────────────────────────
570            WhereOperator::IsIPv4 => self
571                .dialect
572                .inet_check_sql(&field_expr, "IsIPv4")
573                .map_err(|e| FraiseQLError::validation(e.to_string())),
574            WhereOperator::IsIPv6 => self
575                .dialect
576                .inet_check_sql(&field_expr, "IsIPv6")
577                .map_err(|e| FraiseQLError::validation(e.to_string())),
578            WhereOperator::IsPrivate => self
579                .dialect
580                .inet_check_sql(&field_expr, "IsPrivate")
581                .map_err(|e| FraiseQLError::validation(e.to_string())),
582            WhereOperator::IsPublic => self
583                .dialect
584                .inet_check_sql(&field_expr, "IsPublic")
585                .map_err(|e| FraiseQLError::validation(e.to_string())),
586            WhereOperator::IsLoopback => self
587                .dialect
588                .inet_check_sql(&field_expr, "IsLoopback")
589                .map_err(|e| FraiseQLError::validation(e.to_string())),
590            WhereOperator::InSubnet => {
591                let p = self.push_param(params, value.clone());
592                self.dialect
593                    .inet_binary_sql("<<", &field_expr, &p)
594                    .map_err(|e| FraiseQLError::validation(e.to_string()))
595            },
596            WhereOperator::ContainsSubnet | WhereOperator::ContainsIP => {
597                let p = self.push_param(params, value.clone());
598                self.dialect
599                    .inet_binary_sql(">>", &field_expr, &p)
600                    .map_err(|e| FraiseQLError::validation(e.to_string()))
601            },
602            WhereOperator::Overlaps => {
603                let p = self.push_param(params, value.clone());
604                self.dialect
605                    .inet_binary_sql("&&", &field_expr, &p)
606                    .map_err(|e| FraiseQLError::validation(e.to_string()))
607            },
608
609            // ── LTree ─────────────────────────────────────────────────────────
610            WhereOperator::AncestorOf => {
611                let p = self.push_param(params, value.clone());
612                self.dialect
613                    .ltree_binary_sql("@>", &field_expr, &p, "ltree")
614                    .map_err(|e| FraiseQLError::validation(e.to_string()))
615            },
616            WhereOperator::DescendantOf => {
617                let p = self.push_param(params, value.clone());
618                self.dialect
619                    .ltree_binary_sql("<@", &field_expr, &p, "ltree")
620                    .map_err(|e| FraiseQLError::validation(e.to_string()))
621            },
622            WhereOperator::MatchesLquery => {
623                let p = self.push_param(params, value.clone());
624                self.dialect
625                    .ltree_binary_sql("~", &field_expr, &p, "lquery")
626                    .map_err(|e| FraiseQLError::validation(e.to_string()))
627            },
628            WhereOperator::MatchesLtxtquery => {
629                let p = self.push_param(params, value.clone());
630                self.dialect
631                    .ltree_binary_sql("@", &field_expr, &p, "ltxtquery")
632                    .map_err(|e| FraiseQLError::validation(e.to_string()))
633            },
634            WhereOperator::MatchesAnyLquery => {
635                let arr = value.as_array().ok_or_else(|| {
636                    FraiseQLError::validation(
637                        "matches_any_lquery operator requires an array value".to_string(),
638                    )
639                })?;
640                if arr.is_empty() {
641                    return Err(FraiseQLError::validation(
642                        "matches_any_lquery requires at least one lquery".to_string(),
643                    ));
644                }
645                let placeholders: Vec<_> = arr
646                    .iter()
647                    .map(|v| format!("{}::lquery", self.push_param(params, v.clone())))
648                    .collect();
649                self.dialect
650                    .ltree_any_lquery_sql(&field_expr, &placeholders)
651                    .map_err(|e| FraiseQLError::validation(e.to_string()))
652            },
653            WhereOperator::DepthEq => {
654                let p = self.push_param(params, value.clone());
655                self.dialect
656                    .ltree_depth_sql("=", &field_expr, &p)
657                    .map_err(|e| FraiseQLError::validation(e.to_string()))
658            },
659            WhereOperator::DepthNeq => {
660                let p = self.push_param(params, value.clone());
661                self.dialect
662                    .ltree_depth_sql("!=", &field_expr, &p)
663                    .map_err(|e| FraiseQLError::validation(e.to_string()))
664            },
665            WhereOperator::DepthGt => {
666                let p = self.push_param(params, value.clone());
667                self.dialect
668                    .ltree_depth_sql(">", &field_expr, &p)
669                    .map_err(|e| FraiseQLError::validation(e.to_string()))
670            },
671            WhereOperator::DepthGte => {
672                let p = self.push_param(params, value.clone());
673                self.dialect
674                    .ltree_depth_sql(">=", &field_expr, &p)
675                    .map_err(|e| FraiseQLError::validation(e.to_string()))
676            },
677            WhereOperator::DepthLt => {
678                let p = self.push_param(params, value.clone());
679                self.dialect
680                    .ltree_depth_sql("<", &field_expr, &p)
681                    .map_err(|e| FraiseQLError::validation(e.to_string()))
682            },
683            WhereOperator::DepthLte => {
684                let p = self.push_param(params, value.clone());
685                self.dialect
686                    .ltree_depth_sql("<=", &field_expr, &p)
687                    .map_err(|e| FraiseQLError::validation(e.to_string()))
688            },
689            WhereOperator::Lca => {
690                let arr = value.as_array().ok_or_else(|| {
691                    FraiseQLError::validation("lca operator requires an array value".to_string())
692                })?;
693                if arr.is_empty() {
694                    return Err(FraiseQLError::validation(
695                        "lca operator requires at least one path".to_string(),
696                    ));
697                }
698                let placeholders: Vec<_> = arr
699                    .iter()
700                    .map(|v| format!("{}::ltree", self.push_param(params, v.clone())))
701                    .collect();
702                self.dialect
703                    .ltree_lca_sql(&field_expr, &placeholders)
704                    .map_err(|e| FraiseQLError::validation(e.to_string()))
705            },
706
707            // ── Extended operators ────────────────────────────────────────────
708            WhereOperator::Extended(op) => {
709                self.dialect.generate_extended_sql(op, &field_expr, params)
710            },
711
712            // ── Unknown / future operators ────────────────────────────────────
713            // This arm is only reachable if WhereOperator gains new variants
714            // (it is #[non_exhaustive]).  Suppress the lint that fires when all
715            // current variants are already matched above.
716            #[allow(unreachable_patterns)]
717            // Reason: defensive catch-all for future non_exhaustive variants
718            _ => Err(FraiseQLError::Validation {
719                message: format!(
720                    "Operator {operator:?} is not supported by the {} dialect",
721                    self.dialect.name()
722                ),
723                path:    None,
724            }),
725        }
726    }
727
728    fn require_str<'a>(&self, value: &'a serde_json::Value, op: &'static str) -> Result<&'a str> {
729        value.as_str().ok_or_else(|| {
730            FraiseQLError::validation(format!("{op} operator requires a string value"))
731        })
732    }
733}
734
735// ── Default impl ──────────────────────────────────────────────────────────────
736
737impl<D: SqlDialect + Default> Default for GenericWhereGenerator<D> {
738    fn default() -> Self {
739        Self::new(D::default())
740    }
741}
742
743// ── ExtendedOperatorHandler — single blanket impl ─────────────────────────────
744// Delegates to `D::generate_extended_sql`, which each dialect implements.
745
746impl<D: SqlDialect> crate::filters::ExtendedOperatorHandler for GenericWhereGenerator<D> {
747    fn generate_extended_sql(
748        &self,
749        operator: &crate::filters::ExtendedOperator,
750        field_sql: &str,
751        params: &mut Vec<serde_json::Value>,
752    ) -> Result<String> {
753        self.dialect.generate_extended_sql(operator, field_sql, params)
754    }
755}
756
757#[cfg(test)]
758#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
759mod tests {
760    use serde_json::json;
761
762    use super::GenericWhereGenerator;
763    use crate::{
764        dialect::PostgresDialect,
765        where_clause::{WhereClause, WhereOperator},
766    };
767
768    fn field(path: &str, op: WhereOperator, val: serde_json::Value) -> WhereClause {
769        WhereClause::Field {
770            path:     vec![path.to_string()],
771            operator: op,
772            value:    val,
773        }
774    }
775
776    // ── Core comparison / logical operators ──────────────────────────
777
778    #[test]
779    fn generic_eq_postgres() {
780        let gen = GenericWhereGenerator::new(PostgresDialect);
781        let clause = field("email", WhereOperator::Eq, json!("alice@example.com"));
782        let (sql, params) = gen.generate(&clause).unwrap();
783        assert_eq!(sql, "data->>'email' = $1");
784        assert_eq!(params, vec![json!("alice@example.com")]);
785    }
786
787    #[test]
788    fn generic_and_postgres() {
789        let gen = GenericWhereGenerator::new(PostgresDialect);
790        let clause = WhereClause::And(vec![
791            field("status", WhereOperator::Eq, json!("active")),
792            field("age", WhereOperator::Gte, json!(18)),
793        ]);
794        let (sql, params) = gen.generate(&clause).unwrap();
795        assert!(sql.starts_with("(data->>'status' = $1 AND"));
796        assert_eq!(params.len(), 2);
797    }
798
799    #[test]
800    fn generic_empty_and_returns_true() {
801        let gen = GenericWhereGenerator::new(PostgresDialect);
802        let clause = WhereClause::And(vec![]);
803        let (sql, params) = gen.generate(&clause).unwrap();
804        assert_eq!(sql, "TRUE");
805        assert!(params.is_empty());
806    }
807
808    #[test]
809    fn generic_empty_or_returns_false() {
810        let gen = GenericWhereGenerator::new(PostgresDialect);
811        let clause = WhereClause::Or(vec![]);
812        let (sql, params) = gen.generate(&clause).unwrap();
813        assert_eq!(sql, "FALSE");
814        assert!(params.is_empty());
815    }
816
817    #[test]
818    fn generic_not_postgres() {
819        let gen = GenericWhereGenerator::new(PostgresDialect);
820        let clause = WhereClause::Not(Box::new(field("deleted", WhereOperator::Eq, json!(true))));
821        let (sql, _) = gen.generate(&clause).unwrap();
822        assert!(sql.starts_with("NOT ("));
823    }
824
825    #[test]
826    fn generate_resets_counter() {
827        let gen = GenericWhereGenerator::new(PostgresDialect);
828        let clause = field("x", WhereOperator::Eq, json!(1));
829        let (sql1, _) = gen.generate(&clause).unwrap();
830        let (sql2, _) = gen.generate(&clause).unwrap();
831        assert_eq!(sql1, sql2);
832        // Both must reference $1, not $1 then $2
833        assert!(sql1.contains("$1"));
834        assert!(!sql1.contains("$2"));
835    }
836
837    #[test]
838    fn generate_with_param_offset() {
839        let gen = GenericWhereGenerator::new(PostgresDialect);
840        let clause = field("email", WhereOperator::Eq, json!("a@b.com"));
841        let (sql, _) = gen.generate_with_param_offset(&clause, 2).unwrap();
842        assert!(sql.contains("$3"), "Expected $3 (offset 2 + 1), got: {sql}");
843    }
844
845    // ── String operators ─────────────────────────────────────────────
846
847    #[test]
848    fn generic_icontains_postgres() {
849        let gen = GenericWhereGenerator::new(PostgresDialect);
850        let clause = field("email", WhereOperator::Icontains, json!("example.com"));
851        let (sql, params) = gen.generate(&clause).unwrap();
852        assert_eq!(sql, "data->>'email' ILIKE '%' || $1 || '%'");
853        assert_eq!(params, vec![json!("example.com")]);
854    }
855
856    #[test]
857    fn generic_startswith_postgres() {
858        let gen = GenericWhereGenerator::new(PostgresDialect);
859        let clause = field("name", WhereOperator::Startswith, json!("Al"));
860        let (sql, params) = gen.generate(&clause).unwrap();
861        assert_eq!(sql, "data->>'name' LIKE $1 || '%'");
862        assert_eq!(params, vec![json!("Al")]);
863    }
864
865    #[test]
866    fn generic_endswith_postgres() {
867        let gen = GenericWhereGenerator::new(PostgresDialect);
868        let clause = field("name", WhereOperator::Endswith, json!("son"));
869        let (sql, params) = gen.generate(&clause).unwrap();
870        assert_eq!(sql, "data->>'name' LIKE '%' || $1");
871        assert_eq!(params, vec![json!("son")]);
872    }
873
874    // ── Array / IN operators ────────────────────────────────────────
875
876    #[test]
877    fn generic_in_postgres() {
878        let gen = GenericWhereGenerator::new(PostgresDialect);
879        let clause = field("status", WhereOperator::In, json!(["active", "pending"]));
880        let (sql, params) = gen.generate(&clause).unwrap();
881        assert_eq!(sql, "data->>'status' IN ($1, $2)");
882        assert_eq!(params.len(), 2);
883    }
884
885    #[test]
886    fn generic_in_empty_returns_false() {
887        let gen = GenericWhereGenerator::new(PostgresDialect);
888        let clause = field("status", WhereOperator::In, json!([]));
889        let (sql, params) = gen.generate(&clause).unwrap();
890        assert_eq!(sql, "FALSE");
891        assert!(params.is_empty());
892    }
893
894    #[test]
895    fn generic_nin_empty_returns_true() {
896        let gen = GenericWhereGenerator::new(PostgresDialect);
897        let clause = field("status", WhereOperator::Nin, json!([]));
898        let (sql, params) = gen.generate(&clause).unwrap();
899        assert_eq!(sql, "TRUE");
900        assert!(params.is_empty());
901    }
902
903    // ── Security: no value interpolation ─────────────────────────────────────
904
905    #[test]
906    fn no_value_in_sql_string() {
907        let gen = GenericWhereGenerator::new(PostgresDialect);
908        let injection = "'; DROP TABLE users; --";
909        let clause = field("email", WhereOperator::Eq, json!(injection));
910        let (sql, params) = gen.generate(&clause).unwrap();
911        assert!(!sql.contains(injection), "Value must not appear in SQL: {sql}");
912        assert_eq!(params[0], json!(injection));
913    }
914
915    // ── PG-only: Vector operators ─────────────────────────────────────────────
916
917    #[test]
918    fn generic_pg_cosine_distance() {
919        let gen = GenericWhereGenerator::new(PostgresDialect);
920        let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1, 0.2]));
921        let (sql, params) = gen.generate(&clause).unwrap();
922        assert!(sql.contains("<=>"), "Expected <=> operator, got: {sql}");
923        assert!(sql.contains("::vector"), "Expected ::vector cast, got: {sql}");
924        assert_eq!(params.len(), 1);
925    }
926
927    #[test]
928    fn generic_pg_network_ipv4() {
929        let gen = GenericWhereGenerator::new(PostgresDialect);
930        let clause = field("ip", WhereOperator::IsIPv4, json!(true));
931        let (sql, _) = gen.generate(&clause).unwrap();
932        assert!(sql.contains("family("), "Expected family() call, got: {sql}");
933        assert!(sql.contains("= 4"), "Expected = 4, got: {sql}");
934    }
935
936    #[test]
937    fn generic_pg_ltree_ancestor_of() {
938        let gen = GenericWhereGenerator::new(PostgresDialect);
939        let clause = field("path", WhereOperator::AncestorOf, json!("europe.france"));
940        let (sql, params) = gen.generate(&clause).unwrap();
941        assert!(sql.contains("@>") && sql.contains("ltree"), "Got: {sql}");
942        assert_eq!(params.len(), 1);
943    }
944
945    #[test]
946    fn non_pg_vector_op_returns_error() {
947        use crate::dialect::MySqlDialect;
948        let gen = GenericWhereGenerator::new(MySqlDialect);
949        let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1]));
950        let err = gen.generate(&clause).unwrap_err();
951        let msg = err.to_string();
952        assert!(msg.contains("VectorDistance") || msg.contains("not supported"), "Got: {msg}");
953    }
954
955    #[test]
956    fn non_pg_network_op_returns_error() {
957        use crate::dialect::SqliteDialect;
958        let gen = GenericWhereGenerator::new(SqliteDialect);
959        let clause = field("ip", WhereOperator::IsIPv4, json!(true));
960        let err = gen.generate(&clause).unwrap_err();
961        let msg = err.to_string();
962        assert!(msg.contains("Inet") || msg.contains("not supported"), "Got: {msg}");
963    }
964
965    // ── LIKE metacharacter escaping (C3 fix verification) ──────────────
966
967    #[test]
968    fn escape_like_literal_escapes_percent_and_underscore() {
969        assert_eq!(super::escape_like_literal("50%"), "50\\%");
970        assert_eq!(super::escape_like_literal("user_name"), "user\\_name");
971        assert_eq!(super::escape_like_literal("a%b_c\\d"), "a\\%b\\_c\\\\d");
972        assert_eq!(super::escape_like_literal("plain"), "plain");
973    }
974
975    #[test]
976    fn contains_escapes_like_metacharacters() {
977        let gen = GenericWhereGenerator::new(PostgresDialect);
978        let clause = field("name", WhereOperator::Contains, json!("50%off"));
979        let (_sql, params) = gen.generate(&clause).unwrap();
980        // The param value must have % escaped so it's treated as a literal.
981        assert_eq!(params[0], json!("50\\%off"));
982    }
983
984    #[test]
985    fn startswith_escapes_like_metacharacters() {
986        let gen = GenericWhereGenerator::new(PostgresDialect);
987        let clause = field("name", WhereOperator::Startswith, json!("user_"));
988        let (_sql, params) = gen.generate(&clause).unwrap();
989        assert_eq!(params[0], json!("user\\_"));
990    }
991
992    #[test]
993    fn endswith_escapes_like_metacharacters() {
994        let gen = GenericWhereGenerator::new(PostgresDialect);
995        let clause = field("name", WhereOperator::Endswith, json!("100%"));
996        let (_sql, params) = gen.generate(&clause).unwrap();
997        assert_eq!(params[0], json!("100\\%"));
998    }
999
1000    // ── Regex complexity guard (C5 fix verification) ──────────────────
1001
1002    #[test]
1003    fn regex_rejects_nested_quantifiers() {
1004        let gen = GenericWhereGenerator::new(PostgresDialect);
1005        let clause = field("name", WhereOperator::Regex, json!("(a+)+$"));
1006        let err = gen.generate(&clause).unwrap_err();
1007        let msg = err.to_string();
1008        assert!(msg.contains("nested quantifiers"), "Got: {msg}");
1009    }
1010
1011    #[test]
1012    fn regex_rejects_star_star_pattern() {
1013        let gen = GenericWhereGenerator::new(PostgresDialect);
1014        let clause = field("name", WhereOperator::Regex, json!("(x*)*"));
1015        let err = gen.generate(&clause).unwrap_err();
1016        assert!(err.to_string().contains("nested quantifiers"));
1017    }
1018
1019    #[test]
1020    fn regex_rejects_too_long_pattern() {
1021        let gen = GenericWhereGenerator::new(PostgresDialect);
1022        let long_pattern = "a".repeat(1_001);
1023        let clause = field("name", WhereOperator::Regex, json!(long_pattern));
1024        let err = gen.generate(&clause).unwrap_err();
1025        assert!(err.to_string().contains("maximum length"));
1026    }
1027
1028    #[test]
1029    fn regex_allows_safe_patterns() {
1030        let gen = GenericWhereGenerator::new(PostgresDialect);
1031        let clause = field("name", WhereOperator::Regex, json!("^[a-z]+$"));
1032        assert!(gen.generate(&clause).is_ok());
1033    }
1034
1035    #[test]
1036    fn iregex_also_validates_pattern() {
1037        let gen = GenericWhereGenerator::new(PostgresDialect);
1038        let clause = field("name", WhereOperator::Iregex, json!("(a+)+"));
1039        assert!(gen.generate(&clause).is_err());
1040    }
1041}