1use std::{collections::HashSet, sync::Arc};
4
5use fraiseql_error::{FraiseQLError, Result};
6
7use super::counter::ParamCounter;
8use crate::{
9 dialect::SqlDialect,
10 where_clause::{WhereClause, WhereOperator},
11};
12
13pub(crate) fn escape_like_literal(s: &str) -> String {
18 s.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
19}
20
21const MAX_REGEX_PATTERN_LEN: usize = 1_000;
26
27fn validate_regex_pattern(pattern: &str) -> Result<()> {
35 if pattern.len() > MAX_REGEX_PATTERN_LEN {
36 return Err(FraiseQLError::Validation {
37 message: format!(
38 "Regex pattern exceeds maximum length of {MAX_REGEX_PATTERN_LEN} bytes"
39 ),
40 path: None,
41 });
42 }
43
44 let bytes = pattern.as_bytes();
48 let mut depth: i32 = 0;
49 let mut group_has_quantifier = Vec::new(); for (i, &b) in bytes.iter().enumerate() {
52 if i > 0 && bytes[i - 1] == b'\\' {
54 continue;
55 }
56 match b {
57 b'(' => {
58 depth += 1;
59 group_has_quantifier.push(false);
60 },
61 b')' => {
62 let had_quantifier = group_has_quantifier.pop().unwrap_or(false);
63 depth -= 1;
64 if had_quantifier {
66 let next = bytes.get(i + 1).copied();
67 if matches!(next, Some(b'+' | b'*' | b'?' | b'{')) {
68 return Err(FraiseQLError::Validation {
69 message: "Regex pattern contains nested quantifiers (potential \
70 ReDoS). Simplify the pattern to avoid `(…+)+`, \
71 `(…*)*`, or similar constructs."
72 .to_string(),
73 path: None,
74 });
75 }
76 }
77 },
78 b'+' | b'*' | b'?' => {
79 if let Some(flag) = group_has_quantifier.last_mut() {
80 *flag = true;
81 }
82 },
83 b'{' if depth > 0 => {
84 if let Some(flag) = group_has_quantifier.last_mut() {
85 *flag = true;
86 }
87 },
88 _ => {},
89 }
90 }
91
92 Ok(())
93}
94
95pub struct GenericWhereGenerator<D: SqlDialect> {
126 dialect: D,
127 counter: ParamCounter,
128 indexed_columns: Option<Arc<HashSet<String>>>,
131}
132
133impl<D: SqlDialect> GenericWhereGenerator<D> {
134 pub const fn new(dialect: D) -> Self {
136 Self {
137 dialect,
138 counter: ParamCounter::new(),
139 indexed_columns: None,
140 }
141 }
142
143 #[must_use]
148 pub fn with_indexed_columns(mut self, cols: Arc<HashSet<String>>) -> Self {
149 self.indexed_columns = Some(cols);
150 self
151 }
152
153 pub fn generate(&self, clause: &WhereClause) -> Result<(String, Vec<serde_json::Value>)> {
160 self.generate_with_param_offset(clause, 0)
161 }
162
163 pub fn generate_with_param_offset(
173 &self,
174 clause: &WhereClause,
175 offset: usize,
176 ) -> Result<(String, Vec<serde_json::Value>)> {
177 self.counter.reset_to(offset);
178 let mut params = Vec::new();
179 let sql = self.visit(clause, &mut params)?;
180 Ok((sql, params))
181 }
182
183 fn visit(&self, clause: &WhereClause, params: &mut Vec<serde_json::Value>) -> Result<String> {
186 match clause {
187 WhereClause::And(clauses) => self.visit_and(clauses, params),
188 WhereClause::Or(clauses) => self.visit_or(clauses, params),
189 WhereClause::Not(inner) => Ok(format!("NOT ({})", self.visit(inner, params)?)),
190 WhereClause::Field {
191 path,
192 operator,
193 value,
194 } => self.visit_field(path, operator, value, params),
195 }
196 }
197
198 fn visit_and(
199 &self,
200 clauses: &[WhereClause],
201 params: &mut Vec<serde_json::Value>,
202 ) -> Result<String> {
203 if clauses.is_empty() {
204 return Ok(self.dialect.always_true().to_string());
205 }
206 let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
207 Ok(format!("({})", parts?.join(" AND ")))
208 }
209
210 fn visit_or(
211 &self,
212 clauses: &[WhereClause],
213 params: &mut Vec<serde_json::Value>,
214 ) -> Result<String> {
215 if clauses.is_empty() {
216 return Ok(self.dialect.always_false().to_string());
217 }
218 let parts: Result<Vec<_>> = clauses.iter().map(|c| self.visit(c, params)).collect();
219 Ok(format!("({})", parts?.join(" OR ")))
220 }
221
222 fn resolve_field_expr(&self, path: &[String]) -> String {
225 if let Some(indexed) = &self.indexed_columns {
227 let col_name = path.join("__");
228 if indexed.contains(&col_name) {
229 return self.dialect.quote_identifier(&col_name);
230 }
231 }
232 self.dialect.json_extract_scalar("data", path)
233 }
234
235 fn push_param(&self, params: &mut Vec<serde_json::Value>, v: serde_json::Value) -> String {
238 params.push(v);
239 self.dialect.placeholder(self.counter.next())
240 }
241
242 fn visit_field(
245 &self,
246 path: &[String],
247 operator: &WhereOperator,
248 value: &serde_json::Value,
249 params: &mut Vec<serde_json::Value>,
250 ) -> Result<String> {
251 let field_expr = self.resolve_field_expr(path);
252
253 match operator {
254 WhereOperator::Eq => {
256 let p = self.push_param(params, value.clone());
257 if value.is_number() {
258 let cast = self.dialect.cast_to_numeric(&field_expr);
259 let rhs = self.dialect.cast_param_numeric(&p);
262 Ok(format!("{cast} = {rhs}"))
263 } else if value.is_boolean() {
264 let cast = self.dialect.cast_to_boolean(&field_expr);
265 Ok(format!("{cast} = {p}"))
266 } else {
267 Ok(format!("{field_expr} = {p}"))
268 }
269 },
270 WhereOperator::Neq => {
271 let p = self.push_param(params, value.clone());
272 let neq = self.dialect.neq_operator();
273 if value.is_number() {
274 let cast = self.dialect.cast_to_numeric(&field_expr);
275 let rhs = self.dialect.cast_param_numeric(&p);
276 Ok(format!("{cast} {neq} {rhs}"))
277 } else if value.is_boolean() {
278 let cast = self.dialect.cast_to_boolean(&field_expr);
279 Ok(format!("{cast} {neq} {p}"))
280 } else {
281 Ok(format!("{field_expr} {neq} {p}"))
282 }
283 },
284 WhereOperator::Gt | WhereOperator::Gte | WhereOperator::Lt | WhereOperator::Lte => {
285 let op = match operator {
286 WhereOperator::Gt => ">",
287 WhereOperator::Gte => ">=",
288 WhereOperator::Lt => "<",
289 _ => "<=",
290 };
291 let cast = self.dialect.cast_to_numeric(&field_expr);
292 let p = self.push_param(params, value.clone());
293 let rhs = self.dialect.cast_param_numeric(&p);
294 Ok(format!("{cast} {op} {rhs}"))
295 },
296
297 WhereOperator::In | WhereOperator::Nin => {
299 let arr = value.as_array().ok_or_else(|| {
300 FraiseQLError::validation("IN operator requires an array value".to_string())
301 })?;
302 if arr.is_empty() {
303 return Ok(if matches!(operator, WhereOperator::In) {
304 self.dialect.always_false().to_string()
305 } else {
306 self.dialect.always_true().to_string()
307 });
308 }
309 let placeholders: Vec<_> =
310 arr.iter().map(|v| self.push_param(params, v.clone())).collect();
311 let in_list = placeholders.join(", ");
312 let sql = format!("{field_expr} IN ({in_list})");
313 Ok(if matches!(operator, WhereOperator::Nin) {
314 format!("NOT ({sql})")
315 } else {
316 sql
317 })
318 },
319
320 WhereOperator::IsNull => {
322 let is_null = value.as_bool().unwrap_or(true);
323 let null_op = if is_null { "IS NULL" } else { "IS NOT NULL" };
324 Ok(format!("{field_expr} {null_op}"))
325 },
326
327 WhereOperator::Contains => {
329 let val_str = self.require_str(value, "Contains")?;
330 let escaped = escape_like_literal(val_str);
331 let p = self.push_param(params, serde_json::Value::String(escaped));
332 let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
333 Ok(self.dialect.like_sql(&field_expr, &pattern))
334 },
335 WhereOperator::Icontains => {
336 let val_str = self.require_str(value, "Icontains")?;
337 let escaped = escape_like_literal(val_str);
338 let p = self.push_param(params, serde_json::Value::String(escaped));
339 let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
340 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
341 },
342 WhereOperator::Startswith => {
343 let val_str = self.require_str(value, "Startswith")?;
344 let escaped = escape_like_literal(val_str);
345 let p = self.push_param(params, serde_json::Value::String(escaped));
346 let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
347 Ok(self.dialect.like_sql(&field_expr, &pattern))
348 },
349 WhereOperator::Istartswith => {
350 let val_str = self.require_str(value, "Istartswith")?;
351 let escaped = escape_like_literal(val_str);
352 let p = self.push_param(params, serde_json::Value::String(escaped));
353 let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
354 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
355 },
356 WhereOperator::Endswith => {
357 let val_str = self.require_str(value, "Endswith")?;
358 let escaped = escape_like_literal(val_str);
359 let p = self.push_param(params, serde_json::Value::String(escaped));
360 let pattern = self.dialect.concat_sql(&["'%'", &p]);
361 Ok(self.dialect.like_sql(&field_expr, &pattern))
362 },
363 WhereOperator::Iendswith => {
364 let val_str = self.require_str(value, "Iendswith")?;
365 let escaped = escape_like_literal(val_str);
366 let p = self.push_param(params, serde_json::Value::String(escaped));
367 let pattern = self.dialect.concat_sql(&["'%'", &p]);
368 Ok(self.dialect.ilike_sql(&field_expr, &pattern))
369 },
370 WhereOperator::Like => {
371 let p = self.push_param(params, value.clone());
372 Ok(self.dialect.like_sql(&field_expr, &p))
373 },
374 WhereOperator::Ilike => {
375 let p = self.push_param(params, value.clone());
376 Ok(self.dialect.ilike_sql(&field_expr, &p))
377 },
378 WhereOperator::Nlike => {
379 let p = self.push_param(params, value.clone());
380 Ok(format!("NOT ({})", self.dialect.like_sql(&field_expr, &p)))
381 },
382 WhereOperator::Nilike => {
383 let p = self.push_param(params, value.clone());
384 Ok(format!("NOT ({})", self.dialect.ilike_sql(&field_expr, &p)))
385 },
386
387 WhereOperator::Regex => {
389 if let Some(s) = value.as_str() {
390 validate_regex_pattern(s)?;
391 }
392 let p = self.push_param(params, value.clone());
393 self.dialect
394 .regex_sql(&field_expr, &p, false, false)
395 .map_err(|e| FraiseQLError::validation(e.to_string()))
396 },
397 WhereOperator::Iregex => {
398 if let Some(s) = value.as_str() {
399 validate_regex_pattern(s)?;
400 }
401 let p = self.push_param(params, value.clone());
402 self.dialect
403 .regex_sql(&field_expr, &p, true, false)
404 .map_err(|e| FraiseQLError::validation(e.to_string()))
405 },
406 WhereOperator::Nregex => {
407 if let Some(s) = value.as_str() {
408 validate_regex_pattern(s)?;
409 }
410 let p = self.push_param(params, value.clone());
411 self.dialect
412 .regex_sql(&field_expr, &p, false, true)
413 .map_err(|e| FraiseQLError::validation(e.to_string()))
414 },
415 WhereOperator::Niregex => {
416 if let Some(s) = value.as_str() {
417 validate_regex_pattern(s)?;
418 }
419 let p = self.push_param(params, value.clone());
420 self.dialect
421 .regex_sql(&field_expr, &p, true, true)
422 .map_err(|e| FraiseQLError::validation(e.to_string()))
423 },
424
425 WhereOperator::LenEq
427 | WhereOperator::LenNeq
428 | WhereOperator::LenGt
429 | WhereOperator::LenGte
430 | WhereOperator::LenLt
431 | WhereOperator::LenLte => {
432 let op = match operator {
433 WhereOperator::LenEq => "=",
434 WhereOperator::LenNeq => self.dialect.neq_operator(),
435 WhereOperator::LenGt => ">",
436 WhereOperator::LenGte => ">=",
437 WhereOperator::LenLt => "<",
438 _ => "<=",
439 };
440 let len_expr = self.dialect.json_array_length(&field_expr);
441 let p = self.push_param(params, value.clone());
442 Ok(format!("{len_expr} {op} {p}"))
443 },
444
445 WhereOperator::ArrayContains | WhereOperator::StrictlyContains => {
447 let p = self.push_param(params, value.clone());
450 self.dialect
451 .array_contains_sql(&field_expr, &p)
452 .map_err(|e| FraiseQLError::validation(e.to_string()))
453 },
454 WhereOperator::ArrayContainedBy => {
455 let p = self.push_param(params, value.clone());
456 self.dialect
457 .array_contained_by_sql(&field_expr, &p)
458 .map_err(|e| FraiseQLError::validation(e.to_string()))
459 },
460 WhereOperator::ArrayOverlaps => {
461 let p = self.push_param(params, value.clone());
462 self.dialect
463 .array_overlaps_sql(&field_expr, &p)
464 .map_err(|e| FraiseQLError::validation(e.to_string()))
465 },
466
467 WhereOperator::Matches => {
469 let p = self.push_param(params, value.clone());
470 self.dialect
471 .fts_matches_sql(&field_expr, &p)
472 .map_err(|e| FraiseQLError::validation(e.to_string()))
473 },
474 WhereOperator::PlainQuery => {
475 let p = self.push_param(params, value.clone());
476 self.dialect
477 .fts_plain_query_sql(&field_expr, &p)
478 .map_err(|e| FraiseQLError::validation(e.to_string()))
479 },
480 WhereOperator::PhraseQuery => {
481 let p = self.push_param(params, value.clone());
482 self.dialect
483 .fts_phrase_query_sql(&field_expr, &p)
484 .map_err(|e| FraiseQLError::validation(e.to_string()))
485 },
486 WhereOperator::WebsearchQuery => {
487 let p = self.push_param(params, value.clone());
488 self.dialect
489 .fts_websearch_query_sql(&field_expr, &p)
490 .map_err(|e| FraiseQLError::validation(e.to_string()))
491 },
492
493 WhereOperator::CosineDistance => {
495 let p = self.push_param(params, value.clone());
496 self.dialect
497 .vector_distance_sql("<=>", &field_expr, &p)
498 .map_err(|e| FraiseQLError::validation(e.to_string()))
499 },
500 WhereOperator::L2Distance => {
501 let p = self.push_param(params, value.clone());
502 self.dialect
503 .vector_distance_sql("<->", &field_expr, &p)
504 .map_err(|e| FraiseQLError::validation(e.to_string()))
505 },
506 WhereOperator::L1Distance => {
507 let p = self.push_param(params, value.clone());
508 self.dialect
509 .vector_distance_sql("<+>", &field_expr, &p)
510 .map_err(|e| FraiseQLError::validation(e.to_string()))
511 },
512 WhereOperator::HammingDistance => {
513 let p = self.push_param(params, value.clone());
514 self.dialect
515 .vector_distance_sql("<~>", &field_expr, &p)
516 .map_err(|e| FraiseQLError::validation(e.to_string()))
517 },
518 WhereOperator::InnerProduct => {
519 let p = self.push_param(params, value.clone());
520 self.dialect
521 .vector_distance_sql("<#>", &field_expr, &p)
522 .map_err(|e| FraiseQLError::validation(e.to_string()))
523 },
524 WhereOperator::JaccardDistance => {
525 let p = self.push_param(params, value.clone());
526 self.dialect
527 .jaccard_distance_sql(&field_expr, &p)
528 .map_err(|e| FraiseQLError::validation(e.to_string()))
529 },
530
531 WhereOperator::IsIPv4 => self
533 .dialect
534 .inet_check_sql(&field_expr, "IsIPv4")
535 .map_err(|e| FraiseQLError::validation(e.to_string())),
536 WhereOperator::IsIPv6 => self
537 .dialect
538 .inet_check_sql(&field_expr, "IsIPv6")
539 .map_err(|e| FraiseQLError::validation(e.to_string())),
540 WhereOperator::IsPrivate => self
541 .dialect
542 .inet_check_sql(&field_expr, "IsPrivate")
543 .map_err(|e| FraiseQLError::validation(e.to_string())),
544 WhereOperator::IsPublic => self
545 .dialect
546 .inet_check_sql(&field_expr, "IsPublic")
547 .map_err(|e| FraiseQLError::validation(e.to_string())),
548 WhereOperator::IsLoopback => self
549 .dialect
550 .inet_check_sql(&field_expr, "IsLoopback")
551 .map_err(|e| FraiseQLError::validation(e.to_string())),
552 WhereOperator::InSubnet => {
553 let p = self.push_param(params, value.clone());
554 self.dialect
555 .inet_binary_sql("<<", &field_expr, &p)
556 .map_err(|e| FraiseQLError::validation(e.to_string()))
557 },
558 WhereOperator::ContainsSubnet | WhereOperator::ContainsIP => {
559 let p = self.push_param(params, value.clone());
560 self.dialect
561 .inet_binary_sql(">>", &field_expr, &p)
562 .map_err(|e| FraiseQLError::validation(e.to_string()))
563 },
564 WhereOperator::Overlaps => {
565 let p = self.push_param(params, value.clone());
566 self.dialect
567 .inet_binary_sql("&&", &field_expr, &p)
568 .map_err(|e| FraiseQLError::validation(e.to_string()))
569 },
570
571 WhereOperator::AncestorOf => {
573 let p = self.push_param(params, value.clone());
574 self.dialect
575 .ltree_binary_sql("@>", &field_expr, &p, "ltree")
576 .map_err(|e| FraiseQLError::validation(e.to_string()))
577 },
578 WhereOperator::DescendantOf => {
579 let p = self.push_param(params, value.clone());
580 self.dialect
581 .ltree_binary_sql("<@", &field_expr, &p, "ltree")
582 .map_err(|e| FraiseQLError::validation(e.to_string()))
583 },
584 WhereOperator::MatchesLquery => {
585 let p = self.push_param(params, value.clone());
586 self.dialect
587 .ltree_binary_sql("~", &field_expr, &p, "lquery")
588 .map_err(|e| FraiseQLError::validation(e.to_string()))
589 },
590 WhereOperator::MatchesLtxtquery => {
591 let p = self.push_param(params, value.clone());
592 self.dialect
593 .ltree_binary_sql("@", &field_expr, &p, "ltxtquery")
594 .map_err(|e| FraiseQLError::validation(e.to_string()))
595 },
596 WhereOperator::MatchesAnyLquery => {
597 let arr = value.as_array().ok_or_else(|| {
598 FraiseQLError::validation(
599 "matches_any_lquery operator requires an array value".to_string(),
600 )
601 })?;
602 if arr.is_empty() {
603 return Err(FraiseQLError::validation(
604 "matches_any_lquery requires at least one lquery".to_string(),
605 ));
606 }
607 let placeholders: Vec<_> = arr
608 .iter()
609 .map(|v| format!("{}::lquery", self.push_param(params, v.clone())))
610 .collect();
611 self.dialect
612 .ltree_any_lquery_sql(&field_expr, &placeholders)
613 .map_err(|e| FraiseQLError::validation(e.to_string()))
614 },
615 WhereOperator::DepthEq => {
616 let p = self.push_param(params, value.clone());
617 self.dialect
618 .ltree_depth_sql("=", &field_expr, &p)
619 .map_err(|e| FraiseQLError::validation(e.to_string()))
620 },
621 WhereOperator::DepthNeq => {
622 let p = self.push_param(params, value.clone());
623 self.dialect
624 .ltree_depth_sql("!=", &field_expr, &p)
625 .map_err(|e| FraiseQLError::validation(e.to_string()))
626 },
627 WhereOperator::DepthGt => {
628 let p = self.push_param(params, value.clone());
629 self.dialect
630 .ltree_depth_sql(">", &field_expr, &p)
631 .map_err(|e| FraiseQLError::validation(e.to_string()))
632 },
633 WhereOperator::DepthGte => {
634 let p = self.push_param(params, value.clone());
635 self.dialect
636 .ltree_depth_sql(">=", &field_expr, &p)
637 .map_err(|e| FraiseQLError::validation(e.to_string()))
638 },
639 WhereOperator::DepthLt => {
640 let p = self.push_param(params, value.clone());
641 self.dialect
642 .ltree_depth_sql("<", &field_expr, &p)
643 .map_err(|e| FraiseQLError::validation(e.to_string()))
644 },
645 WhereOperator::DepthLte => {
646 let p = self.push_param(params, value.clone());
647 self.dialect
648 .ltree_depth_sql("<=", &field_expr, &p)
649 .map_err(|e| FraiseQLError::validation(e.to_string()))
650 },
651 WhereOperator::Lca => {
652 let arr = value.as_array().ok_or_else(|| {
653 FraiseQLError::validation("lca operator requires an array value".to_string())
654 })?;
655 if arr.is_empty() {
656 return Err(FraiseQLError::validation(
657 "lca operator requires at least one path".to_string(),
658 ));
659 }
660 let placeholders: Vec<_> = arr
661 .iter()
662 .map(|v| format!("{}::ltree", self.push_param(params, v.clone())))
663 .collect();
664 self.dialect
665 .ltree_lca_sql(&field_expr, &placeholders)
666 .map_err(|e| FraiseQLError::validation(e.to_string()))
667 },
668
669 WhereOperator::Extended(op) => {
671 self.dialect.generate_extended_sql(op, &field_expr, params)
672 },
673
674 #[allow(unreachable_patterns)]
679 _ => Err(FraiseQLError::Validation {
681 message: format!(
682 "Operator {operator:?} is not supported by the {} dialect",
683 self.dialect.name()
684 ),
685 path: None,
686 }),
687 }
688 }
689
690 fn require_str<'a>(&self, value: &'a serde_json::Value, op: &'static str) -> Result<&'a str> {
691 value.as_str().ok_or_else(|| {
692 FraiseQLError::validation(format!("{op} operator requires a string value"))
693 })
694 }
695}
696
697impl<D: SqlDialect + Default> Default for GenericWhereGenerator<D> {
700 fn default() -> Self {
701 Self::new(D::default())
702 }
703}
704
705impl<D: SqlDialect> crate::filters::ExtendedOperatorHandler for GenericWhereGenerator<D> {
709 fn generate_extended_sql(
710 &self,
711 operator: &crate::filters::ExtendedOperator,
712 field_sql: &str,
713 params: &mut Vec<serde_json::Value>,
714 ) -> Result<String> {
715 self.dialect.generate_extended_sql(operator, field_sql, params)
716 }
717}
718
719#[cfg(test)]
720#[allow(clippy::unwrap_used)] mod tests {
722 use serde_json::json;
723
724 use super::GenericWhereGenerator;
725 use crate::{
726 dialect::PostgresDialect,
727 where_clause::{WhereClause, WhereOperator},
728 };
729
730 fn field(path: &str, op: WhereOperator, val: serde_json::Value) -> WhereClause {
731 WhereClause::Field {
732 path: vec![path.to_string()],
733 operator: op,
734 value: val,
735 }
736 }
737
738 #[test]
741 fn generic_eq_postgres() {
742 let gen = GenericWhereGenerator::new(PostgresDialect);
743 let clause = field("email", WhereOperator::Eq, json!("alice@example.com"));
744 let (sql, params) = gen.generate(&clause).unwrap();
745 assert_eq!(sql, "data->>'email' = $1");
746 assert_eq!(params, vec![json!("alice@example.com")]);
747 }
748
749 #[test]
750 fn generic_and_postgres() {
751 let gen = GenericWhereGenerator::new(PostgresDialect);
752 let clause = WhereClause::And(vec![
753 field("status", WhereOperator::Eq, json!("active")),
754 field("age", WhereOperator::Gte, json!(18)),
755 ]);
756 let (sql, params) = gen.generate(&clause).unwrap();
757 assert!(sql.starts_with("(data->>'status' = $1 AND"));
758 assert_eq!(params.len(), 2);
759 }
760
761 #[test]
762 fn generic_empty_and_returns_true() {
763 let gen = GenericWhereGenerator::new(PostgresDialect);
764 let clause = WhereClause::And(vec![]);
765 let (sql, params) = gen.generate(&clause).unwrap();
766 assert_eq!(sql, "TRUE");
767 assert!(params.is_empty());
768 }
769
770 #[test]
771 fn generic_empty_or_returns_false() {
772 let gen = GenericWhereGenerator::new(PostgresDialect);
773 let clause = WhereClause::Or(vec![]);
774 let (sql, params) = gen.generate(&clause).unwrap();
775 assert_eq!(sql, "FALSE");
776 assert!(params.is_empty());
777 }
778
779 #[test]
780 fn generic_not_postgres() {
781 let gen = GenericWhereGenerator::new(PostgresDialect);
782 let clause = WhereClause::Not(Box::new(field("deleted", WhereOperator::Eq, json!(true))));
783 let (sql, _) = gen.generate(&clause).unwrap();
784 assert!(sql.starts_with("NOT ("));
785 }
786
787 #[test]
788 fn generate_resets_counter() {
789 let gen = GenericWhereGenerator::new(PostgresDialect);
790 let clause = field("x", WhereOperator::Eq, json!(1));
791 let (sql1, _) = gen.generate(&clause).unwrap();
792 let (sql2, _) = gen.generate(&clause).unwrap();
793 assert_eq!(sql1, sql2);
794 assert!(sql1.contains("$1"));
796 assert!(!sql1.contains("$2"));
797 }
798
799 #[test]
800 fn generate_with_param_offset() {
801 let gen = GenericWhereGenerator::new(PostgresDialect);
802 let clause = field("email", WhereOperator::Eq, json!("a@b.com"));
803 let (sql, _) = gen.generate_with_param_offset(&clause, 2).unwrap();
804 assert!(sql.contains("$3"), "Expected $3 (offset 2 + 1), got: {sql}");
805 }
806
807 #[test]
810 fn generic_icontains_postgres() {
811 let gen = GenericWhereGenerator::new(PostgresDialect);
812 let clause = field("email", WhereOperator::Icontains, json!("example.com"));
813 let (sql, params) = gen.generate(&clause).unwrap();
814 assert_eq!(sql, "data->>'email' ILIKE '%' || $1 || '%'");
815 assert_eq!(params, vec![json!("example.com")]);
816 }
817
818 #[test]
819 fn generic_startswith_postgres() {
820 let gen = GenericWhereGenerator::new(PostgresDialect);
821 let clause = field("name", WhereOperator::Startswith, json!("Al"));
822 let (sql, params) = gen.generate(&clause).unwrap();
823 assert_eq!(sql, "data->>'name' LIKE $1 || '%'");
824 assert_eq!(params, vec![json!("Al")]);
825 }
826
827 #[test]
828 fn generic_endswith_postgres() {
829 let gen = GenericWhereGenerator::new(PostgresDialect);
830 let clause = field("name", WhereOperator::Endswith, json!("son"));
831 let (sql, params) = gen.generate(&clause).unwrap();
832 assert_eq!(sql, "data->>'name' LIKE '%' || $1");
833 assert_eq!(params, vec![json!("son")]);
834 }
835
836 #[test]
839 fn generic_in_postgres() {
840 let gen = GenericWhereGenerator::new(PostgresDialect);
841 let clause = field("status", WhereOperator::In, json!(["active", "pending"]));
842 let (sql, params) = gen.generate(&clause).unwrap();
843 assert_eq!(sql, "data->>'status' IN ($1, $2)");
844 assert_eq!(params.len(), 2);
845 }
846
847 #[test]
848 fn generic_in_empty_returns_false() {
849 let gen = GenericWhereGenerator::new(PostgresDialect);
850 let clause = field("status", WhereOperator::In, json!([]));
851 let (sql, params) = gen.generate(&clause).unwrap();
852 assert_eq!(sql, "FALSE");
853 assert!(params.is_empty());
854 }
855
856 #[test]
857 fn generic_nin_empty_returns_true() {
858 let gen = GenericWhereGenerator::new(PostgresDialect);
859 let clause = field("status", WhereOperator::Nin, json!([]));
860 let (sql, params) = gen.generate(&clause).unwrap();
861 assert_eq!(sql, "TRUE");
862 assert!(params.is_empty());
863 }
864
865 #[test]
868 fn no_value_in_sql_string() {
869 let gen = GenericWhereGenerator::new(PostgresDialect);
870 let injection = "'; DROP TABLE users; --";
871 let clause = field("email", WhereOperator::Eq, json!(injection));
872 let (sql, params) = gen.generate(&clause).unwrap();
873 assert!(!sql.contains(injection), "Value must not appear in SQL: {sql}");
874 assert_eq!(params[0], json!(injection));
875 }
876
877 #[test]
880 fn generic_pg_cosine_distance() {
881 let gen = GenericWhereGenerator::new(PostgresDialect);
882 let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1, 0.2]));
883 let (sql, params) = gen.generate(&clause).unwrap();
884 assert!(sql.contains("<=>"), "Expected <=> operator, got: {sql}");
885 assert!(sql.contains("::vector"), "Expected ::vector cast, got: {sql}");
886 assert_eq!(params.len(), 1);
887 }
888
889 #[test]
890 fn generic_pg_network_ipv4() {
891 let gen = GenericWhereGenerator::new(PostgresDialect);
892 let clause = field("ip", WhereOperator::IsIPv4, json!(true));
893 let (sql, _) = gen.generate(&clause).unwrap();
894 assert!(sql.contains("family("), "Expected family() call, got: {sql}");
895 assert!(sql.contains("= 4"), "Expected = 4, got: {sql}");
896 }
897
898 #[test]
899 fn generic_pg_ltree_ancestor_of() {
900 let gen = GenericWhereGenerator::new(PostgresDialect);
901 let clause = field("path", WhereOperator::AncestorOf, json!("europe.france"));
902 let (sql, params) = gen.generate(&clause).unwrap();
903 assert!(sql.contains("@>") && sql.contains("ltree"), "Got: {sql}");
904 assert_eq!(params.len(), 1);
905 }
906
907 #[test]
908 fn non_pg_vector_op_returns_error() {
909 use crate::dialect::MySqlDialect;
910 let gen = GenericWhereGenerator::new(MySqlDialect);
911 let clause = field("embedding", WhereOperator::CosineDistance, json!([0.1]));
912 let err = gen.generate(&clause).unwrap_err();
913 let msg = err.to_string();
914 assert!(msg.contains("VectorDistance") || msg.contains("not supported"), "Got: {msg}");
915 }
916
917 #[test]
918 fn non_pg_network_op_returns_error() {
919 use crate::dialect::SqliteDialect;
920 let gen = GenericWhereGenerator::new(SqliteDialect);
921 let clause = field("ip", WhereOperator::IsIPv4, json!(true));
922 let err = gen.generate(&clause).unwrap_err();
923 let msg = err.to_string();
924 assert!(msg.contains("Inet") || msg.contains("not supported"), "Got: {msg}");
925 }
926
927 #[test]
930 fn escape_like_literal_escapes_percent_and_underscore() {
931 assert_eq!(super::escape_like_literal("50%"), "50\\%");
932 assert_eq!(super::escape_like_literal("user_name"), "user\\_name");
933 assert_eq!(super::escape_like_literal("a%b_c\\d"), "a\\%b\\_c\\\\d");
934 assert_eq!(super::escape_like_literal("plain"), "plain");
935 }
936
937 #[test]
938 fn contains_escapes_like_metacharacters() {
939 let gen = GenericWhereGenerator::new(PostgresDialect);
940 let clause = field("name", WhereOperator::Contains, json!("50%off"));
941 let (_sql, params) = gen.generate(&clause).unwrap();
942 assert_eq!(params[0], json!("50\\%off"));
944 }
945
946 #[test]
947 fn startswith_escapes_like_metacharacters() {
948 let gen = GenericWhereGenerator::new(PostgresDialect);
949 let clause = field("name", WhereOperator::Startswith, json!("user_"));
950 let (_sql, params) = gen.generate(&clause).unwrap();
951 assert_eq!(params[0], json!("user\\_"));
952 }
953
954 #[test]
955 fn endswith_escapes_like_metacharacters() {
956 let gen = GenericWhereGenerator::new(PostgresDialect);
957 let clause = field("name", WhereOperator::Endswith, json!("100%"));
958 let (_sql, params) = gen.generate(&clause).unwrap();
959 assert_eq!(params[0], json!("100\\%"));
960 }
961
962 #[test]
965 fn regex_rejects_nested_quantifiers() {
966 let gen = GenericWhereGenerator::new(PostgresDialect);
967 let clause = field("name", WhereOperator::Regex, json!("(a+)+$"));
968 let err = gen.generate(&clause).unwrap_err();
969 let msg = err.to_string();
970 assert!(msg.contains("nested quantifiers"), "Got: {msg}");
971 }
972
973 #[test]
974 fn regex_rejects_star_star_pattern() {
975 let gen = GenericWhereGenerator::new(PostgresDialect);
976 let clause = field("name", WhereOperator::Regex, json!("(x*)*"));
977 let err = gen.generate(&clause).unwrap_err();
978 assert!(err.to_string().contains("nested quantifiers"));
979 }
980
981 #[test]
982 fn regex_rejects_too_long_pattern() {
983 let gen = GenericWhereGenerator::new(PostgresDialect);
984 let long_pattern = "a".repeat(1_001);
985 let clause = field("name", WhereOperator::Regex, json!(long_pattern));
986 let err = gen.generate(&clause).unwrap_err();
987 assert!(err.to_string().contains("maximum length"));
988 }
989
990 #[test]
991 fn regex_allows_safe_patterns() {
992 let gen = GenericWhereGenerator::new(PostgresDialect);
993 let clause = field("name", WhereOperator::Regex, json!("^[a-z]+$"));
994 assert!(gen.generate(&clause).is_ok());
995 }
996
997 #[test]
998 fn iregex_also_validates_pattern() {
999 let gen = GenericWhereGenerator::new(PostgresDialect);
1000 let clause = field("name", WhereOperator::Iregex, json!("(a+)+"));
1001 assert!(gen.generate(&clause).is_err());
1002 }
1003}