Skip to main content

flowscope_core/linter/rules/
st_004.rs

1//! LINT_ST_004: Flattenable nested CASE in ELSE.
2//!
3//! SQLFluff ST04 parity: flag `CASE ... ELSE CASE ... END END` patterns where
4//! the nested ELSE-case can be flattened into the outer CASE.
5
6use crate::linter::rule::{LintContext, LintRule};
7use crate::linter::visit;
8use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
9use sqlparser::ast::{Expr, Spanned, Statement};
10use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer, Whitespace};
11
12pub struct FlattenableNestedCase;
13
14impl LintRule for FlattenableNestedCase {
15    fn code(&self) -> &'static str {
16        issue_codes::LINT_ST_004
17    }
18
19    fn name(&self) -> &'static str {
20        "Flattenable nested CASE"
21    }
22
23    fn description(&self) -> &'static str {
24        "Nested 'CASE' statement in 'ELSE' clause could be flattened."
25    }
26
27    fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
28        let mut issues = Vec::new();
29
30        visit::visit_expressions(stmt, &mut |expr| {
31            if !is_flattenable_nested_else_case(expr) {
32                return;
33            }
34
35            let mut issue = Issue::warning(
36                issue_codes::LINT_ST_004,
37                "Nested CASE in ELSE clause can be flattened.",
38            )
39            .with_statement(ctx.statement_index);
40
41            if let Some((span, edits)) = build_flatten_autofix(ctx, expr) {
42                issue = issue.with_span(span);
43                if !edits.is_empty() {
44                    issue = issue.with_autofix_edits(IssueAutofixApplicability::Unsafe, edits);
45                }
46            }
47
48            issues.push(issue);
49        });
50
51        // Parser fallback path for unparsable CASE syntax and templated inner
52        // branches: detect at token level using statement SQL.
53        if issues.is_empty()
54            && (is_synthetic_select_one(stmt) || contains_template_tags(ctx.statement_sql()))
55        {
56            if let Some((span, edits)) = build_flatten_autofix_from_sql(ctx) {
57                let mut issue = Issue::warning(
58                    issue_codes::LINT_ST_004,
59                    "Nested CASE in ELSE clause can be flattened.",
60                )
61                .with_statement(ctx.statement_index)
62                .with_span(span);
63                if !edits.is_empty() {
64                    issue = issue.with_autofix_edits(IssueAutofixApplicability::Unsafe, edits);
65                }
66                issues.push(issue);
67            }
68        }
69
70        issues
71    }
72}
73
74fn is_flattenable_nested_else_case(expr: &Expr) -> bool {
75    let Expr::Case {
76        operand: outer_operand,
77        conditions: outer_conditions,
78        else_result: Some(outer_else),
79        ..
80    } = expr
81    else {
82        return false;
83    };
84
85    // SQLFluff ST04 only applies when there is at least one WHEN in the outer CASE.
86    if outer_conditions.is_empty() {
87        return false;
88    }
89
90    let Some((inner_operand, _inner_conditions, _inner_else)) = case_parts(outer_else) else {
91        return false;
92    };
93
94    case_operands_match(outer_operand.as_deref(), inner_operand)
95}
96
97fn case_parts(
98    case_expr: &Expr,
99) -> Option<(Option<&Expr>, &[sqlparser::ast::CaseWhen], Option<&Expr>)> {
100    match case_expr {
101        Expr::Case {
102            operand,
103            conditions,
104            else_result,
105            ..
106        } => Some((
107            operand.as_deref(),
108            conditions.as_slice(),
109            else_result.as_deref(),
110        )),
111        Expr::Nested(inner) => case_parts(inner),
112        _ => None,
113    }
114}
115
116fn case_operands_match(outer: Option<&Expr>, inner: Option<&Expr>) -> bool {
117    match (outer, inner) {
118        (None, None) => true,
119        (Some(left), Some(right)) => exprs_equal(left, right),
120        _ => false,
121    }
122}
123
124fn exprs_equal(left: &Expr, right: &Expr) -> bool {
125    format!("{left}") == format!("{right}")
126}
127
128fn contains_template_tags(sql: &str) -> bool {
129    sql.contains("{{") || sql.contains("{%") || sql.contains("{#")
130}
131
132fn is_synthetic_select_one(stmt: &Statement) -> bool {
133    let normalized = stmt
134        .to_string()
135        .split_whitespace()
136        .collect::<Vec<_>>()
137        .join(" ");
138    normalized.eq_ignore_ascii_case("SELECT 1")
139}
140
141// ---------------------------------------------------------------------------
142// Autofix: flatten nested CASE in ELSE clause
143// ---------------------------------------------------------------------------
144
145/// Build autofix edits that flatten a nested CASE in ELSE into the outer CASE.
146///
147/// The transformation removes the ELSE...CASE...END wrapper and promotes the
148/// inner CASE's WHEN/ELSE clauses to the outer CASE, preserving comments.
149fn build_flatten_autofix(
150    ctx: &LintContext,
151    outer_expr: &Expr,
152) -> Option<(Span, Vec<IssuePatchEdit>)> {
153    let Expr::Case {
154        else_result: Some(outer_else),
155        ..
156    } = outer_expr
157    else {
158        return None;
159    };
160
161    let inner_case = unwrap_nested(outer_else);
162    let Expr::Case { .. } = inner_case else {
163        return None;
164    };
165
166    let sql = ctx.statement_sql();
167
168    // Get the outer CASE span in statement coordinates.
169    let (outer_start, outer_end) = expr_statement_offsets(ctx, outer_expr)?;
170
171    // Tokenize the CASE expression region to find key positions.
172    let tokens = tokenize_with_spans(sql, ctx.dialect())?;
173    let positioned: Vec<PositionedToken> = tokens
174        .iter()
175        .filter_map(|token| {
176            let (start, end) = token_with_span_offsets(sql, token)?;
177            Some(PositionedToken {
178                token: token.token.clone(),
179                start,
180                end,
181            })
182        })
183        .filter(|token| token.start >= outer_start && token.end <= outer_end)
184        .collect();
185
186    // Find the ELSE keyword that begins the nested CASE, and the inner CASE/END tokens.
187    let flatten_info = find_flatten_positions(&positioned)?;
188
189    build_flatten_edit_from_positions(ctx, sql, &positioned, &flatten_info)
190}
191
192fn build_flatten_autofix_from_sql(ctx: &LintContext) -> Option<(Span, Vec<IssuePatchEdit>)> {
193    let sql = ctx.statement_sql();
194    let masked_sql = contains_template_tags(sql).then(|| mask_templated_areas(sql));
195    let scan_sql = masked_sql.as_deref().unwrap_or(sql);
196    let tokens = tokenize_with_spans(scan_sql, ctx.dialect())?;
197    let positioned: Vec<PositionedToken> = tokens
198        .iter()
199        .filter_map(|token| {
200            let (start, end) = token_with_span_offsets(scan_sql, token)?;
201            Some(PositionedToken {
202                token: token.token.clone(),
203                start,
204                end,
205            })
206        })
207        .collect();
208
209    let flatten_info = find_flatten_positions(&positioned)?;
210    build_flatten_edit_from_positions(ctx, sql, &positioned, &flatten_info)
211}
212
213fn mask_templated_areas(sql: &str) -> String {
214    let mut out = String::with_capacity(sql.len());
215    let mut index = 0usize;
216
217    while let Some((open_index, close_marker)) = find_next_template_open(sql, index) {
218        out.push_str(&sql[index..open_index]);
219        let marker_start = open_index + 2;
220        if let Some(close_offset) = sql[marker_start..].find(close_marker) {
221            let close_index = marker_start + close_offset + close_marker.len();
222            out.push_str(&mask_non_newlines(&sql[open_index..close_index]));
223            index = close_index;
224        } else {
225            out.push_str(&mask_non_newlines(&sql[open_index..]));
226            return out;
227        }
228    }
229
230    out.push_str(&sql[index..]);
231    out
232}
233
234fn find_next_template_open(sql: &str, from: usize) -> Option<(usize, &'static str)> {
235    let rest = sql.get(from..)?;
236    let candidates = [("{{", "}}"), ("{%", "%}"), ("{#", "#}")];
237
238    candidates
239        .into_iter()
240        .filter_map(|(open, close)| rest.find(open).map(|offset| (from + offset, close)))
241        .min_by_key(|(index, _)| *index)
242}
243
244fn mask_non_newlines(segment: &str) -> String {
245    segment
246        .chars()
247        .map(|ch| if ch == '\n' { '\n' } else { ' ' })
248        .collect()
249}
250
251fn build_flatten_edit_from_positions(
252    ctx: &LintContext,
253    sql: &str,
254    positioned: &[PositionedToken],
255    flatten_info: &FlattenPositions,
256) -> Option<(Span, Vec<IssuePatchEdit>)> {
257    let else_start = flatten_info.else_start;
258    let inner_case_body_start = flatten_info.inner_body_start;
259    let inner_end_start = flatten_info.inner_end_start;
260    let inner_end_end = flatten_info.inner_end_end;
261    let outer_end_start = flatten_info.outer_end_start;
262    let outer_case_start = flatten_info.outer_case_start;
263    let outer_end_end = flatten_info.outer_end_end;
264    let else_end = flatten_info.else_end;
265    let inner_case_start = flatten_info.inner_case_start;
266    let inner_case_end = flatten_info.inner_case_end;
267
268    let issue_span = ctx.span_from_statement_offset(outer_case_start, outer_end_end);
269
270    // Replace from the start of the ELSE line up to (but not including) the
271    // outer END token.
272    let replace_start = line_start_offset(sql, else_start);
273    let replace_end = outer_end_start;
274    if replace_end <= replace_start {
275        return Some((issue_span, Vec::new()));
276    }
277
278    // Collect comments between ELSE and inner CASE body.
279    let else_line_start = line_start_offset(sql, else_start);
280    let mut comments_before_body =
281        collect_comments_in_range(positioned, else_line_start, else_start);
282    comments_before_body.extend(collect_comments_in_range(
283        positioned,
284        else_start,
285        inner_case_body_start,
286    ));
287
288    // Collect comments between inner END and outer END.
289    let comments_after_inner_end =
290        collect_comments_in_range(positioned, inner_end_end, outer_end_start);
291
292    // If the rewrite region touches template tags, report only (no autofix).
293    if contains_template_tags(sql.get(replace_start..replace_end)?) {
294        return Some((issue_span, Vec::new()));
295    }
296
297    // In comment-heavy regions, avoid editing comment bytes directly (blocked
298    // by the fix planner). Remove only CASE/ELSE/END wrapper keywords.
299    if has_comments_in_range(positioned, replace_start, replace_end) {
300        let mut edits = Vec::new();
301        edits.push(IssuePatchEdit::new(
302            ctx.span_from_statement_offset(else_start, else_end),
303            String::new(),
304        ));
305        edits.push(IssuePatchEdit::new(
306            ctx.span_from_statement_offset(inner_case_start, inner_case_end),
307            String::new(),
308        ));
309        if inner_case_end < inner_case_body_start
310            && !has_comments_in_range(positioned, inner_case_end, inner_case_body_start)
311        {
312            edits.push(IssuePatchEdit::new(
313                ctx.span_from_statement_offset(inner_case_end, inner_case_body_start),
314                String::new(),
315            ));
316        }
317        edits.push(IssuePatchEdit::new(
318            ctx.span_from_statement_offset(inner_end_start, inner_end_end),
319            String::new(),
320        ));
321        return Some((issue_span, edits));
322    }
323
324    // Get the inner body text (WHEN/ELSE clauses).
325    let inner_body_text = sql.get(inner_case_body_start..inner_end_start)?;
326
327    // Determine indentation levels.
328    let outer_indent = find_indent_of_else(sql, else_start);
329    let inner_body_indent = find_line_prefix(sql, inner_case_body_start);
330
331    // Build the replacement text.
332    let mut replacement = String::new();
333
334    // Add collected comments from between ELSE and inner CASE body.
335    for comment in &comments_before_body {
336        replacement.push_str(&outer_indent);
337        replacement.push_str(comment.trim());
338        replacement.push('\n');
339    }
340
341    // Add inner body lines, re-indented to match outer CASE indentation.
342    let inner_body_trimmed = inner_body_text.trim();
343    if !inner_body_trimmed.is_empty() {
344        for line in inner_body_trimmed.lines() {
345            let stripped = strip_indent(line, &inner_body_indent);
346            replacement.push_str(&outer_indent);
347            replacement.push_str(&stripped);
348            replacement.push('\n');
349        }
350    }
351
352    // Add comments from after inner END.
353    for comment in &comments_after_inner_end {
354        replacement.push_str(&outer_indent);
355        replacement.push_str(comment.trim());
356        replacement.push('\n');
357    }
358
359    // Trim trailing newline from replacement.
360    while replacement.ends_with('\n') {
361        replacement.pop();
362    }
363
364    // Keep the outer END token in place by restoring its indentation prefix.
365    let end_prefix = find_line_prefix(sql, outer_end_start);
366    replacement.push('\n');
367    replacement.push_str(&end_prefix);
368
369    let edit_span = ctx.span_from_statement_offset(replace_start, replace_end);
370    Some((
371        issue_span,
372        vec![IssuePatchEdit::new(edit_span, replacement)],
373    ))
374}
375
376fn has_comments_in_range(tokens: &[PositionedToken], start: usize, end: usize) -> bool {
377    tokens
378        .iter()
379        .any(|t| t.start >= start && t.end <= end && is_comment(&t.token))
380}
381
382#[derive(Debug)]
383struct FlattenPositions {
384    /// Byte offset of the outer CASE keyword.
385    outer_case_start: usize,
386    /// Byte offset of the outer ELSE keyword that contains the nested CASE.
387    else_start: usize,
388    /// Byte offset after the outer ELSE keyword.
389    else_end: usize,
390    /// Byte offset of the inner CASE keyword.
391    inner_case_start: usize,
392    /// Byte offset after the inner CASE keyword.
393    inner_case_end: usize,
394    /// Byte offset where the inner CASE body starts (first WHEN or ELSE after CASE keyword).
395    inner_body_start: usize,
396    /// Byte offset of the inner END keyword.
397    inner_end_start: usize,
398    /// Byte offset after the inner END keyword.
399    inner_end_end: usize,
400    /// Byte offset of the outer END keyword.
401    outer_end_start: usize,
402    /// Byte offset after the outer END keyword.
403    outer_end_end: usize,
404}
405
406fn find_flatten_positions(tokens: &[PositionedToken]) -> Option<FlattenPositions> {
407    let significant: Vec<(usize, &PositionedToken)> = tokens
408        .iter()
409        .enumerate()
410        .filter(|(_, t)| !is_trivia(&t.token))
411        .collect();
412
413    if significant.is_empty() {
414        return None;
415    }
416
417    // Track CASE/END nesting depth.
418    let mut depth = 0usize;
419    let mut outer_case_idx = None;
420    for (sig_idx, (_tok_idx, token)) in significant.iter().enumerate() {
421        if token_word_equals(&token.token, "CASE") {
422            if depth == 0 {
423                outer_case_idx = Some(sig_idx);
424            }
425            depth += 1;
426        } else if token_word_equals(&token.token, "END") {
427            depth = depth.saturating_sub(1);
428            if depth == 0 {
429                // This is the outer END.
430                // Find the ELSE at depth 1 that precedes the inner CASE.
431                let else_info =
432                    find_else_with_nested_case(&significant, outer_case_idx?, sig_idx, tokens)?;
433                return Some(else_info);
434            }
435        }
436    }
437
438    None
439}
440
441fn find_else_with_nested_case(
442    significant: &[(usize, &PositionedToken)],
443    outer_case_sig_idx: usize,
444    outer_end_sig_idx: usize,
445    _tokens: &[PositionedToken],
446) -> Option<FlattenPositions> {
447    // Walk from the outer CASE to outer END tracking depth.
448    let mut depth = 0usize;
449    let outer_case_start = significant.get(outer_case_sig_idx)?.1.start;
450
451    for sig_idx in outer_case_sig_idx..=outer_end_sig_idx {
452        let (_, token) = &significant[sig_idx];
453
454        if token_word_equals(&token.token, "CASE") {
455            depth += 1;
456        }
457
458        if token_word_equals(&token.token, "ELSE") && depth == 1 {
459            // Check if the next significant token after ELSE is CASE.
460            let next_sig = sig_idx + 1;
461            if next_sig < significant.len() {
462                let (_, next_token) = &significant[next_sig];
463                if token_word_equals(&next_token.token, "CASE") {
464                    // Found ELSE followed by CASE.
465                    let else_start = token.start;
466                    let else_end = token.end;
467                    let inner_case_start = next_token.start;
468                    let inner_case_end = next_token.end;
469
470                    // Find where the inner CASE body starts (after CASE keyword and optional operand).
471                    let inner_body_start =
472                        find_inner_body_start(significant, next_sig, outer_end_sig_idx)?;
473
474                    // Find the inner END (at depth 2 -> depth 1).
475                    let mut inner_depth = 0usize;
476                    let mut inner_end_start = None;
477                    let mut inner_end_end = None;
478                    for (_, inner_token) in
479                        significant.iter().take(outer_end_sig_idx).skip(next_sig)
480                    {
481                        if token_word_equals(&inner_token.token, "CASE") {
482                            inner_depth += 1;
483                        } else if token_word_equals(&inner_token.token, "END") {
484                            inner_depth = inner_depth.saturating_sub(1);
485                            if inner_depth == 0 {
486                                inner_end_start = Some(inner_token.start);
487                                inner_end_end = Some(inner_token.end);
488                                break;
489                            }
490                        }
491                    }
492
493                    let outer_end_start = significant[outer_end_sig_idx].1.start;
494                    let outer_end_end = significant[outer_end_sig_idx].1.end;
495
496                    return Some(FlattenPositions {
497                        outer_case_start,
498                        else_start,
499                        else_end,
500                        inner_case_start,
501                        inner_case_end,
502                        inner_body_start,
503                        inner_end_start: inner_end_start?,
504                        inner_end_end: inner_end_end?,
505                        outer_end_start,
506                        outer_end_end,
507                    });
508                }
509            }
510        }
511
512        if token_word_equals(&token.token, "END") {
513            depth = depth.saturating_sub(1);
514        }
515    }
516
517    None
518}
519
520fn find_inner_body_start(
521    significant: &[(usize, &PositionedToken)],
522    inner_case_sig_idx: usize,
523    outer_end_sig_idx: usize,
524) -> Option<usize> {
525    // After CASE, skip optional operand until we find WHEN or ELSE.
526    let mut depth = 0usize;
527    for (_, token) in significant
528        .iter()
529        .take(outer_end_sig_idx)
530        .skip(inner_case_sig_idx)
531    {
532        if token_word_equals(&token.token, "CASE") {
533            depth += 1;
534        } else if token_word_equals(&token.token, "END") {
535            depth = depth.saturating_sub(1);
536        }
537
538        if depth == 1
539            && (token_word_equals(&token.token, "WHEN") || token_word_equals(&token.token, "ELSE"))
540        {
541            return Some(token.start);
542        }
543    }
544    // Template-heavy or parser-fallback SQL may not expose explicit WHEN/ELSE
545    // tokens inside the inner CASE body. Fall back to the byte immediately
546    // after the CASE keyword so we can still emit a detection-only issue.
547    Some(significant.get(inner_case_sig_idx)?.1.end)
548}
549
550fn collect_comments_in_range(tokens: &[PositionedToken], start: usize, end: usize) -> Vec<String> {
551    tokens
552        .iter()
553        .filter(|t| t.start >= start && t.end <= end && is_comment(&t.token))
554        .map(|t| comment_text(&t.token))
555        .collect()
556}
557
558fn comment_text(token: &Token) -> String {
559    match token {
560        Token::Whitespace(Whitespace::SingleLineComment { comment, prefix }) => {
561            format!("{prefix}{comment}")
562        }
563        Token::Whitespace(Whitespace::MultiLineComment(comment)) => {
564            format!("/*{comment}*/")
565        }
566        _ => String::new(),
567    }
568}
569
570fn line_start_offset(sql: &str, offset: usize) -> usize {
571    let before = &sql[..offset];
572    match before.rfind('\n') {
573        Some(nl_pos) => nl_pos + 1,
574        None => 0,
575    }
576}
577
578fn find_indent_of_else(sql: &str, else_offset: usize) -> String {
579    find_line_prefix(sql, else_offset)
580}
581
582fn find_line_prefix(sql: &str, offset: usize) -> String {
583    let before = &sql[..offset];
584    if let Some(nl_pos) = before.rfind('\n') {
585        let line_start = nl_pos + 1;
586        let prefix = &before[line_start..];
587        let indent: String = prefix.chars().take_while(|c| c.is_whitespace()).collect();
588        indent
589    } else {
590        // First line — no leading whitespace assumed.
591        let indent: String = before.chars().take_while(|c| c.is_whitespace()).collect();
592        indent
593    }
594}
595
596fn strip_indent(line: &str, indent: &str) -> String {
597    if let Some(stripped) = line.strip_prefix(indent) {
598        stripped.to_string()
599    } else {
600        line.trim_start().to_string()
601    }
602}
603
604fn unwrap_nested(expr: &Expr) -> &Expr {
605    match expr {
606        Expr::Nested(inner) => unwrap_nested(inner),
607        _ => expr,
608    }
609}
610
611// ---------------------------------------------------------------------------
612// Span and offset utilities
613// ---------------------------------------------------------------------------
614
615#[derive(Clone, Debug)]
616struct PositionedToken {
617    token: Token,
618    start: usize,
619    end: usize,
620}
621
622fn expr_statement_offsets(ctx: &LintContext, expr: &Expr) -> Option<(usize, usize)> {
623    if let Some((start, end)) = expr_span_offsets(ctx.statement_sql(), expr) {
624        return Some((start, end));
625    }
626
627    let (start, end) = expr_span_offsets(ctx.sql, expr)?;
628    if start < ctx.statement_range.start || end > ctx.statement_range.end {
629        return None;
630    }
631
632    Some((
633        start - ctx.statement_range.start,
634        end - ctx.statement_range.start,
635    ))
636}
637
638fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
639    let span = expr.span();
640    if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
641    {
642        return None;
643    }
644
645    let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
646    let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
647    (end >= start).then_some((start, end))
648}
649
650fn tokenize_with_spans(sql: &str, dialect: crate::types::Dialect) -> Option<Vec<TokenWithSpan>> {
651    let dialect = dialect.to_sqlparser_dialect();
652    let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
653    tokenizer.tokenize_with_location().ok()
654}
655
656fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
657    let start = line_col_to_offset(
658        sql,
659        token.span.start.line as usize,
660        token.span.start.column as usize,
661    )?;
662    let end = line_col_to_offset(
663        sql,
664        token.span.end.line as usize,
665        token.span.end.column as usize,
666    )?;
667    Some((start, end))
668}
669
670fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
671    if line == 0 || column == 0 {
672        return None;
673    }
674
675    let mut current_line = 1usize;
676    let mut line_start = 0usize;
677
678    for (idx, ch) in sql.char_indices() {
679        if current_line == line {
680            break;
681        }
682        if ch == '\n' {
683            current_line += 1;
684            line_start = idx + ch.len_utf8();
685        }
686    }
687    if current_line != line {
688        return None;
689    }
690
691    let mut current_column = 1usize;
692    for (rel_idx, ch) in sql[line_start..].char_indices() {
693        if current_column == column {
694            return Some(line_start + rel_idx);
695        }
696        if ch == '\n' {
697            return None;
698        }
699        current_column += 1;
700    }
701
702    if current_column == column {
703        return Some(sql.len());
704    }
705
706    None
707}
708
709fn token_word_equals(token: &Token, expected_upper: &str) -> bool {
710    matches!(token, Token::Word(word) if word.value.eq_ignore_ascii_case(expected_upper))
711}
712
713fn is_trivia(token: &Token) -> bool {
714    matches!(
715        token,
716        Token::Whitespace(
717            Whitespace::Space
718                | Whitespace::Newline
719                | Whitespace::Tab
720                | Whitespace::SingleLineComment { .. }
721                | Whitespace::MultiLineComment(_)
722        )
723    )
724}
725
726fn is_comment(token: &Token) -> bool {
727    matches!(
728        token,
729        Token::Whitespace(Whitespace::SingleLineComment { .. } | Whitespace::MultiLineComment(_))
730    )
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use crate::parser::parse_sql;
737    use crate::types::IssuePatchEdit;
738
739    fn run(sql: &str) -> Vec<Issue> {
740        let statements = parse_sql(sql).expect("parse");
741        let rule = FlattenableNestedCase;
742        statements
743            .iter()
744            .enumerate()
745            .flat_map(|(index, statement)| {
746                rule.check(
747                    statement,
748                    &LintContext {
749                        sql,
750                        statement_range: 0..sql.len(),
751                        statement_index: index,
752                    },
753                )
754            })
755            .collect()
756    }
757
758    fn apply_edits(sql: &str, edits: &[IssuePatchEdit]) -> String {
759        let mut output = sql.to_string();
760        let mut ordered = edits.iter().collect::<Vec<_>>();
761        ordered.sort_by_key(|edit| edit.span.start);
762
763        for edit in ordered.into_iter().rev() {
764            output.replace_range(edit.span.start..edit.span.end, &edit.replacement);
765        }
766
767        output
768    }
769
770    // --- Pass cases from SQLFluff ST04 fixture ---
771
772    #[test]
773    fn passes_nested_case_under_when_clause() {
774        let sql = "SELECT CASE WHEN species = 'Rat' THEN CASE WHEN colour = 'Black' THEN 'Growl' WHEN colour = 'Grey' THEN 'Squeak' END END AS sound FROM mytable";
775        let issues = run(sql);
776        assert!(issues.is_empty());
777    }
778
779    #[test]
780    fn passes_nested_case_inside_larger_else_expression() {
781        let sql = "SELECT CASE WHEN flag = 1 THEN TRUE ELSE score > 10 + CASE WHEN kind = 'b' THEN 8 WHEN kind = 'c' THEN 9 END END AS test FROM t";
782        let issues = run(sql);
783        assert!(issues.is_empty());
784    }
785
786    #[test]
787    fn passes_when_outer_and_inner_case_operands_differ() {
788        let sql = "SELECT CASE WHEN day_of_month IN (11, 12, 13) THEN 'TH' ELSE CASE MOD(day_of_month, 10) WHEN 1 THEN 'ST' WHEN 2 THEN 'ND' WHEN 3 THEN 'RD' ELSE 'TH' END END AS ordinal_suffix FROM calendar";
789        let issues = run(sql);
790        assert!(issues.is_empty());
791    }
792
793    #[test]
794    fn passes_different_case_expressions2() {
795        let sql = "SELECT CASE DayOfMonth WHEN 11 THEN 'TH' WHEN 12 THEN 'TH' WHEN 13 THEN 'TH' ELSE CASE MOD(DayOfMonth, 10) WHEN 1 THEN 'ST' WHEN 2 THEN 'ND' WHEN 3 THEN 'RD' ELSE 'TH' END END AS OrdinalSuffix FROM Calendar";
796        let issues = run(sql);
797        assert!(issues.is_empty());
798    }
799
800    // --- Fail + detection cases ---
801
802    #[test]
803    fn flags_simple_flattenable_else_case() {
804        let sql = "SELECT CASE WHEN species = 'Rat' THEN 'Squeak' ELSE CASE WHEN species = 'Dog' THEN 'Woof' END END AS sound FROM mytable";
805        let issues = run(sql);
806        assert_eq!(issues.len(), 1);
807        assert_eq!(issues[0].code, issue_codes::LINT_ST_004);
808    }
809
810    #[test]
811    fn flags_nested_else_case_with_multiple_when_clauses() {
812        let sql = "SELECT CASE WHEN species = 'Rat' THEN 'Squeak' ELSE CASE WHEN species = 'Dog' THEN 'Woof' WHEN species = 'Mouse' THEN 'Squeak' END END AS sound FROM mytable";
813        let issues = run(sql);
814        assert_eq!(issues.len(), 1);
815    }
816
817    #[test]
818    fn flags_when_outer_and_inner_simple_case_operands_match() {
819        let sql = "SELECT CASE x WHEN 0 THEN 'zero' WHEN 5 THEN 'five' ELSE CASE x WHEN 10 THEN 'ten' WHEN 20 THEN 'twenty' ELSE 'other' END END FROM tab_a";
820        let issues = run(sql);
821        assert_eq!(issues.len(), 1);
822    }
823
824    // --- Autofix tests matching SQLFluff fixture fix_str ---
825
826    #[test]
827    fn autofix_simple_flatten() {
828        let sql = "\
829SELECT
830    c1,
831    CASE
832        WHEN species = 'Rat' THEN 'Squeak'
833        ELSE
834            CASE
835                WHEN species = 'Dog' THEN 'Woof'
836            END
837    END AS sound
838FROM mytable";
839        let issues = run(sql);
840        assert_eq!(issues.len(), 1);
841        let autofix = issues[0].autofix.as_ref().expect("expected autofix");
842        let fixed = apply_edits(sql, &autofix.edits);
843
844        let expected = "\
845SELECT
846    c1,
847    CASE
848        WHEN species = 'Rat' THEN 'Squeak'
849        WHEN species = 'Dog' THEN 'Woof'
850    END AS sound
851FROM mytable";
852        assert_eq!(fixed, expected);
853    }
854
855    #[test]
856    fn autofix_flatten_multiple_whens() {
857        let sql = "\
858SELECT
859    c1,
860    CASE
861        WHEN species = 'Rat' THEN 'Squeak'
862        ELSE
863            CASE
864                WHEN species = 'Dog' THEN 'Woof'
865                WHEN species = 'Mouse' THEN 'Squeak'
866            END
867    END AS sound
868FROM mytable";
869        let issues = run(sql);
870        assert_eq!(issues.len(), 1);
871        let autofix = issues[0].autofix.as_ref().expect("expected autofix");
872        let fixed = apply_edits(sql, &autofix.edits);
873
874        let expected = "\
875SELECT
876    c1,
877    CASE
878        WHEN species = 'Rat' THEN 'Squeak'
879        WHEN species = 'Dog' THEN 'Woof'
880        WHEN species = 'Mouse' THEN 'Squeak'
881    END AS sound
882FROM mytable";
883        assert_eq!(fixed, expected);
884    }
885
886    #[test]
887    fn autofix_flatten_with_else() {
888        let sql = "\
889SELECT
890    c1,
891    CASE
892        WHEN species = 'Rat' THEN 'Squeak'
893        ELSE
894            CASE
895                WHEN species = 'Dog' THEN 'Woof'
896                WHEN species = 'Mouse' THEN 'Squeak'
897                ELSE \"Whaa\"
898            END
899    END AS sound
900FROM mytable";
901        let issues = run(sql);
902        assert_eq!(issues.len(), 1);
903        let autofix = issues[0].autofix.as_ref().expect("expected autofix");
904        let fixed = apply_edits(sql, &autofix.edits);
905
906        let expected = "\
907SELECT
908    c1,
909    CASE
910        WHEN species = 'Rat' THEN 'Squeak'
911        WHEN species = 'Dog' THEN 'Woof'
912        WHEN species = 'Mouse' THEN 'Squeak'
913        ELSE \"Whaa\"
914    END AS sound
915FROM mytable";
916        assert_eq!(fixed, expected);
917    }
918
919    #[test]
920    fn autofix_flatten_same_simple_case_operand() {
921        let sql = "\
922SELECT
923    CASE x
924        WHEN 0 THEN 'zero'
925        WHEN 5 THEN 'five'
926        ELSE
927            CASE x
928                WHEN 10 THEN 'ten'
929                WHEN 20 THEN 'twenty'
930                ELSE 'other'
931            END
932    END
933FROM tab_a;";
934        let issues = run(sql);
935        assert_eq!(issues.len(), 1);
936        let autofix = issues[0].autofix.as_ref().expect("expected autofix");
937        let fixed = apply_edits(sql, &autofix.edits);
938
939        let expected = "\
940SELECT
941    CASE x
942        WHEN 0 THEN 'zero'
943        WHEN 5 THEN 'five'
944        WHEN 10 THEN 'ten'
945        WHEN 20 THEN 'twenty'
946        ELSE 'other'
947    END
948FROM tab_a;";
949        assert_eq!(fixed, expected);
950    }
951}