flowscope_core/linter/rules/
am_007.rs1use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Issue};
8use sqlparser::ast::{
9 CreateView, Query, Select, SetExpr, Statement, TableFactor, Update, UpdateTableFromKind,
10};
11use std::collections::{HashMap, HashSet};
12
13use super::column_count_helpers::{
14 build_query_cte_map, resolve_set_expr_output_columns, CteColumnCounts,
15};
16
17pub struct AmbiguousSetColumns;
18
19#[derive(Default)]
20struct SetCountStats {
21 counts: HashSet<usize>,
22 fully_resolved: bool,
23}
24
25impl LintRule for AmbiguousSetColumns {
26 fn code(&self) -> &'static str {
27 issue_codes::LINT_AM_007
28 }
29
30 fn name(&self) -> &'static str {
31 "Ambiguous set columns"
32 }
33
34 fn description(&self) -> &'static str {
35 "Queries within set query produce different numbers of columns."
36 }
37
38 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
39 let mut violation_count = 0usize;
40 lint_statement_set_ops(statement, &HashMap::new(), &mut violation_count);
41
42 (0..violation_count)
43 .map(|_| {
44 Issue::warning(
45 issue_codes::LINT_AM_007,
46 "Set operation branches resolve to different column counts.",
47 )
48 .with_statement(ctx.statement_index)
49 })
50 .collect()
51 }
52}
53
54fn lint_statement_set_ops(
55 statement: &Statement,
56 outer_ctes: &CteColumnCounts,
57 violations: &mut usize,
58) {
59 match statement {
60 Statement::Query(query) => lint_query_set_ops(query, outer_ctes, violations),
61 Statement::Insert(insert) => {
62 if let Some(source) = &insert.source {
63 lint_query_set_ops(source, outer_ctes, violations);
64 }
65 }
66 Statement::CreateView(CreateView { query, .. }) => {
67 lint_query_set_ops(query, outer_ctes, violations)
68 }
69 Statement::CreateTable(create) => {
70 if let Some(query) = &create.query {
71 lint_query_set_ops(query, outer_ctes, violations);
72 }
73 }
74 Statement::Update(Update {
75 from: Some(from_kind),
76 ..
77 }) => {
78 let tables = match from_kind {
79 UpdateTableFromKind::BeforeSet(t) | UpdateTableFromKind::AfterSet(t) => t,
80 };
81 for twj in tables {
82 lint_table_factor_set_ops(&twj.relation, outer_ctes, violations);
83 for join in &twj.joins {
84 lint_table_factor_set_ops(&join.relation, outer_ctes, violations);
85 }
86 }
87 }
88 _ => {}
89 }
90}
91
92fn lint_query_set_ops(query: &Query, outer_ctes: &CteColumnCounts, violations: &mut usize) {
93 let ctes = build_query_cte_map(query, outer_ctes);
94 lint_set_expr_set_ops(&query.body, &ctes, violations);
95}
96
97fn lint_set_expr_set_ops(set_expr: &SetExpr, ctes: &CteColumnCounts, violations: &mut usize) {
98 match set_expr {
99 SetExpr::SetOperation { left, right, .. } => {
100 let stats = collect_set_branch_counts(set_expr, ctes);
101 if stats.fully_resolved && stats.counts.len() > 1 {
102 *violations += 1;
103 }
104
105 lint_set_expr_set_ops(left, ctes, violations);
106 lint_set_expr_set_ops(right, ctes, violations);
107 }
108 SetExpr::Query(query) => lint_query_set_ops(query, ctes, violations),
109 SetExpr::Select(select) => lint_select_subqueries_set_ops(select, ctes, violations),
110 SetExpr::Insert(statement)
111 | SetExpr::Update(statement)
112 | SetExpr::Delete(statement)
113 | SetExpr::Merge(statement) => lint_statement_set_ops(statement, ctes, violations),
114 _ => {}
115 }
116}
117
118fn lint_select_subqueries_set_ops(select: &Select, ctes: &CteColumnCounts, violations: &mut usize) {
119 for table in &select.from {
120 lint_table_factor_set_ops(&table.relation, ctes, violations);
121 for join in &table.joins {
122 lint_table_factor_set_ops(&join.relation, ctes, violations);
123 }
124 }
125}
126
127fn lint_table_factor_set_ops(
128 table_factor: &TableFactor,
129 ctes: &CteColumnCounts,
130 violations: &mut usize,
131) {
132 match table_factor {
133 TableFactor::Derived { subquery, .. } => lint_query_set_ops(subquery, ctes, violations),
134 TableFactor::NestedJoin {
135 table_with_joins, ..
136 } => {
137 lint_table_factor_set_ops(&table_with_joins.relation, ctes, violations);
138 for join in &table_with_joins.joins {
139 lint_table_factor_set_ops(&join.relation, ctes, violations);
140 }
141 }
142 TableFactor::Pivot { table, .. }
143 | TableFactor::Unpivot { table, .. }
144 | TableFactor::MatchRecognize { table, .. } => {
145 lint_table_factor_set_ops(table, ctes, violations)
146 }
147 _ => {}
148 }
149}
150
151fn collect_set_branch_counts(set_expr: &SetExpr, ctes: &CteColumnCounts) -> SetCountStats {
152 match set_expr {
153 SetExpr::SetOperation { left, right, .. } => {
154 let left_stats = collect_set_branch_counts(left, ctes);
155 let right_stats = collect_set_branch_counts(right, ctes);
156
157 let mut counts = left_stats.counts;
158 counts.extend(right_stats.counts);
159
160 SetCountStats {
161 counts,
162 fully_resolved: left_stats.fully_resolved && right_stats.fully_resolved,
163 }
164 }
165 _ => {
166 if let Some(count) = resolve_set_expr_output_columns(set_expr, ctes) {
167 let mut counts = HashSet::new();
168 counts.insert(count);
169 SetCountStats {
170 counts,
171 fully_resolved: true,
172 }
173 } else {
174 SetCountStats {
175 counts: HashSet::new(),
176 fully_resolved: false,
177 }
178 }
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::parser::parse_sql;
187
188 fn run(sql: &str) -> Vec<Issue> {
189 let statements = parse_sql(sql).expect("parse");
190 let rule = AmbiguousSetColumns;
191 statements
192 .iter()
193 .enumerate()
194 .flat_map(|(index, statement)| {
195 rule.check(
196 statement,
197 &LintContext {
198 sql,
199 statement_range: 0..sql.len(),
200 statement_index: index,
201 },
202 )
203 })
204 .collect()
205 }
206
207 #[test]
210 fn flags_known_set_column_count_mismatch() {
211 let issues = run("select a from t union all select c, d from k");
212 assert_eq!(issues.len(), 1);
213 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
214 }
215
216 #[test]
217 fn allows_known_set_column_count_match() {
218 let issues = run("select a, b from t union all select c, d from k");
219 assert!(issues.is_empty());
220 }
221
222 #[test]
223 fn resolves_cte_wildcard_columns_for_set_comparison() {
224 let issues =
225 run("with cte as (select a, b from t) select * from cte union select c, d from t2");
226 assert!(issues.is_empty());
227 }
228
229 #[test]
230 fn resolves_declared_cte_columns_for_set_comparison() {
231 let issues =
232 run("with cte(a, b) as (select * from t) select * from cte union select c, d from t2");
233 assert!(issues.is_empty());
234 }
235
236 #[test]
237 fn resolves_declared_derived_alias_columns_for_set_comparison() {
238 let issues = run(
239 "select t_alias.* from (select * from t) as t_alias(a, b) union select c, d from t2",
240 );
241 assert!(issues.is_empty());
242 }
243
244 #[test]
245 fn flags_resolved_cte_wildcard_mismatch() {
246 let issues =
247 run("with cte as (select a, b, c from t) select * from cte union select d, e from t2");
248 assert_eq!(issues.len(), 1);
249 }
250
251 #[test]
252 fn flags_declared_cte_width_mismatch_for_set_comparison() {
253 let issues = run(
254 "with cte(a, b, c) as (select * from t) select * from cte union select d, e from t2",
255 );
256 assert_eq!(issues.len(), 1);
257 }
258
259 #[test]
260 fn flags_declared_derived_alias_width_mismatch_for_set_comparison() {
261 let issues = run(
262 "select t_alias.* from (select * from t) as t_alias(a, b, c) union select d, e from t2",
263 );
264 assert_eq!(issues.len(), 1);
265 }
266
267 #[test]
268 fn unresolved_external_wildcard_does_not_trigger() {
269 let issues = run("select a from t1 union all select * from t2");
270 assert!(issues.is_empty());
271 }
272
273 #[test]
274 fn resolves_derived_alias_wildcard() {
275 let issues = run(
276 "select t_alias.* from t2 join (select a from t) as t_alias using (a) union select b from t3",
277 );
278 assert!(issues.is_empty());
279 }
280
281 #[test]
282 fn resolves_nested_with_wildcard_for_set_comparison() {
283 let issues = run(
284 "SELECT * FROM (WITH cte2 AS (SELECT a, b FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
285 );
286 assert!(issues.is_empty());
287 }
288
289 #[test]
290 fn flags_nested_with_wildcard_mismatch_for_set_comparison() {
291 let issues = run(
292 "SELECT * FROM (WITH cte2 AS (SELECT a FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
293 );
294 assert_eq!(issues.len(), 1);
295 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
296 }
297
298 #[test]
299 fn resolves_nested_cte_chain_for_set_comparison() {
300 let issues = run(
301 "with a as (with b as (select 1 from c) select * from b) select * from a union all select k from t2",
302 );
303 assert!(issues.is_empty());
304 }
305
306 #[test]
307 fn resolves_nested_join_alias_wildcard_for_set_comparison() {
308 let issues = run(
309 "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x, y from t3",
310 );
311 assert!(issues.is_empty());
312 }
313
314 #[test]
315 fn flags_nested_join_alias_wildcard_set_mismatch_when_resolved() {
316 let issues = run(
317 "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x from t3",
318 );
319 assert_eq!(issues.len(), 1);
320 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
321 }
322
323 #[test]
324 fn resolves_nested_join_alias_using_width_for_set_comparison() {
325 let issues = run(
326 "select j.* from ((select a from t1) as a1 join (select a from t2) as b1 using(a)) as j union all select x from t3",
327 );
328 assert!(issues.is_empty());
329 }
330
331 #[test]
332 fn resolves_natural_join_nested_alias_width_for_set_comparison() {
333 let issues = run(
334 "select j.* from ((select a from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
335 );
336 assert!(issues.is_empty());
337 }
338
339 #[test]
340 fn natural_join_nested_alias_width_unknown_does_not_trigger_for_set_comparison() {
341 let issues = run(
342 "select j.* from ((select * from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
343 );
344 assert!(issues.is_empty());
345 }
346
347 #[test]
348 fn update_from_with_set_column_mismatch() {
349 let sql = "UPDATE sometable SET sometable.baz = mycte.bar FROM (SELECT foo, bar FROM mytable1 UNION ALL SELECT bar FROM mytable2) as k";
351 let issues = run(sql);
352 assert_eq!(issues.len(), 1);
353 assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
354 }
355}