flowscope_core/linter/rules/
am_002.rs1use crate::linter::rule::{LintContext, LintRule};
6use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
7use sqlparser::ast::*;
8use sqlparser::keywords::Keyword;
9use sqlparser::tokenizer::{Location, Span, Token, TokenWithSpan, Tokenizer};
10
11pub struct BareUnion;
12
13impl LintRule for BareUnion {
14 fn code(&self) -> &'static str {
15 issue_codes::LINT_AM_002
16 }
17
18 fn name(&self) -> &'static str {
19 "Ambiguous UNION quantifier"
20 }
21
22 fn description(&self) -> &'static str {
23 "'UNION [DISTINCT|ALL]' is preferred over just 'UNION'."
24 }
25
26 fn check(&self, stmt: &Statement, ctx: &LintContext) -> Vec<Issue> {
27 let mut issues = Vec::new();
28 let mut unions = union_keyword_ranges_for_context(ctx);
29 match stmt {
30 Statement::Query(query) => check_query(query, &mut unions, ctx, &mut issues),
31 Statement::Insert(insert) => {
32 if let Some(ref source) = insert.source {
33 check_query(source, &mut unions, ctx, &mut issues);
34 }
35 }
36 Statement::CreateView { query, .. } => {
37 check_query(query, &mut unions, ctx, &mut issues)
38 }
39 Statement::CreateTable(create) => {
40 if let Some(ref query) = create.query {
41 check_query(query, &mut unions, ctx, &mut issues);
42 }
43 }
44 _ => {}
45 }
46 issues
47 }
48}
49
50fn union_keyword_ranges_for_context(ctx: &LintContext) -> UnionKeywordRanges {
51 let tokens = tokenized_for_context(ctx);
52 union_keyword_ranges(ctx.statement_sql(), ctx.dialect(), tokens.as_deref())
53}
54
55fn check_query(
56 query: &Query,
57 unions: &mut UnionKeywordRanges,
58 ctx: &LintContext,
59 issues: &mut Vec<Issue>,
60) {
61 if let Some(ref with) = query.with {
62 for cte in &with.cte_tables {
63 check_query(&cte.query, unions, ctx, issues);
64 }
65 }
66 check_query_body(&query.body, unions, ctx, issues);
67}
68
69fn check_query_body(
70 body: &SetExpr,
71 unions: &mut UnionKeywordRanges,
72 ctx: &LintContext,
73 issues: &mut Vec<Issue>,
74) {
75 match body {
76 SetExpr::SetOperation {
77 op: SetOperator::Union,
78 set_quantifier,
79 left,
80 right,
81 } => {
82 check_query_body(left, unions, ctx, issues);
83 let union_span = unions.next();
84
85 if matches!(set_quantifier, SetQuantifier::None | SetQuantifier::ByName)
86 && !matches!(ctx.dialect(), Dialect::Postgres)
89 {
90 let mut issue = Issue::warning(
91 issue_codes::LINT_AM_002,
92 "Use UNION DISTINCT or UNION ALL instead of bare UNION.",
93 )
94 .with_statement(ctx.statement_index);
95 if let Some((start, end)) = union_span {
96 let span = ctx.span_from_statement_offset(start, end);
97 let union_keyword = &ctx.statement_sql()[start..end];
98 let distinct = if union_keyword == union_keyword.to_ascii_lowercase() {
99 "distinct"
100 } else {
101 "DISTINCT"
102 };
103 issue = issue.with_span(span).with_autofix_edits(
104 IssueAutofixApplicability::Safe,
105 vec![IssuePatchEdit::new(
106 span,
107 format!("{union_keyword} {distinct}"),
108 )],
109 );
110 }
111 issues.push(issue);
112 }
113 check_query_body(right, unions, ctx, issues);
114 }
115 SetExpr::SetOperation { left, right, .. } => {
116 check_query_body(left, unions, ctx, issues);
117 check_query_body(right, unions, ctx, issues);
118 }
119 SetExpr::Select(_) => {}
120 SetExpr::Query(q) => {
121 check_query(q, unions, ctx, issues);
122 }
123 _ => {}
124 }
125}
126
127struct UnionKeywordRanges {
128 ranges: Vec<(usize, usize)>,
129 index: usize,
130}
131
132impl UnionKeywordRanges {
133 fn next(&mut self) -> Option<(usize, usize)> {
134 let range = self.ranges.get(self.index).copied();
135 if range.is_some() {
136 self.index += 1;
137 }
138 range
139 }
140}
141
142fn union_keyword_ranges(
143 sql: &str,
144 dialect: Dialect,
145 tokens: Option<&[TokenWithSpan]>,
146) -> UnionKeywordRanges {
147 let owned_tokens;
148 let tokens = if let Some(tokens) = tokens {
149 tokens
150 } else {
151 owned_tokens = match tokenized(sql, dialect) {
152 Some(tokens) => tokens,
153 None => {
154 return UnionKeywordRanges {
155 ranges: Vec::new(),
156 index: 0,
157 };
158 }
159 };
160 &owned_tokens
161 };
162
163 let ranges = tokens
164 .iter()
165 .filter_map(|token| {
166 let Token::Word(word) = &token.token else {
167 return None;
168 };
169 if word.keyword != Keyword::UNION {
170 return None;
171 }
172
173 token_offsets(sql, token)
174 })
175 .collect();
176
177 UnionKeywordRanges { ranges, index: 0 }
178}
179
180fn tokenized(sql: &str, dialect: Dialect) -> Option<Vec<TokenWithSpan>> {
181 let dialect = dialect.to_sqlparser_dialect();
182 let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
183 tokenizer.tokenize_with_location().ok()
184}
185
186fn tokenized_for_context(ctx: &LintContext) -> Option<Vec<TokenWithSpan>> {
187 let (statement_start_line, statement_start_column) =
188 offset_to_line_col(ctx.sql, ctx.statement_range.start)?;
189
190 ctx.with_document_tokens(|tokens| {
191 if tokens.is_empty() {
192 return None;
193 }
194
195 let mut out = Vec::new();
196 for token in tokens {
197 let Some((start, end)) = token_offsets(ctx.sql, token) else {
198 continue;
199 };
200 if start < ctx.statement_range.start || end > ctx.statement_range.end {
201 continue;
202 }
203
204 let Some(start_loc) = relative_location(
205 token.span.start,
206 statement_start_line,
207 statement_start_column,
208 ) else {
209 continue;
210 };
211 let Some(end_loc) =
212 relative_location(token.span.end, statement_start_line, statement_start_column)
213 else {
214 continue;
215 };
216
217 out.push(TokenWithSpan::new(
218 token.token.clone(),
219 Span::new(start_loc, end_loc),
220 ));
221 }
222
223 if out.is_empty() {
224 None
225 } else {
226 Some(out)
227 }
228 })
229}
230
231fn token_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
232 let start = line_col_to_offset(
233 sql,
234 token.span.start.line as usize,
235 token.span.start.column as usize,
236 )?;
237 let end = line_col_to_offset(
238 sql,
239 token.span.end.line as usize,
240 token.span.end.column as usize,
241 )?;
242 Some((start, end))
243}
244
245fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
246 if line == 0 || column == 0 {
247 return None;
248 }
249
250 let mut current_line = 1usize;
251 let mut current_col = 1usize;
252
253 for (offset, ch) in sql.char_indices() {
254 if current_line == line && current_col == column {
255 return Some(offset);
256 }
257
258 if ch == '\n' {
259 current_line += 1;
260 current_col = 1;
261 } else {
262 current_col += 1;
263 }
264 }
265
266 if current_line == line && current_col == column {
267 return Some(sql.len());
268 }
269
270 None
271}
272
273fn offset_to_line_col(sql: &str, offset: usize) -> Option<(usize, usize)> {
274 if offset > sql.len() {
275 return None;
276 }
277 if offset == sql.len() {
278 let mut line = 1usize;
279 let mut column = 1usize;
280 for ch in sql.chars() {
281 if ch == '\n' {
282 line += 1;
283 column = 1;
284 } else {
285 column += 1;
286 }
287 }
288 return Some((line, column));
289 }
290
291 let mut line = 1usize;
292 let mut column = 1usize;
293 for (index, ch) in sql.char_indices() {
294 if index == offset {
295 return Some((line, column));
296 }
297 if ch == '\n' {
298 line += 1;
299 column = 1;
300 } else {
301 column += 1;
302 }
303 }
304
305 None
306}
307
308fn relative_location(
309 location: Location,
310 statement_start_line: usize,
311 statement_start_column: usize,
312) -> Option<Location> {
313 let line = location.line as usize;
314 let column = location.column as usize;
315 if line < statement_start_line {
316 return None;
317 }
318
319 if line == statement_start_line {
320 if column < statement_start_column {
321 return None;
322 }
323 return Some(Location::new(
324 1,
325 (column - statement_start_column + 1) as u64,
326 ));
327 }
328
329 Some(Location::new(
330 (line - statement_start_line + 1) as u64,
331 column as u64,
332 ))
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::linter::rule::with_active_dialect;
339 use crate::parser::{parse_sql, parse_sql_with_dialect};
340 use crate::types::IssueAutofixApplicability;
341
342 fn check_sql(sql: &str) -> Vec<Issue> {
343 let stmts = parse_sql(sql).unwrap();
344 let rule = BareUnion;
345 let ctx = LintContext {
346 sql,
347 statement_range: 0..sql.len(),
348 statement_index: 0,
349 };
350 let mut issues = Vec::new();
351 for stmt in &stmts {
352 issues.extend(rule.check(stmt, &ctx));
353 }
354 issues
355 }
356
357 fn check_sql_in_dialect(sql: &str, dialect: Dialect) -> Vec<Issue> {
358 let stmts = parse_sql_with_dialect(sql, dialect).unwrap();
359 let rule = BareUnion;
360 let mut issues = Vec::new();
361 with_active_dialect(dialect, || {
362 for stmt in &stmts {
363 issues.extend(rule.check(
364 stmt,
365 &LintContext {
366 sql,
367 statement_range: 0..sql.len(),
368 statement_index: 0,
369 },
370 ));
371 }
372 });
373 issues
374 }
375
376 fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
377 let autofix = issue.autofix.as_ref()?;
378 let mut edits = autofix.edits.clone();
379 edits.sort_by(|left, right| right.span.start.cmp(&left.span.start));
380
381 let mut out = sql.to_string();
382 for edit in edits {
383 out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
384 }
385 Some(out)
386 }
387
388 #[test]
389 fn test_bare_union_detected() {
390 let issues = check_sql("SELECT 1 UNION SELECT 2");
391 assert_eq!(issues.len(), 1);
392 assert_eq!(issues[0].code, "LINT_AM_002");
393 }
394
395 #[test]
396 fn test_union_all_ok() {
397 let issues = check_sql("SELECT 1 UNION ALL SELECT 2");
398 assert!(issues.is_empty());
399 }
400
401 #[test]
402 fn test_multiple_bare_unions() {
403 let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
404 assert_eq!(issues.len(), 2);
405 }
406
407 #[test]
408 fn test_mixed_union() {
409 let issues = check_sql("SELECT 1 UNION ALL SELECT 2 UNION SELECT 3");
410 assert_eq!(issues.len(), 1);
411 }
412
413 #[test]
416 fn test_union_distinct_ok() {
417 let issues = check_sql("SELECT a, b FROM t1 UNION DISTINCT SELECT a, b FROM t2");
418 assert!(issues.is_empty());
419 }
420
421 #[test]
422 fn test_bare_union_in_insert() {
423 let issues = check_sql("INSERT INTO target SELECT 1 UNION SELECT 2");
424 assert_eq!(issues.len(), 1);
425 }
426
427 #[test]
428 fn test_bare_union_in_create_view() {
429 let issues = check_sql("CREATE VIEW v AS SELECT 1 UNION SELECT 2");
430 assert_eq!(issues.len(), 1);
431 }
432
433 #[test]
434 fn test_bare_union_in_cte() {
435 let issues = check_sql("WITH cte AS (SELECT 1 UNION SELECT 2) SELECT * FROM cte");
436 assert_eq!(issues.len(), 1);
437 }
438
439 #[test]
440 fn test_union_all_in_cte_ok() {
441 let issues = check_sql("WITH cte AS (SELECT 1 UNION ALL SELECT 2) SELECT * FROM cte");
442 assert!(issues.is_empty());
443 }
444
445 #[test]
446 fn test_triple_bare_union() {
447 let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
448 assert_eq!(issues.len(), 2);
449 }
450
451 #[test]
452 fn test_multiple_bare_unions_have_distinct_spans() {
453 let issues = check_sql("SELECT 1 UNION SELECT 2 UNION SELECT 3");
454 assert_eq!(issues.len(), 2);
455 let first_span = issues[0].span.expect("first UNION should have span");
456 let second_span = issues[1].span.expect("second UNION should have span");
457 assert!(first_span.start < second_span.start);
458 }
459
460 #[test]
461 fn test_except_and_intersect_ok() {
462 let issues = check_sql("SELECT 1 EXCEPT SELECT 2");
463 assert!(issues.is_empty());
464 let issues = check_sql("SELECT 1 INTERSECT SELECT 2");
465 assert!(issues.is_empty());
466 }
467
468 #[test]
469 fn test_union_identifier_with_underscore_does_not_steal_span() {
470 let sql = "SELECT union_col FROM t UNION SELECT 2";
471 let issues = check_sql(sql);
472 assert_eq!(issues.len(), 1);
473 let span = issues[0].span.expect("UNION issue should include a span");
474 let union_pos = sql.find("UNION").expect("query should contain UNION");
475 assert_eq!(span.start, union_pos);
476 }
477
478 #[test]
479 fn test_union_with_comments_keeps_keyword_span() {
480 let sql = "WITH cte AS (SELECT 1 /* left */ UNION /* right */ SELECT 2) SELECT * FROM cte";
481 let issues = check_sql(sql);
482 assert_eq!(issues.len(), 1);
483 let span = issues[0].span.expect("UNION issue should include a span");
484 let union_pos = sql.find("UNION").expect("query should contain UNION");
485 assert_eq!(span.start, union_pos);
486 }
487
488 #[test]
489 fn postgres_bare_union_is_allowed() {
490 let issues = check_sql_in_dialect(
492 "select a, b from tbl1 union select c, d from tbl2",
493 Dialect::Postgres,
494 );
495 assert!(issues.is_empty());
496 }
497
498 #[test]
499 fn test_bare_union_emits_safe_autofix_patch() {
500 let sql = "SELECT 1 UNION SELECT 2";
501 let issues = check_sql(sql);
502 assert_eq!(issues.len(), 1);
503
504 let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
505 assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
506 assert_eq!(autofix.edits.len(), 1);
507
508 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
509 assert_eq!(fixed, "SELECT 1 UNION DISTINCT SELECT 2");
510 }
511}