Skip to main content

rigsql_rules/structure/
st02.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// ST02: Unnecessary CASE expression.
7///
8/// Detects two patterns:
9/// 1. Boolean wrapping: CASE WHEN cond THEN TRUE ELSE FALSE END → use cond directly
10/// 2. IS NULL fallback: CASE WHEN x IS NULL THEN y ELSE x END → use COALESCE(x, y)
11#[derive(Debug, Default)]
12pub struct RuleST02;
13
14impl Rule for RuleST02 {
15    fn code(&self) -> &'static str {
16        "ST02"
17    }
18    fn name(&self) -> &'static str {
19        "structure.simple_case"
20    }
21    fn description(&self) -> &'static str {
22        "Unnecessary CASE expression."
23    }
24    fn explanation(&self) -> &'static str {
25        "A CASE expression is unnecessary when it can be replaced by a simpler construct: \
26         (1) A single WHEN returning TRUE/FALSE (or 1/0) with an opposite ELSE can use the \
27         condition directly. (2) A CASE WHEN x IS NULL THEN y ELSE x END can be replaced \
28         with COALESCE(x, y)."
29    }
30    fn groups(&self) -> &[RuleGroup] {
31        &[RuleGroup::Structure]
32    }
33    fn is_fixable(&self) -> bool {
34        false
35    }
36
37    fn crawl_type(&self) -> CrawlType {
38        CrawlType::Segment(vec![SegmentType::CaseExpression])
39    }
40
41    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
42        let children = ctx.segment.children();
43        let non_trivia: Vec<_> = children
44            .iter()
45            .filter(|s| !s.segment_type().is_trivia())
46            .collect();
47
48        // Must have exactly one WHEN and one ELSE
49        let when_clauses: Vec<_> = non_trivia
50            .iter()
51            .filter(|s| s.segment_type() == SegmentType::WhenClause)
52            .collect();
53        let else_clauses: Vec<_> = non_trivia
54            .iter()
55            .filter(|s| s.segment_type() == SegmentType::ElseClause)
56            .collect();
57
58        if when_clauses.len() != 1 || else_clauses.len() != 1 {
59            return vec![];
60        }
61
62        let when_clause = when_clauses[0];
63        let else_clause = else_clauses[0];
64
65        // Pattern 1: Boolean wrapping (CASE WHEN cond THEN TRUE ELSE FALSE END)
66        let then_value = extract_then_value(when_clause);
67        let else_value = extract_else_value(else_clause);
68
69        if let (Some(ref then_val), Some(ref else_val)) = (then_value, else_value) {
70            let is_bool_pair = (is_truthy(then_val) && is_falsy(else_val))
71                || (is_falsy(then_val) && is_truthy(else_val));
72
73            if is_bool_pair {
74                return vec![LintViolation::with_msg_key(
75                    self.code(),
76                    "Unnecessary CASE expression. Use the boolean condition directly.",
77                    ctx.segment.span(),
78                    "rules.ST02.msg.boolean",
79                    vec![],
80                )];
81            }
82        }
83
84        // Pattern 2: IS NULL fallback (CASE WHEN x IS NULL THEN y ELSE x END)
85        if let Some(msg) = check_is_null_coalesce_pattern(when_clause, else_clause) {
86            return vec![LintViolation::with_msg_key(
87                self.code(),
88                msg,
89                ctx.segment.span(),
90                "rules.ST02.msg.coalesce",
91                vec![],
92            )];
93        }
94
95        vec![]
96    }
97}
98
99/// Check for CASE WHEN x IS NULL THEN y ELSE x END → COALESCE(x, y) pattern.
100fn check_is_null_coalesce_pattern(when_clause: &Segment, else_clause: &Segment) -> Option<String> {
101    let when_children: Vec<_> = when_clause
102        .children()
103        .iter()
104        .filter(|c| !c.segment_type().is_trivia())
105        .collect();
106
107    // Find IS NULL expression in the WHEN clause
108    let is_null_expr = when_children
109        .iter()
110        .find(|c| c.segment_type() == SegmentType::IsNullExpression)?;
111
112    // Get the subject of IS NULL (the column/expression being tested)
113    let tested_col = get_is_null_subject(is_null_expr)?;
114
115    // Get the ELSE value
116    let else_expr = get_else_expression(else_clause)?;
117
118    // If the ELSE value matches the tested column, it's the COALESCE pattern
119    if tested_col.eq_ignore_ascii_case(&else_expr) {
120        Some(
121            "Unnecessary CASE expression. Use COALESCE instead of CASE WHEN IS NULL pattern."
122                .to_string(),
123        )
124    } else {
125        None
126    }
127}
128
129/// Extract the subject of an IS NULL expression (the part before IS NULL).
130fn get_is_null_subject(segment: &Segment) -> Option<String> {
131    let children = segment.children();
132    let non_trivia: Vec<_> = children
133        .iter()
134        .filter(|c| !c.segment_type().is_trivia())
135        .collect();
136
137    non_trivia.first().map(|s| s.raw().trim().to_string())
138}
139
140/// Extract the expression from an ELSE clause (skip ELSE keyword).
141fn get_else_expression(segment: &Segment) -> Option<String> {
142    let children = segment.children();
143    let non_trivia: Vec<_> = children
144        .iter()
145        .filter(|c| !c.segment_type().is_trivia())
146        .collect();
147
148    if non_trivia.len() >= 2 {
149        let expr_parts: String = non_trivia[1..]
150            .iter()
151            .map(|s| s.raw())
152            .collect::<Vec<_>>()
153            .join("");
154        Some(expr_parts.trim().to_string())
155    } else {
156        None
157    }
158}
159
160fn extract_then_value(when_clause: &Segment) -> Option<String> {
161    let children = when_clause.children();
162    let non_trivia: Vec<_> = children
163        .iter()
164        .filter(|s| !s.segment_type().is_trivia())
165        .collect();
166
167    let mut found_then = false;
168    for seg in &non_trivia {
169        if found_then {
170            return Some(seg.raw().trim().to_uppercase());
171        }
172        if seg.segment_type() == SegmentType::Keyword && seg.raw().eq_ignore_ascii_case("THEN") {
173            found_then = true;
174        }
175    }
176    None
177}
178
179fn extract_else_value(else_clause: &Segment) -> Option<String> {
180    let children = else_clause.children();
181    let non_trivia: Vec<_> = children
182        .iter()
183        .filter(|s| !s.segment_type().is_trivia())
184        .collect();
185
186    if non_trivia.len() >= 2 {
187        return Some(non_trivia[1].raw().trim().to_uppercase());
188    }
189    None
190}
191
192fn is_truthy(val: &str) -> bool {
193    val == "TRUE" || val == "1"
194}
195
196fn is_falsy(val: &str) -> bool {
197    val == "FALSE" || val == "0"
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::test_utils::lint_sql;
204
205    #[test]
206    fn test_st02_flags_simple_boolean_case() {
207        let violations = lint_sql("SELECT CASE WHEN x > 0 THEN TRUE ELSE FALSE END;", RuleST02);
208        assert_eq!(violations.len(), 1);
209        assert!(violations[0].message.contains("boolean"));
210    }
211
212    #[test]
213    fn test_st02_accepts_non_boolean_case() {
214        let violations = lint_sql("SELECT CASE WHEN x > 0 THEN 'yes' ELSE 'no' END;", RuleST02);
215        assert_eq!(violations.len(), 0);
216    }
217
218    #[test]
219    fn test_st02_accepts_multi_when() {
220        let violations = lint_sql(
221            "SELECT CASE WHEN x > 0 THEN TRUE WHEN x < 0 THEN FALSE ELSE FALSE END;",
222            RuleST02,
223        );
224        assert_eq!(violations.len(), 0);
225    }
226
227    #[test]
228    fn test_st02_flags_is_null_coalesce_pattern() {
229        let violations = lint_sql(
230            "SELECT CASE WHEN col IS NULL THEN 'default' ELSE col END FROM t",
231            RuleST02,
232        );
233        assert_eq!(violations.len(), 1);
234        assert!(violations[0].message.contains("COALESCE"));
235    }
236
237    #[test]
238    fn test_st02_accepts_complex_case() {
239        let violations = lint_sql(
240            "SELECT CASE WHEN x = 1 THEN 'a' ELSE 'b' END FROM t",
241            RuleST02,
242        );
243        assert_eq!(violations.len(), 0);
244    }
245}