Skip to main content

flowscope_core/linter/rules/
am_002.rs

1//! LINT_AM_002: Bare UNION quantifier.
2//!
3//! `UNION` should be explicit (`UNION DISTINCT` or `UNION ALL`) to avoid ambiguous implicit behavior.
4
5use crate::linter::rule::{LintContext, LintRule};
6use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
7use sqlparser::ast::*;
8use sqlparser::keywords::Keyword;
9use sqlparser::tokenizer::{Location, Span, Token, TokenWithSpan, Tokenizer};
10
11pub struct BareUnion;
12
13impl LintRule for BareUnion {
14    fn code(&self) -> &'static str {
15        issue_codes::LINT_AM_002
16    }
17
18    fn name(&self) -> &'static str {
19        "Ambiguous UNION quantifier"
20    }
21
22    fn description(&self) -> &'static str {
23        "'UNION [DISTINCT|ALL]' is preferred over just 'UNION'."
24    }
25
26    fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
27        let mut issues = Vec::new();
28        let mut unions = union_keyword_ranges_for_context(ctx);
29        match stmt {
30            Statement::Query(query) => check_query(query, &mut unions, ctx, &mut issues),
31            Statement::Insert(insert) => {
32                if let Some(ref source) = insert.source {
33                    check_query(source, &mut unions, ctx, &mut issues);
34                }
35            }
36            Statement::CreateView { query, .. } => {
37                check_query(query, &mut unions, ctx, &mut issues)
38            }
39            Statement::CreateTable(create) => {
40                if let Some(ref query) = create.query {
41                    check_query(query, &mut unions, ctx, &mut issues);
42                }
43            }
44            _ => {}
45        }
46        issues
47    }
48}
49
50fn union_keyword_ranges_for_context(ctx: &LintContext) -> UnionKeywordRanges {
51    let tokens = tokenized_for_context(ctx);
52    union_keyword_ranges(ctx.statement_sql(), ctx.dialect(), tokens.as_deref())
53}
54
55fn check_query(
56    query: &Query,
57    unions: &mut UnionKeywordRanges,
58    ctx: &LintContext,
59    issues: &mut Vec<Issue>,
60) {
61    if let Some(ref with) = query.with {
62        for cte in &with.cte_tables {
63            check_query(&cte.query, unions, ctx, issues);
64        }
65    }
66    check_query_body(&query.body, unions, ctx, issues);
67}
68
69fn check_query_body(
70    body: &SetExpr,
71    unions: &mut UnionKeywordRanges,
72    ctx: &LintContext,
73    issues: &mut Vec<Issue>,
74) {
75    match body {
76        SetExpr::SetOperation {
77            op: SetOperator::Union,
78            set_quantifier,
79            left,
80            right,
81        } => {
82            check_query_body(left, unions, ctx, issues);
83            let union_span = unions.next();
84
85            if matches!(set_quantifier, SetQuantifier::None | SetQuantifier::ByName)
86                // PostgreSQL treats bare UNION as UNION DISTINCT unambiguously,
87                // so flagging it would be noise.
88                && !matches!(ctx.dialect(), Dialect::Postgres)
89            {
90                let mut issue = Issue::warning(
91                    issue_codes::LINT_AM_002,
92                    "Use UNION DISTINCT or UNION ALL instead of bare UNION.",
93                )
94                .with_statement(ctx.statement_index);
95                if let Some((start, end)) = union_span {
96                    let span = ctx.span_from_statement_offset(start, end);
97                    let union_keyword = &ctx.statement_sql()[start..end];
98                    let distinct = if union_keyword == union_keyword.to_ascii_lowercase() {
99                        "distinct"
100                    } else {
101                        "DISTINCT"
102                    };
103                    issue = issue.with_span(span).with_autofix_edits(
104                        IssueAutofixApplicability::Safe,
105                        vec![IssuePatchEdit::new(
106                            span,
107                            format!("{union_keyword} {distinct}"),
108                        )],
109                    );
110                }
111                issues.push(issue);
112            }
113            check_query_body(right, unions, ctx, issues);
114        }
115        SetExpr::SetOperation { left, right, .. } => {
116            check_query_body(left, unions, ctx, issues);
117            check_query_body(right, unions, ctx, issues);
118        }
119        SetExpr::Select(_) => {}
120        SetExpr::Query(q) => {
121            check_query(q, unions, ctx, issues);
122        }
123        _ => {}
124    }
125}
126
127struct UnionKeywordRanges {
128    ranges: Vec<(usize, usize)>,
129    index: usize,
130}
131
132impl UnionKeywordRanges {
133    fn next(&mut self) -> Option<(usize, usize)> {
134        let range = self.ranges.get(self.index).copied();
135        if range.is_some() {
136            self.index += 1;
137        }
138        range
139    }
140}
141
142fn union_keyword_ranges(
143    sql: &str,
144    dialect: Dialect,
145    tokens: Option<&[TokenWithSpan]>,
146) -> UnionKeywordRanges {
147    let owned_tokens;
148    let tokens = if let Some(tokens) = tokens {
149        tokens
150    } else {
151        owned_tokens = match tokenized(sql, dialect) {
152            Some(tokens) => tokens,
153            None => {
154                return UnionKeywordRanges {
155                    ranges: Vec::new(),
156                    index: 0,
157                };
158            }
159        };
160        &owned_tokens
161    };
162
163    let ranges = tokens
164        .iter()
165        .filter_map(|token| {
166            let Token::Word(word) = &token.token else {
167                return None;
168            };
169            if word.keyword != Keyword::UNION {
170                return None;
171            }
172
173            token_offsets(sql, token)
174        })
175        .collect();
176
177    UnionKeywordRanges { ranges, index: 0 }
178}
179
180fn tokenized(sql: &str, dialect: Dialect) -> Option<Vec<TokenWithSpan>> {
181    let dialect = dialect.to_sqlparser_dialect();
182    let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
183    tokenizer.tokenize_with_location().ok()
184}
185
186fn tokenized_for_context(ctx: &LintContext) -> Option<Vec<TokenWithSpan>> {
187    let (statement_start_line, statement_start_column) =
188        offset_to_line_col(ctx.sql, ctx.statement_range.start)?;
189
190    ctx.with_document_tokens(|tokens| {
191        if tokens.is_empty() {
192            return None;
193        }
194
195        let mut out = Vec::new();
196        for token in tokens {
197            let Some((start, end)) = token_offsets(ctx.sql, token) else {
198                continue;
199            };
200            if start < ctx.statement_range.start || end > ctx.statement_range.end {
201                continue;
202            }
203
204            let Some(start_loc) = relative_location(
205                token.span.start,
206                statement_start_line,
207                statement_start_column,
208            ) else {
209                continue;
210            };
211            let Some(end_loc) =
212                relative_location(token.span.end, statement_start_line, statement_start_column)
213            else {
214                continue;
215            };
216
217            out.push(TokenWithSpan::new(
218                token.token.clone(),
219                Span::new(start_loc, end_loc),
220            ));
221        }
222
223        if out.is_empty() {
224            None
225        } else {
226            Some(out)
227        }
228    })
229}
230
231fn token_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
232    let start = line_col_to_offset(
233        sql,
234        token.span.start.line as usize,
235        token.span.start.column as usize,
236    )?;
237    let end = line_col_to_offset(
238        sql,
239        token.span.end.line as usize,
240        token.span.end.column as usize,
241    )?;
242    Some((start, end))
243}
244
245fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
246    if line == 0 || column == 0 {
247        return None;
248    }
249
250    let mut current_line = 1usize;
251    let mut current_col = 1usize;
252
253    for (offset, ch) in sql.char_indices() {
254        if current_line == line && current_col == column {
255            return Some(offset);
256        }
257
258        if ch == '\n' {
259            current_line += 1;
260            current_col = 1;
261        } else {
262            current_col += 1;
263        }
264    }
265
266    if current_line == line && current_col == column {
267        return Some(sql.len());
268    }
269
270    None
271}
272
273fn offset_to_line_col(sql: &str, offset: usize) -> Option<(usize, usize)> {
274    if offset > sql.len() {
275        return None;
276    }
277    if offset == sql.len() {
278        let mut line = 1usize;
279        let mut column = 1usize;
280        for ch in sql.chars() {
281            if ch == '\n' {
282                line += 1;
283                column = 1;
284            } else {
285                column += 1;
286            }
287        }
288        return Some((line, column));
289    }
290
291    let mut line = 1usize;
292    let mut column = 1usize;
293    for (index, ch) in sql.char_indices() {
294        if index == offset {
295            return Some((line, column));
296        }
297        if ch == '\n' {
298            line += 1;
299            column = 1;
300        } else {
301            column += 1;
302        }
303    }
304
305    None
306}
307
308fn relative_location(
309    location: Location,
310    statement_start_line: usize,
311    statement_start_column: usize,
312) -> Option<Location> {
313    let line = location.line as usize;
314    let column = location.column as usize;
315    if line < statement_start_line {
316        return None;
317    }
318
319    if line == statement_start_line {
320        if column < statement_start_column {
321            return None;
322        }
323        return Some(Location::new(
324            1,
325            (column - statement_start_column + 1) as u64,
326        ));
327    }
328
329    Some(Location::new(
330        (line - statement_start_line + 1) as u64,
331        column as u64,
332    ))
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::linter::rule::with_active_dialect;
339    use crate::parser::{parse_sql, parse_sql_with_dialect};
340    use crate::types::IssueAutofixApplicability;
341
342    fn check_sql(sql: &str) -> Vec<Issue> {
343        let stmts = parse_sql(sql).unwrap();
344        let rule = BareUnion;
345        let ctx = LintContext {
346            sql,
347            statement_range: 0..sql.len(),
348            statement_index: 0,
349        };
350        let mut issues = Vec::new();
351        for stmt in &stmts {
352            issues.extend(rule.check(stmt, &ctx));
353        }
354        issues
355    }
356
357    fn check_sql_in_dialect(sql: &str, dialect: Dialect) -> Vec<Issue> {
358        let stmts = parse_sql_with_dialect(sql, dialect).unwrap();
359        let rule = BareUnion;
360        let mut issues = Vec::new();
361        with_active_dialect(dialect, || {
362            for stmt in &stmts {
363                issues.extend(rule.check(
364                    stmt,
365                    &LintContext {
366                        sql,
367                        statement_range: 0..sql.len(),
368                        statement_index: 0,
369                    },
370                ));
371            }
372        });
373        issues
374    }
375
376    fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
377        let autofix = issue.autofix.as_ref()?;
378        let mut edits = autofix.edits.clone();
379        edits.sort_by(|left, right| right.span.start.cmp(&left.span.start));
380
381        let mut out = sql.to_string();
382        for edit in edits {
383            out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
384        }
385        Some(out)
386    }
387
388    #[test]
389    fn test_bare_union_detected() {
390        let issues = check_sql("SELECT 1 UNION SELECT 2");
391        assert_eq!(issues.len(), 1);
392        assert_eq!(issues[0].code, "LINT_AM_002");
393    }
394
395    #[test]
396    fn test_union_all_ok() {
397        let issues = check_sql("SELECT 1 UNION ALL SELECT 2");
398        assert!(issues.is_empty());
399    }
400
401    #[test]
402    fn test_multiple_bare_unions() {
403        let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
404        assert_eq!(issues.len(), 2);
405    }
406
407    #[test]
408    fn test_mixed_union() {
409        let issues = check_sql("SELECT 1 UNION ALL SELECT 2 UNION SELECT 3");
410        assert_eq!(issues.len(), 1);
411    }
412
413    // --- Edge cases adopted from sqlfluff AM02 ---
414
415    #[test]
416    fn test_union_distinct_ok() {
417        let issues = check_sql("SELECT a, b FROM t1 UNION DISTINCT SELECT a, b FROM t2");
418        assert!(issues.is_empty());
419    }
420
421    #[test]
422    fn test_bare_union_in_insert() {
423        let issues = check_sql("INSERT INTO target SELECT 1 UNION SELECT 2");
424        assert_eq!(issues.len(), 1);
425    }
426
427    #[test]
428    fn test_bare_union_in_create_view() {
429        let issues = check_sql("CREATE VIEW v AS SELECT 1 UNION SELECT 2");
430        assert_eq!(issues.len(), 1);
431    }
432
433    #[test]
434    fn test_bare_union_in_cte() {
435        let issues = check_sql("WITH cte AS (SELECT 1 UNION SELECT 2) SELECT * FROM cte");
436        assert_eq!(issues.len(), 1);
437    }
438
439    #[test]
440    fn test_union_all_in_cte_ok() {
441        let issues = check_sql("WITH cte AS (SELECT 1 UNION ALL SELECT 2) SELECT * FROM cte");
442        assert!(issues.is_empty());
443    }
444
445    #[test]
446    fn test_triple_bare_union() {
447        let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
448        assert_eq!(issues.len(), 2);
449    }
450
451    #[test]
452    fn test_multiple_bare_unions_have_distinct_spans() {
453        let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
454        assert_eq!(issues.len(), 2);
455        let first_span = issues[0].span.expect("first UNION should have span");
456        let second_span = issues[1].span.expect("second UNION should have span");
457        assert!(first_span.start < second_span.start);
458    }
459
460    #[test]
461    fn test_except_and_intersect_ok() {
462        let issues = check_sql("SELECT 1 EXCEPT SELECT 2");
463        assert!(issues.is_empty());
464        let issues = check_sql("SELECT 1 INTERSECT SELECT 2");
465        assert!(issues.is_empty());
466    }
467
468    #[test]
469    fn test_union_identifier_with_underscore_does_not_steal_span() {
470        let sql = "SELECT union_col FROM t UNION SELECT 2";
471        let issues = check_sql(sql);
472        assert_eq!(issues.len(), 1);
473        let span = issues[0].span.expect("UNION issue should include a span");
474        let union_pos = sql.find("UNION").expect("query should contain UNION");
475        assert_eq!(span.start, union_pos);
476    }
477
478    #[test]
479    fn test_union_with_comments_keeps_keyword_span() {
480        let sql = "WITH cte AS (SELECT 1 /* left */ UNION /* right */ SELECT 2) SELECT * FROM cte";
481        let issues = check_sql(sql);
482        assert_eq!(issues.len(), 1);
483        let span = issues[0].span.expect("UNION issue should include a span");
484        let union_pos = sql.find("UNION").expect("query should contain UNION");
485        assert_eq!(span.start, union_pos);
486    }
487
488    #[test]
489    fn postgres_bare_union_is_allowed() {
490        // SQLFluff: test_postgres — PostgreSQL treats bare UNION as UNION DISTINCT.
491        let issues = check_sql_in_dialect(
492            "select a, b from tbl1 union select c, d from tbl2",
493            Dialect::Postgres,
494        );
495        assert!(issues.is_empty());
496    }
497
498    #[test]
499    fn test_bare_union_emits_safe_autofix_patch() {
500        let sql = "SELECT 1 UNION SELECT 2";
501        let issues = check_sql(sql);
502        assert_eq!(issues.len(), 1);
503
504        let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
505        assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
506        assert_eq!(autofix.edits.len(), 1);
507
508        let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
509        assert_eq!(fixed, "SELECT 1 UNION DISTINCT SELECT 2");
510    }
511}