flowscope-core 0.7.0

Core SQL lineage analysis engine
Documentation
//! LINT_ST_010: Constant boolean predicate.
//!
//! Detect redundant constant expressions in predicates.

use crate::linter::rule::{LintContext, LintRule};
use crate::types::{issue_codes, Issue};
use sqlparser::ast::{BinaryOperator, Expr, Merge, Statement, Update};

use super::semantic_helpers::{visit_select_expressions, visit_selects_in_statement};

pub struct StructureConstantExpression;

impl LintRule for StructureConstantExpression {
    fn code(&self) -> &'static str {
        issue_codes::LINT_ST_010
    }

    fn name(&self) -> &'static str {
        "Structure constant expression"
    }

    fn description(&self) -> &'static str {
        "Redundant constant expression."
    }

    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
        let mut violation_count = statement_constant_predicate_count(statement);

        visit_selects_in_statement(statement, &mut |select| {
            visit_select_expressions(select, &mut |expr| {
                violation_count += constant_predicate_count(expr);
            });
        });

        (0..violation_count)
            .map(|_| {
                Issue::warning(
                    issue_codes::LINT_ST_010,
                    "Constant boolean expression detected in predicate.",
                )
                .with_statement(ctx.statement_index)
            })
            .collect()
    }
}

fn statement_constant_predicate_count(statement: &Statement) -> usize {
    match statement {
        Statement::Update(Update { selection, .. }) => {
            selection.as_ref().map_or(0, constant_predicate_count)
        }
        Statement::Delete(delete) => delete
            .selection
            .as_ref()
            .map_or(0, constant_predicate_count),
        Statement::Merge(Merge { on, .. }) => constant_predicate_count(on),
        _ => 0,
    }
}

fn constant_predicate_count(expr: &Expr) -> usize {
    match expr {
        Expr::BinaryOp { left, op, right } => {
            let direct_match = is_supported_expression_comparison_operator(op)
                && !contains_comparison_operator_token(left)
                && !contains_comparison_operator_token(right)
                && match (literal_key(left), literal_key(right)) {
                    (Some(left_literal), Some(right_literal)) => {
                        is_supported_literal_comparison_operator(op)
                            && !is_allowed_literal_comparison(op, &left_literal, &right_literal)
                    }
                    _ => expressions_equivalent_for_constant_check(left, right),
                };

            usize::from(direct_match)
                + constant_predicate_count(left)
                + constant_predicate_count(right)
        }
        Expr::UnaryOp { expr: inner, .. }
        | Expr::Nested(inner)
        | Expr::IsNull(inner)
        | Expr::IsNotNull(inner)
        | Expr::Cast { expr: inner, .. } => constant_predicate_count(inner),
        Expr::InList { expr, list, .. } => {
            constant_predicate_count(expr)
                + list.iter().map(constant_predicate_count).sum::<usize>()
        }
        Expr::Between {
            expr, low, high, ..
        } => {
            constant_predicate_count(expr)
                + constant_predicate_count(low)
                + constant_predicate_count(high)
        }
        Expr::Case {
            operand,
            conditions,
            else_result,
            ..
        } => {
            let operand_count = operand
                .as_ref()
                .map_or(0, |expr| constant_predicate_count(expr));
            let condition_count = conditions
                .iter()
                .map(|when| {
                    constant_predicate_count(&when.condition)
                        + constant_predicate_count(&when.result)
                })
                .sum::<usize>();
            let else_count = else_result
                .as_ref()
                .map_or(0, |expr| constant_predicate_count(expr));
            operand_count + condition_count + else_count
        }
        _ => 0,
    }
}

fn is_supported_expression_comparison_operator(op: &BinaryOperator) -> bool {
    matches!(
        op,
        BinaryOperator::Eq
            | BinaryOperator::NotEq
            | BinaryOperator::Lt
            | BinaryOperator::Gt
            | BinaryOperator::LtEq
            | BinaryOperator::GtEq
    )
}

fn is_supported_literal_comparison_operator(op: &BinaryOperator) -> bool {
    matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq)
}

fn contains_comparison_operator_token(expr: &Expr) -> bool {
    match expr {
        Expr::BinaryOp { left, op, right } => {
            is_supported_expression_comparison_operator(op)
                || contains_comparison_operator_token(left)
                || contains_comparison_operator_token(right)
        }
        Expr::AnyOp { left, right, .. } | Expr::AllOp { left, right, .. } => {
            contains_comparison_operator_token(left) || contains_comparison_operator_token(right)
        }
        Expr::UnaryOp { expr: inner, .. }
        | Expr::Nested(inner)
        | Expr::IsNull(inner)
        | Expr::IsNotNull(inner)
        | Expr::Cast { expr: inner, .. } => contains_comparison_operator_token(inner),
        Expr::InList { expr, list, .. } => {
            contains_comparison_operator_token(expr)
                || list.iter().any(contains_comparison_operator_token)
        }
        Expr::Between {
            expr, low, high, ..
        } => {
            contains_comparison_operator_token(expr)
                || contains_comparison_operator_token(low)
                || contains_comparison_operator_token(high)
        }
        Expr::Case {
            operand,
            conditions,
            else_result,
            ..
        } => {
            operand
                .as_ref()
                .is_some_and(|expr| contains_comparison_operator_token(expr))
                || conditions.iter().any(|when| {
                    contains_comparison_operator_token(&when.condition)
                        || contains_comparison_operator_token(&when.result)
                })
                || else_result
                    .as_ref()
                    .is_some_and(|expr| contains_comparison_operator_token(expr))
        }
        _ => false,
    }
}

fn is_allowed_literal_comparison(op: &BinaryOperator, left: &str, right: &str) -> bool {
    *op == BinaryOperator::Eq && left == "1" && (right == "1" || right == "0")
}

fn literal_key(expr: &Expr) -> Option<String> {
    match expr {
        Expr::Value(value) => Some(value.to_string().to_ascii_uppercase()),
        Expr::Nested(inner)
        | Expr::UnaryOp { expr: inner, .. }
        | Expr::Cast { expr: inner, .. } => literal_key(inner),
        _ => None,
    }
}

fn expr_equivalent(left: &Expr, right: &Expr) -> bool {
    match (left, right) {
        (Expr::Identifier(left_ident), Expr::Identifier(right_ident)) => {
            left_ident.value.eq_ignore_ascii_case(&right_ident.value)
        }
        (Expr::CompoundIdentifier(left_parts), Expr::CompoundIdentifier(right_parts)) => {
            left_parts.len() == right_parts.len()
                && left_parts
                    .iter()
                    .zip(right_parts.iter())
                    .all(|(left, right)| left.value.eq_ignore_ascii_case(&right.value))
        }
        (Expr::Nested(left_inner), _) => expr_equivalent(left_inner, right),
        (_, Expr::Nested(right_inner)) => expr_equivalent(left, right_inner),
        (
            Expr::UnaryOp {
                expr: left_inner, ..
            },
            _,
        ) => expr_equivalent(left_inner, right),
        (
            _,
            Expr::UnaryOp {
                expr: right_inner, ..
            },
        ) => expr_equivalent(left, right_inner),
        (
            Expr::Cast {
                expr: left_inner, ..
            },
            _,
        ) => expr_equivalent(left_inner, right),
        (
            _,
            Expr::Cast {
                expr: right_inner, ..
            },
        ) => expr_equivalent(left, right_inner),
        _ => false,
    }
}

fn expressions_equivalent_for_constant_check(left: &Expr, right: &Expr) -> bool {
    if std::mem::discriminant(left) != std::mem::discriminant(right) {
        return false;
    }

    expr_equivalent(left, right)
        || normalize_expr_for_compare(left) == normalize_expr_for_compare(right)
}

fn normalize_expr_for_compare(expr: &Expr) -> String {
    expr.to_string()
        .chars()
        .filter(|ch| !ch.is_whitespace())
        .collect::<String>()
        .to_ascii_uppercase()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::parser::parse_sql;

    fn run(sql: &str) -> Vec<Issue> {
        let statements = parse_sql(sql).expect("parse");
        let rule = StructureConstantExpression;
        statements
            .iter()
            .enumerate()
            .flat_map(|(index, statement)| {
                rule.check(
                    statement,
                    &LintContext {
                        sql,
                        statement_range: 0..sql.len(),
                        statement_index: index,
                    },
                )
            })
            .collect()
    }

    // --- Edge cases adopted from sqlfluff ST10 ---

    #[test]
    fn allows_normal_where_predicate() {
        let issues = run("select * from foo where col = 3");
        assert!(issues.is_empty());
    }

    #[test]
    fn flags_self_comparison_in_where_clause() {
        let issues = run("select * from foo where col = col");
        assert_eq!(issues.len(), 1);
        assert_eq!(issues[0].code, issue_codes::LINT_ST_010);
    }

    #[test]
    fn flags_self_comparison_with_inequality_operator() {
        let issues = run("select * from foo where col < col");
        assert_eq!(issues.len(), 1);

        let issues = run("select * from foo where col >= col");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn flags_self_comparison_in_join_predicate() {
        let issues = run("select foo.a, bar.b from foo left join bar on foo.a = foo.a");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn allows_expected_codegen_literals() {
        let true_case = run("select col from foo where 1=1 and col = 'val'");
        assert!(true_case.is_empty());

        let false_case = run("select col from foo where 1=0 or col = 'val'");
        assert!(false_case.is_empty());
    }

    #[test]
    fn flags_disallowed_literal_comparisons() {
        let issues = run("select col from foo where 'a'!='b' and col = 'val'");
        assert_eq!(issues.len(), 1);

        let issues = run("select col from foo where 1 = 2 or col = 'val'");
        assert_eq!(issues.len(), 1);

        let issues = run("select col from foo where 1 <> 1 or col = 'val'");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn allows_non_equality_literal_comparison() {
        let issues = run("select col from foo where 1 < 2");
        assert!(issues.is_empty());
    }

    #[test]
    fn finds_nested_constant_predicates() {
        let issues = run("select col from foo where cond=1 and (score=score or avg_score >= 3)");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn counts_multiple_constant_predicates_in_single_expression_tree() {
        let issues = run("select * from foo where col = col and score = score");
        assert_eq!(issues.len(), 2);
    }

    #[test]
    fn flags_equal_string_concat_expressions() {
        let issues = run("select * from foo where 'A' || 'B' = 'A' || 'B'");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn flags_equal_arithmetic_expressions() {
        let issues = run("select * from foo where col + 1 = col + 1");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn allows_non_equivalent_arithmetic_literal_comparison() {
        let issues = run("select * from foo where 1 + 1 = 2");
        assert!(issues.is_empty());
    }

    #[test]
    fn allows_true_false_literal_predicates() {
        let true_issues = run("select * from foo where true and x > 3");
        assert!(true_issues.is_empty());

        let false_issues = run("select * from foo where false OR x < 1 OR y != z");
        assert!(false_issues.is_empty());
    }

    #[test]
    fn flags_constant_predicate_in_update_where() {
        let issues = run("update foo set a = 1 where col = col");
        assert_eq!(issues.len(), 1);
    }

    #[test]
    fn flags_constant_predicate_in_delete_where() {
        let issues = run("delete from foo where col = col");
        assert_eq!(issues.len(), 1);
    }
}