Skip to main content

rigsql_rules/ambiguous/
am02.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// AM02: UNION without DISTINCT or ALL is ambiguous.
7///
8/// Bare UNION implicitly means UNION DISTINCT, but this should be made
9/// explicit to avoid confusion.
10#[derive(Debug, Default)]
11pub struct RuleAM02;
12
13impl Rule for RuleAM02 {
14    fn code(&self) -> &'static str {
15        "AM02"
16    }
17    fn name(&self) -> &'static str {
18        "ambiguous.union"
19    }
20    fn description(&self) -> &'static str {
21        "UNION without DISTINCT or ALL."
22    }
23    fn explanation(&self) -> &'static str {
24        "A bare UNION (without ALL or DISTINCT) implicitly deduplicates results, \
25         which is equivalent to UNION DISTINCT. This implicit behavior can be confusing. \
26         Use UNION ALL when you want all rows, or UNION DISTINCT to make the dedup explicit."
27    }
28    fn groups(&self) -> &[RuleGroup] {
29        &[RuleGroup::Ambiguous]
30    }
31    fn is_fixable(&self) -> bool {
32        false
33    }
34
35    fn crawl_type(&self) -> CrawlType {
36        CrawlType::RootOnly
37    }
38
39    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
40        let mut violations = Vec::new();
41        find_bare_unions(ctx.root, &mut violations);
42        violations
43    }
44}
45
46fn find_bare_unions(segment: &Segment, violations: &mut Vec<LintViolation>) {
47    let children = segment.children();
48
49    for (i, child) in children.iter().enumerate() {
50        if let Segment::Token(t) = child {
51            if t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("UNION")
52            {
53                // Check if the next non-trivia sibling is ALL or DISTINCT
54                let next = children[i + 1..]
55                    .iter()
56                    .find(|s| !s.segment_type().is_trivia());
57
58                let has_qualifier = next.is_some_and(|s| {
59                    if let Segment::Token(nt) = s {
60                        nt.token.text.eq_ignore_ascii_case("ALL")
61                            || nt.token.text.eq_ignore_ascii_case("DISTINCT")
62                    } else {
63                        false
64                    }
65                });
66
67                if !has_qualifier {
68                    violations.push(LintViolation::with_msg_key(
69                        "AM02",
70                        "UNION without explicit DISTINCT or ALL.",
71                        t.token.span,
72                        "rules.AM02.msg",
73                        vec![],
74                    ));
75                }
76            }
77        }
78
79        // Recurse into children
80        find_bare_unions(child, violations);
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::test_utils::lint_sql;
88
89    #[test]
90    fn test_am02_flags_bare_union() {
91        let violations = lint_sql("SELECT a FROM t UNION SELECT b FROM u", RuleAM02);
92        assert_eq!(violations.len(), 1);
93        assert!(violations[0].message.contains("UNION"));
94    }
95
96    #[test]
97    fn test_am02_accepts_union_all() {
98        let violations = lint_sql("SELECT a FROM t UNION ALL SELECT b FROM u", RuleAM02);
99        assert_eq!(violations.len(), 0);
100    }
101
102    #[test]
103    fn test_am02_accepts_union_distinct() {
104        let violations = lint_sql("SELECT a FROM t UNION DISTINCT SELECT b FROM u", RuleAM02);
105        assert_eq!(violations.len(), 0);
106    }
107}