Skip to main content

flowscope_core/linter/rules/
am_007.rs

1//! LINT_AM_007: Ambiguous set-operation columns.
2//!
3//! SQLFluff AM07 parity: set-operation branches should resolve to the same
4//! number of output columns when wildcard expansion is deterministically known.
5
6use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Issue};
8use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor, UpdateTableFromKind};
9use std::collections::{HashMap, HashSet};
10
11use super::column_count_helpers::{
12    build_query_cte_map, resolve_set_expr_output_columns, CteColumnCounts,
13};
14
15pub struct AmbiguousSetColumns;
16
17#[derive(Default)]
18struct SetCountStats {
19    counts: HashSet<usize>,
20    fully_resolved: bool,
21}
22
23impl LintRule for AmbiguousSetColumns {
24    fn code(&self) -> &'static str {
25        issue_codes::LINT_AM_007
26    }
27
28    fn name(&self) -> &'static str {
29        "Ambiguous set columns"
30    }
31
32    fn description(&self) -> &'static str {
33        "Queries within set query produce different numbers of columns."
34    }
35
36    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
37        let mut violation_count = 0usize;
38        lint_statement_set_ops(statement, &HashMap::new(), &mut violation_count);
39
40        (0..violation_count)
41            .map(|_| {
42                Issue::warning(
43                    issue_codes::LINT_AM_007,
44                    "Set operation branches resolve to different column counts.",
45                )
46                .with_statement(ctx.statement_index)
47            })
48            .collect()
49    }
50}
51
52fn lint_statement_set_ops(
53    statement: &Statement,
54    outer_ctes: &CteColumnCounts,
55    violations: &mut usize,
56) {
57    match statement {
58        Statement::Query(query) => lint_query_set_ops(query, outer_ctes, violations),
59        Statement::Insert(insert) => {
60            if let Some(source) = &insert.source {
61                lint_query_set_ops(source, outer_ctes, violations);
62            }
63        }
64        Statement::CreateView { query, .. } => lint_query_set_ops(query, outer_ctes, violations),
65        Statement::CreateTable(create) => {
66            if let Some(query) = &create.query {
67                lint_query_set_ops(query, outer_ctes, violations);
68            }
69        }
70        Statement::Update {
71            from: Some(from_kind),
72            ..
73        } => {
74            let tables = match from_kind {
75                UpdateTableFromKind::BeforeSet(t) | UpdateTableFromKind::AfterSet(t) => t,
76            };
77            for twj in tables {
78                lint_table_factor_set_ops(&twj.relation, outer_ctes, violations);
79                for join in &twj.joins {
80                    lint_table_factor_set_ops(&join.relation, outer_ctes, violations);
81                }
82            }
83        }
84        _ => {}
85    }
86}
87
88fn lint_query_set_ops(query: &Query, outer_ctes: &CteColumnCounts, violations: &mut usize) {
89    let ctes = build_query_cte_map(query, outer_ctes);
90    lint_set_expr_set_ops(&query.body, &ctes, violations);
91}
92
93fn lint_set_expr_set_ops(set_expr: &SetExpr, ctes: &CteColumnCounts, violations: &mut usize) {
94    match set_expr {
95        SetExpr::SetOperation { left, right, .. } => {
96            let stats = collect_set_branch_counts(set_expr, ctes);
97            if stats.fully_resolved && stats.counts.len() > 1 {
98                *violations += 1;
99            }
100
101            lint_set_expr_set_ops(left, ctes, violations);
102            lint_set_expr_set_ops(right, ctes, violations);
103        }
104        SetExpr::Query(query) => lint_query_set_ops(query, ctes, violations),
105        SetExpr::Select(select) => lint_select_subqueries_set_ops(select, ctes, violations),
106        SetExpr::Insert(statement)
107        | SetExpr::Update(statement)
108        | SetExpr::Delete(statement)
109        | SetExpr::Merge(statement) => lint_statement_set_ops(statement, ctes, violations),
110        _ => {}
111    }
112}
113
114fn lint_select_subqueries_set_ops(select: &Select, ctes: &CteColumnCounts, violations: &mut usize) {
115    for table in &select.from {
116        lint_table_factor_set_ops(&table.relation, ctes, violations);
117        for join in &table.joins {
118            lint_table_factor_set_ops(&join.relation, ctes, violations);
119        }
120    }
121}
122
123fn lint_table_factor_set_ops(
124    table_factor: &TableFactor,
125    ctes: &CteColumnCounts,
126    violations: &mut usize,
127) {
128    match table_factor {
129        TableFactor::Derived { subquery, .. } => lint_query_set_ops(subquery, ctes, violations),
130        TableFactor::NestedJoin {
131            table_with_joins, ..
132        } => {
133            lint_table_factor_set_ops(&table_with_joins.relation, ctes, violations);
134            for join in &table_with_joins.joins {
135                lint_table_factor_set_ops(&join.relation, ctes, violations);
136            }
137        }
138        TableFactor::Pivot { table, .. }
139        | TableFactor::Unpivot { table, .. }
140        | TableFactor::MatchRecognize { table, .. } => {
141            lint_table_factor_set_ops(table, ctes, violations)
142        }
143        _ => {}
144    }
145}
146
147fn collect_set_branch_counts(set_expr: &SetExpr, ctes: &CteColumnCounts) -> SetCountStats {
148    match set_expr {
149        SetExpr::SetOperation { left, right, .. } => {
150            let left_stats = collect_set_branch_counts(left, ctes);
151            let right_stats = collect_set_branch_counts(right, ctes);
152
153            let mut counts = left_stats.counts;
154            counts.extend(right_stats.counts);
155
156            SetCountStats {
157                counts,
158                fully_resolved: left_stats.fully_resolved && right_stats.fully_resolved,
159            }
160        }
161        _ => {
162            if let Some(count) = resolve_set_expr_output_columns(set_expr, ctes) {
163                let mut counts = HashSet::new();
164                counts.insert(count);
165                SetCountStats {
166                    counts,
167                    fully_resolved: true,
168                }
169            } else {
170                SetCountStats {
171                    counts: HashSet::new(),
172                    fully_resolved: false,
173                }
174            }
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::parser::parse_sql;
183
184    fn run(sql: &str) -> Vec<Issue> {
185        let statements = parse_sql(sql).expect("parse");
186        let rule = AmbiguousSetColumns;
187        statements
188            .iter()
189            .enumerate()
190            .flat_map(|(index, statement)| {
191                rule.check(
192                    statement,
193                    &LintContext {
194                        sql,
195                        statement_range: 0..sql.len(),
196                        statement_index: index,
197                    },
198                )
199            })
200            .collect()
201    }
202
203    // --- Edge cases adopted from sqlfluff AM07 ---
204
205    #[test]
206    fn flags_known_set_column_count_mismatch() {
207        let issues = run("select a from t union all select c, d from k");
208        assert_eq!(issues.len(), 1);
209        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
210    }
211
212    #[test]
213    fn allows_known_set_column_count_match() {
214        let issues = run("select a, b from t union all select c, d from k");
215        assert!(issues.is_empty());
216    }
217
218    #[test]
219    fn resolves_cte_wildcard_columns_for_set_comparison() {
220        let issues =
221            run("with cte as (select a, b from t) select * from cte union select c, d from t2");
222        assert!(issues.is_empty());
223    }
224
225    #[test]
226    fn resolves_declared_cte_columns_for_set_comparison() {
227        let issues =
228            run("with cte(a, b) as (select * from t) select * from cte union select c, d from t2");
229        assert!(issues.is_empty());
230    }
231
232    #[test]
233    fn resolves_declared_derived_alias_columns_for_set_comparison() {
234        let issues = run(
235            "select t_alias.* from (select * from t) as t_alias(a, b) union select c, d from t2",
236        );
237        assert!(issues.is_empty());
238    }
239
240    #[test]
241    fn flags_resolved_cte_wildcard_mismatch() {
242        let issues =
243            run("with cte as (select a, b, c from t) select * from cte union select d, e from t2");
244        assert_eq!(issues.len(), 1);
245    }
246
247    #[test]
248    fn flags_declared_cte_width_mismatch_for_set_comparison() {
249        let issues = run(
250            "with cte(a, b, c) as (select * from t) select * from cte union select d, e from t2",
251        );
252        assert_eq!(issues.len(), 1);
253    }
254
255    #[test]
256    fn flags_declared_derived_alias_width_mismatch_for_set_comparison() {
257        let issues = run(
258            "select t_alias.* from (select * from t) as t_alias(a, b, c) union select d, e from t2",
259        );
260        assert_eq!(issues.len(), 1);
261    }
262
263    #[test]
264    fn unresolved_external_wildcard_does_not_trigger() {
265        let issues = run("select a from t1 union all select * from t2");
266        assert!(issues.is_empty());
267    }
268
269    #[test]
270    fn resolves_derived_alias_wildcard() {
271        let issues = run(
272            "select t_alias.* from t2 join (select a from t) as t_alias using (a) union select b from t3",
273        );
274        assert!(issues.is_empty());
275    }
276
277    #[test]
278    fn resolves_nested_with_wildcard_for_set_comparison() {
279        let issues = run(
280            "SELECT * FROM (WITH cte2 AS (SELECT a, b FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
281        );
282        assert!(issues.is_empty());
283    }
284
285    #[test]
286    fn flags_nested_with_wildcard_mismatch_for_set_comparison() {
287        let issues = run(
288            "SELECT * FROM (WITH cte2 AS (SELECT a FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
289        );
290        assert_eq!(issues.len(), 1);
291        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
292    }
293
294    #[test]
295    fn resolves_nested_cte_chain_for_set_comparison() {
296        let issues = run(
297            "with a as (with b as (select 1 from c) select * from b) select * from a union all select k from t2",
298        );
299        assert!(issues.is_empty());
300    }
301
302    #[test]
303    fn resolves_nested_join_alias_wildcard_for_set_comparison() {
304        let issues = run(
305            "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x, y from t3",
306        );
307        assert!(issues.is_empty());
308    }
309
310    #[test]
311    fn flags_nested_join_alias_wildcard_set_mismatch_when_resolved() {
312        let issues = run(
313            "select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x from t3",
314        );
315        assert_eq!(issues.len(), 1);
316        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
317    }
318
319    #[test]
320    fn resolves_nested_join_alias_using_width_for_set_comparison() {
321        let issues = run(
322            "select j.* from ((select a from t1) as a1 join (select a from t2) as b1 using(a)) as j union all select x from t3",
323        );
324        assert!(issues.is_empty());
325    }
326
327    #[test]
328    fn resolves_natural_join_nested_alias_width_for_set_comparison() {
329        let issues = run(
330            "select j.* from ((select a from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
331        );
332        assert!(issues.is_empty());
333    }
334
335    #[test]
336    fn natural_join_nested_alias_width_unknown_does_not_trigger_for_set_comparison() {
337        let issues = run(
338            "select j.* from ((select * from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
339        );
340        assert!(issues.is_empty());
341    }
342
343    #[test]
344    fn update_from_with_set_column_mismatch() {
345        // SQLFluff: test_fail_cte_no_select_final_statement
346        let sql = "UPDATE sometable SET sometable.baz = mycte.bar FROM (SELECT foo, bar FROM mytable1 UNION ALL SELECT bar FROM mytable2) as k";
347        let issues = run(sql);
348        assert_eq!(issues.len(), 1);
349        assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
350    }
351}