Skip to main content

rigsql_rules/structure/
st10.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// ST10: Constant expression in WHERE clause.
7///
8/// Detects WHERE clauses with tautological conditions like `WHERE 1 = 1`
9/// or `WHERE TRUE`.
10#[derive(Debug, Default)]
11pub struct RuleST10;
12
13impl Rule for RuleST10 {
14    fn code(&self) -> &'static str {
15        "ST10"
16    }
17    fn name(&self) -> &'static str {
18        "structure.where_constant"
19    }
20    fn description(&self) -> &'static str {
21        "WHERE clause contains a constant/tautological expression."
22    }
23    fn explanation(&self) -> &'static str {
24        "A WHERE clause with a constant expression like WHERE 1 = 1 or WHERE TRUE \
25         is either a placeholder that should be removed, or indicates dead code. \
26         Remove the WHERE clause or replace it with a meaningful condition."
27    }
28    fn groups(&self) -> &[RuleGroup] {
29        &[RuleGroup::Structure]
30    }
31    fn is_fixable(&self) -> bool {
32        false
33    }
34
35    fn crawl_type(&self) -> CrawlType {
36        CrawlType::Segment(vec![SegmentType::WhereClause])
37    }
38
39    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40        let children = ctx.segment.children();
41        let non_trivia: Vec<_> = children
42            .iter()
43            .filter(|s| !s.segment_type().is_trivia())
44            .collect();
45
46        // WhereClause: WHERE <expression>
47        // non_trivia[0] = Keyword(WHERE), rest = the condition
48        if non_trivia.len() < 2 {
49            return vec![];
50        }
51
52        // Check for single boolean literal: WHERE TRUE / WHERE FALSE
53        if non_trivia.len() == 2 && non_trivia[1].segment_type() == SegmentType::BooleanLiteral {
54            return vec![LintViolation::with_msg_key(
55                self.code(),
56                "WHERE clause contains a constant expression.",
57                ctx.segment.span(),
58                "rules.ST10.msg",
59                vec![],
60            )];
61        }
62
63        // Check for binary expression with both sides being literals (e.g., 1 = 1)
64        if non_trivia.len() == 2 {
65            if let Some(violation) = check_binary_literal(self.code(), non_trivia[1]) {
66                return vec![violation];
67            }
68        }
69
70        vec![]
71    }
72}
73
74fn check_binary_literal(code: &'static str, seg: &Segment) -> Option<LintViolation> {
75    if seg.segment_type() != SegmentType::BinaryExpression {
76        return None;
77    }
78
79    let children = seg.children();
80    let non_trivia: Vec<_> = children
81        .iter()
82        .filter(|s| !s.segment_type().is_trivia())
83        .collect();
84
85    // BinaryExpression: <left> <operator> <right>
86    if non_trivia.len() != 3 {
87        return None;
88    }
89
90    let left = non_trivia[0];
91    let right = non_trivia[2];
92
93    if is_literal(left) && is_literal(right) {
94        return Some(LintViolation::with_msg_key(
95            code,
96            "WHERE clause contains a constant expression.",
97            seg.span(),
98            "rules.ST10.msg",
99            vec![],
100        ));
101    }
102
103    None
104}
105
106fn is_literal(seg: &Segment) -> bool {
107    matches!(
108        seg.segment_type(),
109        SegmentType::NumericLiteral
110            | SegmentType::StringLiteral
111            | SegmentType::BooleanLiteral
112            | SegmentType::NullLiteral
113            | SegmentType::Literal
114    )
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::test_utils::lint_sql;
121
122    #[test]
123    fn test_st10_flags_where_true() {
124        let violations = lint_sql("SELECT * FROM t WHERE TRUE;", RuleST10);
125        assert_eq!(violations.len(), 1);
126    }
127
128    #[test]
129    fn test_st10_flags_where_1_eq_1() {
130        let violations = lint_sql("SELECT * FROM t WHERE 1 = 1;", RuleST10);
131        assert_eq!(violations.len(), 1);
132    }
133
134    #[test]
135    fn test_st10_accepts_normal_where() {
136        let violations = lint_sql("SELECT * FROM t WHERE x = 1;", RuleST10);
137        assert_eq!(violations.len(), 0);
138    }
139}