use std::collections::HashMap;
use crate::ast::*;
use crate::dialects::Dialect;
use crate::schema::{Schema, normalize_identifier};
pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
let dialect = schema.dialect();
match statement {
Statement::Select(sel) => {
let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
Statement::Select(qualified)
}
Statement::SetOperation(mut set_op) => {
set_op.left = Box::new(qualify_columns(*set_op.left, schema));
set_op.right = Box::new(qualify_columns(*set_op.right, schema));
Statement::SetOperation(set_op)
}
other => other,
}
}
#[derive(Debug, Clone)]
struct SourceColumns {
columns: Vec<String>,
}
fn resolve_source_columns<S: Schema>(
sel: &SelectStatement,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
) -> HashMap<String, SourceColumns> {
let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
if let Some(from) = &sel.from {
collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
}
for join in &sel.joins {
collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
}
source_map
}
fn collect_source_columns<S: Schema>(
source: &TableSource,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
source_map: &mut HashMap<String, SourceColumns>,
) {
match source {
TableSource::Table(table_ref) => {
let key = table_ref
.alias
.as_deref()
.unwrap_or(&table_ref.name)
.to_string();
let norm_key = normalize_identifier(&key, dialect);
let norm_name = normalize_identifier(&table_ref.name, dialect);
if let Some(cols) = cte_columns.get(&norm_name) {
source_map.insert(
norm_key,
SourceColumns {
columns: cols.clone(),
},
);
return;
}
let path = build_table_path(table_ref, dialect);
let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
if let Ok(cols) = schema.column_names(&path_refs) {
source_map.insert(norm_key, SourceColumns { columns: cols });
}
}
TableSource::Subquery { query, alias } => {
if let Some(alias) = alias {
let norm_alias = normalize_identifier(alias, dialect);
let cols = extract_output_columns(query, schema, dialect, cte_columns);
source_map.insert(norm_alias, SourceColumns { columns: cols });
}
}
TableSource::Lateral { source: inner } => {
collect_source_columns(inner, schema, dialect, cte_columns, source_map);
}
TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
collect_source_columns(source, schema, dialect, cte_columns, source_map);
if let Some(alias) = alias {
let norm_alias = normalize_identifier(alias, dialect);
source_map.insert(norm_alias, SourceColumns { columns: vec![] });
}
}
TableSource::Unnest { alias, .. } => {
if let Some(alias) = alias {
let norm_alias = normalize_identifier(alias, dialect);
source_map.insert(norm_alias, SourceColumns { columns: vec![] });
}
}
TableSource::TableFunction { alias, .. } => {
if let Some(alias) = alias {
let norm_alias = normalize_identifier(alias, dialect);
source_map.insert(norm_alias, SourceColumns { columns: vec![] });
}
}
}
}
fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
let mut path = Vec::new();
if let Some(cat) = &table_ref.catalog {
path.push(normalize_identifier(cat, dialect));
}
if let Some(sch) = &table_ref.schema {
path.push(normalize_identifier(sch, dialect));
}
path.push(normalize_identifier(&table_ref.name, dialect));
path
}
fn extract_output_columns<S: Schema>(
stmt: &Statement,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
) -> Vec<String> {
match stmt {
Statement::Select(sel) => {
let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
let mut cols = Vec::new();
for item in &sel.columns {
match item {
SelectItem::Wildcard => {
for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
cols.extend(sc.columns.iter().cloned());
});
}
SelectItem::QualifiedWildcard { table } => {
let norm_table = normalize_identifier(table, dialect);
if let Some(sc) = inner_sources.get(&norm_table) {
cols.extend(sc.columns.iter().cloned());
}
}
SelectItem::Expr { alias, expr } => {
if let Some(alias) = alias {
cols.push(alias.clone());
} else {
cols.push(expr_output_name(expr));
}
}
}
}
cols
}
Statement::SetOperation(set_op) => {
extract_output_columns(&set_op.left, schema, dialect, cte_columns)
}
_ => vec![],
}
}
fn expr_output_name(expr: &Expr) -> String {
match expr {
Expr::Column { name, .. } => name.clone(),
Expr::Function { name, .. } => name.clone(),
Expr::TypedFunction { .. } => "_col".to_string(),
_ => "_col".to_string(),
}
}
fn for_each_source_ordered<F>(
sel: &SelectStatement,
dialect: Dialect,
source_map: &HashMap<String, SourceColumns>,
mut callback: F,
) where
F: FnMut(&SourceColumns),
{
if let Some(from) = &sel.from {
let key = source_key_for(&from.source, dialect);
if let Some(sc) = source_map.get(&key) {
callback(sc);
}
}
for join in &sel.joins {
let key = source_key_for(&join.table, dialect);
if let Some(sc) = source_map.get(&key) {
callback(sc);
}
}
}
fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
match source {
TableSource::Table(tr) => {
let name = tr.alias.as_deref().unwrap_or(&tr.name);
normalize_identifier(name, dialect)
}
TableSource::Subquery { alias, .. } => alias
.as_deref()
.map(|a| normalize_identifier(a, dialect))
.unwrap_or_default(),
TableSource::Lateral { source } => source_key_for(source, dialect),
TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
if let Some(a) = alias {
normalize_identifier(a, dialect)
} else {
source_key_for(source, dialect)
}
}
TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
.as_deref()
.map(|a| normalize_identifier(a, dialect))
.unwrap_or_default(),
}
}
fn qualify_select<S: Schema>(
mut sel: SelectStatement,
schema: &S,
dialect: Dialect,
outer_cte_columns: &HashMap<String, Vec<String>>,
) -> SelectStatement {
let mut cte_columns = outer_cte_columns.clone();
for cte in &sel.ctes {
let cols = if !cte.columns.is_empty() {
cte.columns.clone()
} else {
extract_output_columns(&cte.query, schema, dialect, &cte_columns)
};
let norm_name = normalize_identifier(&cte.name, dialect);
cte_columns.insert(norm_name, cols);
}
sel.ctes = sel
.ctes
.into_iter()
.map(|mut cte| {
cte.query = Box::new(qualify_columns(*cte.query, schema));
cte
})
.collect();
if let Some(ref mut from) = sel.from {
qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
}
for join in &mut sel.joins {
qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
}
let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
let mut new_columns = Vec::new();
let old_columns = std::mem::take(&mut sel.columns);
for item in old_columns {
match item {
SelectItem::Wildcard => {
for_each_source_ordered(&sel, dialect, &source_map, |sc| {
for col_name in &sc.columns {
new_columns.push(SelectItem::Expr {
expr: Expr::Column {
table: None,
name: col_name.clone(),
quote_style: QuoteStyle::None,
table_quote_style: QuoteStyle::None,
},
alias: None,
});
}
});
}
SelectItem::QualifiedWildcard { table } => {
let norm_table = normalize_identifier(&table, dialect);
if let Some(sc) = source_map.get(&norm_table) {
for col_name in &sc.columns {
new_columns.push(SelectItem::Expr {
expr: Expr::Column {
table: Some(table.clone()),
name: col_name.clone(),
quote_style: QuoteStyle::None,
table_quote_style: QuoteStyle::None,
},
alias: None,
});
}
} else {
new_columns.push(SelectItem::QualifiedWildcard { table });
}
}
SelectItem::Expr { expr, alias } => {
let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
new_columns.push(SelectItem::Expr {
expr: qualified_expr,
alias,
});
}
}
}
sel.columns = new_columns;
if let Some(wh) = sel.where_clause {
sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
}
sel.group_by = sel
.group_by
.into_iter()
.map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
.collect();
if let Some(having) = sel.having {
sel.having = Some(qualify_expr(
having,
&source_map,
schema,
dialect,
&cte_columns,
));
}
sel.order_by = sel
.order_by
.into_iter()
.map(|mut item| {
item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
item
})
.collect();
if let Some(qualify) = sel.qualify {
sel.qualify = Some(qualify_expr(
qualify,
&source_map,
schema,
dialect,
&cte_columns,
));
}
for join in &mut sel.joins {
if let Some(on) = join.on.take() {
join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
}
}
sel
}
fn qualify_table_source<S: Schema>(
source: &mut TableSource,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
) {
match source {
TableSource::Subquery { query, .. } => {
*query = Box::new(qualify_columns_inner(
*query.clone(),
schema,
dialect,
cte_columns,
));
}
TableSource::Lateral { source: inner } => {
qualify_table_source(inner, schema, dialect, cte_columns);
}
TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
qualify_table_source(source, schema, dialect, cte_columns);
}
_ => {}
}
}
fn qualify_columns_inner<S: Schema>(
statement: Statement,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
) -> Statement {
match statement {
Statement::Select(sel) => {
Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
}
Statement::SetOperation(mut set_op) => {
set_op.left = Box::new(qualify_columns_inner(
*set_op.left,
schema,
dialect,
cte_columns,
));
set_op.right = Box::new(qualify_columns_inner(
*set_op.right,
schema,
dialect,
cte_columns,
));
Statement::SetOperation(set_op)
}
other => other,
}
}
fn qualify_expr<S: Schema>(
expr: Expr,
source_map: &HashMap<String, SourceColumns>,
schema: &S,
dialect: Dialect,
cte_columns: &HashMap<String, Vec<String>>,
) -> Expr {
expr.transform(&|e| match e {
Expr::Column {
table: None,
name,
quote_style,
table_quote_style,
} => {
let norm_name = normalize_identifier(&name, dialect);
let resolved_source = resolve_column(&norm_name, source_map);
if let Some(source_name) = resolved_source {
Expr::Column {
table: Some(source_name),
name,
quote_style,
table_quote_style,
}
} else {
Expr::Column {
table: None,
name,
quote_style,
table_quote_style,
}
}
}
Expr::InSubquery {
expr,
subquery,
negated,
} => Expr::InSubquery {
expr,
subquery: Box::new(qualify_columns_inner(
*subquery,
schema,
dialect,
cte_columns,
)),
negated,
},
Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
*stmt,
schema,
dialect,
cte_columns,
))),
Expr::Exists { subquery, negated } => Expr::Exists {
subquery: Box::new(qualify_columns_inner(
*subquery,
schema,
dialect,
cte_columns,
)),
negated,
},
other => other,
})
}
fn resolve_column(
norm_col_name: &str,
source_map: &HashMap<String, SourceColumns>,
) -> Option<String> {
let mut matches: Vec<&str> = Vec::new();
for (source_name, sc) in source_map {
if sc
.columns
.iter()
.any(|c| c.eq_ignore_ascii_case(norm_col_name))
{
matches.push(source_name);
}
}
if matches.len() == 1 {
Some(matches[0].to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::generate;
use crate::parser::parse;
use crate::schema::MappingSchema;
fn make_schema() -> MappingSchema {
let mut schema = MappingSchema::new(Dialect::Ansi);
schema
.add_table(
&["users"],
vec![
("id".to_string(), DataType::Int),
("name".to_string(), DataType::Varchar(Some(255))),
("email".to_string(), DataType::Text),
],
)
.unwrap();
schema
.add_table(
&["orders"],
vec![
("id".to_string(), DataType::Int),
("user_id".to_string(), DataType::Int),
(
"amount".to_string(),
DataType::Decimal {
precision: Some(10),
scale: Some(2),
},
),
("status".to_string(), DataType::Varchar(Some(50))),
],
)
.unwrap();
schema
.add_table(
&["products"],
vec![
("id".to_string(), DataType::Int),
("name".to_string(), DataType::Varchar(Some(255))),
(
"price".to_string(),
DataType::Decimal {
precision: Some(10),
scale: Some(2),
},
),
],
)
.unwrap();
schema
}
fn qualify(sql: &str, schema: &MappingSchema) -> String {
let stmt = parse(sql, Dialect::Ansi).unwrap();
let qualified = qualify_columns(stmt, schema);
generate(&qualified, Dialect::Ansi)
}
#[test]
fn test_expand_star() {
let schema = make_schema();
assert_eq!(
qualify("SELECT * FROM users", &schema),
"SELECT id, name, email FROM users"
);
}
#[test]
fn test_expand_qualified_wildcard() {
let schema = make_schema();
assert_eq!(
qualify("SELECT users.* FROM users", &schema),
"SELECT users.id, users.name, users.email FROM users"
);
}
#[test]
fn test_expand_star_with_alias() {
let schema = make_schema();
assert_eq!(
qualify("SELECT * FROM users AS u", &schema),
"SELECT id, name, email FROM users AS u"
);
}
#[test]
fn test_expand_qualified_wildcard_alias() {
let schema = make_schema();
assert_eq!(
qualify("SELECT u.* FROM users AS u", &schema),
"SELECT u.id, u.name, u.email FROM users AS u"
);
}
#[test]
fn test_qualify_unqualified_single_table() {
let schema = make_schema();
assert_eq!(
qualify("SELECT id, name FROM users", &schema),
"SELECT users.id, users.name FROM users"
);
}
#[test]
fn test_qualify_unqualified_single_table_alias() {
let schema = make_schema();
assert_eq!(
qualify("SELECT id, name FROM users AS u", &schema),
"SELECT u.id, u.name FROM users AS u"
);
}
#[test]
fn test_qualify_already_qualified() {
let schema = make_schema();
assert_eq!(
qualify("SELECT users.id, users.name FROM users", &schema),
"SELECT users.id, users.name FROM users"
);
}
#[test]
fn test_qualify_join_unambiguous() {
let schema = make_schema();
assert_eq!(
qualify(
"SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
&schema
),
"SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
);
}
#[test]
fn test_qualify_join_ambiguous_left_unqualified() {
let schema = make_schema();
let result = qualify(
"SELECT id FROM users JOIN orders ON users.id = orders.user_id",
&schema,
);
assert_eq!(
result,
"SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
);
}
#[test]
fn test_qualify_where_clause() {
let schema = make_schema();
assert_eq!(
qualify(
"SELECT name FROM users WHERE email = 'test@test.com'",
&schema
),
"SELECT users.name FROM users WHERE users.email = 'test@test.com'"
);
}
#[test]
fn test_qualify_order_by() {
let schema = make_schema();
assert_eq!(
qualify("SELECT name FROM users ORDER BY email", &schema),
"SELECT users.name FROM users ORDER BY users.email"
);
}
#[test]
fn test_qualify_group_by_having() {
let schema = make_schema();
assert_eq!(
qualify(
"SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
&schema
),
"SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
);
}
#[test]
fn test_expand_star_join() {
let schema = make_schema();
let result = qualify(
"SELECT * FROM users JOIN orders ON users.id = orders.user_id",
&schema,
);
assert_eq!(
result,
"SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
);
}
#[test]
fn test_cte_column_resolution() {
let schema = make_schema();
let result = qualify(
"WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
&schema,
);
assert_eq!(
result,
"WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
);
}
#[test]
fn test_derived_table_column_resolution() {
let schema = make_schema();
let result = qualify(
"SELECT id FROM (SELECT id, name FROM users) AS sub",
&schema,
);
assert_eq!(
result,
"SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
);
}
#[test]
fn test_preserve_expression_aliases() {
let schema = make_schema();
assert_eq!(
qualify("SELECT name AS user_name FROM users", &schema),
"SELECT users.name AS user_name FROM users"
);
}
#[test]
fn test_qualify_join_on() {
let schema = make_schema();
assert_eq!(
qualify(
"SELECT name FROM users JOIN orders ON id = user_id",
&schema
),
"SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
);
}
#[test]
fn test_no_schema_columns_passthrough() {
let schema = make_schema();
assert_eq!(
qualify("SELECT x, y FROM unknown_table", &schema),
"SELECT x, y FROM unknown_table"
);
}
}