1use std::fmt;
46
47const MAX_COLUMN_LEN: usize = 63;
52
53const MAX_CONDITIONS: usize = 256;
59
60pub fn is_safe_identifier(name: &str) -> bool {
70 !name.is_empty()
71 && name.len() <= MAX_COLUMN_LEN
72 && name.bytes().enumerate().all(|(i, b)| {
73 b == b'_'
74 || b.is_ascii_alphabetic()
75 || (i > 0 && b.is_ascii_digit())
76 })
77}
78
79#[derive(Debug, Clone, PartialEq)]
88pub enum SqlValue {
89 Text(String),
91 Integer(i64),
93 Float(f64),
95 Boolean(bool),
97 Null,
99}
100
101impl SqlValue {
102 pub fn type_name(&self) -> &'static str {
104 match self {
105 SqlValue::Text(_) => "text",
106 SqlValue::Integer(_) => "integer",
107 SqlValue::Float(_) => "float",
108 SqlValue::Boolean(_) => "boolean",
109 SqlValue::Null => "null",
110 }
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum Operator {
122 Eq,
124 Ne,
126 Gt,
128 Ge,
130 Lt,
132 Le,
134 Like,
136}
137
138impl Operator {
139 pub fn as_sql(self) -> &'static str {
141 match self {
142 Operator::Eq => "=",
143 Operator::Ne => "!=",
144 Operator::Gt => ">",
145 Operator::Ge => ">=",
146 Operator::Lt => "<",
147 Operator::Le => "<=",
148 Operator::Like => "LIKE",
149 }
150 }
151
152 fn from_symbol(sym: &str) -> Option<Operator> {
155 Some(match sym {
156 "=" | "==" => Operator::Eq,
157 "!=" | "<>" => Operator::Ne,
158 ">" => Operator::Gt,
159 ">=" => Operator::Ge,
160 "<" => Operator::Lt,
161 "<=" => Operator::Le,
162 _ => return None,
163 })
164 }
165
166 fn accepts_null(self) -> bool {
171 matches!(self, Operator::Eq | Operator::Ne)
172 }
173}
174
175impl fmt::Display for Operator {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.write_str(self.as_sql())
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub enum Connector {
188 And,
190 Or,
192}
193
194impl Connector {
195 pub fn as_sql(self) -> &'static str {
197 match self {
198 Connector::And => "AND",
199 Connector::Or => "OR",
200 }
201 }
202}
203
204impl fmt::Display for Connector {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 f.write_str(self.as_sql())
207 }
208}
209
210#[derive(Debug, Clone, PartialEq)]
216pub struct FilterCondition {
217 pub column: String,
219 pub op: Operator,
221 pub value: SqlValue,
223}
224
225#[derive(Debug, Clone, PartialEq, Default)]
230pub struct Filter {
231 pub conditions: Vec<FilterCondition>,
233 pub connectors: Vec<Connector>,
235}
236
237impl Filter {
238 pub fn is_empty(&self) -> bool {
240 self.conditions.is_empty()
241 }
242}
243
244#[derive(Debug, Clone, PartialEq)]
251pub enum FilterError {
252 UnexpectedChar { ch: char, pos: usize },
254 UnterminatedString { pos: usize },
256 InvalidNumber { token: String },
258 ExpectedColumn { found: String },
260 ColumnTooLong { column: String, len: usize },
262 ExpectedOperator { column: String, found: String },
264 MissingOperator { column: String },
266 ExpectedValue { found: String },
268 MissingValue { column: String },
270 UnquotedValue { token: String },
272 ExpectedConnector { found: String },
274 DanglingConnector { connector: Connector },
277 NullWithNonEqualityOp { column: String, op: Operator },
279 LikeRequiresText { column: String, found: &'static str },
281 TooManyConditions { limit: usize },
283}
284
285impl fmt::Display for FilterError {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 match self {
288 FilterError::UnexpectedChar { ch, pos } => write!(
289 f,
290 "unexpected character {ch:?} at position {pos} — the \
291 where-grammar alphabet is identifiers, comparison \
292 symbols, quoted strings and numbers"
293 ),
294 FilterError::UnterminatedString { pos } => {
295 write!(f, "unterminated string literal opened at position {pos}")
296 }
297 FilterError::InvalidNumber { token } => write!(
298 f,
299 "invalid numeric literal `{token}` — expected an integer \
300 or a finite decimal"
301 ),
302 FilterError::ExpectedColumn { found } => write!(
303 f,
304 "expected a column name, found `{found}` — a condition \
305 is `column op value`"
306 ),
307 FilterError::ColumnTooLong { column, len } => write!(
308 f,
309 "column name `{column}` is {len} bytes — exceeds the \
310 Postgres {MAX_COLUMN_LEN}-byte identifier limit"
311 ),
312 FilterError::ExpectedOperator { column, found } => write!(
313 f,
314 "expected a comparison operator after column `{column}`, \
315 found `{found}`"
316 ),
317 FilterError::MissingOperator { column } => write!(
318 f,
319 "expected a comparison operator after column `{column}`, \
320 found end of expression"
321 ),
322 FilterError::ExpectedValue { found } => write!(
323 f,
324 "expected a value, found `{found}`"
325 ),
326 FilterError::MissingValue { column } => write!(
327 f,
328 "expected a value for column `{column}`, found end of \
329 expression"
330 ),
331 FilterError::UnquotedValue { token } => write!(
332 f,
333 "unquoted value `{token}` — string values must be quoted \
334 (`'{token}'`); only numbers and `true`/`false`/`null` \
335 are bare"
336 ),
337 FilterError::ExpectedConnector { found } => write!(
338 f,
339 "expected `AND` or `OR` between conditions, found `{found}`"
340 ),
341 FilterError::DanglingConnector { connector } => write!(
342 f,
343 "expression ends with a dangling `{connector}` — a \
344 connector must be followed by another condition"
345 ),
346 FilterError::NullWithNonEqualityOp { column, op } => write!(
347 f,
348 "`null` compared with `{op}` on column `{column}` — \
349 `null` is only valid with `=` (renders `IS NULL`) or \
350 `!=` (renders `IS NOT NULL`)"
351 ),
352 FilterError::LikeRequiresText { column, found } => write!(
353 f,
354 "`LIKE` on column `{column}` requires a text value, \
355 found {found}"
356 ),
357 FilterError::TooManyConditions { limit } => write!(
358 f,
359 "where expression exceeds the {limit}-condition limit"
360 ),
361 }
362 }
363}
364
365impl std::error::Error for FilterError {}
366
367#[derive(Debug, Clone, PartialEq)]
375enum Token {
376 Word(String),
378 Symbol(String),
380 Str(String),
382 Num(String),
384}
385
386fn describe(tok: &Token) -> String {
388 match tok {
389 Token::Word(w) => w.clone(),
390 Token::Symbol(s) => s.clone(),
391 Token::Str(s) => format!("'{s}'"),
392 Token::Num(n) => n.clone(),
393 }
394}
395
396fn tokenize(expr: &str) -> Result<Vec<Token>, FilterError> {
399 let chars: Vec<char> = expr.chars().collect();
400 let n = chars.len();
401 let mut tokens: Vec<Token> = Vec::new();
402 let mut i = 0;
403
404 while i < n {
405 let c = chars[i];
406
407 if c.is_whitespace() {
409 i += 1;
410 continue;
411 }
412
413 if c == '\'' || c == '"' {
415 let quote = c;
416 let mut buf = String::new();
417 let mut j = i + 1;
418 let mut closed = false;
419 while j < n {
420 let cj = chars[j];
421 if cj == '\\' {
422 if j + 1 < n {
425 buf.push(chars[j + 1]);
426 j += 2;
427 continue;
428 }
429 break;
430 }
431 if cj == quote {
432 closed = true;
433 j += 1;
434 break;
435 }
436 buf.push(cj);
437 j += 1;
438 }
439 if !closed {
440 return Err(FilterError::UnterminatedString { pos: i });
441 }
442 tokens.push(Token::Str(buf));
443 i = j;
444 continue;
445 }
446
447 if c == '=' || c == '!' || c == '<' || c == '>' {
449 if i + 1 < n {
450 let two = match (c, chars[i + 1]) {
451 ('=', '=') => Some("=="),
452 ('!', '=') => Some("!="),
453 ('<', '=') => Some("<="),
454 ('>', '=') => Some(">="),
455 ('<', '>') => Some("<>"),
456 _ => None,
457 };
458 if let Some(sym) = two {
459 tokens.push(Token::Symbol(sym.to_string()));
460 i += 2;
461 continue;
462 }
463 }
464 if c == '!' {
466 return Err(FilterError::UnexpectedChar { ch: '!', pos: i });
467 }
468 tokens.push(Token::Symbol(c.to_string()));
469 i += 1;
470 continue;
471 }
472
473 if c.is_ascii_digit()
475 || (c == '-' && i + 1 < n && chars[i + 1].is_ascii_digit())
476 {
477 let start = i;
478 let mut j = if c == '-' { i + 1 } else { i };
479 while j < n && (chars[j].is_ascii_digit() || chars[j] == '.') {
480 j += 1;
481 }
482 tokens.push(Token::Num(chars[start..j].iter().collect()));
483 i = j;
484 continue;
485 }
486
487 if c.is_ascii_alphabetic() || c == '_' {
489 let start = i;
490 let mut j = i;
491 while j < n && (chars[j].is_ascii_alphanumeric() || chars[j] == '_')
492 {
493 j += 1;
494 }
495 tokens.push(Token::Word(chars[start..j].iter().collect()));
496 i = j;
497 continue;
498 }
499
500 return Err(FilterError::UnexpectedChar { ch: c, pos: i });
501 }
502
503 Ok(tokens)
504}
505
506fn parse_number(raw: &str) -> Result<SqlValue, FilterError> {
509 if let Ok(n) = raw.parse::<i64>() {
510 return Ok(SqlValue::Integer(n));
511 }
512 if let Ok(x) = raw.parse::<f64>() {
513 if x.is_finite() {
514 return Ok(SqlValue::Float(x));
515 }
516 }
517 Err(FilterError::InvalidNumber { token: raw.to_string() })
518}
519
520fn parse_value(tok: &Token) -> Result<SqlValue, FilterError> {
522 match tok {
523 Token::Str(s) => Ok(SqlValue::Text(s.clone())),
524 Token::Num(raw) => parse_number(raw),
525 Token::Word(w) => match w.to_ascii_lowercase().as_str() {
526 "true" => Ok(SqlValue::Boolean(true)),
527 "false" => Ok(SqlValue::Boolean(false)),
528 "null" => Ok(SqlValue::Null),
529 _ => Err(FilterError::UnquotedValue { token: w.clone() }),
530 },
531 Token::Symbol(s) => Err(FilterError::ExpectedValue { found: s.clone() }),
532 }
533}
534
535pub fn parse_filter(
555 expr: &str,
556 bindings: &std::collections::HashMap<String, String>,
557) -> Result<Filter, FilterError> {
558 let raw_tokens = tokenize(expr)?;
559 let tokens: Vec<Token> = raw_tokens
563 .into_iter()
564 .map(|t| match t {
565 Token::Str(s) => Token::Str(
566 crate::exec_context::interpolate_vars(&s, bindings),
567 ),
568 other => other,
569 })
570 .collect();
571 let mut filter = Filter::default();
572 let mut i = 0;
573 let n = tokens.len();
574
575 while i < n {
576 let column = match &tokens[i] {
578 Token::Word(w) => w.clone(),
579 other => {
580 return Err(FilterError::ExpectedColumn {
581 found: describe(other),
582 })
583 }
584 };
585 if column.len() > MAX_COLUMN_LEN {
586 return Err(FilterError::ColumnTooLong {
587 len: column.len(),
588 column,
589 });
590 }
591 i += 1;
592
593 if i >= n {
595 return Err(FilterError::MissingOperator { column });
596 }
597 let op = match &tokens[i] {
598 Token::Symbol(sym) => Operator::from_symbol(sym).ok_or_else(|| {
599 FilterError::ExpectedOperator {
600 column: column.clone(),
601 found: sym.clone(),
602 }
603 })?,
604 Token::Word(w) if w.eq_ignore_ascii_case("like") => Operator::Like,
605 other => {
606 return Err(FilterError::ExpectedOperator {
607 column,
608 found: describe(other),
609 })
610 }
611 };
612 i += 1;
613
614 if i >= n {
616 return Err(FilterError::MissingValue { column });
617 }
618 let value = parse_value(&tokens[i])?;
619 i += 1;
620
621 if matches!(value, SqlValue::Null) && !op.accepts_null() {
623 return Err(FilterError::NullWithNonEqualityOp { column, op });
624 }
625 if op == Operator::Like && !matches!(value, SqlValue::Text(_)) {
626 return Err(FilterError::LikeRequiresText {
627 column,
628 found: value.type_name(),
629 });
630 }
631
632 filter.conditions.push(FilterCondition { column, op, value });
633 if filter.conditions.len() > MAX_CONDITIONS {
634 return Err(FilterError::TooManyConditions {
635 limit: MAX_CONDITIONS,
636 });
637 }
638
639 if i < n {
641 let connector = match &tokens[i] {
642 Token::Word(w) if w.eq_ignore_ascii_case("and") => Connector::And,
643 Token::Word(w) if w.eq_ignore_ascii_case("or") => Connector::Or,
644 other => {
645 return Err(FilterError::ExpectedConnector {
646 found: describe(other),
647 })
648 }
649 };
650 i += 1;
651 filter.connectors.push(connector);
652 if i >= n {
654 return Err(FilterError::DanglingConnector { connector });
655 }
656 }
657 }
658
659 Ok(filter)
660}
661
662pub fn build_pg_where(
701 expr: &str,
702 param_offset: usize,
703 bindings: &std::collections::HashMap<String, String>,
704 column_types: &std::collections::HashMap<String, String>,
705) -> Result<(String, Vec<SqlValue>), FilterError> {
706 if expr.trim().is_empty() {
707 return Ok(("TRUE".to_string(), Vec::new()));
708 }
709
710 let filter = parse_filter(expr, bindings)?;
711 if filter.is_empty() {
712 return Ok(("TRUE".to_string(), Vec::new()));
713 }
714
715 let mut clause = String::new();
716 let mut params: Vec<SqlValue> = Vec::new();
717 let mut idx = param_offset + 1;
718
719 for (i, cond) in filter.conditions.iter().enumerate() {
720 if i > 0 {
721 clause.push(' ');
724 clause.push_str(filter.connectors[i - 1].as_sql());
725 clause.push(' ');
726 }
727 match &cond.value {
728 SqlValue::Null => {
729 let tail = match cond.op {
732 Operator::Ne => "IS NOT NULL",
733 _ => "IS NULL",
734 };
735 clause.push_str(&format!("\"{}\" {tail}", cond.column));
736 }
737 bound => {
738 let known_udt: Option<&str> =
741 match column_types.get(&cond.column) {
742 Some(udt) if is_safe_identifier(udt) => {
743 Some(udt.as_str())
744 }
745 _ => None,
746 };
747 let (column_sql, value_cast) = match (known_udt, cond.op) {
766 (Some(udt), _) => {
767 (format!("\"{}\"", cond.column), format!("::{udt}"))
768 }
769 (None, Operator::Eq | Operator::Ne) => {
770 (format!("\"{}\"::text", cond.column), String::new())
771 }
772 (None, _) => {
773 (format!("\"{}\"", cond.column), String::new())
774 }
775 };
776 clause.push_str(&format!(
777 "{column_sql} {} ${idx}{value_cast}",
778 cond.op.as_sql()
779 ));
780 params.push(bound.clone());
781 idx += 1;
782 }
783 }
784 }
785
786 Ok((clause, params))
787}
788
789#[cfg(test)]
794mod tests {
795 use super::*;
796
797 fn nb() -> std::collections::HashMap<String, String> {
801 std::collections::HashMap::new()
802 }
803
804 fn nt() -> std::collections::HashMap<String, String> {
808 std::collections::HashMap::new()
809 }
810
811 fn ok(expr: &str) -> (String, Vec<SqlValue>) {
812 build_pg_where(expr, 0, &nb(), &nt())
813 .expect("expected the filter to compile")
814 }
815
816 fn err(expr: &str) -> FilterError {
817 build_pg_where(expr, 0, &nb(), &nt())
818 .expect_err("expected a compile error")
819 }
820
821 #[test]
824 fn empty_expression_renders_true() {
825 assert_eq!(ok(""), ("TRUE".to_string(), vec![]));
826 }
827
828 #[test]
829 fn whitespace_only_renders_true() {
830 assert_eq!(ok(" \t \n "), ("TRUE".to_string(), vec![]));
831 }
832
833 #[test]
836 fn single_integer_condition() {
837 let (clause, params) = ok("id = 1");
840 assert_eq!(clause, "\"id\"::text = $1");
841 assert_eq!(params, vec![SqlValue::Integer(1)]);
842 }
843
844 #[test]
845 fn single_string_condition_single_quoted() {
846 let (clause, params) = ok("name = 'Alice'");
847 assert_eq!(clause, "\"name\"::text = $1");
848 assert_eq!(params, vec![SqlValue::Text("Alice".to_string())]);
849 }
850
851 #[test]
852 fn single_string_condition_double_quoted() {
853 let (_, params) = ok("name = \"Bob\"");
854 assert_eq!(params, vec![SqlValue::Text("Bob".to_string())]);
855 }
856
857 #[test]
858 fn negative_integer_value() {
859 let (clause, params) = ok("balance >= -100");
860 assert_eq!(clause, "\"balance\" >= $1");
861 assert_eq!(params, vec![SqlValue::Integer(-100)]);
862 }
863
864 #[test]
865 fn float_value() {
866 let (_, params) = ok("score > 3.14");
867 assert_eq!(params, vec![SqlValue::Float(3.14)]);
868 }
869
870 #[test]
871 fn boolean_values() {
872 assert_eq!(ok("active = true").1, vec![SqlValue::Boolean(true)]);
873 assert_eq!(ok("active = false").1, vec![SqlValue::Boolean(false)]);
874 }
875
876 #[test]
877 fn integer_overflowing_i64_falls_back_to_float() {
878 let (_, params) = ok("n = 10000000000000000000000000");
880 assert!(matches!(params[0], SqlValue::Float(_)));
881 }
882
883 #[test]
886 fn every_operator_renders_canonically() {
887 assert_eq!(ok("a = 1").0, "\"a\"::text = $1");
890 assert_eq!(ok("a != 1").0, "\"a\"::text != $1");
891 assert_eq!(ok("a > 1").0, "\"a\" > $1");
892 assert_eq!(ok("a >= 1").0, "\"a\" >= $1");
893 assert_eq!(ok("a < 1").0, "\"a\" < $1");
894 assert_eq!(ok("a <= 1").0, "\"a\" <= $1");
895 assert_eq!(ok("a LIKE 'x%'").0, "\"a\" LIKE $1");
896 }
897
898 #[test]
899 fn operator_aliases_normalize() {
900 assert_eq!(ok("a == 1").0, "\"a\"::text = $1");
902 assert_eq!(ok("a <> 1").0, "\"a\"::text != $1");
903 }
904
905 #[test]
906 fn like_is_case_insensitive_and_renders_uppercase() {
907 assert_eq!(ok("a like 'x%'").0, "\"a\" LIKE $1");
908 assert_eq!(ok("a LiKe 'x%'").0, "\"a\" LIKE $1");
909 }
910
911 #[test]
919 fn typed_column_comparison_casts_the_value_to_the_column_type() {
920 let b = std::collections::HashMap::from([(
921 "tid".to_string(),
922 "83d078e1-b372-42ba-9572-ff8dc521386e".to_string(),
923 )]);
924 let types = std::collections::HashMap::from([
925 ("tid".to_string(), "uuid".to_string()),
926 ("age".to_string(), "int4".to_string()),
927 ]);
928 let (clause, params) =
929 build_pg_where("tid = '${tid}'", 0, &b, &types).expect("compiles");
930 assert_eq!(
931 clause, "\"tid\" = $1::uuid",
932 "the value is cast to the column's introspected type"
933 );
934 assert_eq!(
935 params,
936 vec![SqlValue::Text(
937 "83d078e1-b372-42ba-9572-ff8dc521386e".to_string()
938 )]
939 );
940 let (clause, _) =
942 build_pg_where("age >= 18", 0, &nb(), &types).expect("compiles");
943 assert_eq!(clause, "\"age\" >= $1::int4");
944 }
945
946 #[test]
956 fn d4_unknown_type_equality_casts_the_column_to_text() {
957 assert_eq!(ok("id == 'x'").0, "\"id\"::text = $1");
958 assert_eq!(ok("id != 'x'").0, "\"id\"::text != $1");
959 assert_eq!(ok("id = 1").0, "\"id\"::text = $1");
961 }
962
963 #[test]
967 fn d4_unknown_type_ordering_stays_a_bare_placeholder() {
968 assert_eq!(ok("age > 18").0, "\"age\" > $1");
969 assert_eq!(ok("age >= 18").0, "\"age\" >= $1");
970 assert_eq!(ok("age < 18").0, "\"age\" < $1");
971 assert_eq!(ok("age <= 18").0, "\"age\" <= $1");
972 }
973
974 #[test]
977 fn d4_unknown_type_like_stays_a_bare_placeholder() {
978 assert_eq!(ok("name LIKE 'a%'").0, "\"name\" LIKE $1");
979 }
980
981 #[test]
984 fn d4_a_known_type_keeps_the_v1_36_4_value_cast() {
985 let types = std::collections::HashMap::from([
986 ("id".to_string(), "uuid".to_string()),
987 ("n".to_string(), "int4".to_string()),
988 ]);
989 assert_eq!(
990 build_pg_where("id == 'x'", 0, &nb(), &types).unwrap().0,
991 "\"id\" = $1::uuid"
992 );
993 assert_eq!(
994 build_pg_where("n > 5", 0, &nb(), &types).unwrap().0,
995 "\"n\" > $1::int4"
996 );
997 }
998
999 #[test]
1003 fn d4_an_unsafe_udt_is_not_spliced_and_equality_still_works() {
1004 let types = std::collections::HashMap::from([(
1005 "id".to_string(),
1006 "int4; DROP TABLE x".to_string(),
1007 )]);
1008 let (clause, _) =
1009 build_pg_where("id = 1", 0, &nb(), &types).expect("compiles");
1010 assert_eq!(
1011 clause, "\"id\"::text = $1",
1012 "the unsafe udt is not spliced; equality falls back to ::text"
1013 );
1014 let (clause, _) =
1016 build_pg_where("id > 1", 0, &nb(), &types).expect("compiles");
1017 assert_eq!(clause, "\"id\" > $1");
1018 }
1019
1020 #[test]
1021 fn typed_column_null_fold_is_not_cast() {
1022 assert_eq!(ok("id = null").0, "\"id\" IS NULL");
1024 }
1025
1026 #[test]
1029 fn two_conditions_joined_by_and() {
1030 let (clause, params) = ok("id = 1 AND name = 'Alice'");
1033 assert_eq!(clause, "\"id\"::text = $1 AND \"name\"::text = $2");
1034 assert_eq!(
1035 params,
1036 vec![SqlValue::Integer(1), SqlValue::Text("Alice".to_string())]
1037 );
1038 }
1039
1040 #[test]
1041 fn two_conditions_joined_by_or() {
1042 assert_eq!(
1043 ok("a = 1 OR b = 2").0,
1044 "\"a\"::text = $1 OR \"b\"::text = $2"
1045 );
1046 }
1047
1048 #[test]
1049 fn connectors_are_case_insensitive() {
1050 assert_eq!(
1051 ok("a = 1 and b = 2").0,
1052 "\"a\"::text = $1 AND \"b\"::text = $2"
1053 );
1054 assert_eq!(
1055 ok("a = 1 Or b = 2").0,
1056 "\"a\"::text = $1 OR \"b\"::text = $2"
1057 );
1058 }
1059
1060 #[test]
1061 fn three_condition_mixed_chain_preserves_order() {
1062 let (clause, params) = ok("a = 1 AND b = 2 OR c = 3");
1063 assert_eq!(
1064 clause,
1065 "\"a\"::text = $1 AND \"b\"::text = $2 OR \"c\"::text = $3"
1066 );
1067 assert_eq!(params.len(), 3);
1068 }
1069
1070 #[test]
1073 fn null_equality_folds_to_is_null() {
1074 let (clause, params) = ok("deleted_at = null");
1075 assert_eq!(clause, "\"deleted_at\" IS NULL");
1076 assert!(params.is_empty(), "IS NULL consumes no bind parameter");
1077 }
1078
1079 #[test]
1080 fn null_inequality_folds_to_is_not_null() {
1081 let (clause, params) = ok("deleted_at != NULL");
1082 assert_eq!(clause, "\"deleted_at\" IS NOT NULL");
1083 assert!(params.is_empty());
1084 }
1085
1086 #[test]
1087 fn null_does_not_occupy_a_parameter_slot() {
1088 let (clause, params) = ok("a = null AND b = 5");
1092 assert_eq!(clause, "\"a\" IS NULL AND \"b\"::text = $1");
1093 assert_eq!(params, vec![SqlValue::Integer(5)]);
1094 }
1095
1096 #[test]
1097 fn rendered_params_never_contain_null() {
1098 let (_, params) = ok("a = null AND b = 1 OR c != null");
1099 assert!(params.iter().all(|v| !matches!(v, SqlValue::Null)));
1100 }
1101
1102 #[test]
1105 fn param_offset_shifts_placeholder_numbering() {
1106 let (clause, _) = build_pg_where("id = 1", 2, &nb(), &nt()).unwrap();
1107 assert_eq!(clause, "\"id\"::text = $3");
1108 }
1109
1110 #[test]
1111 fn param_offset_shifts_every_placeholder() {
1112 let (clause, _) =
1113 build_pg_where("a = 1 AND b = 2", 5, &nb(), &nt()).unwrap();
1114 assert_eq!(clause, "\"a\"::text = $6 AND \"b\"::text = $7");
1115 }
1116
1117 #[test]
1120 fn injection_payload_inside_a_quoted_string_is_an_inert_bind_param() {
1121 let (clause, params) = ok("name = '; DROP TABLE users; --'");
1126 assert_eq!(clause, "\"name\"::text = $1");
1127 assert_eq!(
1128 params,
1129 vec![SqlValue::Text("; DROP TABLE users; --".to_string())]
1130 );
1131 }
1132
1133 #[test]
1134 fn injection_via_statement_terminator_is_rejected_at_tokenize() {
1135 assert!(matches!(
1137 err("name = 'x'; DROP TABLE users"),
1138 FilterError::UnexpectedChar { ch: ';', .. }
1139 ));
1140 }
1141
1142 #[test]
1143 fn injection_via_comment_marker_is_rejected() {
1144 assert!(matches!(
1147 err("a = 1 -- comment"),
1148 FilterError::UnexpectedChar { ch: '-', .. }
1149 ));
1150 }
1151
1152 #[test]
1153 fn injection_via_bare_or_tautology_is_rejected() {
1154 assert!(matches!(
1157 err("name = 'x' OR 1 = 1"),
1158 FilterError::ExpectedColumn { .. }
1159 ));
1160 }
1161
1162 #[test]
1165 fn escaped_quote_inside_string_is_resolved() {
1166 let (_, params) = ok("name = 'O\\'Brien'");
1167 assert_eq!(params, vec![SqlValue::Text("O'Brien".to_string())]);
1168 }
1169
1170 #[test]
1171 fn escaped_backslash_is_resolved() {
1172 let (_, params) = ok("path = 'a\\\\b'");
1173 assert_eq!(params, vec![SqlValue::Text("a\\b".to_string())]);
1174 }
1175
1176 #[test]
1179 fn unterminated_string_errors() {
1180 assert!(matches!(
1181 err("name = 'unclosed"),
1182 FilterError::UnterminatedString { .. }
1183 ));
1184 }
1185
1186 #[test]
1187 fn unexpected_character_errors() {
1188 assert!(matches!(
1189 err("a = 1 & b = 2"),
1190 FilterError::UnexpectedChar { ch: '&', .. }
1191 ));
1192 }
1193
1194 #[test]
1195 fn bare_bang_is_rejected() {
1196 assert!(matches!(
1197 err("a ! 1"),
1198 FilterError::UnexpectedChar { ch: '!', .. }
1199 ));
1200 }
1201
1202 #[test]
1203 fn invalid_number_errors() {
1204 assert!(matches!(
1205 err("a = 1.2.3"),
1206 FilterError::InvalidNumber { .. }
1207 ));
1208 }
1209
1210 #[test]
1213 fn missing_operator_errors() {
1214 assert!(matches!(err("id"), FilterError::MissingOperator { .. }));
1215 }
1216
1217 #[test]
1218 fn missing_value_errors() {
1219 assert!(matches!(err("id ="), FilterError::MissingValue { .. }));
1220 }
1221
1222 #[test]
1223 fn unquoted_string_value_errors() {
1224 assert!(matches!(
1225 err("status = active"),
1226 FilterError::UnquotedValue { .. }
1227 ));
1228 }
1229
1230 #[test]
1231 fn dangling_connector_errors() {
1232 assert!(matches!(
1234 err("id = 1 AND"),
1235 FilterError::DanglingConnector {
1236 connector: Connector::And
1237 }
1238 ));
1239 }
1240
1241 #[test]
1242 fn two_conditions_without_connector_errors() {
1243 assert!(matches!(
1244 err("a = 1 b = 2"),
1245 FilterError::ExpectedConnector { .. }
1246 ));
1247 }
1248
1249 #[test]
1250 fn column_position_non_identifier_errors() {
1251 assert!(matches!(err("1 = 1"), FilterError::ExpectedColumn { .. }));
1252 }
1253
1254 #[test]
1255 fn operator_position_non_operator_errors() {
1256 assert!(matches!(
1257 err("a b c"),
1258 FilterError::ExpectedOperator { .. }
1259 ));
1260 }
1261
1262 #[test]
1265 fn null_with_ordering_operator_errors() {
1266 assert!(matches!(
1267 err("score > null"),
1268 FilterError::NullWithNonEqualityOp { op: Operator::Gt, .. }
1269 ));
1270 }
1271
1272 #[test]
1273 fn null_with_like_errors() {
1274 assert!(matches!(
1275 err("name LIKE null"),
1276 FilterError::NullWithNonEqualityOp { op: Operator::Like, .. }
1277 ));
1278 }
1279
1280 #[test]
1281 fn like_with_non_text_value_errors() {
1282 assert!(matches!(
1283 err("age LIKE 5"),
1284 FilterError::LikeRequiresText { found: "integer", .. }
1285 ));
1286 }
1287
1288 #[test]
1291 fn column_at_the_length_limit_compiles() {
1292 let col = "c".repeat(MAX_COLUMN_LEN);
1293 assert!(build_pg_where(&format!("{col} = 1"), 0, &nb(), &nt()).is_ok());
1294 }
1295
1296 #[test]
1297 fn column_over_the_length_limit_errors() {
1298 let col = "c".repeat(MAX_COLUMN_LEN + 1);
1299 assert!(matches!(
1300 build_pg_where(&format!("{col} = 1"), 0, &nb(), &nt()),
1301 Err(FilterError::ColumnTooLong { .. })
1302 ));
1303 }
1304
1305 #[test]
1306 fn condition_count_at_the_limit_compiles() {
1307 let expr = (0..MAX_CONDITIONS)
1308 .map(|i| format!("c{i} = {i}"))
1309 .collect::<Vec<_>>()
1310 .join(" AND ");
1311 let (_, params) = build_pg_where(&expr, 0, &nb(), &nt()).unwrap();
1312 assert_eq!(params.len(), MAX_CONDITIONS);
1313 }
1314
1315 #[test]
1316 fn condition_count_over_the_limit_errors() {
1317 let expr = (0..=MAX_CONDITIONS)
1318 .map(|i| format!("c{i} = {i}"))
1319 .collect::<Vec<_>>()
1320 .join(" AND ");
1321 assert!(matches!(
1322 build_pg_where(&expr, 0, &nb(), &nt()),
1323 Err(FilterError::TooManyConditions { .. })
1324 ));
1325 }
1326
1327 #[test]
1330 fn parse_filter_exposes_the_typed_ast() {
1331 let filter = parse_filter("id = 1 AND name LIKE 'A%'", &nb()).unwrap();
1332 assert_eq!(filter.conditions.len(), 2);
1333 assert_eq!(filter.connectors, vec![Connector::And]);
1334 assert_eq!(
1335 filter.conditions[0],
1336 FilterCondition {
1337 column: "id".to_string(),
1338 op: Operator::Eq,
1339 value: SqlValue::Integer(1),
1340 }
1341 );
1342 assert_eq!(filter.conditions[1].op, Operator::Like);
1343 assert_eq!(
1344 filter.conditions[1].value,
1345 SqlValue::Text("A%".to_string())
1346 );
1347 }
1348
1349 #[test]
1350 fn parse_filter_invariant_connectors_plus_one_equals_conditions() {
1351 for expr in ["a = 1", "a = 1 AND b = 2", "a = 1 OR b = 2 AND c = 3"] {
1352 let f = parse_filter(expr, &nb()).unwrap();
1353 assert_eq!(f.connectors.len() + 1, f.conditions.len());
1354 }
1355 }
1356
1357 #[test]
1358 fn empty_filter_is_empty() {
1359 assert!(parse_filter("", &nb()).unwrap().is_empty());
1360 assert!(parse_filter(" ", &nb()).unwrap().is_empty());
1361 assert!(!parse_filter("a = 1", &nb()).unwrap().is_empty());
1362 }
1363
1364 #[test]
1367 fn safe_identifiers_are_accepted() {
1368 for name in ["users", "user_id", "_private", "Table1", "a", "_"] {
1369 assert!(is_safe_identifier(name), "`{name}` should be safe");
1370 }
1371 }
1372
1373 #[test]
1374 fn unsafe_identifiers_are_rejected() {
1375 for name in [
1376 "",
1377 "1abc", "user-name", "a b", "drop;table", "col\"injected", "naïve", ] {
1384 assert!(!is_safe_identifier(name), "`{name}` should be rejected");
1385 }
1386 }
1387
1388 #[test]
1389 fn identifier_length_boundary() {
1390 assert!(is_safe_identifier(&"c".repeat(MAX_COLUMN_LEN)));
1391 assert!(!is_safe_identifier(&"c".repeat(MAX_COLUMN_LEN + 1)));
1392 }
1393
1394 #[test]
1397 fn every_error_has_a_non_empty_display() {
1398 let samples = [
1399 err("a = 1 ;"),
1400 err("a = 'x"),
1401 err("a = 1.2.3"),
1402 err("1 = 1"),
1403 err("id"),
1404 err("id ="),
1405 err("a = b"),
1406 err("a = 1 b = 2"),
1407 err("a = 1 AND"),
1408 err("a > null"),
1409 err("a LIKE 5"),
1410 ];
1411 for e in samples {
1412 assert!(!e.to_string().is_empty());
1413 }
1414 }
1415}