rigsql_rules/structure/
st02.rs1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
12pub struct RuleST02;
13
14impl Rule for RuleST02 {
15 fn code(&self) -> &'static str {
16 "ST02"
17 }
18 fn name(&self) -> &'static str {
19 "structure.simple_case"
20 }
21 fn description(&self) -> &'static str {
22 "Unnecessary CASE expression."
23 }
24 fn explanation(&self) -> &'static str {
25 "A CASE expression is unnecessary when it can be replaced by a simpler construct: \
26 (1) A single WHEN returning TRUE/FALSE (or 1/0) with an opposite ELSE can use the \
27 condition directly. (2) A CASE WHEN x IS NULL THEN y ELSE x END can be replaced \
28 with COALESCE(x, y)."
29 }
30 fn groups(&self) -> &[RuleGroup] {
31 &[RuleGroup::Structure]
32 }
33 fn is_fixable(&self) -> bool {
34 false
35 }
36
37 fn crawl_type(&self) -> CrawlType {
38 CrawlType::Segment(vec![SegmentType::CaseExpression])
39 }
40
41 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
42 let children = ctx.segment.children();
43 let non_trivia: Vec<_> = children
44 .iter()
45 .filter(|s| !s.segment_type().is_trivia())
46 .collect();
47
48 let when_clauses: Vec<_> = non_trivia
50 .iter()
51 .filter(|s| s.segment_type() == SegmentType::WhenClause)
52 .collect();
53 let else_clauses: Vec<_> = non_trivia
54 .iter()
55 .filter(|s| s.segment_type() == SegmentType::ElseClause)
56 .collect();
57
58 if when_clauses.len() != 1 || else_clauses.len() != 1 {
59 return vec![];
60 }
61
62 let when_clause = when_clauses[0];
63 let else_clause = else_clauses[0];
64
65 let then_value = extract_then_value(when_clause);
67 let else_value = extract_else_value(else_clause);
68
69 if let (Some(ref then_val), Some(ref else_val)) = (then_value, else_value) {
70 let is_bool_pair = (is_truthy(then_val) && is_falsy(else_val))
71 || (is_falsy(then_val) && is_truthy(else_val));
72
73 if is_bool_pair {
74 return vec![LintViolation::with_msg_key(
75 self.code(),
76 "Unnecessary CASE expression. Use the boolean condition directly.",
77 ctx.segment.span(),
78 "rules.ST02.msg.boolean",
79 vec![],
80 )];
81 }
82 }
83
84 if let Some(msg) = check_is_null_coalesce_pattern(when_clause, else_clause) {
86 return vec![LintViolation::with_msg_key(
87 self.code(),
88 msg,
89 ctx.segment.span(),
90 "rules.ST02.msg.coalesce",
91 vec![],
92 )];
93 }
94
95 vec![]
96 }
97}
98
99fn check_is_null_coalesce_pattern(when_clause: &Segment, else_clause: &Segment) -> Option<String> {
101 let when_children: Vec<_> = when_clause
102 .children()
103 .iter()
104 .filter(|c| !c.segment_type().is_trivia())
105 .collect();
106
107 let is_null_expr = when_children
109 .iter()
110 .find(|c| c.segment_type() == SegmentType::IsNullExpression)?;
111
112 let tested_col = get_is_null_subject(is_null_expr)?;
114
115 let else_expr = get_else_expression(else_clause)?;
117
118 if tested_col.eq_ignore_ascii_case(&else_expr) {
120 Some(
121 "Unnecessary CASE expression. Use COALESCE instead of CASE WHEN IS NULL pattern."
122 .to_string(),
123 )
124 } else {
125 None
126 }
127}
128
129fn get_is_null_subject(segment: &Segment) -> Option<String> {
131 let children = segment.children();
132 let non_trivia: Vec<_> = children
133 .iter()
134 .filter(|c| !c.segment_type().is_trivia())
135 .collect();
136
137 non_trivia.first().map(|s| s.raw().trim().to_string())
138}
139
140fn get_else_expression(segment: &Segment) -> Option<String> {
142 let children = segment.children();
143 let non_trivia: Vec<_> = children
144 .iter()
145 .filter(|c| !c.segment_type().is_trivia())
146 .collect();
147
148 if non_trivia.len() >= 2 {
149 let expr_parts: String = non_trivia[1..]
150 .iter()
151 .map(|s| s.raw())
152 .collect::<Vec<_>>()
153 .join("");
154 Some(expr_parts.trim().to_string())
155 } else {
156 None
157 }
158}
159
160fn extract_then_value(when_clause: &Segment) -> Option<String> {
161 let children = when_clause.children();
162 let non_trivia: Vec<_> = children
163 .iter()
164 .filter(|s| !s.segment_type().is_trivia())
165 .collect();
166
167 let mut found_then = false;
168 for seg in &non_trivia {
169 if found_then {
170 return Some(seg.raw().trim().to_uppercase());
171 }
172 if seg.segment_type() == SegmentType::Keyword && seg.raw().eq_ignore_ascii_case("THEN") {
173 found_then = true;
174 }
175 }
176 None
177}
178
179fn extract_else_value(else_clause: &Segment) -> Option<String> {
180 let children = else_clause.children();
181 let non_trivia: Vec<_> = children
182 .iter()
183 .filter(|s| !s.segment_type().is_trivia())
184 .collect();
185
186 if non_trivia.len() >= 2 {
187 return Some(non_trivia[1].raw().trim().to_uppercase());
188 }
189 None
190}
191
192fn is_truthy(val: &str) -> bool {
193 val == "TRUE" || val == "1"
194}
195
196fn is_falsy(val: &str) -> bool {
197 val == "FALSE" || val == "0"
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::test_utils::lint_sql;
204
205 #[test]
206 fn test_st02_flags_simple_boolean_case() {
207 let violations = lint_sql("SELECT CASE WHEN x > 0 THEN TRUE ELSE FALSE END;", RuleST02);
208 assert_eq!(violations.len(), 1);
209 assert!(violations[0].message.contains("boolean"));
210 }
211
212 #[test]
213 fn test_st02_accepts_non_boolean_case() {
214 let violations = lint_sql("SELECT CASE WHEN x > 0 THEN 'yes' ELSE 'no' END;", RuleST02);
215 assert_eq!(violations.len(), 0);
216 }
217
218 #[test]
219 fn test_st02_accepts_multi_when() {
220 let violations = lint_sql(
221 "SELECT CASE WHEN x > 0 THEN TRUE WHEN x < 0 THEN FALSE ELSE FALSE END;",
222 RuleST02,
223 );
224 assert_eq!(violations.len(), 0);
225 }
226
227 #[test]
228 fn test_st02_flags_is_null_coalesce_pattern() {
229 let violations = lint_sql(
230 "SELECT CASE WHEN col IS NULL THEN 'default' ELSE col END FROM t",
231 RuleST02,
232 );
233 assert_eq!(violations.len(), 1);
234 assert!(violations[0].message.contains("COALESCE"));
235 }
236
237 #[test]
238 fn test_st02_accepts_complex_case() {
239 let violations = lint_sql(
240 "SELECT CASE WHEN x = 1 THEN 'a' ELSE 'b' END FROM t",
241 RuleST02,
242 );
243 assert_eq!(violations.len(), 0);
244 }
245}