rigsql_rules/ambiguous/
am02.rs1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
11pub struct RuleAM02;
12
13impl Rule for RuleAM02 {
14 fn code(&self) -> &'static str {
15 "AM02"
16 }
17 fn name(&self) -> &'static str {
18 "ambiguous.union"
19 }
20 fn description(&self) -> &'static str {
21 "UNION without DISTINCT or ALL."
22 }
23 fn explanation(&self) -> &'static str {
24 "A bare UNION (without ALL or DISTINCT) implicitly deduplicates results, \
25 which is equivalent to UNION DISTINCT. This implicit behavior can be confusing. \
26 Use UNION ALL when you want all rows, or UNION DISTINCT to make the dedup explicit."
27 }
28 fn groups(&self) -> &[RuleGroup] {
29 &[RuleGroup::Ambiguous]
30 }
31 fn is_fixable(&self) -> bool {
32 false
33 }
34
35 fn crawl_type(&self) -> CrawlType {
36 CrawlType::RootOnly
37 }
38
39 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40 let mut violations = Vec::new();
41 find_bare_unions(ctx.root, &mut violations);
42 violations
43 }
44}
45
46fn find_bare_unions(segment: &Segment, violations: &mut Vec<LintViolation>) {
47 let children = segment.children();
48
49 for (i, child) in children.iter().enumerate() {
50 if let Segment::Token(t) = child {
51 if t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("UNION")
52 {
53 let next = children[i + 1..]
55 .iter()
56 .find(|s| !s.segment_type().is_trivia());
57
58 let has_qualifier = next.is_some_and(|s| {
59 if let Segment::Token(nt) = s {
60 nt.token.text.eq_ignore_ascii_case("ALL")
61 || nt.token.text.eq_ignore_ascii_case("DISTINCT")
62 } else {
63 false
64 }
65 });
66
67 if !has_qualifier {
68 violations.push(LintViolation::with_msg_key(
69 "AM02",
70 "UNION without explicit DISTINCT or ALL.",
71 t.token.span,
72 "rules.AM02.msg",
73 vec![],
74 ));
75 }
76 }
77 }
78
79 find_bare_unions(child, violations);
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::test_utils::lint_sql;
88
89 #[test]
90 fn test_am02_flags_bare_union() {
91 let violations = lint_sql("SELECT a FROM t UNION SELECT b FROM u", RuleAM02);
92 assert_eq!(violations.len(), 1);
93 assert!(violations[0].message.contains("UNION"));
94 }
95
96 #[test]
97 fn test_am02_accepts_union_all() {
98 let violations = lint_sql("SELECT a FROM t UNION ALL SELECT b FROM u", RuleAM02);
99 assert_eq!(violations.len(), 0);
100 }
101
102 #[test]
103 fn test_am02_accepts_union_distinct() {
104 let violations = lint_sql("SELECT a FROM t UNION DISTINCT SELECT b FROM u", RuleAM02);
105 assert_eq!(violations.len(), 0);
106 }
107}