Skip to main content

flowscope_core/linter/rules/
am_007.rs

1//! LINT_AM_007: Ambiguous set-operation columns.
2//!
3//! SQLFluff AM07 parity: set-operation branches should resolve to the same
4//! number of output columns when wildcard expansion is deterministically known.
5
6use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Issue};
8use sqlparser::ast::{
9    CreateView, Query, Select, SetExpr, Statement, TableFactor, Update, UpdateTableFromKind,
10};
11use std::collections::{HashMap, HashSet};
12
13use super::column_count_helpers::{
14    build_query_cte_map, resolve_set_expr_output_columns, CteColumnCounts,
15};
16
17pub struct AmbiguousSetColumns;
18
19#[derive(Default)]
20struct SetCountStats {
21    counts: HashSet<usize>,
22    fully_resolved: bool,
23}
24
25impl LintRule for AmbiguousSetColumns {
26    fn code(&self) -> &'static str {
27        issue_codes::LINT_AM_007
28    }
29
30    fn name(&self) -> &'static str {
31        "Ambiguous set columns"
32    }
33
34    fn description(&self) -> &'static str {
35        "Queries within set query produce different numbers of columns."
36    }
37
38    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
39        let mut violation_count = 0usize;
40        lint_statement_set_ops(statement, &HashMap::new(), &mut violation_count);
41
42        (0..violation_count)
43            .map(|_| {
44                Issue::warning(
45                    issue_codes::LINT_AM_007,
46                    "Set operation branches resolve to different column counts.",
47                )
48                .with_statement(ctx.statement_index)
49            })
50            .collect()
51    }
52}
53
54fn lint_statement_set_ops(
55    statement: &Statement,
56    outer_ctes: &CteColumnCounts,
57    violations: &mut usize,
58) {
59    match statement {
60        Statement::Query(query) => lint_query_set_ops(query, outer_ctes, violations),
61        Statement::Insert(insert) => {
62            if let Some(source) = &insert.source {
63                lint_query_set_ops(source, outer_ctes, violations);
64            }
65        }
66        Statement::CreateView(CreateView { query, .. }) => {
67            lint_query_set_ops(query, outer_ctes, violations)
68        }
69        Statement::CreateTable(create) => {
70            if let Some(query) = &create.query {
71                lint_query_set_ops(query, outer_ctes, violations);
72            }
73        }
74        Statement::Update(Update {
75            from: Some(from_kind),
76            ..
77        }) => {
78            let tables = match from_kind {
79                UpdateTableFromKind::BeforeSet(t) | UpdateTableFromKind::AfterSet(t) => t,
80            };
81            for twj in tables {
82                lint_table_factor_set_ops(&twj.relation, outer_ctes, violations);
83                for join in &twj.joins {
84                    lint_table_factor_set_ops(&join.relation, outer_ctes, violations);
85                }
86            }
87        }
88        _ => {}
89    }
90}
91
92fn lint_query_set_ops(query: &Query, outer_ctes: &CteColumnCounts, violations: &mut usize) {
93    let ctes = build_query_cte_map(query, outer_ctes);
94    lint_set_expr_set_ops(&query.body, &ctes, violations);
95}
96
97fn lint_set_expr_set_ops(set_expr: &SetExpr, ctes: &CteColumnCounts, violations: &mut usize) {
98    match set_expr {
99        SetExpr::SetOperation { left, right, .. } => {
100            let stats = collect_set_branch_counts(set_expr, ctes);
101            if stats.fully_resolved && stats.counts.len() > 1 {
102                *violations += 1;
103            }
104
105            lint_set_expr_set_ops(left, ctes, violations);
106            lint_set_expr_set_ops(right, ctes, violations);
107        }
108        SetExpr::Query(query) => lint_query_set_ops(query, ctes, violations),
109        SetExpr::Select(select) => lint_select_subqueries_set_ops(select, ctes, violations),
110        SetExpr::Insert(statement)
111        | SetExpr::Update(statement)
112        | SetExpr::Delete(statement)
113        | SetExpr::Merge(statement) => lint_statement_set_ops(statement, ctes, violations),
114        _ => {}
115    }
116}
117
118fn lint_select_subqueries_set_ops(select: &Select, ctes: &CteColumnCounts, violations: &mut usize) {
119    for table in &select.from {
120        lint_table_factor_set_ops(&table.relation, ctes, violations);
121        for join in &table.joins {
122            lint_table_factor_set_ops(&join.relation, ctes, violations);
123        }
124    }
125}
126
127fn lint_table_factor_set_ops(
128    table_factor: &TableFactor,
129    ctes: &CteColumnCounts,
130    violations: &mut usize,
131) {
132    match table_factor {
133        TableFactor::Derived { subquery, .. } => lint_query_set_ops(subquery, ctes, violations),
134        TableFactor::NestedJoin {
135            table_with_joins, ..
136        } => {
137            lint_table_factor_set_ops(&table_with_joins.relation, ctes, violations);
138            for join in &table_with_joins.joins {
139                lint_table_factor_set_ops(&join.relation, ctes, violations);
140            }
141        }
142        TableFactor::Pivot { table, .. }
143        | TableFactor::Unpivot { table, .. }
144        | TableFactor::MatchRecognize { table, .. } => {
145            lint_table_factor_set_ops(table, ctes, violations)
146        }
147        _ => {}
148    }
149}
150
151fn collect_set_branch_counts(set_expr: &SetExpr, ctes: &CteColumnCounts) -> SetCountStats {
152    match set_expr {
153        SetExpr::SetOperation { left, right, .. } => {
154            let left_stats = collect_set_branch_counts(left, ctes);
155            let right_stats = collect_set_branch_counts(right, ctes);
156
157            let mut counts = left_stats.counts;
158            counts.extend(right_stats.counts);
159
160            SetCountStats {
161                counts,
162                fully_resolved: left_stats.fully_resolved && right_stats.fully_resolved,
163            }
164        }
165        _ => {
166            if let Some(count) = resolve_set_expr_output_columns(set_expr, ctes) {
167                let mut counts = HashSet::new();
168                counts.insert(count);
169                SetCountStats {
170                    counts,
171                    fully_resolved: true,
172                }
173            } else {
174                SetCountStats {
175                    counts: HashSet::new(),
176                    fully_resolved: false,
177                }
178            }
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::parser::parse_sql;
187
188    fn run(sql: &str) -> Vec<Issue> {
189        let statements = parse_sql(sql).expect("parse");
190        let rule = AmbiguousSetColumns;
191        statements
192            .iter()
193            .enumerate()
194            .flat_map(|(index, statement)| {
195                rule.check(
196                    statement,
197                    &LintContext {
198                        sql,
199                        statement_range: 0..sql.len(),
200                        statement_index: index,
201                    },
202                )
203            })
204            .collect()
205    }
206
207    // --- Edge cases adopted from sqlfluff AM07 ---
208
209    #[test]
210    fn flags_known_set_column_count_mismatch() {
211        let issues = run("select a from t union all select c, d from k");
212        assert_eq!(issues.len(), 1);
213        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
214    }
215
216    #[test]
217    fn allows_known_set_column_count_match() {
218        let issues = run("select a, b from t union all select c, d from k");
219        assert!(issues.is_empty());
220    }
221
222    #[test]
223    fn resolves_cte_wildcard_columns_for_set_comparison() {
224        let issues =
225            run("with cte as (select a, b from t) select * from cte union select c, d from t2");
226        assert!(issues.is_empty());
227    }
228
229    #[test]
230    fn resolves_declared_cte_columns_for_set_comparison() {
231        let issues =
232            run("with cte(a, b) as (select * from t) select * from cte union select c, d from t2");
233        assert!(issues.is_empty());
234    }
235
236    #[test]
237    fn resolves_declared_derived_alias_columns_for_set_comparison() {
238        let issues = run(
239            "select t_alias.* from (select * from t) as t_alias(a, b) union select c, d from t2",
240        );
241        assert!(issues.is_empty());
242    }
243
244    #[test]
245    fn flags_resolved_cte_wildcard_mismatch() {
246        let issues =
247            run("with cte as (select a, b, c from t) select * from cte union select d, e from t2");
248        assert_eq!(issues.len(), 1);
249    }
250
251    #[test]
252    fn flags_declared_cte_width_mismatch_for_set_comparison() {
253        let issues = run(
254            "with cte(a, b, c) as (select * from t) select * from cte union select d, e from t2",
255        );
256        assert_eq!(issues.len(), 1);
257    }
258
259    #[test]
260    fn flags_declared_derived_alias_width_mismatch_for_set_comparison() {
261        let issues = run(
262            "select t_alias.* from (select * from t) as t_alias(a, b, c) union select d, e from t2",
263        );
264        assert_eq!(issues.len(), 1);
265    }
266
267    #[test]
268    fn unresolved_external_wildcard_does_not_trigger() {
269        let issues = run("select a from t1 union all select * from t2");
270        assert!(issues.is_empty());
271    }
272
273    #[test]
274    fn resolves_derived_alias_wildcard() {
275        let issues = run(
276            "select t_alias.* from t2 join (select a from t) as t_alias using (a) union select b from t3",
277        );
278        assert!(issues.is_empty());
279    }
280
281    #[test]
282    fn resolves_nested_with_wildcard_for_set_comparison() {
283        let issues = run(
284            "SELECT * FROM (WITH cte2 AS (SELECT a, b FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
285        );
286        assert!(issues.is_empty());
287    }
288
289    #[test]
290    fn flags_nested_with_wildcard_mismatch_for_set_comparison() {
291        let issues = run(
292            "SELECT * FROM (WITH cte2 AS (SELECT a FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
293        );
294        assert_eq!(issues.len(), 1);
295        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
296    }
297
298    #[test]
299    fn resolves_nested_cte_chain_for_set_comparison() {
300        let issues = run(
301            "with a as (with b as (select 1 from c) select * from b) select * from a union all select k from t2",
302        );
303        assert!(issues.is_empty());
304    }
305
306    #[test]
307    fn resolves_nested_join_alias_wildcard_for_set_comparison() {
308        let issues = run(
309            "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x, y from t3",
310        );
311        assert!(issues.is_empty());
312    }
313
314    #[test]
315    fn flags_nested_join_alias_wildcard_set_mismatch_when_resolved() {
316        let issues = run(
317            "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x from t3",
318        );
319        assert_eq!(issues.len(), 1);
320        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
321    }
322
323    #[test]
324    fn resolves_nested_join_alias_using_width_for_set_comparison() {
325        let issues = run(
326            "select j.* from ((select a from t1) as a1 join (select a from t2) as b1 using(a)) as j union all select x from t3",
327        );
328        assert!(issues.is_empty());
329    }
330
331    #[test]
332    fn resolves_natural_join_nested_alias_width_for_set_comparison() {
333        let issues = run(
334            "select j.* from ((select a from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
335        );
336        assert!(issues.is_empty());
337    }
338
339    #[test]
340    fn natural_join_nested_alias_width_unknown_does_not_trigger_for_set_comparison() {
341        let issues = run(
342            "select j.* from ((select * from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
343        );
344        assert!(issues.is_empty());
345    }
346
347    #[test]
348    fn update_from_with_set_column_mismatch() {
349        // SQLFluff: test_fail_cte_no_select_final_statement
350        let sql = "UPDATE sometable SET sometable.baz = mycte.bar FROM (SELECT foo, bar FROM mytable1 UNION ALL SELECT bar FROM mytable2) as k";
351        let issues = run(sql);
352        assert_eq!(issues.len(), 1);
353        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
354    }
355}