use crate::linter::rule::{LintContext, LintRule};
use crate::types::{issue_codes, Issue};
use sqlparser::ast::{
CreateView, Query, Select, SetExpr, Statement, TableFactor, Update, UpdateTableFromKind,
};
use std::collections::{HashMap, HashSet};
use super::column_count_helpers::{
build_query_cte_map, resolve_set_expr_output_columns, CteColumnCounts,
};
pub struct AmbiguousSetColumns;
#[derive(Default)]
struct SetCountStats {
counts: HashSet<usize>,
fully_resolved: bool,
}
impl LintRule for AmbiguousSetColumns {
fn code(&self) -> &'static str {
issue_codes::LINT_AM_007
}
fn name(&self) -> &'static str {
"Ambiguous set columns"
}
fn description(&self) -> &'static str {
"Queries within set query produce different numbers of columns."
}
fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
let mut violation_count = 0usize;
lint_statement_set_ops(statement, &HashMap::new(), &mut violation_count);
(0..violation_count)
.map(|_| {
Issue::warning(
issue_codes::LINT_AM_007,
"Set operation branches resolve to different column counts.",
)
.with_statement(ctx.statement_index)
})
.collect()
}
}
fn lint_statement_set_ops(
statement: &Statement,
outer_ctes: &CteColumnCounts,
violations: &mut usize,
) {
match statement {
Statement::Query(query) => lint_query_set_ops(query, outer_ctes, violations),
Statement::Insert(insert) => {
if let Some(source) = &insert.source {
lint_query_set_ops(source, outer_ctes, violations);
}
}
Statement::CreateView(CreateView { query, .. }) => {
lint_query_set_ops(query, outer_ctes, violations)
}
Statement::CreateTable(create) => {
if let Some(query) = &create.query {
lint_query_set_ops(query, outer_ctes, violations);
}
}
Statement::Update(Update {
from: Some(from_kind),
..
}) => {
let tables = match from_kind {
UpdateTableFromKind::BeforeSet(t) | UpdateTableFromKind::AfterSet(t) => t,
};
for twj in tables {
lint_table_factor_set_ops(&twj.relation, outer_ctes, violations);
for join in &twj.joins {
lint_table_factor_set_ops(&join.relation, outer_ctes, violations);
}
}
}
_ => {}
}
}
fn lint_query_set_ops(query: &Query, outer_ctes: &CteColumnCounts, violations: &mut usize) {
let ctes = build_query_cte_map(query, outer_ctes);
lint_set_expr_set_ops(&query.body, &ctes, violations);
}
fn lint_set_expr_set_ops(set_expr: &SetExpr, ctes: &CteColumnCounts, violations: &mut usize) {
match set_expr {
SetExpr::SetOperation { left, right, .. } => {
let stats = collect_set_branch_counts(set_expr, ctes);
if stats.fully_resolved && stats.counts.len() > 1 {
*violations += 1;
}
lint_set_expr_set_ops(left, ctes, violations);
lint_set_expr_set_ops(right, ctes, violations);
}
SetExpr::Query(query) => lint_query_set_ops(query, ctes, violations),
SetExpr::Select(select) => lint_select_subqueries_set_ops(select, ctes, violations),
SetExpr::Insert(statement)
| SetExpr::Update(statement)
| SetExpr::Delete(statement)
| SetExpr::Merge(statement) => lint_statement_set_ops(statement, ctes, violations),
_ => {}
}
}
fn lint_select_subqueries_set_ops(select: &Select, ctes: &CteColumnCounts, violations: &mut usize) {
for table in &select.from {
lint_table_factor_set_ops(&table.relation, ctes, violations);
for join in &table.joins {
lint_table_factor_set_ops(&join.relation, ctes, violations);
}
}
}
fn lint_table_factor_set_ops(
table_factor: &TableFactor,
ctes: &CteColumnCounts,
violations: &mut usize,
) {
match table_factor {
TableFactor::Derived { subquery, .. } => lint_query_set_ops(subquery, ctes, violations),
TableFactor::NestedJoin {
table_with_joins, ..
} => {
lint_table_factor_set_ops(&table_with_joins.relation, ctes, violations);
for join in &table_with_joins.joins {
lint_table_factor_set_ops(&join.relation, ctes, violations);
}
}
TableFactor::Pivot { table, .. }
| TableFactor::Unpivot { table, .. }
| TableFactor::MatchRecognize { table, .. } => {
lint_table_factor_set_ops(table, ctes, violations)
}
_ => {}
}
}
fn collect_set_branch_counts(set_expr: &SetExpr, ctes: &CteColumnCounts) -> SetCountStats {
match set_expr {
SetExpr::SetOperation { left, right, .. } => {
let left_stats = collect_set_branch_counts(left, ctes);
let right_stats = collect_set_branch_counts(right, ctes);
let mut counts = left_stats.counts;
counts.extend(right_stats.counts);
SetCountStats {
counts,
fully_resolved: left_stats.fully_resolved && right_stats.fully_resolved,
}
}
_ => {
if let Some(count) = resolve_set_expr_output_columns(set_expr, ctes) {
let mut counts = HashSet::new();
counts.insert(count);
SetCountStats {
counts,
fully_resolved: true,
}
} else {
SetCountStats {
counts: HashSet::new(),
fully_resolved: false,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_sql;
fn run(sql: &str) -> Vec<Issue> {
let statements = parse_sql(sql).expect("parse");
let rule = AmbiguousSetColumns;
statements
.iter()
.enumerate()
.flat_map(|(index, statement)| {
rule.check(
statement,
&LintContext {
sql,
statement_range: 0..sql.len(),
statement_index: index,
},
)
})
.collect()
}
#[test]
fn flags_known_set_column_count_mismatch() {
let issues = run("select a from t union all select c, d from k");
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
}
#[test]
fn allows_known_set_column_count_match() {
let issues = run("select a, b from t union all select c, d from k");
assert!(issues.is_empty());
}
#[test]
fn resolves_cte_wildcard_columns_for_set_comparison() {
let issues =
run("with cte as (select a, b from t) select * from cte union select c, d from t2");
assert!(issues.is_empty());
}
#[test]
fn resolves_declared_cte_columns_for_set_comparison() {
let issues =
run("with cte(a, b) as (select * from t) select * from cte union select c, d from t2");
assert!(issues.is_empty());
}
#[test]
fn resolves_declared_derived_alias_columns_for_set_comparison() {
let issues = run(
"select t_alias.* from (select * from t) as t_alias(a, b) union select c, d from t2",
);
assert!(issues.is_empty());
}
#[test]
fn flags_resolved_cte_wildcard_mismatch() {
let issues =
run("with cte as (select a, b, c from t) select * from cte union select d, e from t2");
assert_eq!(issues.len(), 1);
}
#[test]
fn flags_declared_cte_width_mismatch_for_set_comparison() {
let issues = run(
"with cte(a, b, c) as (select * from t) select * from cte union select d, e from t2",
);
assert_eq!(issues.len(), 1);
}
#[test]
fn flags_declared_derived_alias_width_mismatch_for_set_comparison() {
let issues = run(
"select t_alias.* from (select * from t) as t_alias(a, b, c) union select d, e from t2",
);
assert_eq!(issues.len(), 1);
}
#[test]
fn unresolved_external_wildcard_does_not_trigger() {
let issues = run("select a from t1 union all select * from t2");
assert!(issues.is_empty());
}
#[test]
fn resolves_derived_alias_wildcard() {
let issues = run(
"select t_alias.* from t2 join (select a from t) as t_alias using (a) union select b from t3",
);
assert!(issues.is_empty());
}
#[test]
fn resolves_nested_with_wildcard_for_set_comparison() {
let issues = run(
"SELECT * FROM (WITH cte2 AS (SELECT a, b FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
);
assert!(issues.is_empty());
}
#[test]
fn flags_nested_with_wildcard_mismatch_for_set_comparison() {
let issues = run(
"SELECT * FROM (WITH cte2 AS (SELECT a FROM table2) SELECT * FROM cte2 as cte_al) UNION SELECT e, f FROM table3",
);
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
}
#[test]
fn resolves_nested_cte_chain_for_set_comparison() {
let issues = run(
"with a as (with b as (select 1 from c) select * from b) select * from a union all select k from t2",
);
assert!(issues.is_empty());
}
#[test]
fn resolves_nested_join_alias_wildcard_for_set_comparison() {
let issues = run(
"select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x, y from t3",
);
assert!(issues.is_empty());
}
#[test]
fn flags_nested_join_alias_wildcard_set_mismatch_when_resolved() {
let issues = run(
"select j.* from ((select a from t1) as a1 join (select b from t2) as b1 on a1.a = b1.b) as j union all select x from t3",
);
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
}
#[test]
fn resolves_nested_join_alias_using_width_for_set_comparison() {
let issues = run(
"select j.* from ((select a from t1) as a1 join (select a from t2) as b1 using(a)) as j union all select x from t3",
);
assert!(issues.is_empty());
}
#[test]
fn resolves_natural_join_nested_alias_width_for_set_comparison() {
let issues = run(
"select j.* from ((select a from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
);
assert!(issues.is_empty());
}
#[test]
fn natural_join_nested_alias_width_unknown_does_not_trigger_for_set_comparison() {
let issues = run(
"select j.* from ((select * from t1) as a1 natural join (select a from t2) as b1) as j union all select x from t3",
);
assert!(issues.is_empty());
}
#[test]
fn update_from_with_set_column_mismatch() {
let sql = "UPDATE sometable SET sometable.baz = mycte.bar FROM (SELECT foo, bar FROM mytable1 UNION ALL SELECT bar FROM mytable2) as k";
let issues = run(sql);
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::LINT_AM_007);
}
}