Skip to main content

rigsql_rules/structure/
st10.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// ST10: Constant expression in WHERE clause.
7///
8/// Detects WHERE clauses with tautological conditions like `WHERE 1 = 1`
9/// or `WHERE TRUE`.
10#[derive(Debug, Default)]
11pub struct RuleST10;
12
13impl Rule for RuleST10 {
14    fn code(&self) -> &'static str {
15        "ST10"
16    }
17    fn name(&self) -> &'static str {
18        "structure.where_constant"
19    }
20    fn description(&self) -> &'static str {
21        "WHERE clause contains a constant/tautological expression."
22    }
23    fn explanation(&self) -> &'static str {
24        "A WHERE clause with a constant expression like WHERE 1 = 1 or WHERE TRUE \
25         is either a placeholder that should be removed, or indicates dead code. \
26         Remove the WHERE clause or replace it with a meaningful condition."
27    }
28    fn groups(&self) -> &[RuleGroup] {
29        &[RuleGroup::Structure]
30    }
31    fn is_fixable(&self) -> bool {
32        false
33    }
34
35    fn crawl_type(&self) -> CrawlType {
36        CrawlType::Segment(vec![SegmentType::WhereClause])
37    }
38
39    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40        let children = ctx.segment.children();
41        let non_trivia: Vec<_> = children
42            .iter()
43            .filter(|s| !s.segment_type().is_trivia())
44            .collect();
45
46        // WhereClause: WHERE <expression>
47        // non_trivia[0] = Keyword(WHERE), rest = the condition
48        if non_trivia.len() < 2 {
49            return vec![];
50        }
51
52        // Check for single boolean literal: WHERE TRUE / WHERE FALSE
53        if non_trivia.len() == 2 && non_trivia[1].segment_type() == SegmentType::BooleanLiteral {
54            return vec![LintViolation::new(
55                self.code(),
56                "WHERE clause contains a constant expression.",
57                ctx.segment.span(),
58            )];
59        }
60
61        // Check for binary expression with both sides being literals (e.g., 1 = 1)
62        if non_trivia.len() == 2 {
63            if let Some(violation) = check_binary_literal(self.code(), non_trivia[1]) {
64                return vec![violation];
65            }
66        }
67
68        vec![]
69    }
70}
71
72fn check_binary_literal(code: &'static str, seg: &Segment) -> Option<LintViolation> {
73    if seg.segment_type() != SegmentType::BinaryExpression {
74        return None;
75    }
76
77    let children = seg.children();
78    let non_trivia: Vec<_> = children
79        .iter()
80        .filter(|s| !s.segment_type().is_trivia())
81        .collect();
82
83    // BinaryExpression: <left> <operator> <right>
84    if non_trivia.len() != 3 {
85        return None;
86    }
87
88    let left = non_trivia[0];
89    let right = non_trivia[2];
90
91    if is_literal(left) && is_literal(right) {
92        return Some(LintViolation::new(
93            code,
94            "WHERE clause contains a constant expression.",
95            seg.span(),
96        ));
97    }
98
99    None
100}
101
102fn is_literal(seg: &Segment) -> bool {
103    matches!(
104        seg.segment_type(),
105        SegmentType::NumericLiteral
106            | SegmentType::StringLiteral
107            | SegmentType::BooleanLiteral
108            | SegmentType::NullLiteral
109            | SegmentType::Literal
110    )
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::test_utils::lint_sql;
117
118    #[test]
119    fn test_st10_flags_where_true() {
120        let violations = lint_sql("SELECT * FROM t WHERE TRUE;", RuleST10);
121        assert_eq!(violations.len(), 1);
122    }
123
124    #[test]
125    fn test_st10_flags_where_1_eq_1() {
126        let violations = lint_sql("SELECT * FROM t WHERE 1 = 1;", RuleST10);
127        assert_eq!(violations.len(), 1);
128    }
129
130    #[test]
131    fn test_st10_accepts_normal_where() {
132        let violations = lint_sql("SELECT * FROM t WHERE x = 1;", RuleST10);
133        assert_eq!(violations.len(), 0);
134    }
135}