Skip to main content

flowscope_core/linter/rules/
st_001.rs

1//! LINT_ST_001: Unnecessary ELSE NULL in CASE expressions.
2//!
3//! `CASE ... ELSE NULL END` is redundant because CASE already returns NULL
4//! when no branch matches. The ELSE NULL can be removed.
5
6use crate::linter::helpers;
7use crate::linter::rule::{LintContext, LintRule};
8use crate::linter::visit;
9use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
10use sqlparser::ast::*;
11use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer, Whitespace};
12
13pub struct UnnecessaryElseNull;
14
15impl LintRule for UnnecessaryElseNull {
16    fn code(&self) -> &'static str {
17        issue_codes::LINT_ST_001
18    }
19
20    fn name(&self) -> &'static str {
21        "Unnecessary ELSE NULL"
22    }
23
24    fn description(&self) -> &'static str {
25        "Do not specify 'else null' in a case when statement (redundant)."
26    }
27
28    fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
29        let mut violation_count = 0usize;
30        visit::visit_expressions(stmt, &mut |expr| {
31            if let Expr::Case {
32                else_result: Some(else_expr),
33                ..
34            } = expr
35            {
36                if helpers::is_null_expr(else_expr) {
37                    violation_count += 1;
38                }
39            }
40        });
41        let mut autofix_candidates = st001_else_null_candidates_for_context(ctx);
42        autofix_candidates.sort_by_key(|candidate| candidate.span.start);
43        let candidates_align = autofix_candidates.len() == violation_count;
44
45        (0..violation_count)
46            .map(|index| {
47                let mut issue = Issue::info(
48                    issue_codes::LINT_ST_001,
49                    "ELSE NULL is redundant in CASE expressions; it can be removed.",
50                )
51                .with_statement(ctx.statement_index);
52
53                if candidates_align {
54                    let candidate = &autofix_candidates[index];
55                    issue = issue.with_span(candidate.span).with_autofix_edits(
56                        IssueAutofixApplicability::Safe,
57                        candidate.edits.clone(),
58                    );
59                }
60
61                issue
62            })
63            .collect()
64    }
65}
66
67#[derive(Clone, Debug)]
68struct PositionedToken {
69    token: Token,
70    start: usize,
71    end: usize,
72}
73
74#[derive(Clone, Debug)]
75struct St001AutofixCandidate {
76    span: Span,
77    edits: Vec<IssuePatchEdit>,
78}
79
80#[derive(Clone, Copy, Debug)]
81struct CaseFrame {
82    else_sig_pos: Option<usize>,
83}
84
85fn st001_else_null_candidates_for_context(ctx: &LintContext) -> Vec<St001AutofixCandidate> {
86    let tokens = statement_positioned_tokens(ctx);
87    if tokens.is_empty() {
88        return Vec::new();
89    }
90
91    st001_else_null_candidates_from_tokens(&tokens)
92}
93
94fn statement_positioned_tokens(ctx: &LintContext) -> Vec<PositionedToken> {
95    let from_document_tokens = ctx.with_document_tokens(|tokens| {
96        if tokens.is_empty() {
97            return None;
98        }
99
100        let mut positioned = Vec::new();
101        for token in tokens {
102            let (start, end) = token_with_span_offsets(ctx.sql, token)?;
103            if start < ctx.statement_range.start || end > ctx.statement_range.end {
104                continue;
105            }
106
107            positioned.push(PositionedToken {
108                token: token.token.clone(),
109                start,
110                end,
111            });
112        }
113
114        Some(positioned)
115    });
116
117    if let Some(tokens) = from_document_tokens {
118        return tokens;
119    }
120
121    let dialect = ctx.dialect().to_sqlparser_dialect();
122    let mut tokenizer = Tokenizer::new(dialect.as_ref(), ctx.statement_sql());
123    let Ok(tokens) = tokenizer.tokenize_with_location() else {
124        return Vec::new();
125    };
126
127    let mut positioned = Vec::new();
128    for token in &tokens {
129        let Some((start, end)) = token_with_span_offsets(ctx.statement_sql(), token) else {
130            continue;
131        };
132        positioned.push(PositionedToken {
133            token: token.token.clone(),
134            start: ctx.statement_range.start + start,
135            end: ctx.statement_range.start + end,
136        });
137    }
138
139    positioned
140}
141
142fn st001_else_null_candidates_from_tokens(
143    tokens: &[PositionedToken],
144) -> Vec<St001AutofixCandidate> {
145    let significant_positions: Vec<usize> = tokens
146        .iter()
147        .enumerate()
148        .filter_map(|(index, token)| (!is_trivia(&token.token)).then_some(index))
149        .collect();
150
151    let mut candidates = Vec::new();
152    let mut case_stack: Vec<CaseFrame> = Vec::new();
153
154    for (sig_pos, token_index) in significant_positions.iter().copied().enumerate() {
155        if token_word_equals(&tokens[token_index].token, "CASE") {
156            case_stack.push(CaseFrame { else_sig_pos: None });
157            continue;
158        }
159
160        if token_word_equals(&tokens[token_index].token, "ELSE") {
161            if let Some(frame) = case_stack.last_mut() {
162                frame.else_sig_pos = Some(sig_pos);
163            }
164            continue;
165        }
166
167        if !token_word_equals(&tokens[token_index].token, "END") {
168            continue;
169        }
170
171        let Some(frame) = case_stack.pop() else {
172            continue;
173        };
174        let Some(else_sig_pos) = frame.else_sig_pos else {
175            continue;
176        };
177        if else_sig_pos + 1 >= sig_pos {
178            continue;
179        }
180
181        let null_token_index = significant_positions[else_sig_pos + 1];
182        if else_sig_pos + 2 != sig_pos
183            || !token_word_equals(&tokens[null_token_index].token, "NULL")
184        {
185            continue;
186        }
187
188        if else_sig_pos == 0 {
189            continue;
190        }
191
192        let else_token_index = significant_positions[else_sig_pos];
193        let previous_sig_token_index = significant_positions[else_sig_pos - 1];
194        let removal_start_token_index = previous_sig_token_index.saturating_add(1);
195
196        if removal_start_token_index > else_token_index
197            || removal_start_token_index >= tokens.len()
198            || trivia_contains_comment(tokens, removal_start_token_index, null_token_index + 1)
199        {
200            continue;
201        }
202
203        let removal_span = Span::new(
204            tokens[removal_start_token_index].start,
205            tokens[null_token_index].end,
206        );
207        candidates.push(St001AutofixCandidate {
208            span: removal_span,
209            edits: vec![IssuePatchEdit::new(removal_span, "")],
210        });
211    }
212
213    candidates
214}
215
216fn token_word_equals(token: &Token, expected_upper: &str) -> bool {
217    matches!(token, Token::Word(word) if word.value.eq_ignore_ascii_case(expected_upper))
218}
219
220fn is_trivia(token: &Token) -> bool {
221    matches!(
222        token,
223        Token::Whitespace(
224            Whitespace::Space
225                | Whitespace::Newline
226                | Whitespace::Tab
227                | Whitespace::SingleLineComment { .. }
228                | Whitespace::MultiLineComment(_)
229        )
230    )
231}
232
233fn trivia_contains_comment(tokens: &[PositionedToken], start: usize, end: usize) -> bool {
234    if start >= end {
235        return false;
236    }
237
238    tokens[start..end].iter().any(|token| {
239        matches!(
240            token.token,
241            Token::Whitespace(
242                Whitespace::SingleLineComment { .. } | Whitespace::MultiLineComment(_)
243            )
244        )
245    })
246}
247
248fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
249    let start = line_col_to_offset(
250        sql,
251        token.span.start.line as usize,
252        token.span.start.column as usize,
253    )?;
254    let end = line_col_to_offset(
255        sql,
256        token.span.end.line as usize,
257        token.span.end.column as usize,
258    )?;
259    Some((start, end))
260}
261
262fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
263    if line == 0 || column == 0 {
264        return None;
265    }
266
267    let mut current_line = 1usize;
268    let mut current_col = 1usize;
269
270    for (offset, ch) in sql.char_indices() {
271        if current_line == line && current_col == column {
272            return Some(offset);
273        }
274
275        if ch == '\n' {
276            current_line += 1;
277            current_col = 1;
278        } else {
279            current_col += 1;
280        }
281    }
282
283    if current_line == line && current_col == column {
284        return Some(sql.len());
285    }
286
287    None
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use crate::parser::parse_sql;
294    use crate::types::IssueAutofixApplicability;
295
296    fn check_sql(sql: &str) -> Vec<Issue> {
297        let stmts = parse_sql(sql).unwrap();
298        let rule = UnnecessaryElseNull;
299        let ctx = LintContext {
300            sql,
301            statement_range: 0..sql.len(),
302            statement_index: 0,
303        };
304        let mut issues = Vec::new();
305        for stmt in &stmts {
306            issues.extend(rule.check(stmt, &ctx));
307        }
308        issues
309    }
310
311    #[test]
312    fn test_else_null_detected() {
313        let issues = check_sql("SELECT CASE WHEN x > 1 THEN 'a' ELSE NULL END FROM t");
314        assert_eq!(issues.len(), 1);
315        assert_eq!(issues[0].code, "LINT_ST_001");
316    }
317
318    #[test]
319    fn test_no_else_ok() {
320        let issues = check_sql("SELECT CASE WHEN x > 1 THEN 'a' END FROM t");
321        assert!(issues.is_empty());
322    }
323
324    #[test]
325    fn test_else_value_ok() {
326        let issues = check_sql("SELECT CASE WHEN x > 1 THEN 'a' ELSE 'b' END FROM t");
327        assert!(issues.is_empty());
328    }
329
330    // --- Edge cases adopted from sqlfluff ST01 (structure.else_null) ---
331
332    #[test]
333    fn test_simple_case_else_null() {
334        // CASE x WHEN ... ELSE NULL END
335        let issues = check_sql(
336            "SELECT CASE name WHEN 'cat' THEN 'meow' WHEN 'dog' THEN 'woof' ELSE NULL END FROM t",
337        );
338        assert_eq!(issues.len(), 1);
339    }
340
341    #[test]
342    fn test_else_with_complex_expression_ok() {
343        let issues =
344            check_sql("SELECT CASE name WHEN 'cat' THEN 'meow' ELSE UPPER(name) END FROM t");
345        assert!(issues.is_empty());
346    }
347
348    #[test]
349    fn test_multiple_when_branches_else_null() {
350        let issues = check_sql(
351            "SELECT CASE WHEN x = 1 THEN 'a' WHEN x = 2 THEN 'b' WHEN x = 3 THEN 'c' ELSE NULL END FROM t",
352        );
353        assert_eq!(issues.len(), 1);
354    }
355
356    #[test]
357    fn test_nested_case_else_null() {
358        // Both the inner and outer CASE have ELSE NULL
359        let issues = check_sql(
360            "SELECT CASE WHEN x > 0 THEN CASE WHEN y > 0 THEN 'pos' ELSE NULL END ELSE NULL END FROM t",
361        );
362        assert_eq!(issues.len(), 2);
363    }
364
365    #[test]
366    fn test_else_null_in_where_clause() {
367        let issues =
368            check_sql("SELECT * FROM t WHERE (CASE WHEN x > 0 THEN 1 ELSE NULL END) IS NOT NULL");
369        assert_eq!(issues.len(), 1);
370    }
371
372    #[test]
373    fn test_else_null_in_cte() {
374        let issues = check_sql(
375            "WITH cte AS (SELECT CASE WHEN x > 0 THEN 'yes' ELSE NULL END AS flag FROM t) SELECT * FROM cte",
376        );
377        assert_eq!(issues.len(), 1);
378    }
379
380    #[test]
381    fn test_else_null_emits_safe_autofix_patch() {
382        let sql = "SELECT CASE WHEN x > 1 THEN 'a' ELSE NULL END FROM t";
383        let issues = check_sql(sql);
384        assert_eq!(issues.len(), 1);
385
386        let autofix = issues[0]
387            .autofix
388            .as_ref()
389            .expect("expected ST001 core autofix metadata");
390        assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
391        assert_eq!(autofix.edits.len(), 1);
392
393        let edit = &autofix.edits[0];
394        let rewritten = format!(
395            "{}{}{}",
396            &sql[..edit.span.start],
397            edit.replacement,
398            &sql[edit.span.end..]
399        );
400        assert_eq!(rewritten, "SELECT CASE WHEN x > 1 THEN 'a' END FROM t");
401    }
402}