rigsql_rules/ambiguous/
am07.rs1use rigsql_core::{Segment, SegmentType};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::LintViolation;
5
6#[derive(Debug, Default)]
10pub struct RuleAM07;
11
12impl Rule for RuleAM07 {
13 fn code(&self) -> &'static str {
14 "AM07"
15 }
16 fn name(&self) -> &'static str {
17 "ambiguous.set_column_count"
18 }
19 fn description(&self) -> &'static str {
20 "Set operation column count mismatch."
21 }
22 fn explanation(&self) -> &'static str {
23 "UNION, INTERSECT, and EXCEPT operations require each branch to have the same \
24 number of columns. A mismatch will cause a runtime error in most databases. \
25 This rule checks that each branch has a consistent number of select items."
26 }
27 fn groups(&self) -> &[RuleGroup] {
28 &[RuleGroup::Ambiguous]
29 }
30 fn is_fixable(&self) -> bool {
31 false
32 }
33
34 fn crawl_type(&self) -> CrawlType {
35 CrawlType::RootOnly
36 }
37
38 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
39 let mut violations = Vec::new();
40 check_set_operations(ctx.root, &mut violations);
41 violations
42 }
43}
44
45fn check_set_operations(segment: &Segment, violations: &mut Vec<LintViolation>) {
46 let children = segment.children();
47
48 let has_set_op = children.iter().any(|c| {
50 if let Segment::Token(t) = c {
51 t.segment_type == SegmentType::Keyword
52 && (t.token.text.eq_ignore_ascii_case("UNION")
53 || t.token.text.eq_ignore_ascii_case("INTERSECT")
54 || t.token.text.eq_ignore_ascii_case("EXCEPT"))
55 } else {
56 false
57 }
58 });
59
60 if has_set_op {
61 let mut select_item_counts = Vec::new();
63
64 for child in children {
65 if child.segment_type() == SegmentType::SelectStatement
66 || child.segment_type() == SegmentType::SelectClause
67 {
68 if let Some(count) = count_select_items(child) {
69 select_item_counts.push((child.span(), count));
70 }
71 }
72 }
73
74 if segment.segment_type() == SegmentType::SelectStatement {
77 let direct_clause = children
78 .iter()
79 .find(|c| c.segment_type() == SegmentType::SelectClause);
80 if let Some(clause) = direct_clause {
81 let count = count_clause_items(clause);
82 if count > 0 {
83 select_item_counts.insert(0, (clause.span(), count));
84 }
85 }
86 }
87
88 if select_item_counts.len() >= 2 {
89 let first_count = select_item_counts[0].1;
90 for (span, count) in &select_item_counts[1..] {
91 if *count != first_count {
92 violations.push(LintViolation::with_msg_key(
93 "AM07",
94 format!(
95 "Set operation column count mismatch: expected {} but found {}.",
96 first_count, count
97 ),
98 *span,
99 "rules.AM07.msg",
100 vec![
101 ("expected".to_string(), first_count.to_string()),
102 ("found".to_string(), count.to_string()),
103 ],
104 ));
105 }
106 }
107 }
108 }
109
110 for child in children {
112 check_set_operations(child, violations);
113 }
114}
115
116fn count_select_items(segment: &Segment) -> Option<usize> {
118 if segment.segment_type() == SegmentType::SelectClause {
119 return Some(count_clause_items(segment));
120 }
121
122 for child in segment.children() {
123 if child.segment_type() == SegmentType::SelectClause {
124 return Some(count_clause_items(child));
125 }
126 }
127 None
128}
129
130fn count_clause_items(clause: &Segment) -> usize {
132 let commas = clause
133 .children()
134 .iter()
135 .filter(|c| c.segment_type() == SegmentType::Comma)
136 .count();
137 commas + 1
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::test_utils::lint_sql;
144
145 #[test]
146 fn test_am07_accepts_matching_columns() {
147 let violations = lint_sql("SELECT a, b FROM t UNION ALL SELECT c, d FROM u", RuleAM07);
148 assert_eq!(violations.len(), 0);
149 }
150
151 #[test]
152 fn test_am07_accepts_single_select() {
153 let violations = lint_sql("SELECT a, b FROM t", RuleAM07);
154 assert_eq!(violations.len(), 0);
155 }
156}