flowscope_core/linter/rules/
am_007.rs1use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Issue};
8use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor, UpdateTableFromKind};
9use std::collections::{HashMap, HashSet};
10
11use super::column_count_helpers::{
12 build_query_cte_map, resolve_set_expr_output_columns, CteColumnCounts,
13};
14
15pub struct AmbiguousSetColumns;
16
17#[derive(Default)]
18struct SetCountStats {
19 counts: HashSet<usize>,
20 fully_resolved: bool,
21}
22
23impl LintRule for AmbiguousSetColumns {
24 fn code(&self) -> &'static str {
25 issue_codes::LINT_AM_007
26 }
27
28 fn name(&self) -> &'static str {
29 "Ambiguous set columns"
30 }
31
32 fn description(&self) -> &'static str {
33 "Queries within set query produce different numbers of columns."
34 }
35
36 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
37 let mut violation_count = 0usize;
38 lint_statement_set_ops(statement, &HashMap::new(), &mut violation_count);
39
40 (0..violation_count)
41 .map(|_| {
42 Issue::warning(
43 issue_codes::LINT_AM_007,
44 "Set operation branches resolve to different column counts.",
45 )
46 .with_statement(ctx.statement_index)
47 })
48 .collect()
49 }
50}
51
52fn lint_statement_set_ops(
53 statement: &Statement,
54 outer_ctes: &CteColumnCounts,
55 violations: &mut usize,
56) {
57 match statement {
58 Statement::Query(query) => lint_query_set_ops(query, outer_ctes, violations),
59 Statement::Insert(insert) => {
60 if let Some(source) = &insert.source {
61 lint_query_set_ops(source, outer_ctes, violations);
62 }
63 }
64 Statement::CreateView { query, .. } => lint_query_set_ops(query, outer_ctes, violations),
65 Statement::CreateTable(create) => {
66 if let Some(query) = &create.query {
67 lint_query_set_ops(query, outer_ctes, violations);
68 }
69 }
70 Statement::Update {
71 from: Some(from_kind),
72 ..
73 } => {
74 let tables = match from_kind {
75 UpdateTableFromKind::BeforeSet(t) | UpdateTableFromKind::AfterSet(t) => t,
76 };
77 for twj in tables {
78 lint_table_factor_set_ops(&twj.relation, outer_ctes, violations);
79 for join in &twj.joins {
80 lint_table_factor_set_ops(&join.relation, outer_ctes, violations);
81 }
82 }
83 }
84 _ => {}
85 }
86}
87
88fn lint_query_set_ops(query: &Query, outer_ctes: &CteColumnCounts, violations: &mut usize) {
89 let ctes = build_query_cte_map(query, outer_ctes);
90 lint_set_expr_set_ops(&query.body, &ctes, violations);
91}
92
93fn lint_set_expr_set_ops(set_expr: &SetExpr, ctes: &CteColumnCounts, violations: &mut usize) {
94 match set_expr {
95 SetExpr::SetOperation { left, right, .. } => {
96 let stats = collect_set_branch_counts(set_expr, ctes);
97 if stats.fully_resolved && stats.counts.len() > 1 {
98 *violations += 1;
99 }
100
101 lint_set_expr_set_ops(left, ctes, violations);
102 lint_set_expr_set_ops(right, ctes, violations);
103 }
104 SetExpr::Query(query) => lint_query_set_ops(query, ctes, violations),
105 SetExpr::Select(select) => lint_select_subqueries_set_ops(select, ctes, violations),
106 SetExpr::Insert(statement)
107 | SetExpr::Update(statement)
108 | SetExpr::Delete(statement)
109 | SetExpr::Merge(statement) => lint_statement_set_ops(statement, ctes, violations),
110 _ => {}
111 }
112}
113
114fn lint_select_subqueries_set_ops(select: &Select, ctes: &CteColumnCounts, violations: &mut usize) {
115 for table in &select.from {
116 lint_table_factor_set_ops(&table.relation, ctes, violations);
117 for join in &table.joins {
118 lint_table_factor_set_ops(&join.relation, ctes, violations);
119 }
120 }
121}
122
123fn lint_table_factor_set_ops(
124 table_factor: &TableFactor,
125 ctes: &CteColumnCounts,
126 violations: &mut usize,
127) {
128 match table_factor {
129 TableFactor::Derived { subquery, .. } => lint_query_set_ops(subquery, ctes, violations),
130 TableFactor::NestedJoin {
131 table_with_joins, ..
132 } => {
133 lint_table_factor_set_ops(&table_with_joins.relation, ctes, violations);
134 for join in &table_with_joins.joins {
135 lint_table_factor_set_ops(&join.relation, ctes, violations);
136 }
137 }
138 TableFactor::Pivot { table, .. }
139 | TableFactor::Unpivot { table, .. }
140 | TableFactor::MatchRecognize { table, .. } => {
141 lint_table_factor_set_ops(table, ctes, violations)
142 }
143 _ => {}
144 }
145}
146
147fn collect_set_branch_counts(set_expr: &SetExpr, ctes: &CteColumnCounts) -> SetCountStats {
148 match set_expr {
149 SetExpr::SetOperation { left, right, .. } => {
150 let left_stats = collect_set_branch_counts(left, ctes);
151 let right_stats = collect_set_branch_counts(right, ctes);
152
153 let mut counts = left_stats.counts;
154 counts.extend(right_stats.counts);
155
156 SetCountStats {
157 counts,
158 fully_resolved: left_stats.fully_resolved && right_stats.fully_resolved,
159 }
160 }
161 _ => {
162 if let Some(count) = resolve_set_expr_output_columns(set_expr, ctes) {
163 let mut counts = HashSet::new();
164 counts.insert(count);
165 SetCountStats {
166 counts,
167 fully_resolved: true,
168 }
169 } else {
170 SetCountStats {
171 counts: HashSet::new(),
172 fully_resolved: false,
173 }
174 }
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::parser::parse_sql;
183
184 fn run(sql: &str) -> Vec<Issue> {
185 let statements = parse_sql(sql).expect("parse");
186 let rule = AmbiguousSetColumns;
187 statements
188 .iter()
189 .enumerate()
190 .flat_map(|(index, statement)| {
191 rule.check(
192 statement,
193 &LintContext {
194 sql,
195 statement_range: 0..sql.len(),
196 statement_index: index,
197 },
198 )
199 })
200 .collect()
201 }
202
203 #[test]
206 fn flags_known_set_column_count_mismatch() {
207 let issues = run("select a from t union all select c, d from k");
208 assert_eq!(issues.len(), 1);
209 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
210 }
211
212 #[test]
213 fn allows_known_set_column_count_match() {
214 let issues = run("select a, b from t union all select c, d from k");
215 assert!(issues.is_empty());
216 }
217
218 #[test]
219 fn resolves_cte_wildcard_columns_for_set_comparison() {
220 let issues =
221 run("with cte as (select a, b from t) select * from cte union select c, d from t2");
222 assert!(issues.is_empty());
223 }
224
225 #[test]
226 fn resolves_declared_cte_columns_for_set_comparison() {
227 let issues =
228 run("with cte(a, b) as (select * from t) select * from cte union select c, d from t2");
229 assert!(issues.is_empty());
230 }
231
232 #[test]
233 fn resolves_declared_derived_alias_columns_for_set_comparison() {
234 let issues = run(
235 "select t_alias.* from (select * from t) as t_alias(a, b) union select c, d from t2",
236 );
237 assert!(issues.is_empty());
238 }
239
240 #[test]
241 fn flags_resolved_cte_wildcard_mismatch() {
242 let issues =
243 run("with cte as (select a, b, c from t) select * from cte union select d, e from t2");
244 assert_eq!(issues.len(), 1);
245 }
246
247 #[test]
248 fn flags_declared_cte_width_mismatch_for_set_comparison() {
249 let issues = run(
250 "with cte(a, b, c) as (select * from t) select * from cte union select d, e from t2",
251 );
252 assert_eq!(issues.len(), 1);
253 }
254
255 #[test]
256 fn flags_declared_derived_alias_width_mismatch_for_set_comparison() {
257 let issues = run(
258 "select t_alias.* from (select * from t) as t_alias(a, b, c) union select d, e from t2",
259 );
260 assert_eq!(issues.len(), 1);
261 }
262
263 #[test]
264 fn unresolved_external_wildcard_does_not_trigger() {
265 let issues = run("select a from t1 union all select * from t2");
266 assert!(issues.is_empty());
267 }
268
269 #[test]
270 fn resolves_derived_alias_wildcard() {
271 let issues = run(
272 "select t_alias.* from t2 join (select a from t) as t_alias using (a) union select b from t3",
273 );
274 assert!(issues.is_empty());
275 }
276
277 #[test]
278 fn resolves_nested_with_wildcard_for_set_comparison() {
279 let issues = run(
280 "SELECT * FROM (WITH cte2 AS (SELECT a, b FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
281 );
282 assert!(issues.is_empty());
283 }
284
285 #[test]
286 fn flags_nested_with_wildcard_mismatch_for_set_comparison() {
287 let issues = run(
288 "SELECT * FROM (WITH cte2 AS (SELECT a FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
289 );
290 assert_eq!(issues.len(), 1);
291 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
292 }
293
294 #[test]
295 fn resolves_nested_cte_chain_for_set_comparison() {
296 let issues = run(
297 "with a as (with b as (select 1 from c) select * from b) select * from a union all select k from t2",
298 );
299 assert!(issues.is_empty());
300 }
301
302 #[test]
303 fn resolves_nested_join_alias_wildcard_for_set_comparison() {
304 let issues = run(
305 "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x, y from t3",
306 );
307 assert!(issues.is_empty());
308 }
309
310 #[test]
311 fn flags_nested_join_alias_wildcard_set_mismatch_when_resolved() {
312 let issues = run(
313 "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x from t3",
314 );
315 assert_eq!(issues.len(), 1);
316 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
317 }
318
319 #[test]
320 fn resolves_nested_join_alias_using_width_for_set_comparison() {
321 let issues = run(
322 "select j.* from ((select a from t1) as a1 join (select a from t2) as b1 using(a)) as j union all select x from t3",
323 );
324 assert!(issues.is_empty());
325 }
326
327 #[test]
328 fn resolves_natural_join_nested_alias_width_for_set_comparison() {
329 let issues = run(
330 "select j.* from ((select a from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
331 );
332 assert!(issues.is_empty());
333 }
334
335 #[test]
336 fn natural_join_nested_alias_width_unknown_does_not_trigger_for_set_comparison() {
337 let issues = run(
338 "select j.* from ((select * from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
339 );
340 assert!(issues.is_empty());
341 }
342
343 #[test]
344 fn update_from_with_set_column_mismatch() {
345 let sql = "UPDATE sometable SET sometable.baz = mycte.bar FROM (SELECT foo, bar FROM mytable1 UNION ALL SELECT bar FROM mytable2) as k";
347 let issues = run(sql);
348 assert_eq!(issues.len(), 1);
349 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
350 }
351}