use std::fmt;
const MAX_COLUMN_LEN: usize = 63;
const MAX_CONDITIONS: usize = 256;
pub fn is_safe_identifier(name: &str) -> bool {
!name.is_empty()
&& name.len() <= MAX_COLUMN_LEN
&& name.bytes().enumerate().all(|(i, b)| {
b == b'_'
|| b.is_ascii_alphabetic()
|| (i > 0 && b.is_ascii_digit())
})
}
#[derive(Debug, Clone, PartialEq)]
pub enum SqlValue {
Text(String),
Integer(i64),
Float(f64),
Boolean(bool),
Null,
}
impl SqlValue {
pub fn type_name(&self) -> &'static str {
match self {
SqlValue::Text(_) => "text",
SqlValue::Integer(_) => "integer",
SqlValue::Float(_) => "float",
SqlValue::Boolean(_) => "boolean",
SqlValue::Null => "null",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Operator {
Eq,
Ne,
Gt,
Ge,
Lt,
Le,
Like,
}
impl Operator {
pub fn as_sql(self) -> &'static str {
match self {
Operator::Eq => "=",
Operator::Ne => "!=",
Operator::Gt => ">",
Operator::Ge => ">=",
Operator::Lt => "<",
Operator::Le => "<=",
Operator::Like => "LIKE",
}
}
fn from_symbol(sym: &str) -> Option<Operator> {
Some(match sym {
"=" | "==" => Operator::Eq,
"!=" | "<>" => Operator::Ne,
">" => Operator::Gt,
">=" => Operator::Ge,
"<" => Operator::Lt,
"<=" => Operator::Le,
_ => return None,
})
}
fn accepts_null(self) -> bool {
matches!(self, Operator::Eq | Operator::Ne)
}
}
impl fmt::Display for Operator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_sql())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Connector {
And,
Or,
}
impl Connector {
pub fn as_sql(self) -> &'static str {
match self {
Connector::And => "AND",
Connector::Or => "OR",
}
}
}
impl fmt::Display for Connector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_sql())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FilterCondition {
pub column: String,
pub op: Operator,
pub value: SqlValue,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Filter {
pub conditions: Vec<FilterCondition>,
pub connectors: Vec<Connector>,
}
impl Filter {
pub fn is_empty(&self) -> bool {
self.conditions.is_empty()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FilterError {
UnexpectedChar { ch: char, pos: usize },
UnterminatedString { pos: usize },
InvalidNumber { token: String },
ExpectedColumn { found: String },
ColumnTooLong { column: String, len: usize },
ExpectedOperator { column: String, found: String },
MissingOperator { column: String },
ExpectedValue { found: String },
MissingValue { column: String },
UnquotedValue { token: String },
ExpectedConnector { found: String },
DanglingConnector { connector: Connector },
NullWithNonEqualityOp { column: String, op: Operator },
LikeRequiresText { column: String, found: &'static str },
TooManyConditions { limit: usize },
}
impl fmt::Display for FilterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FilterError::UnexpectedChar { ch, pos } => write!(
f,
"unexpected character {ch:?} at position {pos} — the \
where-grammar alphabet is identifiers, comparison \
symbols, quoted strings and numbers"
),
FilterError::UnterminatedString { pos } => {
write!(f, "unterminated string literal opened at position {pos}")
}
FilterError::InvalidNumber { token } => write!(
f,
"invalid numeric literal `{token}` — expected an integer \
or a finite decimal"
),
FilterError::ExpectedColumn { found } => write!(
f,
"expected a column name, found `{found}` — a condition \
is `column op value`"
),
FilterError::ColumnTooLong { column, len } => write!(
f,
"column name `{column}` is {len} bytes — exceeds the \
Postgres {MAX_COLUMN_LEN}-byte identifier limit"
),
FilterError::ExpectedOperator { column, found } => write!(
f,
"expected a comparison operator after column `{column}`, \
found `{found}`"
),
FilterError::MissingOperator { column } => write!(
f,
"expected a comparison operator after column `{column}`, \
found end of expression"
),
FilterError::ExpectedValue { found } => write!(
f,
"expected a value, found `{found}`"
),
FilterError::MissingValue { column } => write!(
f,
"expected a value for column `{column}`, found end of \
expression"
),
FilterError::UnquotedValue { token } => write!(
f,
"unquoted value `{token}` — string values must be quoted \
(`'{token}'`); only numbers and `true`/`false`/`null` \
are bare"
),
FilterError::ExpectedConnector { found } => write!(
f,
"expected `AND` or `OR` between conditions, found `{found}`"
),
FilterError::DanglingConnector { connector } => write!(
f,
"expression ends with a dangling `{connector}` — a \
connector must be followed by another condition"
),
FilterError::NullWithNonEqualityOp { column, op } => write!(
f,
"`null` compared with `{op}` on column `{column}` — \
`null` is only valid with `=` (renders `IS NULL`) or \
`!=` (renders `IS NOT NULL`)"
),
FilterError::LikeRequiresText { column, found } => write!(
f,
"`LIKE` on column `{column}` requires a text value, \
found {found}"
),
FilterError::TooManyConditions { limit } => write!(
f,
"where expression exceeds the {limit}-condition limit"
),
}
}
}
impl std::error::Error for FilterError {}
#[derive(Debug, Clone, PartialEq)]
enum Token {
Word(String),
Symbol(String),
Str(String),
Num(String),
}
fn describe(tok: &Token) -> String {
match tok {
Token::Word(w) => w.clone(),
Token::Symbol(s) => s.clone(),
Token::Str(s) => format!("'{s}'"),
Token::Num(n) => n.clone(),
}
}
fn tokenize(expr: &str) -> Result<Vec<Token>, FilterError> {
let chars: Vec<char> = expr.chars().collect();
let n = chars.len();
let mut tokens: Vec<Token> = Vec::new();
let mut i = 0;
while i < n {
let c = chars[i];
if c.is_whitespace() {
i += 1;
continue;
}
if c == '\'' || c == '"' {
let quote = c;
let mut buf = String::new();
let mut j = i + 1;
let mut closed = false;
while j < n {
let cj = chars[j];
if cj == '\\' {
if j + 1 < n {
buf.push(chars[j + 1]);
j += 2;
continue;
}
break;
}
if cj == quote {
closed = true;
j += 1;
break;
}
buf.push(cj);
j += 1;
}
if !closed {
return Err(FilterError::UnterminatedString { pos: i });
}
tokens.push(Token::Str(buf));
i = j;
continue;
}
if c == '=' || c == '!' || c == '<' || c == '>' {
if i + 1 < n {
let two = match (c, chars[i + 1]) {
('=', '=') => Some("=="),
('!', '=') => Some("!="),
('<', '=') => Some("<="),
('>', '=') => Some(">="),
('<', '>') => Some("<>"),
_ => None,
};
if let Some(sym) = two {
tokens.push(Token::Symbol(sym.to_string()));
i += 2;
continue;
}
}
if c == '!' {
return Err(FilterError::UnexpectedChar { ch: '!', pos: i });
}
tokens.push(Token::Symbol(c.to_string()));
i += 1;
continue;
}
if c.is_ascii_digit()
|| (c == '-' && i + 1 < n && chars[i + 1].is_ascii_digit())
{
let start = i;
let mut j = if c == '-' { i + 1 } else { i };
while j < n && (chars[j].is_ascii_digit() || chars[j] == '.') {
j += 1;
}
tokens.push(Token::Num(chars[start..j].iter().collect()));
i = j;
continue;
}
if c.is_ascii_alphabetic() || c == '_' {
let start = i;
let mut j = i;
while j < n && (chars[j].is_ascii_alphanumeric() || chars[j] == '_')
{
j += 1;
}
tokens.push(Token::Word(chars[start..j].iter().collect()));
i = j;
continue;
}
return Err(FilterError::UnexpectedChar { ch: c, pos: i });
}
Ok(tokens)
}
fn parse_number(raw: &str) -> Result<SqlValue, FilterError> {
if let Ok(n) = raw.parse::<i64>() {
return Ok(SqlValue::Integer(n));
}
if let Ok(x) = raw.parse::<f64>() {
if x.is_finite() {
return Ok(SqlValue::Float(x));
}
}
Err(FilterError::InvalidNumber { token: raw.to_string() })
}
fn parse_value(tok: &Token) -> Result<SqlValue, FilterError> {
match tok {
Token::Str(s) => Ok(SqlValue::Text(s.clone())),
Token::Num(raw) => parse_number(raw),
Token::Word(w) => match w.to_ascii_lowercase().as_str() {
"true" => Ok(SqlValue::Boolean(true)),
"false" => Ok(SqlValue::Boolean(false)),
"null" => Ok(SqlValue::Null),
_ => Err(FilterError::UnquotedValue { token: w.clone() }),
},
Token::Symbol(s) => Err(FilterError::ExpectedValue { found: s.clone() }),
}
}
pub fn parse_filter(
expr: &str,
bindings: &std::collections::HashMap<String, String>,
) -> Result<Filter, FilterError> {
let raw_tokens = tokenize(expr)?;
let tokens: Vec<Token> = raw_tokens
.into_iter()
.map(|t| match t {
Token::Str(s) => Token::Str(
crate::exec_context::interpolate_vars(&s, bindings),
),
other => other,
})
.collect();
let mut filter = Filter::default();
let mut i = 0;
let n = tokens.len();
while i < n {
let column = match &tokens[i] {
Token::Word(w) => w.clone(),
other => {
return Err(FilterError::ExpectedColumn {
found: describe(other),
})
}
};
if column.len() > MAX_COLUMN_LEN {
return Err(FilterError::ColumnTooLong {
len: column.len(),
column,
});
}
i += 1;
if i >= n {
return Err(FilterError::MissingOperator { column });
}
let op = match &tokens[i] {
Token::Symbol(sym) => Operator::from_symbol(sym).ok_or_else(|| {
FilterError::ExpectedOperator {
column: column.clone(),
found: sym.clone(),
}
})?,
Token::Word(w) if w.eq_ignore_ascii_case("like") => Operator::Like,
other => {
return Err(FilterError::ExpectedOperator {
column,
found: describe(other),
})
}
};
i += 1;
if i >= n {
return Err(FilterError::MissingValue { column });
}
let value = parse_value(&tokens[i])?;
i += 1;
if matches!(value, SqlValue::Null) && !op.accepts_null() {
return Err(FilterError::NullWithNonEqualityOp { column, op });
}
if op == Operator::Like && !matches!(value, SqlValue::Text(_)) {
return Err(FilterError::LikeRequiresText {
column,
found: value.type_name(),
});
}
filter.conditions.push(FilterCondition { column, op, value });
if filter.conditions.len() > MAX_CONDITIONS {
return Err(FilterError::TooManyConditions {
limit: MAX_CONDITIONS,
});
}
if i < n {
let connector = match &tokens[i] {
Token::Word(w) if w.eq_ignore_ascii_case("and") => Connector::And,
Token::Word(w) if w.eq_ignore_ascii_case("or") => Connector::Or,
other => {
return Err(FilterError::ExpectedConnector {
found: describe(other),
})
}
};
i += 1;
filter.connectors.push(connector);
if i >= n {
return Err(FilterError::DanglingConnector { connector });
}
}
}
Ok(filter)
}
pub fn build_pg_where(
expr: &str,
param_offset: usize,
bindings: &std::collections::HashMap<String, String>,
column_types: &std::collections::HashMap<String, String>,
) -> Result<(String, Vec<SqlValue>), FilterError> {
if expr.trim().is_empty() {
return Ok(("TRUE".to_string(), Vec::new()));
}
let filter = parse_filter(expr, bindings)?;
if filter.is_empty() {
return Ok(("TRUE".to_string(), Vec::new()));
}
let mut clause = String::new();
let mut params: Vec<SqlValue> = Vec::new();
let mut idx = param_offset + 1;
for (i, cond) in filter.conditions.iter().enumerate() {
if i > 0 {
clause.push(' ');
clause.push_str(filter.connectors[i - 1].as_sql());
clause.push(' ');
}
match &cond.value {
SqlValue::Null => {
let tail = match cond.op {
Operator::Ne => "IS NOT NULL",
_ => "IS NULL",
};
clause.push_str(&format!("\"{}\" {tail}", cond.column));
}
bound => {
let known_udt: Option<&str> =
match column_types.get(&cond.column) {
Some(udt) if is_safe_identifier(udt) => {
Some(udt.as_str())
}
_ => None,
};
let (column_sql, value_cast) = match (known_udt, cond.op) {
(Some(udt), _) => {
(format!("\"{}\"", cond.column), format!("::{udt}"))
}
(None, Operator::Eq | Operator::Ne) => {
(format!("\"{}\"::text", cond.column), String::new())
}
(None, _) => {
(format!("\"{}\"", cond.column), String::new())
}
};
clause.push_str(&format!(
"{column_sql} {} ${idx}{value_cast}",
cond.op.as_sql()
));
params.push(bound.clone());
idx += 1;
}
}
}
Ok((clause, params))
}
#[cfg(test)]
mod tests {
use super::*;
fn nb() -> std::collections::HashMap<String, String> {
std::collections::HashMap::new()
}
fn nt() -> std::collections::HashMap<String, String> {
std::collections::HashMap::new()
}
fn ok(expr: &str) -> (String, Vec<SqlValue>) {
build_pg_where(expr, 0, &nb(), &nt())
.expect("expected the filter to compile")
}
fn err(expr: &str) -> FilterError {
build_pg_where(expr, 0, &nb(), &nt())
.expect_err("expected a compile error")
}
#[test]
fn empty_expression_renders_true() {
assert_eq!(ok(""), ("TRUE".to_string(), vec![]));
}
#[test]
fn whitespace_only_renders_true() {
assert_eq!(ok(" \t \n "), ("TRUE".to_string(), vec![]));
}
#[test]
fn single_integer_condition() {
let (clause, params) = ok("id = 1");
assert_eq!(clause, "\"id\"::text = $1");
assert_eq!(params, vec![SqlValue::Integer(1)]);
}
#[test]
fn single_string_condition_single_quoted() {
let (clause, params) = ok("name = 'Alice'");
assert_eq!(clause, "\"name\"::text = $1");
assert_eq!(params, vec![SqlValue::Text("Alice".to_string())]);
}
#[test]
fn single_string_condition_double_quoted() {
let (_, params) = ok("name = \"Bob\"");
assert_eq!(params, vec![SqlValue::Text("Bob".to_string())]);
}
#[test]
fn negative_integer_value() {
let (clause, params) = ok("balance >= -100");
assert_eq!(clause, "\"balance\" >= $1");
assert_eq!(params, vec![SqlValue::Integer(-100)]);
}
#[test]
fn float_value() {
let (_, params) = ok("score > 3.14");
assert_eq!(params, vec![SqlValue::Float(3.14)]);
}
#[test]
fn boolean_values() {
assert_eq!(ok("active = true").1, vec![SqlValue::Boolean(true)]);
assert_eq!(ok("active = false").1, vec![SqlValue::Boolean(false)]);
}
#[test]
fn integer_overflowing_i64_falls_back_to_float() {
let (_, params) = ok("n = 10000000000000000000000000");
assert!(matches!(params[0], SqlValue::Float(_)));
}
#[test]
fn every_operator_renders_canonically() {
assert_eq!(ok("a = 1").0, "\"a\"::text = $1");
assert_eq!(ok("a != 1").0, "\"a\"::text != $1");
assert_eq!(ok("a > 1").0, "\"a\" > $1");
assert_eq!(ok("a >= 1").0, "\"a\" >= $1");
assert_eq!(ok("a < 1").0, "\"a\" < $1");
assert_eq!(ok("a <= 1").0, "\"a\" <= $1");
assert_eq!(ok("a LIKE 'x%'").0, "\"a\" LIKE $1");
}
#[test]
fn operator_aliases_normalize() {
assert_eq!(ok("a == 1").0, "\"a\"::text = $1");
assert_eq!(ok("a <> 1").0, "\"a\"::text != $1");
}
#[test]
fn like_is_case_insensitive_and_renders_uppercase() {
assert_eq!(ok("a like 'x%'").0, "\"a\" LIKE $1");
assert_eq!(ok("a LiKe 'x%'").0, "\"a\" LIKE $1");
}
#[test]
fn typed_column_comparison_casts_the_value_to_the_column_type() {
let b = std::collections::HashMap::from([(
"tid".to_string(),
"83d078e1-b372-42ba-9572-ff8dc521386e".to_string(),
)]);
let types = std::collections::HashMap::from([
("tid".to_string(), "uuid".to_string()),
("age".to_string(), "int4".to_string()),
]);
let (clause, params) =
build_pg_where("tid = '${tid}'", 0, &b, &types).expect("compiles");
assert_eq!(
clause, "\"tid\" = $1::uuid",
"the value is cast to the column's introspected type"
);
assert_eq!(
params,
vec![SqlValue::Text(
"83d078e1-b372-42ba-9572-ff8dc521386e".to_string()
)]
);
let (clause, _) =
build_pg_where("age >= 18", 0, &nb(), &types).expect("compiles");
assert_eq!(clause, "\"age\" >= $1::int4");
}
#[test]
fn d4_unknown_type_equality_casts_the_column_to_text() {
assert_eq!(ok("id == 'x'").0, "\"id\"::text = $1");
assert_eq!(ok("id != 'x'").0, "\"id\"::text != $1");
assert_eq!(ok("id = 1").0, "\"id\"::text = $1");
}
#[test]
fn d4_unknown_type_ordering_stays_a_bare_placeholder() {
assert_eq!(ok("age > 18").0, "\"age\" > $1");
assert_eq!(ok("age >= 18").0, "\"age\" >= $1");
assert_eq!(ok("age < 18").0, "\"age\" < $1");
assert_eq!(ok("age <= 18").0, "\"age\" <= $1");
}
#[test]
fn d4_unknown_type_like_stays_a_bare_placeholder() {
assert_eq!(ok("name LIKE 'a%'").0, "\"name\" LIKE $1");
}
#[test]
fn d4_a_known_type_keeps_the_v1_36_4_value_cast() {
let types = std::collections::HashMap::from([
("id".to_string(), "uuid".to_string()),
("n".to_string(), "int4".to_string()),
]);
assert_eq!(
build_pg_where("id == 'x'", 0, &nb(), &types).unwrap().0,
"\"id\" = $1::uuid"
);
assert_eq!(
build_pg_where("n > 5", 0, &nb(), &types).unwrap().0,
"\"n\" > $1::int4"
);
}
#[test]
fn d4_an_unsafe_udt_is_not_spliced_and_equality_still_works() {
let types = std::collections::HashMap::from([(
"id".to_string(),
"int4; DROP TABLE x".to_string(),
)]);
let (clause, _) =
build_pg_where("id = 1", 0, &nb(), &types).expect("compiles");
assert_eq!(
clause, "\"id\"::text = $1",
"the unsafe udt is not spliced; equality falls back to ::text"
);
let (clause, _) =
build_pg_where("id > 1", 0, &nb(), &types).expect("compiles");
assert_eq!(clause, "\"id\" > $1");
}
#[test]
fn typed_column_null_fold_is_not_cast() {
assert_eq!(ok("id = null").0, "\"id\" IS NULL");
}
#[test]
fn two_conditions_joined_by_and() {
let (clause, params) = ok("id = 1 AND name = 'Alice'");
assert_eq!(clause, "\"id\"::text = $1 AND \"name\"::text = $2");
assert_eq!(
params,
vec![SqlValue::Integer(1), SqlValue::Text("Alice".to_string())]
);
}
#[test]
fn two_conditions_joined_by_or() {
assert_eq!(
ok("a = 1 OR b = 2").0,
"\"a\"::text = $1 OR \"b\"::text = $2"
);
}
#[test]
fn connectors_are_case_insensitive() {
assert_eq!(
ok("a = 1 and b = 2").0,
"\"a\"::text = $1 AND \"b\"::text = $2"
);
assert_eq!(
ok("a = 1 Or b = 2").0,
"\"a\"::text = $1 OR \"b\"::text = $2"
);
}
#[test]
fn three_condition_mixed_chain_preserves_order() {
let (clause, params) = ok("a = 1 AND b = 2 OR c = 3");
assert_eq!(
clause,
"\"a\"::text = $1 AND \"b\"::text = $2 OR \"c\"::text = $3"
);
assert_eq!(params.len(), 3);
}
#[test]
fn null_equality_folds_to_is_null() {
let (clause, params) = ok("deleted_at = null");
assert_eq!(clause, "\"deleted_at\" IS NULL");
assert!(params.is_empty(), "IS NULL consumes no bind parameter");
}
#[test]
fn null_inequality_folds_to_is_not_null() {
let (clause, params) = ok("deleted_at != NULL");
assert_eq!(clause, "\"deleted_at\" IS NOT NULL");
assert!(params.is_empty());
}
#[test]
fn null_does_not_occupy_a_parameter_slot() {
let (clause, params) = ok("a = null AND b = 5");
assert_eq!(clause, "\"a\" IS NULL AND \"b\"::text = $1");
assert_eq!(params, vec![SqlValue::Integer(5)]);
}
#[test]
fn rendered_params_never_contain_null() {
let (_, params) = ok("a = null AND b = 1 OR c != null");
assert!(params.iter().all(|v| !matches!(v, SqlValue::Null)));
}
#[test]
fn param_offset_shifts_placeholder_numbering() {
let (clause, _) = build_pg_where("id = 1", 2, &nb(), &nt()).unwrap();
assert_eq!(clause, "\"id\"::text = $3");
}
#[test]
fn param_offset_shifts_every_placeholder() {
let (clause, _) =
build_pg_where("a = 1 AND b = 2", 5, &nb(), &nt()).unwrap();
assert_eq!(clause, "\"a\"::text = $6 AND \"b\"::text = $7");
}
#[test]
fn injection_payload_inside_a_quoted_string_is_an_inert_bind_param() {
let (clause, params) = ok("name = '; DROP TABLE users; --'");
assert_eq!(clause, "\"name\"::text = $1");
assert_eq!(
params,
vec![SqlValue::Text("; DROP TABLE users; --".to_string())]
);
}
#[test]
fn injection_via_statement_terminator_is_rejected_at_tokenize() {
assert!(matches!(
err("name = 'x'; DROP TABLE users"),
FilterError::UnexpectedChar { ch: ';', .. }
));
}
#[test]
fn injection_via_comment_marker_is_rejected() {
assert!(matches!(
err("a = 1 -- comment"),
FilterError::UnexpectedChar { ch: '-', .. }
));
}
#[test]
fn injection_via_bare_or_tautology_is_rejected() {
assert!(matches!(
err("name = 'x' OR 1 = 1"),
FilterError::ExpectedColumn { .. }
));
}
#[test]
fn escaped_quote_inside_string_is_resolved() {
let (_, params) = ok("name = 'O\\'Brien'");
assert_eq!(params, vec![SqlValue::Text("O'Brien".to_string())]);
}
#[test]
fn escaped_backslash_is_resolved() {
let (_, params) = ok("path = 'a\\\\b'");
assert_eq!(params, vec![SqlValue::Text("a\\b".to_string())]);
}
#[test]
fn unterminated_string_errors() {
assert!(matches!(
err("name = 'unclosed"),
FilterError::UnterminatedString { .. }
));
}
#[test]
fn unexpected_character_errors() {
assert!(matches!(
err("a = 1 & b = 2"),
FilterError::UnexpectedChar { ch: '&', .. }
));
}
#[test]
fn bare_bang_is_rejected() {
assert!(matches!(
err("a ! 1"),
FilterError::UnexpectedChar { ch: '!', .. }
));
}
#[test]
fn invalid_number_errors() {
assert!(matches!(
err("a = 1.2.3"),
FilterError::InvalidNumber { .. }
));
}
#[test]
fn missing_operator_errors() {
assert!(matches!(err("id"), FilterError::MissingOperator { .. }));
}
#[test]
fn missing_value_errors() {
assert!(matches!(err("id ="), FilterError::MissingValue { .. }));
}
#[test]
fn unquoted_string_value_errors() {
assert!(matches!(
err("status = active"),
FilterError::UnquotedValue { .. }
));
}
#[test]
fn dangling_connector_errors() {
assert!(matches!(
err("id = 1 AND"),
FilterError::DanglingConnector {
connector: Connector::And
}
));
}
#[test]
fn two_conditions_without_connector_errors() {
assert!(matches!(
err("a = 1 b = 2"),
FilterError::ExpectedConnector { .. }
));
}
#[test]
fn column_position_non_identifier_errors() {
assert!(matches!(err("1 = 1"), FilterError::ExpectedColumn { .. }));
}
#[test]
fn operator_position_non_operator_errors() {
assert!(matches!(
err("a b c"),
FilterError::ExpectedOperator { .. }
));
}
#[test]
fn null_with_ordering_operator_errors() {
assert!(matches!(
err("score > null"),
FilterError::NullWithNonEqualityOp { op: Operator::Gt, .. }
));
}
#[test]
fn null_with_like_errors() {
assert!(matches!(
err("name LIKE null"),
FilterError::NullWithNonEqualityOp { op: Operator::Like, .. }
));
}
#[test]
fn like_with_non_text_value_errors() {
assert!(matches!(
err("age LIKE 5"),
FilterError::LikeRequiresText { found: "integer", .. }
));
}
#[test]
fn column_at_the_length_limit_compiles() {
let col = "c".repeat(MAX_COLUMN_LEN);
assert!(build_pg_where(&format!("{col} = 1"), 0, &nb(), &nt()).is_ok());
}
#[test]
fn column_over_the_length_limit_errors() {
let col = "c".repeat(MAX_COLUMN_LEN + 1);
assert!(matches!(
build_pg_where(&format!("{col} = 1"), 0, &nb(), &nt()),
Err(FilterError::ColumnTooLong { .. })
));
}
#[test]
fn condition_count_at_the_limit_compiles() {
let expr = (0..MAX_CONDITIONS)
.map(|i| format!("c{i} = {i}"))
.collect::<Vec<_>>()
.join(" AND ");
let (_, params) = build_pg_where(&expr, 0, &nb(), &nt()).unwrap();
assert_eq!(params.len(), MAX_CONDITIONS);
}
#[test]
fn condition_count_over_the_limit_errors() {
let expr = (0..=MAX_CONDITIONS)
.map(|i| format!("c{i} = {i}"))
.collect::<Vec<_>>()
.join(" AND ");
assert!(matches!(
build_pg_where(&expr, 0, &nb(), &nt()),
Err(FilterError::TooManyConditions { .. })
));
}
#[test]
fn parse_filter_exposes_the_typed_ast() {
let filter = parse_filter("id = 1 AND name LIKE 'A%'", &nb()).unwrap();
assert_eq!(filter.conditions.len(), 2);
assert_eq!(filter.connectors, vec![Connector::And]);
assert_eq!(
filter.conditions[0],
FilterCondition {
column: "id".to_string(),
op: Operator::Eq,
value: SqlValue::Integer(1),
}
);
assert_eq!(filter.conditions[1].op, Operator::Like);
assert_eq!(
filter.conditions[1].value,
SqlValue::Text("A%".to_string())
);
}
#[test]
fn parse_filter_invariant_connectors_plus_one_equals_conditions() {
for expr in ["a = 1", "a = 1 AND b = 2", "a = 1 OR b = 2 AND c = 3"] {
let f = parse_filter(expr, &nb()).unwrap();
assert_eq!(f.connectors.len() + 1, f.conditions.len());
}
}
#[test]
fn empty_filter_is_empty() {
assert!(parse_filter("", &nb()).unwrap().is_empty());
assert!(parse_filter(" ", &nb()).unwrap().is_empty());
assert!(!parse_filter("a = 1", &nb()).unwrap().is_empty());
}
#[test]
fn safe_identifiers_are_accepted() {
for name in ["users", "user_id", "_private", "Table1", "a", "_"] {
assert!(is_safe_identifier(name), "`{name}` should be safe");
}
}
#[test]
fn unsafe_identifiers_are_rejected() {
for name in [
"",
"1abc", "user-name", "a b", "drop;table", "col\"injected", "naïve", ] {
assert!(!is_safe_identifier(name), "`{name}` should be rejected");
}
}
#[test]
fn identifier_length_boundary() {
assert!(is_safe_identifier(&"c".repeat(MAX_COLUMN_LEN)));
assert!(!is_safe_identifier(&"c".repeat(MAX_COLUMN_LEN + 1)));
}
#[test]
fn every_error_has_a_non_empty_display() {
let samples = [
err("a = 1 ;"),
err("a = 'x"),
err("a = 1.2.3"),
err("1 = 1"),
err("id"),
err("id ="),
err("a = b"),
err("a = 1 b = 2"),
err("a = 1 AND"),
err("a > null"),
err("a LIKE 5"),
];
for e in samples {
assert!(!e.to_string().is_empty());
}
}
}