Skip to main content

rigsql_rules/ambiguous/
am07.rs

1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6/// AM07: UNION/INTERSECT/EXCEPT branches should have matching column counts.
7///
8/// Checks that set operations have the same number of select items on each side.
9#[derive(Debug, Default)]
10pub struct RuleAM07;
11
12impl Rule for RuleAM07 {
13    fn code(&self) -> &'static str {
14        "AM07"
15    }
16    fn name(&self) -> &'static str {
17        "ambiguous.set_column_count"
18    }
19    fn description(&self) -> &'static str {
20        "Set operation column count mismatch."
21    }
22    fn explanation(&self) -> &'static str {
23        "UNION, INTERSECT, and EXCEPT operations require each branch to have the same \
24         number of columns. A mismatch will cause a runtime error in most databases. \
25         This rule checks that each branch has a consistent number of select items."
26    }
27    fn groups(&self) -> &[RuleGroup] {
28        &[RuleGroup::Ambiguous]
29    }
30    fn is_fixable(&self) -> bool {
31        false
32    }
33
34    fn crawl_type(&self) -> CrawlType {
35        CrawlType::RootOnly
36    }
37
38    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39        let mut violations = Vec::new();
40        check_set_operations(ctx.root, &mut violations);
41        violations
42    }
43}
44
45fn check_set_operations(segment: &Segment, violations: &mut Vec<LintViolation>) {
46    let children = segment.children();
47
48    // Check if this node has set operation keywords among its children
49    let has_set_op = children.iter().any(|c| {
50        if let Segment::Token(t) = c {
51            t.segment_type == SegmentType::Keyword
52                && (t.token.text.eq_ignore_ascii_case("UNION")
53                    || t.token.text.eq_ignore_ascii_case("INTERSECT")
54                    || t.token.text.eq_ignore_ascii_case("EXCEPT"))
55        } else {
56            false
57        }
58    });
59
60    if has_set_op {
61        // Collect SELECT clauses from SelectStatement children at this level
62        let mut select_item_counts = Vec::new();
63
64        for child in children {
65            if child.segment_type() == SegmentType::SelectStatement
66                || child.segment_type() == SegmentType::SelectClause
67            {
68                if let Some(count) = count_select_items(child) {
69                    select_item_counts.push((child.span(), count));
70                }
71            }
72        }
73
74        // Also check if this segment itself starts with a SelectClause
75        // (for the first branch before the UNION keyword)
76        if segment.segment_type() == SegmentType::SelectStatement {
77            let direct_clause = children
78                .iter()
79                .find(|c| c.segment_type() == SegmentType::SelectClause);
80            if let Some(clause) = direct_clause {
81                let count = count_clause_items(clause);
82                if count > 0 {
83                    select_item_counts.insert(0, (clause.span(), count));
84                }
85            }
86        }
87
88        if select_item_counts.len() >= 2 {
89            let first_count = select_item_counts[0].1;
90            for (span, count) in &select_item_counts[1..] {
91                if *count != first_count {
92                    violations.push(LintViolation::with_msg_key(
93                        "AM07",
94                        format!(
95                            "Set operation column count mismatch: expected {} but found {}.",
96                            first_count, count
97                        ),
98                        *span,
99                        "rules.AM07.msg",
100                        vec![
101                            ("expected".to_string(), first_count.to_string()),
102                            ("found".to_string(), count.to_string()),
103                        ],
104                    ));
105                }
106            }
107        }
108    }
109
110    // Recurse
111    for child in children {
112        check_set_operations(child, violations);
113    }
114}
115
116/// Count select items in a SelectStatement by finding its SelectClause.
117fn count_select_items(segment: &Segment) -> Option<usize> {
118    if segment.segment_type() == SegmentType::SelectClause {
119        return Some(count_clause_items(segment));
120    }
121
122    for child in segment.children() {
123        if child.segment_type() == SegmentType::SelectClause {
124            return Some(count_clause_items(child));
125        }
126    }
127    None
128}
129
130/// Count items in a SelectClause by counting commas + 1.
131fn count_clause_items(clause: &Segment) -> usize {
132    let commas = clause
133        .children()
134        .iter()
135        .filter(|c| c.segment_type() == SegmentType::Comma)
136        .count();
137    commas + 1
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::test_utils::lint_sql;
144
145    #[test]
146    fn test_am07_accepts_matching_columns() {
147        let violations = lint_sql("SELECT a, b FROM t UNION ALL SELECT c, d FROM u", RuleAM07);
148        assert_eq!(violations.len(), 0);
149    }
150
151    #[test]
152    fn test_am07_accepts_single_select() {
153        let violations = lint_sql("SELECT a, b FROM t", RuleAM07);
154        assert_eq!(violations.len(), 0);
155    }
156}