Skip to main content

flowscope_core/linter/rules/
st_010.rs

1//! LINT_ST_010: Constant boolean predicate.
2//!
3//! Detect redundant constant expressions in predicates.
4
5use crate::linter::rule::{LintContext, LintRule};
6use crate::types::{issue_codes, Issue};
7use sqlparser::ast::{BinaryOperator, Expr, Statement};
8
9use super::semantic_helpers::{visit_select_expressions, visit_selects_in_statement};
10
11pub struct StructureConstantExpression;
12
13impl LintRule for StructureConstantExpression {
14    fn code(&self) -> &'static str {
15        issue_codes::LINT_ST_010
16    }
17
18    fn name(&self) -> &'static str {
19        "Structure constant expression"
20    }
21
22    fn description(&self) -> &'static str {
23        "Redundant constant expression."
24    }
25
26    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
27        let mut violation_count = statement_constant_predicate_count(statement);
28
29        visit_selects_in_statement(statement, &mut |select| {
30            visit_select_expressions(select, &mut |expr| {
31                violation_count += constant_predicate_count(expr);
32            });
33        });
34
35        (0..violation_count)
36            .map(|_| {
37                Issue::warning(
38                    issue_codes::LINT_ST_010,
39                    "Constant boolean expression detected in predicate.",
40                )
41                .with_statement(ctx.statement_index)
42            })
43            .collect()
44    }
45}
46
47fn statement_constant_predicate_count(statement: &Statement) -> usize {
48    match statement {
49        Statement::Update { selection, .. } => {
50            selection.as_ref().map_or(0, constant_predicate_count)
51        }
52        Statement::Delete(delete) => delete
53            .selection
54            .as_ref()
55            .map_or(0, constant_predicate_count),
56        Statement::Merge { on, .. } => constant_predicate_count(on),
57        _ => 0,
58    }
59}
60
61fn constant_predicate_count(expr: &Expr) -> usize {
62    match expr {
63        Expr::BinaryOp { left, op, right } => {
64            let direct_match = is_supported_expression_comparison_operator(op)
65                && !contains_comparison_operator_token(left)
66                && !contains_comparison_operator_token(right)
67                && match (literal_key(left), literal_key(right)) {
68                    (Some(left_literal), Some(right_literal)) => {
69                        is_supported_literal_comparison_operator(op)
70                            && !is_allowed_literal_comparison(op, &left_literal, &right_literal)
71                    }
72                    _ => expressions_equivalent_for_constant_check(left, right),
73                };
74
75            usize::from(direct_match)
76                + constant_predicate_count(left)
77                + constant_predicate_count(right)
78        }
79        Expr::UnaryOp { expr: inner, .. }
80        | Expr::Nested(inner)
81        | Expr::IsNull(inner)
82        | Expr::IsNotNull(inner)
83        | Expr::Cast { expr: inner, .. } => constant_predicate_count(inner),
84        Expr::InList { expr, list, .. } => {
85            constant_predicate_count(expr)
86                + list.iter().map(constant_predicate_count).sum::<usize>()
87        }
88        Expr::Between {
89            expr, low, high, ..
90        } => {
91            constant_predicate_count(expr)
92                + constant_predicate_count(low)
93                + constant_predicate_count(high)
94        }
95        Expr::Case {
96            operand,
97            conditions,
98            else_result,
99            ..
100        } => {
101            let operand_count = operand
102                .as_ref()
103                .map_or(0, |expr| constant_predicate_count(expr));
104            let condition_count = conditions
105                .iter()
106                .map(|when| {
107                    constant_predicate_count(&when.condition)
108                        + constant_predicate_count(&when.result)
109                })
110                .sum::<usize>();
111            let else_count = else_result
112                .as_ref()
113                .map_or(0, |expr| constant_predicate_count(expr));
114            operand_count + condition_count + else_count
115        }
116        _ => 0,
117    }
118}
119
120fn is_supported_expression_comparison_operator(op: &BinaryOperator) -> bool {
121    matches!(
122        op,
123        BinaryOperator::Eq
124            | BinaryOperator::NotEq
125            | BinaryOperator::Lt
126            | BinaryOperator::Gt
127            | BinaryOperator::LtEq
128            | BinaryOperator::GtEq
129    )
130}
131
132fn is_supported_literal_comparison_operator(op: &BinaryOperator) -> bool {
133    matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq)
134}
135
136fn contains_comparison_operator_token(expr: &Expr) -> bool {
137    match expr {
138        Expr::BinaryOp { left, op, right } => {
139            is_supported_expression_comparison_operator(op)
140                || contains_comparison_operator_token(left)
141                || contains_comparison_operator_token(right)
142        }
143        Expr::AnyOp { left, right, .. } | Expr::AllOp { left, right, .. } => {
144            contains_comparison_operator_token(left) || contains_comparison_operator_token(right)
145        }
146        Expr::UnaryOp { expr: inner, .. }
147        | Expr::Nested(inner)
148        | Expr::IsNull(inner)
149        | Expr::IsNotNull(inner)
150        | Expr::Cast { expr: inner, .. } => contains_comparison_operator_token(inner),
151        Expr::InList { expr, list, .. } => {
152            contains_comparison_operator_token(expr)
153                || list.iter().any(contains_comparison_operator_token)
154        }
155        Expr::Between {
156            expr, low, high, ..
157        } => {
158            contains_comparison_operator_token(expr)
159                || contains_comparison_operator_token(low)
160                || contains_comparison_operator_token(high)
161        }
162        Expr::Case {
163            operand,
164            conditions,
165            else_result,
166            ..
167        } => {
168            operand
169                .as_ref()
170                .is_some_and(|expr| contains_comparison_operator_token(expr))
171                || conditions.iter().any(|when| {
172                    contains_comparison_operator_token(&when.condition)
173                        || contains_comparison_operator_token(&when.result)
174                })
175                || else_result
176                    .as_ref()
177                    .is_some_and(|expr| contains_comparison_operator_token(expr))
178        }
179        _ => false,
180    }
181}
182
183fn is_allowed_literal_comparison(op: &BinaryOperator, left: &str, right: &str) -> bool {
184    *op == BinaryOperator::Eq && left == "1" && (right == "1" || right == "0")
185}
186
187fn literal_key(expr: &Expr) -> Option<String> {
188    match expr {
189        Expr::Value(value) => Some(value.to_string().to_ascii_uppercase()),
190        Expr::Nested(inner)
191        | Expr::UnaryOp { expr: inner, .. }
192        | Expr::Cast { expr: inner, .. } => literal_key(inner),
193        _ => None,
194    }
195}
196
197fn expr_equivalent(left: &Expr, right: &Expr) -> bool {
198    match (left, right) {
199        (Expr::Identifier(left_ident), Expr::Identifier(right_ident)) => {
200            left_ident.value.eq_ignore_ascii_case(&right_ident.value)
201        }
202        (Expr::CompoundIdentifier(left_parts), Expr::CompoundIdentifier(right_parts)) => {
203            left_parts.len() == right_parts.len()
204                && left_parts
205                    .iter()
206                    .zip(right_parts.iter())
207                    .all(|(left, right)| left.value.eq_ignore_ascii_case(&right.value))
208        }
209        (Expr::Nested(left_inner), _) => expr_equivalent(left_inner, right),
210        (_, Expr::Nested(right_inner)) => expr_equivalent(left, right_inner),
211        (
212            Expr::UnaryOp {
213                expr: left_inner, ..
214            },
215            _,
216        ) => expr_equivalent(left_inner, right),
217        (
218            _,
219            Expr::UnaryOp {
220                expr: right_inner, ..
221            },
222        ) => expr_equivalent(left, right_inner),
223        (
224            Expr::Cast {
225                expr: left_inner, ..
226            },
227            _,
228        ) => expr_equivalent(left_inner, right),
229        (
230            _,
231            Expr::Cast {
232                expr: right_inner, ..
233            },
234        ) => expr_equivalent(left, right_inner),
235        _ => false,
236    }
237}
238
239fn expressions_equivalent_for_constant_check(left: &Expr, right: &Expr) -> bool {
240    if std::mem::discriminant(left) != std::mem::discriminant(right) {
241        return false;
242    }
243
244    expr_equivalent(left, right)
245        || normalize_expr_for_compare(left) == normalize_expr_for_compare(right)
246}
247
248fn normalize_expr_for_compare(expr: &Expr) -> String {
249    expr.to_string()
250        .chars()
251        .filter(|ch| !ch.is_whitespace())
252        .collect::<String>()
253        .to_ascii_uppercase()
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::parser::parse_sql;
260
261    fn run(sql: &str) -> Vec<Issue> {
262        let statements = parse_sql(sql).expect("parse");
263        let rule = StructureConstantExpression;
264        statements
265            .iter()
266            .enumerate()
267            .flat_map(|(index, statement)| {
268                rule.check(
269                    statement,
270                    &LintContext {
271                        sql,
272                        statement_range: 0..sql.len(),
273                        statement_index: index,
274                    },
275                )
276            })
277            .collect()
278    }
279
280    // --- Edge cases adopted from sqlfluff ST10 ---
281
282    #[test]
283    fn allows_normal_where_predicate() {
284        let issues = run("select * from foo where col = 3");
285        assert!(issues.is_empty());
286    }
287
288    #[test]
289    fn flags_self_comparison_in_where_clause() {
290        let issues = run("select * from foo where col = col");
291        assert_eq!(issues.len(), 1);
292        assert_eq!(issues[0].code, issue_codes::LINT_ST_010);
293    }
294
295    #[test]
296    fn flags_self_comparison_with_inequality_operator() {
297        let issues = run("select * from foo where col < col");
298        assert_eq!(issues.len(), 1);
299
300        let issues = run("select * from foo where col >= col");
301        assert_eq!(issues.len(), 1);
302    }
303
304    #[test]
305    fn flags_self_comparison_in_join_predicate() {
306        let issues = run("select foo.a, bar.b from foo left join bar on foo.a = foo.a");
307        assert_eq!(issues.len(), 1);
308    }
309
310    #[test]
311    fn allows_expected_codegen_literals() {
312        let true_case = run("select col from foo where 1=1 and col = 'val'");
313        assert!(true_case.is_empty());
314
315        let false_case = run("select col from foo where 1=0 or col = 'val'");
316        assert!(false_case.is_empty());
317    }
318
319    #[test]
320    fn flags_disallowed_literal_comparisons() {
321        let issues = run("select col from foo where 'a'!='b' and col = 'val'");
322        assert_eq!(issues.len(), 1);
323
324        let issues = run("select col from foo where 1 = 2 or col = 'val'");
325        assert_eq!(issues.len(), 1);
326
327        let issues = run("select col from foo where 1 <> 1 or col = 'val'");
328        assert_eq!(issues.len(), 1);
329    }
330
331    #[test]
332    fn allows_non_equality_literal_comparison() {
333        let issues = run("select col from foo where 1 < 2");
334        assert!(issues.is_empty());
335    }
336
337    #[test]
338    fn finds_nested_constant_predicates() {
339        let issues = run("select col from foo where cond=1 and (score=score or avg_score >= 3)");
340        assert_eq!(issues.len(), 1);
341    }
342
343    #[test]
344    fn counts_multiple_constant_predicates_in_single_expression_tree() {
345        let issues = run("select * from foo where col = col and score = score");
346        assert_eq!(issues.len(), 2);
347    }
348
349    #[test]
350    fn flags_equal_string_concat_expressions() {
351        let issues = run("select * from foo where 'A' || 'B' = 'A' || 'B'");
352        assert_eq!(issues.len(), 1);
353    }
354
355    #[test]
356    fn flags_equal_arithmetic_expressions() {
357        let issues = run("select * from foo where col + 1 = col + 1");
358        assert_eq!(issues.len(), 1);
359    }
360
361    #[test]
362    fn allows_non_equivalent_arithmetic_literal_comparison() {
363        let issues = run("select * from foo where 1 + 1 = 2");
364        assert!(issues.is_empty());
365    }
366
367    #[test]
368    fn allows_true_false_literal_predicates() {
369        let true_issues = run("select * from foo where true and x > 3");
370        assert!(true_issues.is_empty());
371
372        let false_issues = run("select * from foo where false OR x < 1 OR y != z");
373        assert!(false_issues.is_empty());
374    }
375
376    #[test]
377    fn flags_constant_predicate_in_update_where() {
378        let issues = run("update foo set a = 1 where col = col");
379        assert_eq!(issues.len(), 1);
380    }
381
382    #[test]
383    fn flags_constant_predicate_in_delete_where() {
384        let issues = run("delete from foo where col = col");
385        assert_eq!(issues.len(), 1);
386    }
387}