Skip to main content

fraiseql_db/where_generator/
generic.rs

1//! Generic WHERE clause generator parameterised over a SQL dialect.
2
3use std::{collections::HashSet, sync::Arc};
4
5use fraiseql_error::{FraiseQLError, Result};
6
7use super::counter::ParamCounter;
8use crate::{
9    dialect::SqlDialect,
10    where_clause::{WhereClause, WhereOperator},
11};
12
13/// Escape LIKE metacharacters (`%`, `_`, `\`) in a user-supplied string so
14/// that it is treated as a literal substring inside a LIKE/ILIKE pattern.
15///
16/// Order matters: `\` is escaped first to avoid double-escaping.
17pub(crate) fn escape_like_literal(s: &str) -> String {
18    s.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
19}
20
21/// Maximum allowed length for user-supplied regex patterns.
22///
23/// PostgreSQL has no built-in regex timeout, so excessively long patterns
24/// or patterns with nested quantifiers can cause CPU exhaustion (ReDoS).
25const MAX_REGEX_PATTERN_LEN: usize = 1_000;
26
27/// Validate a user-supplied regex pattern for obvious ReDoS risks.
28///
29/// Rejects:
30/// - Patterns exceeding `MAX_REGEX_PATTERN_LEN` bytes
31/// - Patterns containing nested quantifiers (e.g., `(a+)+`, `(a*)*`, `(a+)*`)
32///
33/// This is not a full ReDoS detector but catches the most common attack vectors.
34fn validate_regex_pattern(pattern: &str) -> Result<()> {
35    if pattern.len() > MAX_REGEX_PATTERN_LEN {
36        return Err(FraiseQLError::Validation {
37            message: format!(
38                "Regex pattern exceeds maximum length of {MAX_REGEX_PATTERN_LEN} bytes"
39            ),
40            path:    None,
41        });
42    }
43
44    // Detect nested quantifiers: a quantifier (+, *, ?, {n}) immediately after
45    // a closing paren that itself follows a quantifier. Simplified heuristic:
46    // look for `)` followed by a quantifier, where the group contains a quantifier.
47    let bytes = pattern.as_bytes();
48    let mut depth: i32 = 0;
49    let mut group_has_quantifier = Vec::new(); // stack: does current group have a quantifier?
50
51    for (i, &b) in bytes.iter().enumerate() {
52        // Skip escaped characters
53        if i > 0 && bytes[i - 1] == b'\\' {
54            continue;
55        }
56        match b {
57            b'(' => {
58                depth += 1;
59                group_has_quantifier.push(false);
60            },
61            b')' => {
62                let had_quantifier = group_has_quantifier.pop().unwrap_or(false);
63                depth -= 1;
64                // Check if a quantifier follows this closing paren
65                if had_quantifier {
66                    let next = bytes.get(i + 1).copied();
67                    if matches!(next, Some(b'+' | b'*' | b'?' | b'{')) {
68                        return Err(FraiseQLError::Validation {
69                            message: "Regex pattern contains nested quantifiers (potential \
70                                      ReDoS). Simplify the pattern to avoid `(…+)+`, \
71                                      `(…*)*`, or similar constructs."
72                                .to_string(),
73                            path:    None,
74                        });
75                    }
76                }
77            },
78            b'+' | b'*' | b'?' => {
79                if let Some(flag) = group_has_quantifier.last_mut() {
80                    *flag = true;
81                }
82            },
83            b'{' if depth > 0 => {
84                if let Some(flag) = group_has_quantifier.last_mut() {
85                    *flag = true;
86                }
87            },
88            _ => {},
89        }
90    }
91
92    Ok(())
93}
94
95/// Generic WHERE clause SQL generator.
96///
97/// Replaces `PostgresWhereGenerator`, `MySqlWhereGenerator`,
98/// `SqliteWhereGenerator`, and `SqlServerWhereGenerator` — all dialect-specific
99/// primitives are delegated to `D: SqlDialect`.
100///
101/// # Interior mutability
102///
103/// The parameter counter uses `Cell<usize>` (via `ParamCounter`).  This is
104/// safe because:
105/// - `GenericWhereGenerator` is not `Sync` — no concurrent access is possible.
106/// - `generate()` resets the counter before every call.
107///
108/// # Example
109///
110/// ```rust
111/// use fraiseql_db::dialect::PostgresDialect;
112/// use fraiseql_db::where_generator::GenericWhereGenerator;
113/// use fraiseql_db::{WhereClause, WhereOperator};
114/// use serde_json::json;
115///
116/// let gen = GenericWhereGenerator::new(PostgresDialect);
117/// let clause = WhereClause::Field {
118///     path: vec!["email".to_string()],
119///     operator: WhereOperator::Eq,
120///     value: json!("alice@example.com"),
121/// };
122/// let (sql, params) = gen.generate(&clause).unwrap();
123/// assert_eq!(sql, "data->>'email' = $1");
124/// ```
125pub struct GenericWhereGenerator<D: SqlDialect> {
126    dialect:         D,
127    counter:         ParamCounter,
128    /// Optional indexed-column set (PostgreSQL optimisation: short-circuits JSONB
129    /// extraction when a generated column covers the path).
130    indexed_columns: Option<Arc<HashSet<String>>>,
131}
132
133impl<D: SqlDialect> GenericWhereGenerator<D> {
134    /// Create a new generator for the given dialect.
135    pub const fn new(dialect: D) -> Self {
136        Self {
137            dialect,
138            counter: ParamCounter::new(),
139            indexed_columns: None,
140        }
141    }
142
143    /// Attach an indexed-columns set (PostgreSQL optimisation).
144    ///
145    /// When a WHERE path matches a column name in this set, the generator
146    /// emits `"col_name" = $N` instead of `data->>'col_name' = $N`.
147    #[must_use]
148    pub fn with_indexed_columns(mut self, cols: Arc<HashSet<String>>) -> Self {
149        self.indexed_columns = Some(cols);
150        self
151    }
152
153    /// Generate SQL WHERE clause starting parameter numbering at 1.
154    ///
155    /// # Errors
156    ///
157    /// Returns `FraiseQLError::Validation` if the clause uses an operator
158    /// not supported by the dialect.
159    pub fn generate(&self, clause: &WhereClause) -> Result<(String, Vec<serde_json::Value>)> {
160        self.generate_with_param_offset(clause, 0)
161    }
162
163    /// Generate SQL WHERE clause with parameter numbering starting after `offset`.
164    ///
165    /// Use when the WHERE clause is appended to a query that already has bound
166    /// parameters (e.g. cursor values in relay pagination).
167    ///
168    /// # Errors
169    ///
170    /// Returns `FraiseQLError::Validation` if the clause uses an unsupported
171    /// operator.
172    pub fn generate_with_param_offset(
173        &self,
174        clause: &WhereClause,
175        offset: usize,
176    ) -> Result<(String, Vec<serde_json::Value>)> {
177        self.counter.reset_to(offset);
178        let mut params = Vec::new();
179        let sql = self.visit(clause, &mut params)?;
180        Ok((sql, params))
181    }
182
183    // ── Visitor ───────────────────────────────────────────────────────────────
184
185    fn visit(&self, clause: &WhereClause, params: &mut Vec<serde_json::Value>) -> Result<String> {
186        match clause {
187            WhereClause::And(clauses) => self.visit_and(clauses, params),
188            WhereClause::Or(clauses) => self.visit_or(clauses, params),
189            WhereClause::Not(inner) => Ok(format!("NOT ({})", self.visit(inner, params)?)),
190            WhereClause::Field {
191                path,
192                operator,
193                value,
194            } => self.visit_field(path, operator, value, params),
195        }
196    }
197
198    fn visit_and(
199        &self,
200        clauses: &[WhereClause],
201        params: &mut Vec<serde_json::Value>,
202    ) -> Result<String> {
203        if clauses.is_empty() {
204            return Ok(self.dialect.always_true().to_string());
205        }
206        let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
207        Ok(format!("({})", parts?.join(" AND ")))
208    }
209
210    fn visit_or(
211        &self,
212        clauses: &[WhereClause],
213        params: &mut Vec<serde_json::Value>,
214    ) -> Result<String> {
215        if clauses.is_empty() {
216            return Ok(self.dialect.always_false().to_string());
217        }
218        let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
219        Ok(format!("({})", parts?.join(" OR ")))
220    }
221
222    // ── Field expression resolution ───────────────────────────────────────────
223
224    fn resolve_field_expr(&self, path: &[String]) -> String {
225        // PostgreSQL indexed-column optimisation.
226        if let Some(indexed) = &self.indexed_columns {
227            let col_name = path.join("__");
228            if indexed.contains(&col_name) {
229                return self.dialect.quote_identifier(&col_name);
230            }
231        }
232        self.dialect.json_extract_scalar("data", path)
233    }
234
235    // ── Push a parameter and return its placeholder ───────────────────────────
236
237    fn push_param(&self, params: &mut Vec<serde_json::Value>, v: serde_json::Value) -> String {
238        params.push(v);
239        self.dialect.placeholder(self.counter.next())
240    }
241
242    // ── Field visitor ─────────────────────────────────────────────────────────
243
244    fn visit_field(
245        &self,
246        path: &[String],
247        operator: &WhereOperator,
248        value: &serde_json::Value,
249        params: &mut Vec<serde_json::Value>,
250    ) -> Result<String> {
251        let field_expr = self.resolve_field_expr(path);
252
253        match operator {
254            // ── Comparison ────────────────────────────────────────────────────
255            WhereOperator::Eq => {
256                let p = self.push_param(params, value.clone());
257                if value.is_number() {
258                    let cast = self.dialect.cast_to_numeric(&field_expr);
259                    // Dialect-specific RHS cast: PostgreSQL uses (p::text)::numeric to
260                    // avoid wire-protocol type mismatch; other dialects pass p unchanged.
261                    let rhs = self.dialect.cast_param_numeric(&p);
262                    Ok(format!("{cast} = {rhs}"))
263                } else if value.is_boolean() {
264                    let cast = self.dialect.cast_to_boolean(&field_expr);
265                    Ok(format!("{cast} = {p}"))
266                } else {
267                    Ok(format!("{field_expr} = {p}"))
268                }
269            },
270            WhereOperator::Neq => {
271                let p = self.push_param(params, value.clone());
272                let neq = self.dialect.neq_operator();
273                if value.is_number() {
274                    let cast = self.dialect.cast_to_numeric(&field_expr);
275                    let rhs = self.dialect.cast_param_numeric(&p);
276                    Ok(format!("{cast} {neq} {rhs}"))
277                } else if value.is_boolean() {
278                    let cast = self.dialect.cast_to_boolean(&field_expr);
279                    Ok(format!("{cast} {neq} {p}"))
280                } else {
281                    Ok(format!("{field_expr} {neq} {p}"))
282                }
283            },
284            WhereOperator::Gt | WhereOperator::Gte | WhereOperator::Lt | WhereOperator::Lte => {
285                let op = match operator {
286                    WhereOperator::Gt => ">",
287                    WhereOperator::Gte => ">=",
288                    WhereOperator::Lt => "<",
289                    _ => "<=",
290                };
291                let cast = self.dialect.cast_to_numeric(&field_expr);
292                let p = self.push_param(params, value.clone());
293                let rhs = self.dialect.cast_param_numeric(&p);
294                Ok(format!("{cast} {op} {rhs}"))
295            },
296
297            // ── Containment ───────────────────────────────────────────────────
298            WhereOperator::In | WhereOperator::Nin => {
299                let arr = value.as_array().ok_or_else(|| {
300                    FraiseQLError::validation("IN operator requires an array value".to_string())
301                })?;
302                if arr.is_empty() {
303                    return Ok(if matches!(operator, WhereOperator::In) {
304                        self.dialect.always_false().to_string()
305                    } else {
306                        self.dialect.always_true().to_string()
307                    });
308                }
309                let placeholders: Vec<_> =
310                    arr.iter().map(|v| self.push_param(params, v.clone())).collect();
311                let in_list = placeholders.join(", ");
312                let sql = format!("{field_expr} IN ({in_list})");
313                Ok(if matches!(operator, WhereOperator::Nin) {
314                    format!("NOT ({sql})")
315                } else {
316                    sql
317                })
318            },
319
320            // ── NULL ──────────────────────────────────────────────────────────
321            WhereOperator::IsNull => {
322                let is_null = value.as_bool().unwrap_or(true);
323                let null_op = if is_null { "IS NULL" } else { "IS NOT NULL" };
324                Ok(format!("{field_expr} {null_op}"))
325            },
326
327            // ── String: LIKE family ───────────────────────────────────────────
328            WhereOperator::Contains => {
329                let val_str = self.require_str(value, "Contains")?;
330                let escaped = escape_like_literal(val_str);
331                let p = self.push_param(params, serde_json::Value::String(escaped));
332                let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
333                Ok(self.dialect.like_sql(&field_expr, &pattern))
334            },
335            WhereOperator::Icontains => {
336                let val_str = self.require_str(value, "Icontains")?;
337                let escaped = escape_like_literal(val_str);
338                let p = self.push_param(params, serde_json::Value::String(escaped));
339                let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
340                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
341            },
342            WhereOperator::Startswith => {
343                let val_str = self.require_str(value, "Startswith")?;
344                let escaped = escape_like_literal(val_str);
345                let p = self.push_param(params, serde_json::Value::String(escaped));
346                let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
347                Ok(self.dialect.like_sql(&field_expr, &pattern))
348            },
349            WhereOperator::Istartswith => {
350                let val_str = self.require_str(value, "Istartswith")?;
351                let escaped = escape_like_literal(val_str);
352                let p = self.push_param(params, serde_json::Value::String(escaped));
353                let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
354                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
355            },
356            WhereOperator::Endswith => {
357                let val_str = self.require_str(value, "Endswith")?;
358                let escaped = escape_like_literal(val_str);
359                let p = self.push_param(params, serde_json::Value::String(escaped));
360                let pattern = self.dialect.concat_sql(&["'%'", &p]);
361                Ok(self.dialect.like_sql(&field_expr, &pattern))
362            },
363            WhereOperator::Iendswith => {
364                let val_str = self.require_str(value, "Iendswith")?;
365                let escaped = escape_like_literal(val_str);
366                let p = self.push_param(params, serde_json::Value::String(escaped));
367                let pattern = self.dialect.concat_sql(&["'%'", &p]);
368                Ok(self.dialect.ilike_sql(&field_expr, &pattern))
369            },
370            WhereOperator::Like => {
371                let p = self.push_param(params, value.clone());
372                Ok(self.dialect.like_sql(&field_expr, &p))
373            },
374            WhereOperator::Ilike => {
375                let p = self.push_param(params, value.clone());
376                Ok(self.dialect.ilike_sql(&field_expr, &p))
377            },
378            WhereOperator::Nlike => {
379                let p = self.push_param(params, value.clone());
380                Ok(format!("NOT ({})", self.dialect.like_sql(&field_expr, &p)))
381            },
382            WhereOperator::Nilike => {
383                let p = self.push_param(params, value.clone());
384                Ok(format!("NOT ({})", self.dialect.ilike_sql(&field_expr, &p)))
385            },
386
387            // ── String: Regex ─────────────────────────────────────────────────
388            WhereOperator::Regex => {
389                if let Some(s) = value.as_str() {
390                    validate_regex_pattern(s)?;
391                }
392                let p = self.push_param(params, value.clone());
393                self.dialect
394                    .regex_sql(&field_expr, &p, false, false)
395                    .map_err(|e| FraiseQLError::validation(e.to_string()))
396            },
397            WhereOperator::Iregex => {
398                if let Some(s) = value.as_str() {
399                    validate_regex_pattern(s)?;
400                }
401                let p = self.push_param(params, value.clone());
402                self.dialect
403                    .regex_sql(&field_expr, &p, true, false)
404                    .map_err(|e| FraiseQLError::validation(e.to_string()))
405            },
406            WhereOperator::Nregex => {
407                if let Some(s) = value.as_str() {
408                    validate_regex_pattern(s)?;
409                }
410                let p = self.push_param(params, value.clone());
411                self.dialect
412                    .regex_sql(&field_expr, &p, false, true)
413                    .map_err(|e| FraiseQLError::validation(e.to_string()))
414            },
415            WhereOperator::Niregex => {
416                if let Some(s) = value.as_str() {
417                    validate_regex_pattern(s)?;
418                }
419                let p = self.push_param(params, value.clone());
420                self.dialect
421                    .regex_sql(&field_expr, &p, true, true)
422                    .map_err(|e| FraiseQLError::validation(e.to_string()))
423            },
424
425            // ── Array: length ─────────────────────────────────────────────────
426            WhereOperator::LenEq
427            | WhereOperator::LenNeq
428            | WhereOperator::LenGt
429            | WhereOperator::LenGte
430            | WhereOperator::LenLt
431            | WhereOperator::LenLte => {
432                let op = match operator {
433                    WhereOperator::LenEq => "=",
434                    WhereOperator::LenNeq => self.dialect.neq_operator(),
435                    WhereOperator::LenGt => ">",
436                    WhereOperator::LenGte => ">=",
437                    WhereOperator::LenLt => "<",
438                    _ => "<=",
439                };
440                let len_expr = self.dialect.json_array_length(&field_expr);
441                let p = self.push_param(params, value.clone());
442                Ok(format!("{len_expr} {op} {p}"))
443            },
444
445            // ── Array: containment ────────────────────────────────────────────
446            WhereOperator::ArrayContains | WhereOperator::StrictlyContains => {
447                // Both @> (ArrayContains) and @> (StrictlyContains, a JSONB-level
448                // strict containment) are routed to array_contains_sql.
449                let p = self.push_param(params, value.clone());
450                self.dialect
451                    .array_contains_sql(&field_expr, &p)
452                    .map_err(|e| FraiseQLError::validation(e.to_string()))
453            },
454            WhereOperator::ArrayContainedBy => {
455                let p = self.push_param(params, value.clone());
456                self.dialect
457                    .array_contained_by_sql(&field_expr, &p)
458                    .map_err(|e| FraiseQLError::validation(e.to_string()))
459            },
460            WhereOperator::ArrayOverlaps => {
461                let p = self.push_param(params, value.clone());
462                self.dialect
463                    .array_overlaps_sql(&field_expr, &p)
464                    .map_err(|e| FraiseQLError::validation(e.to_string()))
465            },
466
467            // ── Full-text search ──────────────────────────────────────────────
468            WhereOperator::Matches => {
469                let p = self.push_param(params, value.clone());
470                self.dialect
471                    .fts_matches_sql(&field_expr, &p)
472                    .map_err(|e| FraiseQLError::validation(e.to_string()))
473            },
474            WhereOperator::PlainQuery => {
475                let p = self.push_param(params, value.clone());
476                self.dialect
477                    .fts_plain_query_sql(&field_expr, &p)
478                    .map_err(|e| FraiseQLError::validation(e.to_string()))
479            },
480            WhereOperator::PhraseQuery => {
481                let p = self.push_param(params, value.clone());
482                self.dialect
483                    .fts_phrase_query_sql(&field_expr, &p)
484                    .map_err(|e| FraiseQLError::validation(e.to_string()))
485            },
486            WhereOperator::WebsearchQuery => {
487                let p = self.push_param(params, value.clone());
488                self.dialect
489                    .fts_websearch_query_sql(&field_expr, &p)
490                    .map_err(|e| FraiseQLError::validation(e.to_string()))
491            },
492
493            // ── Vector (pgvector) ─────────────────────────────────────────────
494            WhereOperator::CosineDistance => {
495                let p = self.push_param(params, value.clone());
496                self.dialect
497                    .vector_distance_sql("<=>", &field_expr, &p)
498                    .map_err(|e| FraiseQLError::validation(e.to_string()))
499            },
500            WhereOperator::L2Distance => {
501                let p = self.push_param(params, value.clone());
502                self.dialect
503                    .vector_distance_sql("<->", &field_expr, &p)
504                    .map_err(|e| FraiseQLError::validation(e.to_string()))
505            },
506            WhereOperator::L1Distance => {
507                let p = self.push_param(params, value.clone());
508                self.dialect
509                    .vector_distance_sql("<+>", &field_expr, &p)
510                    .map_err(|e| FraiseQLError::validation(e.to_string()))
511            },
512            WhereOperator::HammingDistance => {
513                let p = self.push_param(params, value.clone());
514                self.dialect
515                    .vector_distance_sql("<~>", &field_expr, &p)
516                    .map_err(|e| FraiseQLError::validation(e.to_string()))
517            },
518            WhereOperator::InnerProduct => {
519                let p = self.push_param(params, value.clone());
520                self.dialect
521                    .vector_distance_sql("<#>", &field_expr, &p)
522                    .map_err(|e| FraiseQLError::validation(e.to_string()))
523            },
524            WhereOperator::JaccardDistance => {
525                let p = self.push_param(params, value.clone());
526                self.dialect
527                    .jaccard_distance_sql(&field_expr, &p)
528                    .map_err(|e| FraiseQLError::validation(e.to_string()))
529            },
530
531            // ── Network (INET/CIDR) ───────────────────────────────────────────
532            WhereOperator::IsIPv4 => self
533                .dialect
534                .inet_check_sql(&field_expr, "IsIPv4")
535                .map_err(|e| FraiseQLError::validation(e.to_string())),
536            WhereOperator::IsIPv6 => self
537                .dialect
538                .inet_check_sql(&field_expr, "IsIPv6")
539                .map_err(|e| FraiseQLError::validation(e.to_string())),
540            WhereOperator::IsPrivate => self
541                .dialect
542                .inet_check_sql(&field_expr, "IsPrivate")
543                .map_err(|e| FraiseQLError::validation(e.to_string())),
544            WhereOperator::IsPublic => self
545                .dialect
546                .inet_check_sql(&field_expr, "IsPublic")
547                .map_err(|e| FraiseQLError::validation(e.to_string())),
548            WhereOperator::IsLoopback => self
549                .dialect
550                .inet_check_sql(&field_expr, "IsLoopback")
551                .map_err(|e| FraiseQLError::validation(e.to_string())),
552            WhereOperator::InSubnet => {
553                let p = self.push_param(params, value.clone());
554                self.dialect
555                    .inet_binary_sql("<<", &field_expr, &p)
556                    .map_err(|e| FraiseQLError::validation(e.to_string()))
557            },
558            WhereOperator::ContainsSubnet | WhereOperator::ContainsIP => {
559                let p = self.push_param(params, value.clone());
560                self.dialect
561                    .inet_binary_sql(">>", &field_expr, &p)
562                    .map_err(|e| FraiseQLError::validation(e.to_string()))
563            },
564            WhereOperator::Overlaps => {
565                let p = self.push_param(params, value.clone());
566                self.dialect
567                    .inet_binary_sql("&&", &field_expr, &p)
568                    .map_err(|e| FraiseQLError::validation(e.to_string()))
569            },
570
571            // ── LTree ─────────────────────────────────────────────────────────
572            WhereOperator::AncestorOf => {
573                let p = self.push_param(params, value.clone());
574                self.dialect
575                    .ltree_binary_sql("@>", &field_expr, &p, "ltree")
576                    .map_err(|e| FraiseQLError::validation(e.to_string()))
577            },
578            WhereOperator::DescendantOf => {
579                let p = self.push_param(params, value.clone());
580                self.dialect
581                    .ltree_binary_sql("<@", &field_expr, &p, "ltree")
582                    .map_err(|e| FraiseQLError::validation(e.to_string()))
583            },
584            WhereOperator::MatchesLquery => {
585                let p = self.push_param(params, value.clone());
586                self.dialect
587                    .ltree_binary_sql("~", &field_expr, &p, "lquery")
588                    .map_err(|e| FraiseQLError::validation(e.to_string()))
589            },
590            WhereOperator::MatchesLtxtquery => {
591                let p = self.push_param(params, value.clone());
592                self.dialect
593                    .ltree_binary_sql("@", &field_expr, &p, "ltxtquery")
594                    .map_err(|e| FraiseQLError::validation(e.to_string()))
595            },
596            WhereOperator::MatchesAnyLquery => {
597                let arr = value.as_array().ok_or_else(|| {
598                    FraiseQLError::validation(
599                        "matches_any_lquery operator requires an array value".to_string(),
600                    )
601                })?;
602                if arr.is_empty() {
603                    return Err(FraiseQLError::validation(
604                        "matches_any_lquery requires at least one lquery".to_string(),
605                    ));
606                }
607                let placeholders: Vec<_> = arr
608                    .iter()
609                    .map(|v| format!("{}::lquery", self.push_param(params, v.clone())))
610                    .collect();
611                self.dialect
612                    .ltree_any_lquery_sql(&field_expr, &placeholders)
613                    .map_err(|e| FraiseQLError::validation(e.to_string()))
614            },
615            WhereOperator::DepthEq => {
616                let p = self.push_param(params, value.clone());
617                self.dialect
618                    .ltree_depth_sql("=", &field_expr, &p)
619                    .map_err(|e| FraiseQLError::validation(e.to_string()))
620            },
621            WhereOperator::DepthNeq => {
622                let p = self.push_param(params, value.clone());
623                self.dialect
624                    .ltree_depth_sql("!=", &field_expr, &p)
625                    .map_err(|e| FraiseQLError::validation(e.to_string()))
626            },
627            WhereOperator::DepthGt => {
628                let p = self.push_param(params, value.clone());
629                self.dialect
630                    .ltree_depth_sql(">", &field_expr, &p)
631                    .map_err(|e| FraiseQLError::validation(e.to_string()))
632            },
633            WhereOperator::DepthGte => {
634                let p = self.push_param(params, value.clone());
635                self.dialect
636                    .ltree_depth_sql(">=", &field_expr, &p)
637                    .map_err(|e| FraiseQLError::validation(e.to_string()))
638            },
639            WhereOperator::DepthLt => {
640                let p = self.push_param(params, value.clone());
641                self.dialect
642                    .ltree_depth_sql("<", &field_expr, &p)
643                    .map_err(|e| FraiseQLError::validation(e.to_string()))
644            },
645            WhereOperator::DepthLte => {
646                let p = self.push_param(params, value.clone());
647                self.dialect
648                    .ltree_depth_sql("<=", &field_expr, &p)
649                    .map_err(|e| FraiseQLError::validation(e.to_string()))
650            },
651            WhereOperator::Lca => {
652                let arr = value.as_array().ok_or_else(|| {
653                    FraiseQLError::validation("lca operator requires an array value".to_string())
654                })?;
655                if arr.is_empty() {
656                    return Err(FraiseQLError::validation(
657                        "lca operator requires at least one path".to_string(),
658                    ));
659                }
660                let placeholders: Vec<_> = arr
661                    .iter()
662                    .map(|v| format!("{}::ltree", self.push_param(params, v.clone())))
663                    .collect();
664                self.dialect
665                    .ltree_lca_sql(&field_expr, &placeholders)
666                    .map_err(|e| FraiseQLError::validation(e.to_string()))
667            },
668
669            // ── Extended operators ────────────────────────────────────────────
670            WhereOperator::Extended(op) => {
671                self.dialect.generate_extended_sql(op, &field_expr, params)
672            },
673
674            // ── Unknown / future operators ────────────────────────────────────
675            // This arm is only reachable if WhereOperator gains new variants
676            // (it is #[non_exhaustive]).  Suppress the lint that fires when all
677            // current variants are already matched above.
678            #[allow(unreachable_patterns)]
679            // Reason: defensive catch-all for future non_exhaustive variants
680            _ => Err(FraiseQLError::Validation {
681                message: format!(
682                    "Operator {operator:?} is not supported by the {} dialect",
683                    self.dialect.name()
684                ),
685                path:    None,
686            }),
687        }
688    }
689
690    fn require_str<'a>(&self, value: &'a serde_json::Value, op: &'static str) -> Result<&'a str> {
691        value.as_str().ok_or_else(|| {
692            FraiseQLError::validation(format!("{op} operator requires a string value"))
693        })
694    }
695}
696
697// ── Default impl ──────────────────────────────────────────────────────────────
698
699impl<D: SqlDialect + Default> Default for GenericWhereGenerator<D> {
700    fn default() -> Self {
701        Self::new(D::default())
702    }
703}
704
705// ── ExtendedOperatorHandler — single blanket impl ─────────────────────────────
706// Delegates to `D::generate_extended_sql`, which each dialect implements.
707
708impl<D: SqlDialect> crate::filters::ExtendedOperatorHandler for GenericWhereGenerator<D> {
709    fn generate_extended_sql(
710        &self,
711        operator: &crate::filters::ExtendedOperator,
712        field_sql: &str,
713        params: &mut Vec<serde_json::Value>,
714    ) -> Result<String> {
715        self.dialect.generate_extended_sql(operator, field_sql, params)
716    }
717}
718
719#[cfg(test)]
720#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
721mod tests {
722    use serde_json::json;
723
724    use super::GenericWhereGenerator;
725    use crate::{
726        dialect::PostgresDialect,
727        where_clause::{WhereClause, WhereOperator},
728    };
729
730    fn field(path: &str, op: WhereOperator, val: serde_json::Value) -> WhereClause {
731        WhereClause::Field {
732            path:     vec![path.to_string()],
733            operator: op,
734            value:    val,
735        }
736    }
737
738    // ── Core comparison / logical operators ──────────────────────────
739
740    #[test]
741    fn generic_eq_postgres() {
742        let gen = GenericWhereGenerator::new(PostgresDialect);
743        let clause = field("email", WhereOperator::Eq, json!("alice@example.com"));
744        let (sql, params) = gen.generate(&clause).unwrap();
745        assert_eq!(sql, "data->>'email' = $1");
746        assert_eq!(params, vec![json!("alice@example.com")]);
747    }
748
749    #[test]
750    fn generic_and_postgres() {
751        let gen = GenericWhereGenerator::new(PostgresDialect);
752        let clause = WhereClause::And(vec![
753            field("status", WhereOperator::Eq, json!("active")),
754            field("age", WhereOperator::Gte, json!(18)),
755        ]);
756        let (sql, params) = gen.generate(&clause).unwrap();
757        assert!(sql.starts_with("(data->>'status' = $1 AND"));
758        assert_eq!(params.len(), 2);
759    }
760
761    #[test]
762    fn generic_empty_and_returns_true() {
763        let gen = GenericWhereGenerator::new(PostgresDialect);
764        let clause = WhereClause::And(vec![]);
765        let (sql, params) = gen.generate(&clause).unwrap();
766        assert_eq!(sql, "TRUE");
767        assert!(params.is_empty());
768    }
769
770    #[test]
771    fn generic_empty_or_returns_false() {
772        let gen = GenericWhereGenerator::new(PostgresDialect);
773        let clause = WhereClause::Or(vec![]);
774        let (sql, params) = gen.generate(&clause).unwrap();
775        assert_eq!(sql, "FALSE");
776        assert!(params.is_empty());
777    }
778
779    #[test]
780    fn generic_not_postgres() {
781        let gen = GenericWhereGenerator::new(PostgresDialect);
782        let clause = WhereClause::Not(Box::new(field("deleted", WhereOperator::Eq, json!(true))));
783        let (sql, _) = gen.generate(&clause).unwrap();
784        assert!(sql.starts_with("NOT ("));
785    }
786
787    #[test]
788    fn generate_resets_counter() {
789        let gen = GenericWhereGenerator::new(PostgresDialect);
790        let clause = field("x", WhereOperator::Eq, json!(1));
791        let (sql1, _) = gen.generate(&clause).unwrap();
792        let (sql2, _) = gen.generate(&clause).unwrap();
793        assert_eq!(sql1, sql2);
794        // Both must reference $1, not $1 then $2
795        assert!(sql1.contains("$1"));
796        assert!(!sql1.contains("$2"));
797    }
798
799    #[test]
800    fn generate_with_param_offset() {
801        let gen = GenericWhereGenerator::new(PostgresDialect);
802        let clause = field("email", WhereOperator::Eq, json!("a@b.com"));
803        let (sql, _) = gen.generate_with_param_offset(&clause, 2).unwrap();
804        assert!(sql.contains("$3"), "Expected $3 (offset 2 + 1), got: {sql}");
805    }
806
807    // ── String operators ─────────────────────────────────────────────
808
809    #[test]
810    fn generic_icontains_postgres() {
811        let gen = GenericWhereGenerator::new(PostgresDialect);
812        let clause = field("email", WhereOperator::Icontains, json!("example.com"));
813        let (sql, params) = gen.generate(&clause).unwrap();
814        assert_eq!(sql, "data->>'email' ILIKE '%' || $1 || '%'");
815        assert_eq!(params, vec![json!("example.com")]);
816    }
817
818    #[test]
819    fn generic_startswith_postgres() {
820        let gen = GenericWhereGenerator::new(PostgresDialect);
821        let clause = field("name", WhereOperator::Startswith, json!("Al"));
822        let (sql, params) = gen.generate(&clause).unwrap();
823        assert_eq!(sql, "data->>'name' LIKE $1 || '%'");
824        assert_eq!(params, vec![json!("Al")]);
825    }
826
827    #[test]
828    fn generic_endswith_postgres() {
829        let gen = GenericWhereGenerator::new(PostgresDialect);
830        let clause = field("name", WhereOperator::Endswith, json!("son"));
831        let (sql, params) = gen.generate(&clause).unwrap();
832        assert_eq!(sql, "data->>'name' LIKE '%' || $1");
833        assert_eq!(params, vec![json!("son")]);
834    }
835
836    // ── Array / IN operators ────────────────────────────────────────
837
838    #[test]
839    fn generic_in_postgres() {
840        let gen = GenericWhereGenerator::new(PostgresDialect);
841        let clause = field("status", WhereOperator::In, json!(["active", "pending"]));
842        let (sql, params) = gen.generate(&clause).unwrap();
843        assert_eq!(sql, "data->>'status' IN ($1, $2)");
844        assert_eq!(params.len(), 2);
845    }
846
847    #[test]
848    fn generic_in_empty_returns_false() {
849        let gen = GenericWhereGenerator::new(PostgresDialect);
850        let clause = field("status", WhereOperator::In, json!([]));
851        let (sql, params) = gen.generate(&clause).unwrap();
852        assert_eq!(sql, "FALSE");
853        assert!(params.is_empty());
854    }
855
856    #[test]
857    fn generic_nin_empty_returns_true() {
858        let gen = GenericWhereGenerator::new(PostgresDialect);
859        let clause = field("status", WhereOperator::Nin, json!([]));
860        let (sql, params) = gen.generate(&clause).unwrap();
861        assert_eq!(sql, "TRUE");
862        assert!(params.is_empty());
863    }
864
865    // ── Security: no value interpolation ─────────────────────────────────────
866
867    #[test]
868    fn no_value_in_sql_string() {
869        let gen = GenericWhereGenerator::new(PostgresDialect);
870        let injection = "'; DROP TABLE users; --";
871        let clause = field("email", WhereOperator::Eq, json!(injection));
872        let (sql, params) = gen.generate(&clause).unwrap();
873        assert!(!sql.contains(injection), "Value must not appear in SQL: {sql}");
874        assert_eq!(params[0], json!(injection));
875    }
876
877    // ── PG-only: Vector operators ─────────────────────────────────────────────
878
879    #[test]
880    fn generic_pg_cosine_distance() {
881        let gen = GenericWhereGenerator::new(PostgresDialect);
882        let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1, 0.2]));
883        let (sql, params) = gen.generate(&clause).unwrap();
884        assert!(sql.contains("<=>"), "Expected <=> operator, got: {sql}");
885        assert!(sql.contains("::vector"), "Expected ::vector cast, got: {sql}");
886        assert_eq!(params.len(), 1);
887    }
888
889    #[test]
890    fn generic_pg_network_ipv4() {
891        let gen = GenericWhereGenerator::new(PostgresDialect);
892        let clause = field("ip", WhereOperator::IsIPv4, json!(true));
893        let (sql, _) = gen.generate(&clause).unwrap();
894        assert!(sql.contains("family("), "Expected family() call, got: {sql}");
895        assert!(sql.contains("= 4"), "Expected = 4, got: {sql}");
896    }
897
898    #[test]
899    fn generic_pg_ltree_ancestor_of() {
900        let gen = GenericWhereGenerator::new(PostgresDialect);
901        let clause = field("path", WhereOperator::AncestorOf, json!("europe.france"));
902        let (sql, params) = gen.generate(&clause).unwrap();
903        assert!(sql.contains("@>") && sql.contains("ltree"), "Got: {sql}");
904        assert_eq!(params.len(), 1);
905    }
906
907    #[test]
908    fn non_pg_vector_op_returns_error() {
909        use crate::dialect::MySqlDialect;
910        let gen = GenericWhereGenerator::new(MySqlDialect);
911        let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1]));
912        let err = gen.generate(&clause).unwrap_err();
913        let msg = err.to_string();
914        assert!(msg.contains("VectorDistance") || msg.contains("not supported"), "Got: {msg}");
915    }
916
917    #[test]
918    fn non_pg_network_op_returns_error() {
919        use crate::dialect::SqliteDialect;
920        let gen = GenericWhereGenerator::new(SqliteDialect);
921        let clause = field("ip", WhereOperator::IsIPv4, json!(true));
922        let err = gen.generate(&clause).unwrap_err();
923        let msg = err.to_string();
924        assert!(msg.contains("Inet") || msg.contains("not supported"), "Got: {msg}");
925    }
926
927    // ── LIKE metacharacter escaping (C3 fix verification) ──────────────
928
929    #[test]
930    fn escape_like_literal_escapes_percent_and_underscore() {
931        assert_eq!(super::escape_like_literal("50%"), "50\\%");
932        assert_eq!(super::escape_like_literal("user_name"), "user\\_name");
933        assert_eq!(super::escape_like_literal("a%b_c\\d"), "a\\%b\\_c\\\\d");
934        assert_eq!(super::escape_like_literal("plain"), "plain");
935    }
936
937    #[test]
938    fn contains_escapes_like_metacharacters() {
939        let gen = GenericWhereGenerator::new(PostgresDialect);
940        let clause = field("name", WhereOperator::Contains, json!("50%off"));
941        let (_sql, params) = gen.generate(&clause).unwrap();
942        // The param value must have % escaped so it's treated as a literal.
943        assert_eq!(params[0], json!("50\\%off"));
944    }
945
946    #[test]
947    fn startswith_escapes_like_metacharacters() {
948        let gen = GenericWhereGenerator::new(PostgresDialect);
949        let clause = field("name", WhereOperator::Startswith, json!("user_"));
950        let (_sql, params) = gen.generate(&clause).unwrap();
951        assert_eq!(params[0], json!("user\\_"));
952    }
953
954    #[test]
955    fn endswith_escapes_like_metacharacters() {
956        let gen = GenericWhereGenerator::new(PostgresDialect);
957        let clause = field("name", WhereOperator::Endswith, json!("100%"));
958        let (_sql, params) = gen.generate(&clause).unwrap();
959        assert_eq!(params[0], json!("100\\%"));
960    }
961
962    // ── Regex complexity guard (C5 fix verification) ──────────────────
963
964    #[test]
965    fn regex_rejects_nested_quantifiers() {
966        let gen = GenericWhereGenerator::new(PostgresDialect);
967        let clause = field("name", WhereOperator::Regex, json!("(a+)+$"));
968        let err = gen.generate(&clause).unwrap_err();
969        let msg = err.to_string();
970        assert!(msg.contains("nested quantifiers"), "Got: {msg}");
971    }
972
973    #[test]
974    fn regex_rejects_star_star_pattern() {
975        let gen = GenericWhereGenerator::new(PostgresDialect);
976        let clause = field("name", WhereOperator::Regex, json!("(x*)*"));
977        let err = gen.generate(&clause).unwrap_err();
978        assert!(err.to_string().contains("nested quantifiers"));
979    }
980
981    #[test]
982    fn regex_rejects_too_long_pattern() {
983        let gen = GenericWhereGenerator::new(PostgresDialect);
984        let long_pattern = "a".repeat(1_001);
985        let clause = field("name", WhereOperator::Regex, json!(long_pattern));
986        let err = gen.generate(&clause).unwrap_err();
987        assert!(err.to_string().contains("maximum length"));
988    }
989
990    #[test]
991    fn regex_allows_safe_patterns() {
992        let gen = GenericWhereGenerator::new(PostgresDialect);
993        let clause = field("name", WhereOperator::Regex, json!("^[a-z]+$"));
994        assert!(gen.generate(&clause).is_ok());
995    }
996
997    #[test]
998    fn iregex_also_validates_pattern() {
999        let gen = GenericWhereGenerator::new(PostgresDialect);
1000        let clause = field("name", WhereOperator::Iregex, json!("(a+)+"));
1001        assert!(gen.generate(&clause).is_err());
1002    }
1003}