use crate::linter::rule::{LintContext, LintRule};
use crate::types::{issue_codes, Issue};
use sqlparser::ast::{BinaryOperator, Expr, Merge, Statement, Update};
use super::semantic_helpers::{visit_select_expressions, visit_selects_in_statement};
pub struct StructureConstantExpression;
impl LintRule for StructureConstantExpression {
fn code(&self) -> &'static str {
issue_codes::LINT_ST_010
}
fn name(&self) -> &'static str {
"Structure constant expression"
}
fn description(&self) -> &'static str {
"Redundant constant expression."
}
fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
let mut violation_count = statement_constant_predicate_count(statement);
visit_selects_in_statement(statement, &mut |select| {
visit_select_expressions(select, &mut |expr| {
violation_count += constant_predicate_count(expr);
});
});
(0..violation_count)
.map(|_| {
Issue::warning(
issue_codes::LINT_ST_010,
"Constant boolean expression detected in predicate.",
)
.with_statement(ctx.statement_index)
})
.collect()
}
}
fn statement_constant_predicate_count(statement: &Statement) -> usize {
match statement {
Statement::Update(Update { selection, .. }) => {
selection.as_ref().map_or(0, constant_predicate_count)
}
Statement::Delete(delete) => delete
.selection
.as_ref()
.map_or(0, constant_predicate_count),
Statement::Merge(Merge { on, .. }) => constant_predicate_count(on),
_ => 0,
}
}
fn constant_predicate_count(expr: &Expr) -> usize {
match expr {
Expr::BinaryOp { left, op, right } => {
let direct_match = is_supported_expression_comparison_operator(op)
&& !contains_comparison_operator_token(left)
&& !contains_comparison_operator_token(right)
&& match (literal_key(left), literal_key(right)) {
(Some(left_literal), Some(right_literal)) => {
is_supported_literal_comparison_operator(op)
&& !is_allowed_literal_comparison(op, &left_literal, &right_literal)
}
_ => expressions_equivalent_for_constant_check(left, right),
};
usize::from(direct_match)
+ constant_predicate_count(left)
+ constant_predicate_count(right)
}
Expr::UnaryOp { expr: inner, .. }
| Expr::Nested(inner)
| Expr::IsNull(inner)
| Expr::IsNotNull(inner)
| Expr::Cast { expr: inner, .. } => constant_predicate_count(inner),
Expr::InList { expr, list, .. } => {
constant_predicate_count(expr)
+ list.iter().map(constant_predicate_count).sum::<usize>()
}
Expr::Between {
expr, low, high, ..
} => {
constant_predicate_count(expr)
+ constant_predicate_count(low)
+ constant_predicate_count(high)
}
Expr::Case {
operand,
conditions,
else_result,
..
} => {
let operand_count = operand
.as_ref()
.map_or(0, |expr| constant_predicate_count(expr));
let condition_count = conditions
.iter()
.map(|when| {
constant_predicate_count(&when.condition)
+ constant_predicate_count(&when.result)
})
.sum::<usize>();
let else_count = else_result
.as_ref()
.map_or(0, |expr| constant_predicate_count(expr));
operand_count + condition_count + else_count
}
_ => 0,
}
}
fn is_supported_expression_comparison_operator(op: &BinaryOperator) -> bool {
matches!(
op,
BinaryOperator::Eq
| BinaryOperator::NotEq
| BinaryOperator::Lt
| BinaryOperator::Gt
| BinaryOperator::LtEq
| BinaryOperator::GtEq
)
}
fn is_supported_literal_comparison_operator(op: &BinaryOperator) -> bool {
matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq)
}
fn contains_comparison_operator_token(expr: &Expr) -> bool {
match expr {
Expr::BinaryOp { left, op, right } => {
is_supported_expression_comparison_operator(op)
|| contains_comparison_operator_token(left)
|| contains_comparison_operator_token(right)
}
Expr::AnyOp { left, right, .. } | Expr::AllOp { left, right, .. } => {
contains_comparison_operator_token(left) || contains_comparison_operator_token(right)
}
Expr::UnaryOp { expr: inner, .. }
| Expr::Nested(inner)
| Expr::IsNull(inner)
| Expr::IsNotNull(inner)
| Expr::Cast { expr: inner, .. } => contains_comparison_operator_token(inner),
Expr::InList { expr, list, .. } => {
contains_comparison_operator_token(expr)
|| list.iter().any(contains_comparison_operator_token)
}
Expr::Between {
expr, low, high, ..
} => {
contains_comparison_operator_token(expr)
|| contains_comparison_operator_token(low)
|| contains_comparison_operator_token(high)
}
Expr::Case {
operand,
conditions,
else_result,
..
} => {
operand
.as_ref()
.is_some_and(|expr| contains_comparison_operator_token(expr))
|| conditions.iter().any(|when| {
contains_comparison_operator_token(&when.condition)
|| contains_comparison_operator_token(&when.result)
})
|| else_result
.as_ref()
.is_some_and(|expr| contains_comparison_operator_token(expr))
}
_ => false,
}
}
fn is_allowed_literal_comparison(op: &BinaryOperator, left: &str, right: &str) -> bool {
*op == BinaryOperator::Eq && left == "1" && (right == "1" || right == "0")
}
fn literal_key(expr: &Expr) -> Option<String> {
match expr {
Expr::Value(value) => Some(value.to_string().to_ascii_uppercase()),
Expr::Nested(inner)
| Expr::UnaryOp { expr: inner, .. }
| Expr::Cast { expr: inner, .. } => literal_key(inner),
_ => None,
}
}
fn expr_equivalent(left: &Expr, right: &Expr) -> bool {
match (left, right) {
(Expr::Identifier(left_ident), Expr::Identifier(right_ident)) => {
left_ident.value.eq_ignore_ascii_case(&right_ident.value)
}
(Expr::CompoundIdentifier(left_parts), Expr::CompoundIdentifier(right_parts)) => {
left_parts.len() == right_parts.len()
&& left_parts
.iter()
.zip(right_parts.iter())
.all(|(left, right)| left.value.eq_ignore_ascii_case(&right.value))
}
(Expr::Nested(left_inner), _) => expr_equivalent(left_inner, right),
(_, Expr::Nested(right_inner)) => expr_equivalent(left, right_inner),
(
Expr::UnaryOp {
expr: left_inner, ..
},
_,
) => expr_equivalent(left_inner, right),
(
_,
Expr::UnaryOp {
expr: right_inner, ..
},
) => expr_equivalent(left, right_inner),
(
Expr::Cast {
expr: left_inner, ..
},
_,
) => expr_equivalent(left_inner, right),
(
_,
Expr::Cast {
expr: right_inner, ..
},
) => expr_equivalent(left, right_inner),
_ => false,
}
}
fn expressions_equivalent_for_constant_check(left: &Expr, right: &Expr) -> bool {
if std::mem::discriminant(left) != std::mem::discriminant(right) {
return false;
}
expr_equivalent(left, right)
|| normalize_expr_for_compare(left) == normalize_expr_for_compare(right)
}
fn normalize_expr_for_compare(expr: &Expr) -> String {
expr.to_string()
.chars()
.filter(|ch| !ch.is_whitespace())
.collect::<String>()
.to_ascii_uppercase()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_sql;
fn run(sql: &str) -> Vec<Issue> {
let statements = parse_sql(sql).expect("parse");
let rule = StructureConstantExpression;
statements
.iter()
.enumerate()
.flat_map(|(index, statement)| {
rule.check(
statement,
&LintContext {
sql,
statement_range: 0..sql.len(),
statement_index: index,
},
)
})
.collect()
}
#[test]
fn allows_normal_where_predicate() {
let issues = run("select * from foo where col = 3");
assert!(issues.is_empty());
}
#[test]
fn flags_self_comparison_in_where_clause() {
let issues = run("select * from foo where col = col");
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::LINT_ST_010);
}
#[test]
fn flags_self_comparison_with_inequality_operator() {
let issues = run("select * from foo where col < col");
assert_eq!(issues.len(), 1);
let issues = run("select * from foo where col >= col");
assert_eq!(issues.len(), 1);
}
#[test]
fn flags_self_comparison_in_join_predicate() {
let issues = run("select foo.a, bar.b from foo left join bar on foo.a = foo.a");
assert_eq!(issues.len(), 1);
}
#[test]
fn allows_expected_codegen_literals() {
let true_case = run("select col from foo where 1=1 and col = 'val'");
assert!(true_case.is_empty());
let false_case = run("select col from foo where 1=0 or col = 'val'");
assert!(false_case.is_empty());
}
#[test]
fn flags_disallowed_literal_comparisons() {
let issues = run("select col from foo where 'a'!='b' and col = 'val'");
assert_eq!(issues.len(), 1);
let issues = run("select col from foo where 1 = 2 or col = 'val'");
assert_eq!(issues.len(), 1);
let issues = run("select col from foo where 1 <> 1 or col = 'val'");
assert_eq!(issues.len(), 1);
}
#[test]
fn allows_non_equality_literal_comparison() {
let issues = run("select col from foo where 1 < 2");
assert!(issues.is_empty());
}
#[test]
fn finds_nested_constant_predicates() {
let issues = run("select col from foo where cond=1 and (score=score or avg_score >= 3)");
assert_eq!(issues.len(), 1);
}
#[test]
fn counts_multiple_constant_predicates_in_single_expression_tree() {
let issues = run("select * from foo where col = col and score = score");
assert_eq!(issues.len(), 2);
}
#[test]
fn flags_equal_string_concat_expressions() {
let issues = run("select * from foo where 'A' || 'B' = 'A' || 'B'");
assert_eq!(issues.len(), 1);
}
#[test]
fn flags_equal_arithmetic_expressions() {
let issues = run("select * from foo where col + 1 = col + 1");
assert_eq!(issues.len(), 1);
}
#[test]
fn allows_non_equivalent_arithmetic_literal_comparison() {
let issues = run("select * from foo where 1 + 1 = 2");
assert!(issues.is_empty());
}
#[test]
fn allows_true_false_literal_predicates() {
let true_issues = run("select * from foo where true and x > 3");
assert!(true_issues.is_empty());
let false_issues = run("select * from foo where false OR x < 1 OR y != z");
assert!(false_issues.is_empty());
}
#[test]
fn flags_constant_predicate_in_update_where() {
let issues = run("update foo set a = 1 where col = col");
assert_eq!(issues.len(), 1);
}
#[test]
fn flags_constant_predicate_in_delete_where() {
let issues = run("delete from foo where col = col");
assert_eq!(issues.len(), 1);
}
}