1use std::{collections::HashSet, sync::Arc};
4
5use fraiseql_error::{FraiseQLError, Result};
6
7use super::counter::ParamCounter;
8use crate::{
9 dialect::SqlDialect,
10 where_clause::{WhereClause, WhereOperator},
11};
12
13pub(crate) fn escape_like_literal(s: &str) -> String {
18 s.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
19}
20
21const MAX_REGEX_PATTERN_LEN: usize = 1_000;
26
27fn validate_regex_pattern(pattern: &str) -> Result<()> {
35 if pattern.len() > MAX_REGEX_PATTERN_LEN {
36 return Err(FraiseQLError::Validation {
37 message: format!(
38 "Regex pattern exceeds maximum length of {MAX_REGEX_PATTERN_LEN} bytes"
39 ),
40 path: None,
41 });
42 }
43
44 let bytes = pattern.as_bytes();
48 let mut depth: i32 = 0;
49 let mut group_has_quantifier = Vec::new(); for (i, &b) in bytes.iter().enumerate() {
52 if i > 0 && bytes[i - 1] == b'\\' {
54 continue;
55 }
56 match b {
57 b'(' => {
58 depth += 1;
59 group_has_quantifier.push(false);
60 },
61 b')' => {
62 let had_quantifier = group_has_quantifier.pop().unwrap_or(false);
63 depth -= 1;
64 if had_quantifier {
66 let next = bytes.get(i + 1).copied();
67 if matches!(next, Some(b'+' | b'*' | b'?' | b'{')) {
68 return Err(FraiseQLError::Validation {
69 message: "Regex pattern contains nested quantifiers (potential \
70 ReDoS). Simplify the pattern to avoid `(…+)+`, \
71 `(…*)*`, or similar constructs."
72 .to_string(),
73 path: None,
74 });
75 }
76 }
77 },
78 b'+' | b'*' | b'?' => {
79 if let Some(flag) = group_has_quantifier.last_mut() {
80 *flag = true;
81 }
82 },
83 b'{' if depth > 0 => {
84 if let Some(flag) = group_has_quantifier.last_mut() {
85 *flag = true;
86 }
87 },
88 _ => {},
89 }
90 }
91
92 Ok(())
93}
94
95pub struct GenericWhereGenerator<D: SqlDialect> {
126 dialect: D,
127 counter: ParamCounter,
128 indexed_columns: Option<Arc<HashSet<String>>>,
131}
132
133impl<D: SqlDialect> GenericWhereGenerator<D> {
134 pub const fn new(dialect: D) -> Self {
136 Self {
137 dialect,
138 counter: ParamCounter::new(),
139 indexed_columns: None,
140 }
141 }
142
143 #[must_use]
148 pub fn with_indexed_columns(mut self, cols: Arc<HashSet<String>>) -> Self {
149 self.indexed_columns = Some(cols);
150 self
151 }
152
153 pub fn generate(&self, clause: &WhereClause) -> Result<(String, Vec<serde_json::Value>)> {
160 self.generate_with_param_offset(clause, 0)
161 }
162
163 pub fn generate_with_param_offset(
173 &self,
174 clause: &WhereClause,
175 offset: usize,
176 ) -> Result<(String, Vec<serde_json::Value>)> {
177 self.counter.reset_to(offset);
178 let mut params = Vec::new();
179 let sql = self.visit(clause, &mut params)?;
180 Ok((sql, params))
181 }
182
183 fn visit(&self, clause: &WhereClause, params: &mut Vec<serde_json::Value>) -> Result<String> {
186 match clause {
187 WhereClause::And(clauses) => self.visit_and(clauses, params),
188 WhereClause::Or(clauses) => self.visit_or(clauses, params),
189 WhereClause::Not(inner) => Ok(format!("NOT ({})", self.visit(inner, params)?)),
190 WhereClause::Field {
191 path,
192 operator,
193 value,
194 } => self.visit_field(path, operator, value, params),
195 WhereClause::NativeField {
196 column,
197 pg_cast,
198 operator,
199 value,
200 } => self.visit_native_field(column, pg_cast, operator, value, params),
201 }
202 }
203
204 fn visit_native_field(
210 &self,
211 column: &str,
212 pg_cast: &str,
213 operator: &WhereOperator,
214 value: &serde_json::Value,
215 params: &mut Vec<serde_json::Value>,
216 ) -> Result<String> {
217 let col_expr = self.dialect.quote_identifier(column);
218 let p = self.push_param(params, value.clone());
219 let rhs = if pg_cast.is_empty() {
220 p
221 } else {
222 self.dialect.cast_native_param(&p, pg_cast)
223 };
224 match operator {
225 WhereOperator::Eq => Ok(format!("{col_expr} = {rhs}")),
226 WhereOperator::Neq => {
227 let neq = self.dialect.neq_operator();
228 Ok(format!("{col_expr} {neq} {rhs}"))
229 },
230 _ => Err(FraiseQLError::validation(format!(
231 "Operator {operator:?} is not supported for native column conditions"
232 ))),
233 }
234 }
235
236 fn visit_and(
237 &self,
238 clauses: &[WhereClause],
239 params: &mut Vec<serde_json::Value>,
240 ) -> Result<String> {
241 if clauses.is_empty() {
242 return Ok(self.dialect.always_true().to_string());
243 }
244 let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
245 Ok(format!("({})", parts?.join(" AND ")))
246 }
247
248 fn visit_or(
249 &self,
250 clauses: &[WhereClause],
251 params: &mut Vec<serde_json::Value>,
252 ) -> Result<String> {
253 if clauses.is_empty() {
254 return Ok(self.dialect.always_false().to_string());
255 }
256 let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
257 Ok(format!("({})", parts?.join(" OR ")))
258 }
259
260 fn resolve_field_expr(&self, path: &[String]) -> String {
263 if let Some(indexed) = &self.indexed_columns {
265 let col_name = path.join("__");
266 if indexed.contains(&col_name) {
267 return self.dialect.quote_identifier(&col_name);
268 }
269 }
270 self.dialect.json_extract_scalar("data", path)
271 }
272
273 fn push_param(&self, params: &mut Vec<serde_json::Value>, v: serde_json::Value) -> String {
276 params.push(v);
277 self.dialect.placeholder(self.counter.next())
278 }
279
280 fn visit_field(
283 &self,
284 path: &[String],
285 operator: &WhereOperator,
286 value: &serde_json::Value,
287 params: &mut Vec<serde_json::Value>,
288 ) -> Result<String> {
289 let field_expr = self.resolve_field_expr(path);
290
291 match operator {
292 WhereOperator::Eq => {
294 let p = self.push_param(params, value.clone());
295 if value.is_number() {
296 let cast = self.dialect.cast_to_numeric(&field_expr);
297 let rhs = self.dialect.cast_param_numeric(&p);
300 Ok(format!("{cast} = {rhs}"))
301 } else if value.is_boolean() {
302 let cast = self.dialect.cast_to_boolean(&field_expr);
303 Ok(format!("{cast} = {p}"))
304 } else {
305 Ok(format!("{field_expr} = {p}"))
306 }
307 },
308 WhereOperator::Neq => {
309 let p = self.push_param(params, value.clone());
310 let neq = self.dialect.neq_operator();
311 if value.is_number() {
312 let cast = self.dialect.cast_to_numeric(&field_expr);
313 let rhs = self.dialect.cast_param_numeric(&p);
314 Ok(format!("{cast} {neq} {rhs}"))
315 } else if value.is_boolean() {
316 let cast = self.dialect.cast_to_boolean(&field_expr);
317 Ok(format!("{cast} {neq} {p}"))
318 } else {
319 Ok(format!("{field_expr} {neq} {p}"))
320 }
321 },
322 WhereOperator::Gt | WhereOperator::Gte | WhereOperator::Lt | WhereOperator::Lte => {
323 let op = match operator {
324 WhereOperator::Gt => ">",
325 WhereOperator::Gte => ">=",
326 WhereOperator::Lt => "<",
327 _ => "<=",
328 };
329 let cast = self.dialect.cast_to_numeric(&field_expr);
330 let p = self.push_param(params, value.clone());
331 let rhs = self.dialect.cast_param_numeric(&p);
332 Ok(format!("{cast} {op} {rhs}"))
333 },
334
335 WhereOperator::In | WhereOperator::Nin => {
337 let arr = value.as_array().ok_or_else(|| {
338 FraiseQLError::validation("IN operator requires an array value".to_string())
339 })?;
340 if arr.is_empty() {
341 return Ok(if matches!(operator, WhereOperator::In) {
342 self.dialect.always_false().to_string()
343 } else {
344 self.dialect.always_true().to_string()
345 });
346 }
347 let placeholders: Vec<_> =
348 arr.iter().map(|v| self.push_param(params, v.clone())).collect();
349 let in_list = placeholders.join(", ");
350 let sql = format!("{field_expr} IN ({in_list})");
351 Ok(if matches!(operator, WhereOperator::Nin) {
352 format!("NOT ({sql})")
353 } else {
354 sql
355 })
356 },
357
358 WhereOperator::IsNull => {
360 let is_null = value.as_bool().unwrap_or(true);
361 let null_op = if is_null { "IS NULL" } else { "IS NOT NULL" };
362 Ok(format!("{field_expr} {null_op}"))
363 },
364
365 WhereOperator::Contains => {
367 let val_str = self.require_str(value, "Contains")?;
368 let escaped = escape_like_literal(val_str);
369 let p = self.push_param(params, serde_json::Value::String(escaped));
370 let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
371 Ok(self.dialect.like_sql(&field_expr, &pattern))
372 },
373 WhereOperator::Icontains => {
374 let val_str = self.require_str(value, "Icontains")?;
375 let escaped = escape_like_literal(val_str);
376 let p = self.push_param(params, serde_json::Value::String(escaped));
377 let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
378 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
379 },
380 WhereOperator::Startswith => {
381 let val_str = self.require_str(value, "Startswith")?;
382 let escaped = escape_like_literal(val_str);
383 let p = self.push_param(params, serde_json::Value::String(escaped));
384 let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
385 Ok(self.dialect.like_sql(&field_expr, &pattern))
386 },
387 WhereOperator::Istartswith => {
388 let val_str = self.require_str(value, "Istartswith")?;
389 let escaped = escape_like_literal(val_str);
390 let p = self.push_param(params, serde_json::Value::String(escaped));
391 let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
392 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
393 },
394 WhereOperator::Endswith => {
395 let val_str = self.require_str(value, "Endswith")?;
396 let escaped = escape_like_literal(val_str);
397 let p = self.push_param(params, serde_json::Value::String(escaped));
398 let pattern = self.dialect.concat_sql(&["'%'", &p]);
399 Ok(self.dialect.like_sql(&field_expr, &pattern))
400 },
401 WhereOperator::Iendswith => {
402 let val_str = self.require_str(value, "Iendswith")?;
403 let escaped = escape_like_literal(val_str);
404 let p = self.push_param(params, serde_json::Value::String(escaped));
405 let pattern = self.dialect.concat_sql(&["'%'", &p]);
406 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
407 },
408 WhereOperator::Like => {
409 let p = self.push_param(params, value.clone());
410 Ok(self.dialect.like_sql(&field_expr, &p))
411 },
412 WhereOperator::Ilike => {
413 let p = self.push_param(params, value.clone());
414 Ok(self.dialect.ilike_sql(&field_expr, &p))
415 },
416 WhereOperator::Nlike => {
417 let p = self.push_param(params, value.clone());
418 Ok(format!("NOT ({})", self.dialect.like_sql(&field_expr, &p)))
419 },
420 WhereOperator::Nilike => {
421 let p = self.push_param(params, value.clone());
422 Ok(format!("NOT ({})", self.dialect.ilike_sql(&field_expr, &p)))
423 },
424
425 WhereOperator::Regex => {
427 if let Some(s) = value.as_str() {
428 validate_regex_pattern(s)?;
429 }
430 let p = self.push_param(params, value.clone());
431 self.dialect
432 .regex_sql(&field_expr, &p, false, false)
433 .map_err(|e| FraiseQLError::validation(e.to_string()))
434 },
435 WhereOperator::Iregex => {
436 if let Some(s) = value.as_str() {
437 validate_regex_pattern(s)?;
438 }
439 let p = self.push_param(params, value.clone());
440 self.dialect
441 .regex_sql(&field_expr, &p, true, false)
442 .map_err(|e| FraiseQLError::validation(e.to_string()))
443 },
444 WhereOperator::Nregex => {
445 if let Some(s) = value.as_str() {
446 validate_regex_pattern(s)?;
447 }
448 let p = self.push_param(params, value.clone());
449 self.dialect
450 .regex_sql(&field_expr, &p, false, true)
451 .map_err(|e| FraiseQLError::validation(e.to_string()))
452 },
453 WhereOperator::Niregex => {
454 if let Some(s) = value.as_str() {
455 validate_regex_pattern(s)?;
456 }
457 let p = self.push_param(params, value.clone());
458 self.dialect
459 .regex_sql(&field_expr, &p, true, true)
460 .map_err(|e| FraiseQLError::validation(e.to_string()))
461 },
462
463 WhereOperator::LenEq
465 | WhereOperator::LenNeq
466 | WhereOperator::LenGt
467 | WhereOperator::LenGte
468 | WhereOperator::LenLt
469 | WhereOperator::LenLte => {
470 let op = match operator {
471 WhereOperator::LenEq => "=",
472 WhereOperator::LenNeq => self.dialect.neq_operator(),
473 WhereOperator::LenGt => ">",
474 WhereOperator::LenGte => ">=",
475 WhereOperator::LenLt => "<",
476 _ => "<=",
477 };
478 let len_expr = self.dialect.json_array_length(&field_expr);
479 let p = self.push_param(params, value.clone());
480 Ok(format!("{len_expr} {op} {p}"))
481 },
482
483 WhereOperator::ArrayContains | WhereOperator::StrictlyContains => {
485 let p = self.push_param(params, value.clone());
488 self.dialect
489 .array_contains_sql(&field_expr, &p)
490 .map_err(|e| FraiseQLError::validation(e.to_string()))
491 },
492 WhereOperator::ArrayContainedBy => {
493 let p = self.push_param(params, value.clone());
494 self.dialect
495 .array_contained_by_sql(&field_expr, &p)
496 .map_err(|e| FraiseQLError::validation(e.to_string()))
497 },
498 WhereOperator::ArrayOverlaps => {
499 let p = self.push_param(params, value.clone());
500 self.dialect
501 .array_overlaps_sql(&field_expr, &p)
502 .map_err(|e| FraiseQLError::validation(e.to_string()))
503 },
504
505 WhereOperator::Matches => {
507 let p = self.push_param(params, value.clone());
508 self.dialect
509 .fts_matches_sql(&field_expr, &p)
510 .map_err(|e| FraiseQLError::validation(e.to_string()))
511 },
512 WhereOperator::PlainQuery => {
513 let p = self.push_param(params, value.clone());
514 self.dialect
515 .fts_plain_query_sql(&field_expr, &p)
516 .map_err(|e| FraiseQLError::validation(e.to_string()))
517 },
518 WhereOperator::PhraseQuery => {
519 let p = self.push_param(params, value.clone());
520 self.dialect
521 .fts_phrase_query_sql(&field_expr, &p)
522 .map_err(|e| FraiseQLError::validation(e.to_string()))
523 },
524 WhereOperator::WebsearchQuery => {
525 let p = self.push_param(params, value.clone());
526 self.dialect
527 .fts_websearch_query_sql(&field_expr, &p)
528 .map_err(|e| FraiseQLError::validation(e.to_string()))
529 },
530
531 WhereOperator::CosineDistance => {
533 let p = self.push_param(params, value.clone());
534 self.dialect
535 .vector_distance_sql("<=>", &field_expr, &p)
536 .map_err(|e| FraiseQLError::validation(e.to_string()))
537 },
538 WhereOperator::L2Distance => {
539 let p = self.push_param(params, value.clone());
540 self.dialect
541 .vector_distance_sql("<->", &field_expr, &p)
542 .map_err(|e| FraiseQLError::validation(e.to_string()))
543 },
544 WhereOperator::L1Distance => {
545 let p = self.push_param(params, value.clone());
546 self.dialect
547 .vector_distance_sql("<+>", &field_expr, &p)
548 .map_err(|e| FraiseQLError::validation(e.to_string()))
549 },
550 WhereOperator::HammingDistance => {
551 let p = self.push_param(params, value.clone());
552 self.dialect
553 .vector_distance_sql("<~>", &field_expr, &p)
554 .map_err(|e| FraiseQLError::validation(e.to_string()))
555 },
556 WhereOperator::InnerProduct => {
557 let p = self.push_param(params, value.clone());
558 self.dialect
559 .vector_distance_sql("<#>", &field_expr, &p)
560 .map_err(|e| FraiseQLError::validation(e.to_string()))
561 },
562 WhereOperator::JaccardDistance => {
563 let p = self.push_param(params, value.clone());
564 self.dialect
565 .jaccard_distance_sql(&field_expr, &p)
566 .map_err(|e| FraiseQLError::validation(e.to_string()))
567 },
568
569 WhereOperator::IsIPv4 => self
571 .dialect
572 .inet_check_sql(&field_expr, "IsIPv4")
573 .map_err(|e| FraiseQLError::validation(e.to_string())),
574 WhereOperator::IsIPv6 => self
575 .dialect
576 .inet_check_sql(&field_expr, "IsIPv6")
577 .map_err(|e| FraiseQLError::validation(e.to_string())),
578 WhereOperator::IsPrivate => self
579 .dialect
580 .inet_check_sql(&field_expr, "IsPrivate")
581 .map_err(|e| FraiseQLError::validation(e.to_string())),
582 WhereOperator::IsPublic => self
583 .dialect
584 .inet_check_sql(&field_expr, "IsPublic")
585 .map_err(|e| FraiseQLError::validation(e.to_string())),
586 WhereOperator::IsLoopback => self
587 .dialect
588 .inet_check_sql(&field_expr, "IsLoopback")
589 .map_err(|e| FraiseQLError::validation(e.to_string())),
590 WhereOperator::InSubnet => {
591 let p = self.push_param(params, value.clone());
592 self.dialect
593 .inet_binary_sql("<<", &field_expr, &p)
594 .map_err(|e| FraiseQLError::validation(e.to_string()))
595 },
596 WhereOperator::ContainsSubnet | WhereOperator::ContainsIP => {
597 let p = self.push_param(params, value.clone());
598 self.dialect
599 .inet_binary_sql(">>", &field_expr, &p)
600 .map_err(|e| FraiseQLError::validation(e.to_string()))
601 },
602 WhereOperator::Overlaps => {
603 let p = self.push_param(params, value.clone());
604 self.dialect
605 .inet_binary_sql("&&", &field_expr, &p)
606 .map_err(|e| FraiseQLError::validation(e.to_string()))
607 },
608
609 WhereOperator::AncestorOf => {
611 let p = self.push_param(params, value.clone());
612 self.dialect
613 .ltree_binary_sql("@>", &field_expr, &p, "ltree")
614 .map_err(|e| FraiseQLError::validation(e.to_string()))
615 },
616 WhereOperator::DescendantOf => {
617 let p = self.push_param(params, value.clone());
618 self.dialect
619 .ltree_binary_sql("<@", &field_expr, &p, "ltree")
620 .map_err(|e| FraiseQLError::validation(e.to_string()))
621 },
622 WhereOperator::MatchesLquery => {
623 let p = self.push_param(params, value.clone());
624 self.dialect
625 .ltree_binary_sql("~", &field_expr, &p, "lquery")
626 .map_err(|e| FraiseQLError::validation(e.to_string()))
627 },
628 WhereOperator::MatchesLtxtquery => {
629 let p = self.push_param(params, value.clone());
630 self.dialect
631 .ltree_binary_sql("@", &field_expr, &p, "ltxtquery")
632 .map_err(|e| FraiseQLError::validation(e.to_string()))
633 },
634 WhereOperator::MatchesAnyLquery => {
635 let arr = value.as_array().ok_or_else(|| {
636 FraiseQLError::validation(
637 "matches_any_lquery operator requires an array value".to_string(),
638 )
639 })?;
640 if arr.is_empty() {
641 return Err(FraiseQLError::validation(
642 "matches_any_lquery requires at least one lquery".to_string(),
643 ));
644 }
645 let placeholders: Vec<_> = arr
646 .iter()
647 .map(|v| format!("{}::lquery", self.push_param(params, v.clone())))
648 .collect();
649 self.dialect
650 .ltree_any_lquery_sql(&field_expr, &placeholders)
651 .map_err(|e| FraiseQLError::validation(e.to_string()))
652 },
653 WhereOperator::DepthEq => {
654 let p = self.push_param(params, value.clone());
655 self.dialect
656 .ltree_depth_sql("=", &field_expr, &p)
657 .map_err(|e| FraiseQLError::validation(e.to_string()))
658 },
659 WhereOperator::DepthNeq => {
660 let p = self.push_param(params, value.clone());
661 self.dialect
662 .ltree_depth_sql("!=", &field_expr, &p)
663 .map_err(|e| FraiseQLError::validation(e.to_string()))
664 },
665 WhereOperator::DepthGt => {
666 let p = self.push_param(params, value.clone());
667 self.dialect
668 .ltree_depth_sql(">", &field_expr, &p)
669 .map_err(|e| FraiseQLError::validation(e.to_string()))
670 },
671 WhereOperator::DepthGte => {
672 let p = self.push_param(params, value.clone());
673 self.dialect
674 .ltree_depth_sql(">=", &field_expr, &p)
675 .map_err(|e| FraiseQLError::validation(e.to_string()))
676 },
677 WhereOperator::DepthLt => {
678 let p = self.push_param(params, value.clone());
679 self.dialect
680 .ltree_depth_sql("<", &field_expr, &p)
681 .map_err(|e| FraiseQLError::validation(e.to_string()))
682 },
683 WhereOperator::DepthLte => {
684 let p = self.push_param(params, value.clone());
685 self.dialect
686 .ltree_depth_sql("<=", &field_expr, &p)
687 .map_err(|e| FraiseQLError::validation(e.to_string()))
688 },
689 WhereOperator::Lca => {
690 let arr = value.as_array().ok_or_else(|| {
691 FraiseQLError::validation("lca operator requires an array value".to_string())
692 })?;
693 if arr.is_empty() {
694 return Err(FraiseQLError::validation(
695 "lca operator requires at least one path".to_string(),
696 ));
697 }
698 let placeholders: Vec<_> = arr
699 .iter()
700 .map(|v| format!("{}::ltree", self.push_param(params, v.clone())))
701 .collect();
702 self.dialect
703 .ltree_lca_sql(&field_expr, &placeholders)
704 .map_err(|e| FraiseQLError::validation(e.to_string()))
705 },
706
707 WhereOperator::Extended(op) => {
709 self.dialect.generate_extended_sql(op, &field_expr, params)
710 },
711
712 #[allow(unreachable_patterns)]
717 _ => Err(FraiseQLError::Validation {
719 message: format!(
720 "Operator {operator:?} is not supported by the {} dialect",
721 self.dialect.name()
722 ),
723 path: None,
724 }),
725 }
726 }
727
728 fn require_str<'a>(&self, value: &'a serde_json::Value, op: &'static str) -> Result<&'a str> {
729 value.as_str().ok_or_else(|| {
730 FraiseQLError::validation(format!("{op} operator requires a string value"))
731 })
732 }
733}
734
735impl<D: SqlDialect + Default> Default for GenericWhereGenerator<D> {
738 fn default() -> Self {
739 Self::new(D::default())
740 }
741}
742
743impl<D: SqlDialect> crate::filters::ExtendedOperatorHandler for GenericWhereGenerator<D> {
747 fn generate_extended_sql(
748 &self,
749 operator: &crate::filters::ExtendedOperator,
750 field_sql: &str,
751 params: &mut Vec<serde_json::Value>,
752 ) -> Result<String> {
753 self.dialect.generate_extended_sql(operator, field_sql, params)
754 }
755}
756
757#[cfg(test)]
758#[allow(clippy::unwrap_used)] mod tests {
760 use serde_json::json;
761
762 use super::GenericWhereGenerator;
763 use crate::{
764 dialect::PostgresDialect,
765 where_clause::{WhereClause, WhereOperator},
766 };
767
768 fn field(path: &str, op: WhereOperator, val: serde_json::Value) -> WhereClause {
769 WhereClause::Field {
770 path: vec![path.to_string()],
771 operator: op,
772 value: val,
773 }
774 }
775
776 #[test]
779 fn generic_eq_postgres() {
780 let gen = GenericWhereGenerator::new(PostgresDialect);
781 let clause = field("email", WhereOperator::Eq, json!("alice@example.com"));
782 let (sql, params) = gen.generate(&clause).unwrap();
783 assert_eq!(sql, "data->>'email' = $1");
784 assert_eq!(params, vec![json!("alice@example.com")]);
785 }
786
787 #[test]
788 fn generic_and_postgres() {
789 let gen = GenericWhereGenerator::new(PostgresDialect);
790 let clause = WhereClause::And(vec![
791 field("status", WhereOperator::Eq, json!("active")),
792 field("age", WhereOperator::Gte, json!(18)),
793 ]);
794 let (sql, params) = gen.generate(&clause).unwrap();
795 assert!(sql.starts_with("(data->>'status' = $1 AND"));
796 assert_eq!(params.len(), 2);
797 }
798
799 #[test]
800 fn generic_empty_and_returns_true() {
801 let gen = GenericWhereGenerator::new(PostgresDialect);
802 let clause = WhereClause::And(vec![]);
803 let (sql, params) = gen.generate(&clause).unwrap();
804 assert_eq!(sql, "TRUE");
805 assert!(params.is_empty());
806 }
807
808 #[test]
809 fn generic_empty_or_returns_false() {
810 let gen = GenericWhereGenerator::new(PostgresDialect);
811 let clause = WhereClause::Or(vec![]);
812 let (sql, params) = gen.generate(&clause).unwrap();
813 assert_eq!(sql, "FALSE");
814 assert!(params.is_empty());
815 }
816
817 #[test]
818 fn generic_not_postgres() {
819 let gen = GenericWhereGenerator::new(PostgresDialect);
820 let clause = WhereClause::Not(Box::new(field("deleted", WhereOperator::Eq, json!(true))));
821 let (sql, _) = gen.generate(&clause).unwrap();
822 assert!(sql.starts_with("NOT ("));
823 }
824
825 #[test]
826 fn generate_resets_counter() {
827 let gen = GenericWhereGenerator::new(PostgresDialect);
828 let clause = field("x", WhereOperator::Eq, json!(1));
829 let (sql1, _) = gen.generate(&clause).unwrap();
830 let (sql2, _) = gen.generate(&clause).unwrap();
831 assert_eq!(sql1, sql2);
832 assert!(sql1.contains("$1"));
834 assert!(!sql1.contains("$2"));
835 }
836
837 #[test]
838 fn generate_with_param_offset() {
839 let gen = GenericWhereGenerator::new(PostgresDialect);
840 let clause = field("email", WhereOperator::Eq, json!("a@b.com"));
841 let (sql, _) = gen.generate_with_param_offset(&clause, 2).unwrap();
842 assert!(sql.contains("$3"), "Expected $3 (offset 2 + 1), got: {sql}");
843 }
844
845 #[test]
848 fn generic_icontains_postgres() {
849 let gen = GenericWhereGenerator::new(PostgresDialect);
850 let clause = field("email", WhereOperator::Icontains, json!("example.com"));
851 let (sql, params) = gen.generate(&clause).unwrap();
852 assert_eq!(sql, "data->>'email' ILIKE '%' || $1 || '%'");
853 assert_eq!(params, vec![json!("example.com")]);
854 }
855
856 #[test]
857 fn generic_startswith_postgres() {
858 let gen = GenericWhereGenerator::new(PostgresDialect);
859 let clause = field("name", WhereOperator::Startswith, json!("Al"));
860 let (sql, params) = gen.generate(&clause).unwrap();
861 assert_eq!(sql, "data->>'name' LIKE $1 || '%'");
862 assert_eq!(params, vec![json!("Al")]);
863 }
864
865 #[test]
866 fn generic_endswith_postgres() {
867 let gen = GenericWhereGenerator::new(PostgresDialect);
868 let clause = field("name", WhereOperator::Endswith, json!("son"));
869 let (sql, params) = gen.generate(&clause).unwrap();
870 assert_eq!(sql, "data->>'name' LIKE '%' || $1");
871 assert_eq!(params, vec![json!("son")]);
872 }
873
874 #[test]
877 fn generic_in_postgres() {
878 let gen = GenericWhereGenerator::new(PostgresDialect);
879 let clause = field("status", WhereOperator::In, json!(["active", "pending"]));
880 let (sql, params) = gen.generate(&clause).unwrap();
881 assert_eq!(sql, "data->>'status' IN ($1, $2)");
882 assert_eq!(params.len(), 2);
883 }
884
885 #[test]
886 fn generic_in_empty_returns_false() {
887 let gen = GenericWhereGenerator::new(PostgresDialect);
888 let clause = field("status", WhereOperator::In, json!([]));
889 let (sql, params) = gen.generate(&clause).unwrap();
890 assert_eq!(sql, "FALSE");
891 assert!(params.is_empty());
892 }
893
894 #[test]
895 fn generic_nin_empty_returns_true() {
896 let gen = GenericWhereGenerator::new(PostgresDialect);
897 let clause = field("status", WhereOperator::Nin, json!([]));
898 let (sql, params) = gen.generate(&clause).unwrap();
899 assert_eq!(sql, "TRUE");
900 assert!(params.is_empty());
901 }
902
903 #[test]
906 fn no_value_in_sql_string() {
907 let gen = GenericWhereGenerator::new(PostgresDialect);
908 let injection = "'; DROP TABLE users; --";
909 let clause = field("email", WhereOperator::Eq, json!(injection));
910 let (sql, params) = gen.generate(&clause).unwrap();
911 assert!(!sql.contains(injection), "Value must not appear in SQL: {sql}");
912 assert_eq!(params[0], json!(injection));
913 }
914
915 #[test]
918 fn generic_pg_cosine_distance() {
919 let gen = GenericWhereGenerator::new(PostgresDialect);
920 let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1, 0.2]));
921 let (sql, params) = gen.generate(&clause).unwrap();
922 assert!(sql.contains("<=>"), "Expected <=> operator, got: {sql}");
923 assert!(sql.contains("::vector"), "Expected ::vector cast, got: {sql}");
924 assert_eq!(params.len(), 1);
925 }
926
927 #[test]
928 fn generic_pg_network_ipv4() {
929 let gen = GenericWhereGenerator::new(PostgresDialect);
930 let clause = field("ip", WhereOperator::IsIPv4, json!(true));
931 let (sql, _) = gen.generate(&clause).unwrap();
932 assert!(sql.contains("family("), "Expected family() call, got: {sql}");
933 assert!(sql.contains("= 4"), "Expected = 4, got: {sql}");
934 }
935
936 #[test]
937 fn generic_pg_ltree_ancestor_of() {
938 let gen = GenericWhereGenerator::new(PostgresDialect);
939 let clause = field("path", WhereOperator::AncestorOf, json!("europe.france"));
940 let (sql, params) = gen.generate(&clause).unwrap();
941 assert!(sql.contains("@>") && sql.contains("ltree"), "Got: {sql}");
942 assert_eq!(params.len(), 1);
943 }
944
945 #[test]
946 fn non_pg_vector_op_returns_error() {
947 use crate::dialect::MySqlDialect;
948 let gen = GenericWhereGenerator::new(MySqlDialect);
949 let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1]));
950 let err = gen.generate(&clause).unwrap_err();
951 let msg = err.to_string();
952 assert!(msg.contains("VectorDistance") || msg.contains("not supported"), "Got: {msg}");
953 }
954
955 #[test]
956 fn non_pg_network_op_returns_error() {
957 use crate::dialect::SqliteDialect;
958 let gen = GenericWhereGenerator::new(SqliteDialect);
959 let clause = field("ip", WhereOperator::IsIPv4, json!(true));
960 let err = gen.generate(&clause).unwrap_err();
961 let msg = err.to_string();
962 assert!(msg.contains("Inet") || msg.contains("not supported"), "Got: {msg}");
963 }
964
965 #[test]
968 fn escape_like_literal_escapes_percent_and_underscore() {
969 assert_eq!(super::escape_like_literal("50%"), "50\\%");
970 assert_eq!(super::escape_like_literal("user_name"), "user\\_name");
971 assert_eq!(super::escape_like_literal("a%b_c\\d"), "a\\%b\\_c\\\\d");
972 assert_eq!(super::escape_like_literal("plain"), "plain");
973 }
974
975 #[test]
976 fn contains_escapes_like_metacharacters() {
977 let gen = GenericWhereGenerator::new(PostgresDialect);
978 let clause = field("name", WhereOperator::Contains, json!("50%off"));
979 let (_sql, params) = gen.generate(&clause).unwrap();
980 assert_eq!(params[0], json!("50\\%off"));
982 }
983
984 #[test]
985 fn startswith_escapes_like_metacharacters() {
986 let gen = GenericWhereGenerator::new(PostgresDialect);
987 let clause = field("name", WhereOperator::Startswith, json!("user_"));
988 let (_sql, params) = gen.generate(&clause).unwrap();
989 assert_eq!(params[0], json!("user\\_"));
990 }
991
992 #[test]
993 fn endswith_escapes_like_metacharacters() {
994 let gen = GenericWhereGenerator::new(PostgresDialect);
995 let clause = field("name", WhereOperator::Endswith, json!("100%"));
996 let (_sql, params) = gen.generate(&clause).unwrap();
997 assert_eq!(params[0], json!("100\\%"));
998 }
999
1000 #[test]
1003 fn regex_rejects_nested_quantifiers() {
1004 let gen = GenericWhereGenerator::new(PostgresDialect);
1005 let clause = field("name", WhereOperator::Regex, json!("(a+)+$"));
1006 let err = gen.generate(&clause).unwrap_err();
1007 let msg = err.to_string();
1008 assert!(msg.contains("nested quantifiers"), "Got: {msg}");
1009 }
1010
1011 #[test]
1012 fn regex_rejects_star_star_pattern() {
1013 let gen = GenericWhereGenerator::new(PostgresDialect);
1014 let clause = field("name", WhereOperator::Regex, json!("(x*)*"));
1015 let err = gen.generate(&clause).unwrap_err();
1016 assert!(err.to_string().contains("nested quantifiers"));
1017 }
1018
1019 #[test]
1020 fn regex_rejects_too_long_pattern() {
1021 let gen = GenericWhereGenerator::new(PostgresDialect);
1022 let long_pattern = "a".repeat(1_001);
1023 let clause = field("name", WhereOperator::Regex, json!(long_pattern));
1024 let err = gen.generate(&clause).unwrap_err();
1025 assert!(err.to_string().contains("maximum length"));
1026 }
1027
1028 #[test]
1029 fn regex_allows_safe_patterns() {
1030 let gen = GenericWhereGenerator::new(PostgresDialect);
1031 let clause = field("name", WhereOperator::Regex, json!("^[a-z]+$"));
1032 assert!(gen.generate(&clause).is_ok());
1033 }
1034
1035 #[test]
1036 fn iregex_also_validates_pattern() {
1037 let gen = GenericWhereGenerator::new(PostgresDialect);
1038 let clause = field("name", WhereOperator::Iregex, json!("(a+)+"));
1039 assert!(gen.generate(&clause).is_err());
1040 }
1041}