use crate::dialects::transform_recursive;
use crate::dialects::DialectType;
use crate::expressions::{
Alias, BinaryOp, Column, Expression, Identifier, Join, LateralView, Literal, Over, Paren,
Select, TableRef, VarArgFunc, With,
};
use crate::resolver::{Resolver, ResolverError};
use crate::schema::{normalize_name, Schema};
use crate::scope::{build_scope, traverse_scope, Scope};
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error, Clone)]
pub enum QualifyColumnsError {
#[error("Unknown table: {0}")]
UnknownTable(String),
#[error("Unknown column: {0}")]
UnknownColumn(String),
#[error("Ambiguous column: {0}")]
AmbiguousColumn(String),
#[error("Cannot automatically join: {0}")]
CannotAutoJoin(String),
#[error("Unknown output column: {0}")]
UnknownOutputColumn(String),
#[error("Column could not be resolved: {column}{for_table}")]
ColumnNotResolved { column: String, for_table: String },
#[error("Resolver error: {0}")]
ResolverError(#[from] ResolverError),
}
pub type QualifyColumnsResult<T> = Result<T, QualifyColumnsError>;
#[derive(Debug, Clone, Default)]
pub struct QualifyColumnsOptions {
pub expand_alias_refs: bool,
pub expand_stars: bool,
pub infer_schema: Option<bool>,
pub allow_partial_qualification: bool,
pub dialect: Option<DialectType>,
}
impl QualifyColumnsOptions {
pub fn new() -> Self {
Self {
expand_alias_refs: true,
expand_stars: true,
infer_schema: None,
allow_partial_qualification: false,
dialect: None,
}
}
pub fn with_expand_alias_refs(mut self, expand: bool) -> Self {
self.expand_alias_refs = expand;
self
}
pub fn with_expand_stars(mut self, expand: bool) -> Self {
self.expand_stars = expand;
self
}
pub fn with_dialect(mut self, dialect: DialectType) -> Self {
self.dialect = Some(dialect);
self
}
pub fn with_allow_partial(mut self, allow: bool) -> Self {
self.allow_partial_qualification = allow;
self
}
}
pub fn qualify_columns(
expression: Expression,
schema: &dyn Schema,
options: &QualifyColumnsOptions,
) -> QualifyColumnsResult<Expression> {
let infer_schema = options.infer_schema.unwrap_or(schema.is_empty());
let dialect = options.dialect.or_else(|| schema.dialect());
let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
let transformed = transform_recursive(expression, &|node| {
if first_error.borrow().is_some() {
return Ok(node);
}
match node {
Expression::Select(mut select) => {
if let Some(with) = &mut select.with {
pushdown_cte_alias_columns_with(with);
}
let scope_expr = Expression::Select(select.clone());
let scope = build_scope(&scope_expr);
let mut resolver = Resolver::new(&scope, schema, infer_schema);
let column_tables = if first_error.borrow().is_none() {
match expand_using(&mut select, &scope, &mut resolver) {
Ok(ct) => ct,
Err(err) => {
*first_error.borrow_mut() = Some(err);
HashMap::new()
}
}
} else {
HashMap::new()
};
if first_error.borrow().is_none() {
if let Err(err) = qualify_columns_in_scope(
&mut select,
&scope,
&mut resolver,
options.allow_partial_qualification,
) {
*first_error.borrow_mut() = Some(err);
}
}
if first_error.borrow().is_none() && options.expand_alias_refs {
if let Err(err) = expand_alias_refs(&mut select, &mut resolver, dialect) {
*first_error.borrow_mut() = Some(err);
}
}
if first_error.borrow().is_none() && options.expand_stars {
if let Err(err) =
expand_stars(&mut select, &scope, &mut resolver, &column_tables)
{
*first_error.borrow_mut() = Some(err);
}
}
if first_error.borrow().is_none() {
if let Err(err) = qualify_outputs_select(&mut select) {
*first_error.borrow_mut() = Some(err);
}
}
if first_error.borrow().is_none() {
if let Err(err) = expand_group_by(&mut select, dialect) {
*first_error.borrow_mut() = Some(err);
}
}
Ok(Expression::Select(select))
}
_ => Ok(node),
}
})
.map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
if let Some(err) = first_error.into_inner() {
return Err(err);
}
Ok(transformed)
}
pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult<()> {
let mut all_unqualified = Vec::new();
for scope in traverse_scope(expression) {
if let Expression::Select(_) = &scope.expression {
let unqualified = get_unqualified_columns(&scope);
let external = get_external_columns(&scope);
if !external.is_empty() && !is_correlated_subquery(&scope) {
let first = &external[0];
let for_table = if first.table.is_some() {
format!(" for table: '{}'", first.table.as_ref().unwrap())
} else {
String::new()
};
return Err(QualifyColumnsError::ColumnNotResolved {
column: first.name.clone(),
for_table,
});
}
all_unqualified.extend(unqualified);
}
}
if !all_unqualified.is_empty() {
let first = &all_unqualified[0];
return Err(QualifyColumnsError::AmbiguousColumn(first.name.clone()));
}
Ok(())
}
fn get_source_name(expr: &Expression) -> Option<String> {
match expr {
Expression::Table(t) => Some(
t.alias
.as_ref()
.map(|a| a.name.clone())
.unwrap_or_else(|| t.name.name.clone()),
),
Expression::Subquery(sq) => sq.alias.as_ref().map(|a| a.name.clone()),
_ => None,
}
}
fn get_ordered_source_names(select: &Select) -> Vec<String> {
let mut ordered = Vec::new();
if let Some(from) = &select.from {
for expr in &from.expressions {
if let Some(name) = get_source_name(expr) {
ordered.push(name);
}
}
}
for join in &select.joins {
if let Some(name) = get_source_name(&join.this) {
ordered.push(name);
}
}
ordered
}
fn make_coalesce(column_name: &str, tables: &[String]) -> Expression {
let args: Vec<Expression> = tables
.iter()
.map(|t| Expression::qualified_column(t.as_str(), column_name))
.collect();
Expression::Coalesce(Box::new(VarArgFunc {
expressions: args,
original_name: None,
inferred_type: None,
}))
}
fn expand_using(
select: &mut Select,
_scope: &Scope,
resolver: &mut Resolver,
) -> QualifyColumnsResult<HashMap<String, Vec<String>>> {
let mut columns: HashMap<String, String> = HashMap::new();
let mut column_tables: HashMap<String, Vec<String>> = HashMap::new();
let join_names: HashSet<String> = select
.joins
.iter()
.filter_map(|j| get_source_name(&j.this))
.collect();
let all_ordered = get_ordered_source_names(select);
let mut ordered: Vec<String> = all_ordered
.iter()
.filter(|name| !join_names.contains(name.as_str()))
.cloned()
.collect();
if join_names.is_empty() {
return Ok(column_tables);
}
fn update_source_columns(
source_name: &str,
columns: &mut HashMap<String, String>,
resolver: &mut Resolver,
) {
if let Ok(source_cols) = resolver.get_source_columns(source_name) {
for col_name in source_cols {
columns
.entry(col_name)
.or_insert_with(|| source_name.to_string());
}
}
}
for source_name in &ordered {
update_source_columns(source_name, &mut columns, resolver);
}
for i in 0..select.joins.len() {
let source_table = ordered.last().cloned().unwrap_or_default();
if !source_table.is_empty() {
update_source_columns(&source_table, &mut columns, resolver);
}
let join_table = get_source_name(&select.joins[i].this).unwrap_or_default();
ordered.push(join_table.clone());
if select.joins[i].using.is_empty() {
continue;
}
let _join_columns: Vec<String> =
resolver.get_source_columns(&join_table).unwrap_or_default();
let using_identifiers: Vec<String> = select.joins[i]
.using
.iter()
.map(|id| id.name.clone())
.collect();
let using_count = using_identifiers.len();
let is_semi_or_anti = matches!(
select.joins[i].kind,
crate::expressions::JoinKind::Semi
| crate::expressions::JoinKind::Anti
| crate::expressions::JoinKind::LeftSemi
| crate::expressions::JoinKind::LeftAnti
| crate::expressions::JoinKind::RightSemi
| crate::expressions::JoinKind::RightAnti
);
let mut conditions: Vec<Expression> = Vec::new();
for identifier in &using_identifiers {
let table = columns
.get(identifier)
.cloned()
.unwrap_or_else(|| source_table.clone());
let lhs = if i == 0 || using_count == 1 {
Expression::qualified_column(table.as_str(), identifier.as_str())
} else {
let coalesce_cols: Vec<String> = ordered[..ordered.len() - 1]
.iter()
.filter(|t| {
resolver
.get_source_columns(t)
.unwrap_or_default()
.contains(identifier)
})
.cloned()
.collect();
if coalesce_cols.len() > 1 {
make_coalesce(identifier, &coalesce_cols)
} else {
Expression::qualified_column(table.as_str(), identifier.as_str())
}
};
let rhs = Expression::qualified_column(join_table.as_str(), identifier.as_str());
conditions.push(Expression::Eq(Box::new(BinaryOp::new(lhs, rhs))));
if !is_semi_or_anti {
let tables = column_tables
.entry(identifier.clone())
.or_insert_with(Vec::new);
if !tables.contains(&table) {
tables.push(table.clone());
}
if !tables.contains(&join_table) {
tables.push(join_table.clone());
}
}
}
let on_condition = conditions
.into_iter()
.reduce(|acc, cond| Expression::And(Box::new(BinaryOp::new(acc, cond))))
.expect("at least one USING column");
select.joins[i].on = Some(on_condition);
select.joins[i].using = vec![];
}
if !column_tables.is_empty() {
let mut new_expressions = Vec::with_capacity(select.expressions.len());
for expr in &select.expressions {
match expr {
Expression::Column(col)
if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
{
let tables = &column_tables[&col.name.name];
let coalesce = make_coalesce(&col.name.name, tables);
new_expressions.push(Expression::Alias(Box::new(Alias {
this: coalesce,
alias: Identifier::new(&col.name.name),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})));
}
_ => {
let mut rewritten = expr.clone();
rewrite_using_columns_in_expression(&mut rewritten, &column_tables);
new_expressions.push(rewritten);
}
}
}
select.expressions = new_expressions;
if let Some(where_clause) = &mut select.where_clause {
rewrite_using_columns_in_expression(&mut where_clause.this, &column_tables);
}
if let Some(group_by) = &mut select.group_by {
for expr in &mut group_by.expressions {
rewrite_using_columns_in_expression(expr, &column_tables);
}
}
if let Some(having) = &mut select.having {
rewrite_using_columns_in_expression(&mut having.this, &column_tables);
}
if let Some(qualify) = &mut select.qualify {
rewrite_using_columns_in_expression(&mut qualify.this, &column_tables);
}
if let Some(order_by) = &mut select.order_by {
for ordered in &mut order_by.expressions {
rewrite_using_columns_in_expression(&mut ordered.this, &column_tables);
}
}
}
Ok(column_tables)
}
fn rewrite_using_columns_in_expression(
expr: &mut Expression,
column_tables: &HashMap<String, Vec<String>>,
) {
let transformed = transform_recursive(expr.clone(), &|node| match node {
Expression::Column(col)
if col.table.is_none() && column_tables.contains_key(&col.name.name) =>
{
let tables = &column_tables[&col.name.name];
Ok(make_coalesce(&col.name.name, tables))
}
other => Ok(other),
});
if let Ok(next) = transformed {
*expr = next;
}
}
fn qualify_columns_in_scope(
select: &mut Select,
scope: &Scope,
resolver: &mut Resolver,
allow_partial: bool,
) -> QualifyColumnsResult<()> {
for expr in &mut select.expressions {
qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
}
if let Some(where_clause) = &mut select.where_clause {
qualify_columns_in_expression(&mut where_clause.this, scope, resolver, allow_partial)?;
}
if let Some(group_by) = &mut select.group_by {
for expr in &mut group_by.expressions {
qualify_columns_in_expression(expr, scope, resolver, allow_partial)?;
}
}
if let Some(having) = &mut select.having {
qualify_columns_in_expression(&mut having.this, scope, resolver, allow_partial)?;
}
if let Some(qualify) = &mut select.qualify {
qualify_columns_in_expression(&mut qualify.this, scope, resolver, allow_partial)?;
}
if let Some(order_by) = &mut select.order_by {
for ordered in &mut order_by.expressions {
qualify_columns_in_expression(&mut ordered.this, scope, resolver, allow_partial)?;
}
}
for join in &mut select.joins {
qualify_columns_in_expression(&mut join.this, scope, resolver, allow_partial)?;
if let Some(on) = &mut join.on {
qualify_columns_in_expression(on, scope, resolver, allow_partial)?;
}
}
Ok(())
}
fn expand_alias_refs(
select: &mut Select,
_resolver: &mut Resolver,
_dialect: Option<DialectType>,
) -> QualifyColumnsResult<()> {
let mut alias_to_expression: HashMap<String, (Expression, usize)> = HashMap::new();
for (i, expr) in select.expressions.iter_mut().enumerate() {
replace_alias_refs_in_expression(expr, &alias_to_expression, false);
if let Expression::Alias(alias) = expr {
alias_to_expression.insert(alias.alias.name.clone(), (alias.this.clone(), i + 1));
}
}
if let Some(where_clause) = &mut select.where_clause {
replace_alias_refs_in_expression(&mut where_clause.this, &alias_to_expression, false);
}
if let Some(group_by) = &mut select.group_by {
for expr in &mut group_by.expressions {
replace_alias_refs_in_expression(expr, &alias_to_expression, true);
}
}
if let Some(having) = &mut select.having {
replace_alias_refs_in_expression(&mut having.this, &alias_to_expression, false);
}
if let Some(qualify) = &mut select.qualify {
replace_alias_refs_in_expression(&mut qualify.this, &alias_to_expression, false);
}
if let Some(order_by) = &mut select.order_by {
for ordered in &mut order_by.expressions {
replace_alias_refs_in_expression(&mut ordered.this, &alias_to_expression, false);
}
}
Ok(())
}
fn expand_group_by(select: &mut Select, _dialect: Option<DialectType>) -> QualifyColumnsResult<()> {
let projections = select.expressions.clone();
if let Some(group_by) = &mut select.group_by {
for group_expr in &mut group_by.expressions {
if let Some(index) = positional_reference(group_expr) {
let replacement = select_expression_at_position(&projections, index)?;
*group_expr = replacement;
}
}
}
Ok(())
}
fn expand_stars(
select: &mut Select,
_scope: &Scope,
resolver: &mut Resolver,
column_tables: &HashMap<String, Vec<String>>,
) -> QualifyColumnsResult<()> {
let mut new_selections: Vec<Expression> = Vec::new();
let mut has_star = false;
let mut coalesced_columns: HashSet<String> = HashSet::new();
let ordered_sources = get_ordered_source_names(select);
for expr in &select.expressions {
match expr {
Expression::Star(_) => {
has_star = true;
for source_name in &ordered_sources {
if let Ok(columns) = resolver.get_source_columns(source_name) {
if columns.contains(&"*".to_string()) || columns.is_empty() {
return Ok(());
}
for col_name in &columns {
if coalesced_columns.contains(col_name) {
continue;
}
if let Some(tables) = column_tables.get(col_name) {
if tables.contains(source_name) {
coalesced_columns.insert(col_name.clone());
let coalesce = make_coalesce(col_name, tables);
new_selections.push(Expression::Alias(Box::new(Alias {
this: coalesce,
alias: Identifier::new(col_name),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})));
continue;
}
}
new_selections
.push(create_qualified_column(col_name, Some(source_name)));
}
}
}
}
Expression::Column(col) if is_star_column(col) => {
has_star = true;
if let Some(table) = &col.table {
let table_name = &table.name;
if !ordered_sources.contains(table_name) {
return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
}
if let Ok(columns) = resolver.get_source_columns(table_name) {
if columns.contains(&"*".to_string()) || columns.is_empty() {
return Ok(());
}
for col_name in &columns {
if coalesced_columns.contains(col_name) {
continue;
}
if let Some(tables) = column_tables.get(col_name) {
if tables.contains(table_name) {
coalesced_columns.insert(col_name.clone());
let coalesce = make_coalesce(col_name, tables);
new_selections.push(Expression::Alias(Box::new(Alias {
this: coalesce,
alias: Identifier::new(col_name),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})));
continue;
}
}
new_selections
.push(create_qualified_column(col_name, Some(table_name)));
}
}
}
}
_ => new_selections.push(expr.clone()),
}
}
if has_star {
select.expressions = new_selections;
}
Ok(())
}
pub fn qualify_outputs(scope: &Scope) -> QualifyColumnsResult<()> {
if let Expression::Select(mut select) = scope.expression.clone() {
qualify_outputs_select(&mut select)?;
}
Ok(())
}
fn qualify_outputs_select(select: &mut Select) -> QualifyColumnsResult<()> {
let mut new_selections: Vec<Expression> = Vec::new();
for (i, expr) in select.expressions.iter().enumerate() {
match expr {
Expression::Alias(_) => new_selections.push(expr.clone()),
Expression::Column(col) => {
new_selections.push(create_alias(expr.clone(), &col.name.name));
}
Expression::Star(_) => new_selections.push(expr.clone()),
_ => {
let alias_name = get_output_name(expr).unwrap_or_else(|| format!("_col_{}", i));
new_selections.push(create_alias(expr.clone(), &alias_name));
}
}
}
select.expressions = new_selections;
Ok(())
}
fn qualify_columns_in_expression(
expr: &mut Expression,
scope: &Scope,
resolver: &mut Resolver,
allow_partial: bool,
) -> QualifyColumnsResult<()> {
let first_error: RefCell<Option<QualifyColumnsError>> = RefCell::new(None);
let resolver_cell: RefCell<&mut Resolver> = RefCell::new(resolver);
let transformed = transform_recursive(expr.clone(), &|node| {
if first_error.borrow().is_some() {
return Ok(node);
}
match node {
Expression::Column(mut col) => {
if let Err(err) = qualify_single_column(
&mut col,
scope,
&mut resolver_cell.borrow_mut(),
allow_partial,
) {
*first_error.borrow_mut() = Some(err);
}
Ok(Expression::Column(col))
}
_ => Ok(node),
}
})
.map_err(|err| QualifyColumnsError::CannotAutoJoin(err.to_string()))?;
if let Some(err) = first_error.into_inner() {
return Err(err);
}
*expr = transformed;
Ok(())
}
fn qualify_single_column(
col: &mut Column,
scope: &Scope,
resolver: &mut Resolver,
allow_partial: bool,
) -> QualifyColumnsResult<()> {
if is_star_column(col) {
return Ok(());
}
if let Some(table) = &col.table {
let table_name = &table.name;
if !scope.sources.contains_key(table_name) {
if resolver.table_exists_in_schema(table_name) {
return Ok(());
}
return Err(QualifyColumnsError::UnknownTable(table_name.clone()));
}
if let Ok(source_columns) = resolver.get_source_columns(table_name) {
let normalized_column_name = normalize_column_name(&col.name.name, resolver.dialect);
if !allow_partial
&& !source_columns.is_empty()
&& !source_columns.iter().any(|column| {
normalize_column_name(column, resolver.dialect) == normalized_column_name
})
&& !source_columns.contains(&"*".to_string())
{
return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
}
}
return Ok(());
}
if let Some(table_name) = resolver.get_table(&col.name.name) {
col.table = Some(Identifier::new(table_name));
return Ok(());
}
if let Some(outer_table) = resolver.find_column_in_outer_schema_tables(&col.name.name) {
col.table = Some(Identifier::new(outer_table));
return Ok(());
}
if !allow_partial {
return Err(QualifyColumnsError::UnknownColumn(col.name.name.clone()));
}
Ok(())
}
fn normalize_column_name(name: &str, dialect: Option<DialectType>) -> String {
normalize_name(name, dialect, false, true)
}
fn replace_alias_refs_in_expression(
expr: &mut Expression,
alias_to_expression: &HashMap<String, (Expression, usize)>,
literal_index: bool,
) {
let transformed = transform_recursive(expr.clone(), &|node| match node {
Expression::Column(col) if col.table.is_none() => {
if let Some((alias_expr, index)) = alias_to_expression.get(&col.name.name) {
if literal_index && matches!(alias_expr, Expression::Literal(_)) {
return Ok(Expression::number(*index as i64));
}
return Ok(Expression::Paren(Box::new(Paren {
this: alias_expr.clone(),
trailing_comments: vec![],
})));
}
Ok(Expression::Column(col))
}
other => Ok(other),
});
if let Ok(next) = transformed {
*expr = next;
}
}
fn positional_reference(expr: &Expression) -> Option<usize> {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
let Literal::Number(value) = lit.as_ref() else {
unreachable!()
};
value.parse::<usize>().ok()
}
_ => None,
}
}
fn select_expression_at_position(
projections: &[Expression],
index: usize,
) -> QualifyColumnsResult<Expression> {
if index == 0 || index > projections.len() {
return Err(QualifyColumnsError::UnknownOutputColumn(index.to_string()));
}
let projection = projections[index - 1].clone();
Ok(match projection {
Expression::Alias(alias) => alias.this.clone(),
other => other,
})
}
fn get_reserved_words(dialect: Option<DialectType>) -> HashSet<&'static str> {
let mut words: HashSet<&'static str> = [
"ADD",
"ALL",
"ALTER",
"AND",
"ANY",
"AS",
"ASC",
"BETWEEN",
"BY",
"CASE",
"CAST",
"CHECK",
"COLUMN",
"CONSTRAINT",
"CREATE",
"CROSS",
"CURRENT",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"DATABASE",
"DEFAULT",
"DELETE",
"DESC",
"DISTINCT",
"DROP",
"ELSE",
"END",
"ESCAPE",
"EXCEPT",
"EXISTS",
"FALSE",
"FETCH",
"FOR",
"FOREIGN",
"FROM",
"FULL",
"GRANT",
"GROUP",
"HAVING",
"IF",
"IN",
"INDEX",
"INNER",
"INSERT",
"INTERSECT",
"INTO",
"IS",
"JOIN",
"KEY",
"LEFT",
"LIKE",
"LIMIT",
"NATURAL",
"NOT",
"NULL",
"OFFSET",
"ON",
"OR",
"ORDER",
"OUTER",
"PRIMARY",
"REFERENCES",
"REPLACE",
"RETURNING",
"RIGHT",
"ROLLBACK",
"ROW",
"ROWS",
"SELECT",
"SESSION_USER",
"SET",
"SOME",
"TABLE",
"THEN",
"TO",
"TRUE",
"TRUNCATE",
"UNION",
"UNIQUE",
"UPDATE",
"USING",
"VALUES",
"VIEW",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
]
.iter()
.copied()
.collect();
match dialect {
Some(DialectType::MySQL) => {
words.extend(
[
"ANALYZE",
"BOTH",
"CHANGE",
"CONDITION",
"DATABASES",
"DAY_HOUR",
"DAY_MICROSECOND",
"DAY_MINUTE",
"DAY_SECOND",
"DELAYED",
"DETERMINISTIC",
"DIV",
"DUAL",
"EACH",
"ELSEIF",
"ENCLOSED",
"EXPLAIN",
"FLOAT4",
"FLOAT8",
"FORCE",
"HOUR_MICROSECOND",
"HOUR_MINUTE",
"HOUR_SECOND",
"IGNORE",
"INFILE",
"INT1",
"INT2",
"INT3",
"INT4",
"INT8",
"ITERATE",
"KEYS",
"KILL",
"LEADING",
"LEAVE",
"LINES",
"LOAD",
"LOCK",
"LONG",
"LONGBLOB",
"LONGTEXT",
"LOOP",
"LOW_PRIORITY",
"MATCH",
"MEDIUMBLOB",
"MEDIUMINT",
"MEDIUMTEXT",
"MINUTE_MICROSECOND",
"MINUTE_SECOND",
"MOD",
"MODIFIES",
"NO_WRITE_TO_BINLOG",
"OPTIMIZE",
"OPTIONALLY",
"OUT",
"OUTFILE",
"PURGE",
"READS",
"REGEXP",
"RELEASE",
"RENAME",
"REPEAT",
"REQUIRE",
"RESIGNAL",
"RETURN",
"REVOKE",
"RLIKE",
"SCHEMA",
"SCHEMAS",
"SECOND_MICROSECOND",
"SENSITIVE",
"SEPARATOR",
"SHOW",
"SIGNAL",
"SPATIAL",
"SQL",
"SQLEXCEPTION",
"SQLSTATE",
"SQLWARNING",
"SQL_BIG_RESULT",
"SQL_CALC_FOUND_ROWS",
"SQL_SMALL_RESULT",
"SSL",
"STARTING",
"STRAIGHT_JOIN",
"TERMINATED",
"TINYBLOB",
"TINYINT",
"TINYTEXT",
"TRAILING",
"TRIGGER",
"UNDO",
"UNLOCK",
"UNSIGNED",
"USAGE",
"UTC_DATE",
"UTC_TIME",
"UTC_TIMESTAMP",
"VARBINARY",
"VARCHARACTER",
"WHILE",
"WRITE",
"XOR",
"YEAR_MONTH",
"ZEROFILL",
]
.iter()
.copied(),
);
}
Some(DialectType::PostgreSQL) | Some(DialectType::CockroachDB) => {
words.extend(
[
"ANALYSE",
"ANALYZE",
"ARRAY",
"AUTHORIZATION",
"BINARY",
"BOTH",
"COLLATE",
"CONCURRENTLY",
"DO",
"FREEZE",
"ILIKE",
"INITIALLY",
"ISNULL",
"LATERAL",
"LEADING",
"LOCALTIME",
"LOCALTIMESTAMP",
"NOTNULL",
"ONLY",
"OVERLAPS",
"PLACING",
"SIMILAR",
"SYMMETRIC",
"TABLESAMPLE",
"TRAILING",
"VARIADIC",
"VERBOSE",
]
.iter()
.copied(),
);
}
Some(DialectType::BigQuery) => {
words.extend(
[
"ASSERT_ROWS_MODIFIED",
"COLLATE",
"CONTAINS",
"CUBE",
"DEFINE",
"ENUM",
"EXTRACT",
"FOLLOWING",
"GROUPING",
"GROUPS",
"HASH",
"IGNORE",
"LATERAL",
"LOOKUP",
"MERGE",
"NEW",
"NO",
"NULLS",
"OF",
"OVER",
"PARTITION",
"PRECEDING",
"PROTO",
"RANGE",
"RECURSIVE",
"RESPECT",
"ROLLUP",
"STRUCT",
"TABLESAMPLE",
"TREAT",
"UNBOUNDED",
"WITHIN",
]
.iter()
.copied(),
);
}
Some(DialectType::Snowflake) => {
words.extend(
[
"ACCOUNT",
"BOTH",
"CONNECT",
"FOLLOWING",
"ILIKE",
"INCREMENT",
"ISSUE",
"LATERAL",
"LEADING",
"LOCALTIME",
"LOCALTIMESTAMP",
"MINUS",
"QUALIFY",
"REGEXP",
"RLIKE",
"SOME",
"START",
"TABLESAMPLE",
"TOP",
"TRAILING",
"TRY_CAST",
]
.iter()
.copied(),
);
}
Some(DialectType::TSQL) | Some(DialectType::Fabric) => {
words.extend(
[
"BACKUP",
"BREAK",
"BROWSE",
"BULK",
"CASCADE",
"CHECKPOINT",
"CLOSE",
"CLUSTERED",
"COALESCE",
"COMPUTE",
"CONTAINS",
"CONTAINSTABLE",
"CONTINUE",
"CONVERT",
"DBCC",
"DEALLOCATE",
"DENY",
"DISK",
"DISTRIBUTED",
"DUMP",
"ERRLVL",
"EXEC",
"EXECUTE",
"EXIT",
"EXTERNAL",
"FILE",
"FILLFACTOR",
"FREETEXT",
"FREETEXTTABLE",
"FUNCTION",
"GOTO",
"HOLDLOCK",
"IDENTITY",
"IDENTITYCOL",
"IDENTITY_INSERT",
"KILL",
"LINENO",
"MERGE",
"NONCLUSTERED",
"NULLIF",
"OF",
"OFF",
"OFFSETS",
"OPEN",
"OPENDATASOURCE",
"OPENQUERY",
"OPENROWSET",
"OPENXML",
"OVER",
"PERCENT",
"PIVOT",
"PLAN",
"PRINT",
"PROC",
"PROCEDURE",
"PUBLIC",
"RAISERROR",
"READ",
"READTEXT",
"RECONFIGURE",
"REPLICATION",
"RESTORE",
"RESTRICT",
"REVERT",
"ROWCOUNT",
"ROWGUIDCOL",
"RULE",
"SAVE",
"SECURITYAUDIT",
"SEMANTICKEYPHRASETABLE",
"SEMANTICSIMILARITYDETAILSTABLE",
"SEMANTICSIMILARITYTABLE",
"SETUSER",
"SHUTDOWN",
"STATISTICS",
"SYSTEM_USER",
"TEXTSIZE",
"TOP",
"TRAN",
"TRANSACTION",
"TRIGGER",
"TSEQUAL",
"UNPIVOT",
"UPDATETEXT",
"WAITFOR",
"WRITETEXT",
]
.iter()
.copied(),
);
}
Some(DialectType::ClickHouse) => {
words.extend(
[
"ANTI",
"ARRAY",
"ASOF",
"FINAL",
"FORMAT",
"GLOBAL",
"INF",
"KILL",
"MATERIALIZED",
"NAN",
"PREWHERE",
"SAMPLE",
"SEMI",
"SETTINGS",
"TOP",
]
.iter()
.copied(),
);
}
Some(DialectType::DuckDB) => {
words.extend(
[
"ANALYSE",
"ANALYZE",
"ARRAY",
"BOTH",
"LATERAL",
"LEADING",
"LOCALTIME",
"LOCALTIMESTAMP",
"PLACING",
"QUALIFY",
"SIMILAR",
"TABLESAMPLE",
"TRAILING",
]
.iter()
.copied(),
);
}
Some(DialectType::Hive) | Some(DialectType::Spark) | Some(DialectType::Databricks) => {
words.extend(
[
"BOTH",
"CLUSTER",
"DISTRIBUTE",
"EXCHANGE",
"EXTENDED",
"FUNCTION",
"LATERAL",
"LEADING",
"MACRO",
"OVER",
"PARTITION",
"PERCENT",
"RANGE",
"READS",
"REDUCE",
"REGEXP",
"REVOKE",
"RLIKE",
"ROLLUP",
"SEMI",
"SORT",
"TABLESAMPLE",
"TRAILING",
"TRANSFORM",
"UNBOUNDED",
"UNIQUEJOIN",
]
.iter()
.copied(),
);
}
Some(DialectType::Trino) | Some(DialectType::Presto) | Some(DialectType::Athena) => {
words.extend(
[
"CUBE",
"DEALLOCATE",
"DESCRIBE",
"EXECUTE",
"EXTRACT",
"GROUPING",
"LATERAL",
"LOCALTIME",
"LOCALTIMESTAMP",
"NORMALIZE",
"PREPARE",
"ROLLUP",
"SOME",
"TABLESAMPLE",
"UESCAPE",
"UNNEST",
]
.iter()
.copied(),
);
}
Some(DialectType::Oracle) => {
words.extend(
[
"ACCESS",
"AUDIT",
"CLUSTER",
"COMMENT",
"COMPRESS",
"CONNECT",
"EXCLUSIVE",
"FILE",
"IDENTIFIED",
"IMMEDIATE",
"INCREMENT",
"INITIAL",
"LEVEL",
"LOCK",
"LONG",
"MAXEXTENTS",
"MINUS",
"MODE",
"NOAUDIT",
"NOCOMPRESS",
"NOWAIT",
"NUMBER",
"OF",
"OFFLINE",
"ONLINE",
"PCTFREE",
"PRIOR",
"RAW",
"RENAME",
"RESOURCE",
"REVOKE",
"SHARE",
"SIZE",
"START",
"SUCCESSFUL",
"SYNONYM",
"SYSDATE",
"TRIGGER",
"UID",
"VALIDATE",
"VARCHAR2",
"WHENEVER",
]
.iter()
.copied(),
);
}
Some(DialectType::Redshift) => {
words.extend(
[
"AZ64",
"BZIP2",
"DELTA",
"DELTA32K",
"DISTSTYLE",
"ENCODE",
"GZIP",
"ILIKE",
"LIMIT",
"LUNS",
"LZO",
"LZOP",
"MOSTLY13",
"MOSTLY32",
"MOSTLY8",
"RAW",
"SIMILAR",
"SNAPSHOT",
"SORTKEY",
"SYSDATE",
"TOP",
"ZSTD",
]
.iter()
.copied(),
);
}
_ => {
words.extend(
[
"ANALYZE",
"ARRAY",
"BOTH",
"CUBE",
"GROUPING",
"LATERAL",
"LEADING",
"LOCALTIME",
"LOCALTIMESTAMP",
"OVER",
"PARTITION",
"QUALIFY",
"RANGE",
"ROLLUP",
"SIMILAR",
"SOME",
"TABLESAMPLE",
"TRAILING",
]
.iter()
.copied(),
);
}
}
words
}
fn needs_quoting(name: &str, reserved_words: &HashSet<&str>) -> bool {
if name.is_empty() {
return false;
}
if name.as_bytes()[0].is_ascii_digit() {
return true;
}
if !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
return true;
}
let upper = name.to_uppercase();
reserved_words.contains(upper.as_str())
}
fn maybe_quote(id: &mut Identifier, reserved_words: &HashSet<&str>) {
if id.quoted || id.name.is_empty() || id.name == "*" {
return;
}
if needs_quoting(&id.name, reserved_words) {
id.quoted = true;
}
}
fn quote_identifiers_recursive(expr: &mut Expression, reserved_words: &HashSet<&str>) {
match expr {
Expression::Identifier(id) => {
maybe_quote(id, reserved_words);
}
Expression::Column(col) => {
maybe_quote(&mut col.name, reserved_words);
if let Some(ref mut table) = col.table {
maybe_quote(table, reserved_words);
}
}
Expression::Table(table_ref) => {
maybe_quote(&mut table_ref.name, reserved_words);
if let Some(ref mut schema) = table_ref.schema {
maybe_quote(schema, reserved_words);
}
if let Some(ref mut catalog) = table_ref.catalog {
maybe_quote(catalog, reserved_words);
}
if let Some(ref mut alias) = table_ref.alias {
maybe_quote(alias, reserved_words);
}
for ca in &mut table_ref.column_aliases {
maybe_quote(ca, reserved_words);
}
for p in &mut table_ref.partitions {
maybe_quote(p, reserved_words);
}
for h in &mut table_ref.hints {
quote_identifiers_recursive(h, reserved_words);
}
if let Some(ref mut ver) = table_ref.version {
quote_identifiers_recursive(&mut ver.this, reserved_words);
if let Some(ref mut e) = ver.expression {
quote_identifiers_recursive(e, reserved_words);
}
}
}
Expression::Star(star) => {
if let Some(ref mut table) = star.table {
maybe_quote(table, reserved_words);
}
if let Some(ref mut except_ids) = star.except {
for id in except_ids {
maybe_quote(id, reserved_words);
}
}
if let Some(ref mut replace_aliases) = star.replace {
for alias in replace_aliases {
maybe_quote(&mut alias.alias, reserved_words);
quote_identifiers_recursive(&mut alias.this, reserved_words);
}
}
if let Some(ref mut rename_pairs) = star.rename {
for (from, to) in rename_pairs {
maybe_quote(from, reserved_words);
maybe_quote(to, reserved_words);
}
}
}
Expression::Alias(alias) => {
maybe_quote(&mut alias.alias, reserved_words);
for ca in &mut alias.column_aliases {
maybe_quote(ca, reserved_words);
}
quote_identifiers_recursive(&mut alias.this, reserved_words);
}
Expression::Select(select) => {
for e in &mut select.expressions {
quote_identifiers_recursive(e, reserved_words);
}
if let Some(ref mut from) = select.from {
for e in &mut from.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
for join in &mut select.joins {
quote_join(join, reserved_words);
}
for lv in &mut select.lateral_views {
quote_lateral_view(lv, reserved_words);
}
if let Some(ref mut prewhere) = select.prewhere {
quote_identifiers_recursive(prewhere, reserved_words);
}
if let Some(ref mut wh) = select.where_clause {
quote_identifiers_recursive(&mut wh.this, reserved_words);
}
if let Some(ref mut gb) = select.group_by {
for e in &mut gb.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut hv) = select.having {
quote_identifiers_recursive(&mut hv.this, reserved_words);
}
if let Some(ref mut q) = select.qualify {
quote_identifiers_recursive(&mut q.this, reserved_words);
}
if let Some(ref mut ob) = select.order_by {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
if let Some(ref mut lim) = select.limit {
quote_identifiers_recursive(&mut lim.this, reserved_words);
}
if let Some(ref mut off) = select.offset {
quote_identifiers_recursive(&mut off.this, reserved_words);
}
if let Some(ref mut with) = select.with {
quote_with(with, reserved_words);
}
if let Some(ref mut windows) = select.windows {
for nw in windows {
maybe_quote(&mut nw.name, reserved_words);
quote_over(&mut nw.spec, reserved_words);
}
}
if let Some(ref mut distinct_on) = select.distinct_on {
for e in distinct_on {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut limit_by) = select.limit_by {
for e in limit_by {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut settings) = select.settings {
for e in settings {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut format) = select.format {
quote_identifiers_recursive(format, reserved_words);
}
}
Expression::Union(u) => {
quote_identifiers_recursive(&mut u.left, reserved_words);
quote_identifiers_recursive(&mut u.right, reserved_words);
if let Some(ref mut with) = u.with {
quote_with(with, reserved_words);
}
if let Some(ref mut ob) = u.order_by {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
if let Some(ref mut lim) = u.limit {
quote_identifiers_recursive(lim, reserved_words);
}
if let Some(ref mut off) = u.offset {
quote_identifiers_recursive(off, reserved_words);
}
}
Expression::Intersect(i) => {
quote_identifiers_recursive(&mut i.left, reserved_words);
quote_identifiers_recursive(&mut i.right, reserved_words);
if let Some(ref mut with) = i.with {
quote_with(with, reserved_words);
}
if let Some(ref mut ob) = i.order_by {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
}
Expression::Except(e) => {
quote_identifiers_recursive(&mut e.left, reserved_words);
quote_identifiers_recursive(&mut e.right, reserved_words);
if let Some(ref mut with) = e.with {
quote_with(with, reserved_words);
}
if let Some(ref mut ob) = e.order_by {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
}
Expression::Subquery(sq) => {
quote_identifiers_recursive(&mut sq.this, reserved_words);
if let Some(ref mut alias) = sq.alias {
maybe_quote(alias, reserved_words);
}
for ca in &mut sq.column_aliases {
maybe_quote(ca, reserved_words);
}
if let Some(ref mut ob) = sq.order_by {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
}
Expression::Insert(ins) => {
quote_table_ref(&mut ins.table, reserved_words);
for c in &mut ins.columns {
maybe_quote(c, reserved_words);
}
for row in &mut ins.values {
for e in row {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut q) = ins.query {
quote_identifiers_recursive(q, reserved_words);
}
for (id, val) in &mut ins.partition {
maybe_quote(id, reserved_words);
if let Some(ref mut v) = val {
quote_identifiers_recursive(v, reserved_words);
}
}
for e in &mut ins.returning {
quote_identifiers_recursive(e, reserved_words);
}
if let Some(ref mut on_conflict) = ins.on_conflict {
quote_identifiers_recursive(on_conflict, reserved_words);
}
if let Some(ref mut with) = ins.with {
quote_with(with, reserved_words);
}
if let Some(ref mut alias) = ins.alias {
maybe_quote(alias, reserved_words);
}
if let Some(ref mut src_alias) = ins.source_alias {
maybe_quote(src_alias, reserved_words);
}
}
Expression::Update(upd) => {
quote_table_ref(&mut upd.table, reserved_words);
for tr in &mut upd.extra_tables {
quote_table_ref(tr, reserved_words);
}
for join in &mut upd.table_joins {
quote_join(join, reserved_words);
}
for (id, val) in &mut upd.set {
maybe_quote(id, reserved_words);
quote_identifiers_recursive(val, reserved_words);
}
if let Some(ref mut from) = upd.from_clause {
for e in &mut from.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
for join in &mut upd.from_joins {
quote_join(join, reserved_words);
}
if let Some(ref mut wh) = upd.where_clause {
quote_identifiers_recursive(&mut wh.this, reserved_words);
}
for e in &mut upd.returning {
quote_identifiers_recursive(e, reserved_words);
}
if let Some(ref mut with) = upd.with {
quote_with(with, reserved_words);
}
}
Expression::Delete(del) => {
quote_table_ref(&mut del.table, reserved_words);
if let Some(ref mut alias) = del.alias {
maybe_quote(alias, reserved_words);
}
for tr in &mut del.using {
quote_table_ref(tr, reserved_words);
}
if let Some(ref mut wh) = del.where_clause {
quote_identifiers_recursive(&mut wh.this, reserved_words);
}
if let Some(ref mut with) = del.with {
quote_with(with, reserved_words);
}
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin)
| Expression::Mod(bin)
| Expression::BitwiseAnd(bin)
| Expression::BitwiseOr(bin)
| Expression::BitwiseXor(bin)
| Expression::Concat(bin)
| Expression::Adjacent(bin)
| Expression::TsMatch(bin)
| Expression::PropertyEQ(bin)
| Expression::ArrayContainsAll(bin)
| Expression::ArrayContainedBy(bin)
| Expression::ArrayOverlaps(bin)
| Expression::JSONBContainsAllTopKeys(bin)
| Expression::JSONBContainsAnyTopKeys(bin)
| Expression::JSONBDeleteAtPath(bin)
| Expression::ExtendsLeft(bin)
| Expression::ExtendsRight(bin)
| Expression::Is(bin)
| Expression::NullSafeEq(bin)
| Expression::NullSafeNeq(bin)
| Expression::Glob(bin)
| Expression::Match(bin)
| Expression::MemberOf(bin)
| Expression::BitwiseLeftShift(bin)
| Expression::BitwiseRightShift(bin) => {
quote_identifiers_recursive(&mut bin.left, reserved_words);
quote_identifiers_recursive(&mut bin.right, reserved_words);
}
Expression::Like(like) | Expression::ILike(like) => {
quote_identifiers_recursive(&mut like.left, reserved_words);
quote_identifiers_recursive(&mut like.right, reserved_words);
if let Some(ref mut esc) = like.escape {
quote_identifiers_recursive(esc, reserved_words);
}
}
Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
quote_identifiers_recursive(&mut un.this, reserved_words);
}
Expression::In(in_expr) => {
quote_identifiers_recursive(&mut in_expr.this, reserved_words);
for e in &mut in_expr.expressions {
quote_identifiers_recursive(e, reserved_words);
}
if let Some(ref mut q) = in_expr.query {
quote_identifiers_recursive(q, reserved_words);
}
if let Some(ref mut un) = in_expr.unnest {
quote_identifiers_recursive(un, reserved_words);
}
}
Expression::Between(bw) => {
quote_identifiers_recursive(&mut bw.this, reserved_words);
quote_identifiers_recursive(&mut bw.low, reserved_words);
quote_identifiers_recursive(&mut bw.high, reserved_words);
}
Expression::IsNull(is_null) => {
quote_identifiers_recursive(&mut is_null.this, reserved_words);
}
Expression::IsTrue(is_tf) | Expression::IsFalse(is_tf) => {
quote_identifiers_recursive(&mut is_tf.this, reserved_words);
}
Expression::Exists(ex) => {
quote_identifiers_recursive(&mut ex.this, reserved_words);
}
Expression::Function(func) => {
for arg in &mut func.args {
quote_identifiers_recursive(arg, reserved_words);
}
}
Expression::AggregateFunction(agg) => {
for arg in &mut agg.args {
quote_identifiers_recursive(arg, reserved_words);
}
if let Some(ref mut filter) = agg.filter {
quote_identifiers_recursive(filter, reserved_words);
}
for o in &mut agg.order_by {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
Expression::WindowFunction(wf) => {
quote_identifiers_recursive(&mut wf.this, reserved_words);
quote_over(&mut wf.over, reserved_words);
}
Expression::Case(case) => {
if let Some(ref mut operand) = case.operand {
quote_identifiers_recursive(operand, reserved_words);
}
for (when, then) in &mut case.whens {
quote_identifiers_recursive(when, reserved_words);
quote_identifiers_recursive(then, reserved_words);
}
if let Some(ref mut else_) = case.else_ {
quote_identifiers_recursive(else_, reserved_words);
}
}
Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
quote_identifiers_recursive(&mut cast.this, reserved_words);
if let Some(ref mut fmt) = cast.format {
quote_identifiers_recursive(fmt, reserved_words);
}
}
Expression::Paren(paren) => {
quote_identifiers_recursive(&mut paren.this, reserved_words);
}
Expression::Annotated(ann) => {
quote_identifiers_recursive(&mut ann.this, reserved_words);
}
Expression::With(with) => {
quote_with(with, reserved_words);
}
Expression::Cte(cte) => {
maybe_quote(&mut cte.alias, reserved_words);
for c in &mut cte.columns {
maybe_quote(c, reserved_words);
}
quote_identifiers_recursive(&mut cte.this, reserved_words);
}
Expression::From(from) => {
for e in &mut from.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Join(join) => {
quote_join(join, reserved_words);
}
Expression::JoinedTable(jt) => {
quote_identifiers_recursive(&mut jt.left, reserved_words);
for join in &mut jt.joins {
quote_join(join, reserved_words);
}
if let Some(ref mut alias) = jt.alias {
maybe_quote(alias, reserved_words);
}
}
Expression::Where(wh) => {
quote_identifiers_recursive(&mut wh.this, reserved_words);
}
Expression::GroupBy(gb) => {
for e in &mut gb.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Having(hv) => {
quote_identifiers_recursive(&mut hv.this, reserved_words);
}
Expression::OrderBy(ob) => {
for o in &mut ob.expressions {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
Expression::Ordered(ord) => {
quote_identifiers_recursive(&mut ord.this, reserved_words);
}
Expression::Limit(lim) => {
quote_identifiers_recursive(&mut lim.this, reserved_words);
}
Expression::Offset(off) => {
quote_identifiers_recursive(&mut off.this, reserved_words);
}
Expression::Qualify(q) => {
quote_identifiers_recursive(&mut q.this, reserved_words);
}
Expression::Window(ws) => {
for e in &mut ws.partition_by {
quote_identifiers_recursive(e, reserved_words);
}
for o in &mut ws.order_by {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
Expression::Over(over) => {
quote_over(over, reserved_words);
}
Expression::WithinGroup(wg) => {
quote_identifiers_recursive(&mut wg.this, reserved_words);
for o in &mut wg.order_by {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
}
Expression::Pivot(piv) => {
quote_identifiers_recursive(&mut piv.this, reserved_words);
for e in &mut piv.expressions {
quote_identifiers_recursive(e, reserved_words);
}
for f in &mut piv.fields {
quote_identifiers_recursive(f, reserved_words);
}
if let Some(ref mut alias) = piv.alias {
maybe_quote(alias, reserved_words);
}
}
Expression::Unpivot(unpiv) => {
quote_identifiers_recursive(&mut unpiv.this, reserved_words);
maybe_quote(&mut unpiv.value_column, reserved_words);
maybe_quote(&mut unpiv.name_column, reserved_words);
for e in &mut unpiv.columns {
quote_identifiers_recursive(e, reserved_words);
}
if let Some(ref mut alias) = unpiv.alias {
maybe_quote(alias, reserved_words);
}
}
Expression::Values(vals) => {
for tuple in &mut vals.expressions {
for e in &mut tuple.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
if let Some(ref mut alias) = vals.alias {
maybe_quote(alias, reserved_words);
}
for ca in &mut vals.column_aliases {
maybe_quote(ca, reserved_words);
}
}
Expression::Array(arr) => {
for e in &mut arr.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Struct(st) => {
for (_name, e) in &mut st.fields {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Tuple(tup) => {
for e in &mut tup.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Subscript(sub) => {
quote_identifiers_recursive(&mut sub.this, reserved_words);
quote_identifiers_recursive(&mut sub.index, reserved_words);
}
Expression::Dot(dot) => {
quote_identifiers_recursive(&mut dot.this, reserved_words);
maybe_quote(&mut dot.field, reserved_words);
}
Expression::ScopeResolution(sr) => {
if let Some(ref mut this) = sr.this {
quote_identifiers_recursive(this, reserved_words);
}
quote_identifiers_recursive(&mut sr.expression, reserved_words);
}
Expression::Lateral(lat) => {
quote_identifiers_recursive(&mut lat.this, reserved_words);
}
Expression::DPipe(dpipe) => {
quote_identifiers_recursive(&mut dpipe.this, reserved_words);
quote_identifiers_recursive(&mut dpipe.expression, reserved_words);
}
Expression::Merge(merge) => {
quote_identifiers_recursive(&mut merge.this, reserved_words);
quote_identifiers_recursive(&mut merge.using, reserved_words);
if let Some(ref mut on) = merge.on {
quote_identifiers_recursive(on, reserved_words);
}
if let Some(ref mut whens) = merge.whens {
quote_identifiers_recursive(whens, reserved_words);
}
if let Some(ref mut with) = merge.with_ {
quote_identifiers_recursive(with, reserved_words);
}
if let Some(ref mut ret) = merge.returning {
quote_identifiers_recursive(ret, reserved_words);
}
}
Expression::LateralView(lv) => {
quote_lateral_view(lv, reserved_words);
}
Expression::Anonymous(anon) => {
quote_identifiers_recursive(&mut anon.this, reserved_words);
for e in &mut anon.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::Filter(filter) => {
quote_identifiers_recursive(&mut filter.this, reserved_words);
quote_identifiers_recursive(&mut filter.expression, reserved_words);
}
Expression::Returning(ret) => {
for e in &mut ret.expressions {
quote_identifiers_recursive(e, reserved_words);
}
}
Expression::BracedWildcard(inner) => {
quote_identifiers_recursive(inner, reserved_words);
}
Expression::ReturnStmt(inner) => {
quote_identifiers_recursive(inner, reserved_words);
}
Expression::Literal(_)
| Expression::Boolean(_)
| Expression::Null(_)
| Expression::DataType(_)
| Expression::Raw(_)
| Expression::Placeholder(_)
| Expression::CurrentDate(_)
| Expression::CurrentTime(_)
| Expression::CurrentTimestamp(_)
| Expression::CurrentTimestampLTZ(_)
| Expression::SessionUser(_)
| Expression::RowNumber(_)
| Expression::Rank(_)
| Expression::DenseRank(_)
| Expression::PercentRank(_)
| Expression::CumeDist(_)
| Expression::Random(_)
| Expression::Pi(_)
| Expression::JSONPathRoot(_) => {
}
_ => {}
}
}
fn quote_join(join: &mut Join, reserved_words: &HashSet<&str>) {
quote_identifiers_recursive(&mut join.this, reserved_words);
if let Some(ref mut on) = join.on {
quote_identifiers_recursive(on, reserved_words);
}
for id in &mut join.using {
maybe_quote(id, reserved_words);
}
if let Some(ref mut mc) = join.match_condition {
quote_identifiers_recursive(mc, reserved_words);
}
for piv in &mut join.pivots {
quote_identifiers_recursive(piv, reserved_words);
}
}
fn quote_with(with: &mut With, reserved_words: &HashSet<&str>) {
for cte in &mut with.ctes {
maybe_quote(&mut cte.alias, reserved_words);
for c in &mut cte.columns {
maybe_quote(c, reserved_words);
}
for k in &mut cte.key_expressions {
maybe_quote(k, reserved_words);
}
quote_identifiers_recursive(&mut cte.this, reserved_words);
}
}
fn quote_over(over: &mut Over, reserved_words: &HashSet<&str>) {
if let Some(ref mut wn) = over.window_name {
maybe_quote(wn, reserved_words);
}
for e in &mut over.partition_by {
quote_identifiers_recursive(e, reserved_words);
}
for o in &mut over.order_by {
quote_identifiers_recursive(&mut o.this, reserved_words);
}
if let Some(ref mut alias) = over.alias {
maybe_quote(alias, reserved_words);
}
}
fn quote_table_ref(table_ref: &mut TableRef, reserved_words: &HashSet<&str>) {
maybe_quote(&mut table_ref.name, reserved_words);
if let Some(ref mut schema) = table_ref.schema {
maybe_quote(schema, reserved_words);
}
if let Some(ref mut catalog) = table_ref.catalog {
maybe_quote(catalog, reserved_words);
}
if let Some(ref mut alias) = table_ref.alias {
maybe_quote(alias, reserved_words);
}
for ca in &mut table_ref.column_aliases {
maybe_quote(ca, reserved_words);
}
for p in &mut table_ref.partitions {
maybe_quote(p, reserved_words);
}
for h in &mut table_ref.hints {
quote_identifiers_recursive(h, reserved_words);
}
}
fn quote_lateral_view(lv: &mut LateralView, reserved_words: &HashSet<&str>) {
quote_identifiers_recursive(&mut lv.this, reserved_words);
if let Some(ref mut ta) = lv.table_alias {
maybe_quote(ta, reserved_words);
}
for ca in &mut lv.column_aliases {
maybe_quote(ca, reserved_words);
}
}
pub fn quote_identifiers(expression: Expression, dialect: Option<DialectType>) -> Expression {
let reserved_words = get_reserved_words(dialect);
let mut result = expression;
quote_identifiers_recursive(&mut result, &reserved_words);
result
}
pub fn pushdown_cte_alias_columns(_scope: &Scope) {
}
fn pushdown_cte_alias_columns_with(with: &mut With) {
for cte in &mut with.ctes {
if cte.columns.is_empty() {
continue;
}
if let Expression::Select(select) = &mut cte.this {
let mut next_expressions = Vec::with_capacity(select.expressions.len());
for (i, projection) in select.expressions.iter().enumerate() {
let Some(alias_name) = cte.columns.get(i) else {
next_expressions.push(projection.clone());
continue;
};
match projection {
Expression::Alias(existing) => {
let mut aliased = existing.clone();
aliased.alias = alias_name.clone();
next_expressions.push(Expression::Alias(aliased));
}
_ => {
next_expressions.push(create_alias(projection.clone(), &alias_name.name));
}
}
}
select.expressions = next_expressions;
}
}
}
fn get_scope_columns(scope: &Scope) -> Vec<ColumnRef> {
let mut columns = Vec::new();
collect_columns(&scope.expression, &mut columns);
columns
}
#[derive(Debug, Clone)]
struct ColumnRef {
table: Option<String>,
name: String,
}
fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
match expr {
Expression::Column(col) => {
columns.push(ColumnRef {
table: col.table.as_ref().map(|t| t.name.clone()),
name: col.name.name.clone(),
});
}
Expression::Select(select) => {
for e in &select.expressions {
collect_columns(e, columns);
}
if let Some(from) = &select.from {
for e in &from.expressions {
collect_columns(e, columns);
}
}
if let Some(where_clause) = &select.where_clause {
collect_columns(&where_clause.this, columns);
}
if let Some(group_by) = &select.group_by {
for e in &group_by.expressions {
collect_columns(e, columns);
}
}
if let Some(having) = &select.having {
collect_columns(&having.this, columns);
}
if let Some(order_by) = &select.order_by {
for o in &order_by.expressions {
collect_columns(&o.this, columns);
}
}
for join in &select.joins {
collect_columns(&join.this, columns);
if let Some(on) = &join.on {
collect_columns(on, columns);
}
}
}
Expression::Alias(alias) => {
collect_columns(&alias.this, columns);
}
Expression::Function(func) => {
for arg in &func.args {
collect_columns(arg, columns);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
collect_columns(arg, columns);
}
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin) => {
collect_columns(&bin.left, columns);
collect_columns(&bin.right, columns);
}
Expression::Not(unary) | Expression::Neg(unary) => {
collect_columns(&unary.this, columns);
}
Expression::Paren(paren) => {
collect_columns(&paren.this, columns);
}
Expression::Case(case) => {
if let Some(operand) = &case.operand {
collect_columns(operand, columns);
}
for (when, then) in &case.whens {
collect_columns(when, columns);
collect_columns(then, columns);
}
if let Some(else_) = &case.else_ {
collect_columns(else_, columns);
}
}
Expression::Cast(cast) => {
collect_columns(&cast.this, columns);
}
Expression::In(in_expr) => {
collect_columns(&in_expr.this, columns);
for e in &in_expr.expressions {
collect_columns(e, columns);
}
if let Some(query) = &in_expr.query {
collect_columns(query, columns);
}
}
Expression::Between(between) => {
collect_columns(&between.this, columns);
collect_columns(&between.low, columns);
collect_columns(&between.high, columns);
}
Expression::Subquery(subquery) => {
collect_columns(&subquery.this, columns);
}
_ => {}
}
}
fn get_unqualified_columns(scope: &Scope) -> Vec<ColumnRef> {
get_scope_columns(scope)
.into_iter()
.filter(|c| c.table.is_none())
.collect()
}
fn get_external_columns(scope: &Scope) -> Vec<ColumnRef> {
let source_names: HashSet<_> = scope.sources.keys().cloned().collect();
get_scope_columns(scope)
.into_iter()
.filter(|c| {
if let Some(table) = &c.table {
!source_names.contains(table)
} else {
false
}
})
.collect()
}
fn is_correlated_subquery(scope: &Scope) -> bool {
scope.can_be_correlated && !get_external_columns(scope).is_empty()
}
fn is_star_column(col: &Column) -> bool {
col.name.name == "*"
}
fn create_qualified_column(name: &str, table: Option<&str>) -> Expression {
Expression::boxed_column(Column {
name: Identifier::new(name),
table: table.map(Identifier::new),
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})
}
fn create_alias(expr: Expression, alias_name: &str) -> Expression {
Expression::Alias(Box::new(Alias {
this: expr,
alias: Identifier::new(alias_name),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
}
fn get_output_name(expr: &Expression) -> Option<String> {
match expr {
Expression::Column(col) => Some(col.name.name.clone()),
Expression::Alias(alias) => Some(alias.alias.name.clone()),
Expression::Identifier(id) => Some(id.name.clone()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::DataType;
use crate::generator::Generator;
use crate::parser::Parser;
use crate::scope::build_scope;
use crate::{MappingSchema, Schema};
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
#[test]
fn test_qualify_columns_options() {
let options = QualifyColumnsOptions::new()
.with_expand_alias_refs(true)
.with_expand_stars(false)
.with_dialect(DialectType::PostgreSQL)
.with_allow_partial(true);
assert!(options.expand_alias_refs);
assert!(!options.expand_stars);
assert_eq!(options.dialect, Some(DialectType::PostgreSQL));
assert!(options.allow_partial_qualification);
}
#[test]
fn test_get_scope_columns() {
let expr = parse("SELECT a, b FROM t WHERE c = 1");
let scope = build_scope(&expr);
let columns = get_scope_columns(&scope);
assert!(columns.iter().any(|c| c.name == "a"));
assert!(columns.iter().any(|c| c.name == "b"));
assert!(columns.iter().any(|c| c.name == "c"));
}
#[test]
fn test_get_unqualified_columns() {
let expr = parse("SELECT t.a, b FROM t");
let scope = build_scope(&expr);
let unqualified = get_unqualified_columns(&scope);
assert!(unqualified.iter().any(|c| c.name == "b"));
assert!(!unqualified.iter().any(|c| c.name == "a"));
}
#[test]
fn test_is_star_column() {
let col = Column {
name: Identifier::new("*"),
table: Some(Identifier::new("t")),
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
};
assert!(is_star_column(&col));
let col2 = Column {
name: Identifier::new("id"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
};
assert!(!is_star_column(&col2));
}
#[test]
fn test_create_qualified_column() {
let expr = create_qualified_column("id", Some("users"));
let sql = gen(&expr);
assert!(sql.contains("users"));
assert!(sql.contains("id"));
}
#[test]
fn test_create_alias() {
let col = Expression::boxed_column(Column {
name: Identifier::new("value"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let aliased = create_alias(col, "total");
let sql = gen(&aliased);
assert!(sql.contains("AS") || sql.contains("total"));
}
#[test]
fn test_validate_qualify_columns_success() {
let expr = parse("SELECT t.a, t.b FROM t");
let result = validate_qualify_columns(&expr);
let _ = result;
}
#[test]
fn test_collect_columns_nested() {
let expr = parse("SELECT a + b, c FROM t WHERE d > 0 GROUP BY e HAVING f = 1");
let mut columns = Vec::new();
collect_columns(&expr, &mut columns);
let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
assert!(names.contains(&"a"));
assert!(names.contains(&"b"));
assert!(names.contains(&"c"));
assert!(names.contains(&"d"));
assert!(names.contains(&"e"));
assert!(names.contains(&"f"));
}
#[test]
fn test_collect_columns_in_case() {
let expr = parse("SELECT CASE WHEN a = 1 THEN b ELSE c END FROM t");
let mut columns = Vec::new();
collect_columns(&expr, &mut columns);
let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
assert!(names.contains(&"a"));
assert!(names.contains(&"b"));
assert!(names.contains(&"c"));
}
#[test]
fn test_collect_columns_in_subquery() {
let expr = parse("SELECT a FROM t WHERE b IN (SELECT c FROM s)");
let mut columns = Vec::new();
collect_columns(&expr, &mut columns);
let names: Vec<_> = columns.iter().map(|c| c.name.as_str()).collect();
assert!(names.contains(&"a"));
assert!(names.contains(&"b"));
assert!(names.contains(&"c"));
}
#[test]
fn test_qualify_outputs_basic() {
let expr = parse("SELECT a, b + c FROM t");
let scope = build_scope(&expr);
let result = qualify_outputs(&scope);
assert!(result.is_ok());
}
#[test]
fn test_qualify_columns_expands_star_with_schema() {
let expr = parse("SELECT * FROM users");
let mut schema = MappingSchema::new();
schema
.add_table(
"users",
&[
(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
),
("name".to_string(), DataType::Text),
("email".to_string(), DataType::Text),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(!sql.contains("SELECT *"));
assert!(sql.contains("users.id"));
assert!(sql.contains("users.name"));
assert!(sql.contains("users.email"));
}
#[test]
fn test_qualify_columns_expands_group_by_positions() {
let expr = parse("SELECT a, b FROM t GROUP BY 1, 2");
let mut schema = MappingSchema::new();
schema
.add_table(
"t",
&[
(
"a".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
),
(
"b".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(!sql.contains("GROUP BY 1"));
assert!(!sql.contains("GROUP BY 2"));
assert!(sql.contains("GROUP BY"));
assert!(sql.contains("t.a"));
assert!(sql.contains("t.b"));
}
#[test]
fn test_expand_using_simple() {
let expr = parse("SELECT x.b FROM x JOIN y USING (b)");
let mut schema = MappingSchema::new();
schema
.add_table(
"x",
&[
("a".to_string(), DataType::BigInt { length: None }),
("b".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"y",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
!sql.contains("USING"),
"USING should be replaced with ON: {sql}"
);
assert!(
sql.contains("ON x.b = y.b"),
"ON condition should be x.b = y.b: {sql}"
);
assert!(sql.contains("SELECT x.b"), "SELECT should keep x.b: {sql}");
}
#[test]
fn test_expand_using_unqualified_coalesce() {
let expr = parse("SELECT b FROM x JOIN y USING(b)");
let mut schema = MappingSchema::new();
schema
.add_table(
"x",
&[
("a".to_string(), DataType::BigInt { length: None }),
("b".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"y",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COALESCE(x.b, y.b)"),
"Unqualified USING column should become COALESCE: {sql}"
);
assert!(
sql.contains("AS b"),
"COALESCE should be aliased as 'b': {sql}"
);
assert!(
sql.contains("ON x.b = y.b"),
"ON condition should be generated: {sql}"
);
}
#[test]
fn test_expand_using_with_where() {
let expr = parse("SELECT b FROM x JOIN y USING(b) WHERE b = 1");
let mut schema = MappingSchema::new();
schema
.add_table(
"x",
&[("b".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
schema
.add_table(
"y",
&[("b".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("WHERE COALESCE(x.b, y.b)"),
"WHERE should use COALESCE for USING column: {sql}"
);
}
#[test]
fn test_expand_using_multi_join() {
let expr = parse("SELECT b FROM x JOIN y USING(b) JOIN z USING(b)");
let mut schema = MappingSchema::new();
for table in &["x", "y", "z"] {
schema
.add_table(
table,
&[("b".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
}
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COALESCE(x.b, y.b, z.b)"),
"Should have 3-table COALESCE: {sql}"
);
assert!(
sql.contains("ON x.b = y.b"),
"First join ON condition: {sql}"
);
}
#[test]
fn test_expand_using_multi_column() {
let expr = parse("SELECT b, c FROM y JOIN z USING(b, c)");
let mut schema = MappingSchema::new();
schema
.add_table(
"y",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"z",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COALESCE(y.b, z.b)"),
"column 'b' should get COALESCE: {sql}"
);
assert!(
sql.contains("COALESCE(y.c, z.c)"),
"column 'c' should get COALESCE: {sql}"
);
assert!(
sql.contains("y.b = z.b") && sql.contains("y.c = z.c"),
"ON should have both equality conditions: {sql}"
);
}
#[test]
fn test_expand_using_star() {
let expr = parse("SELECT * FROM x JOIN y USING(b)");
let mut schema = MappingSchema::new();
schema
.add_table(
"x",
&[
("a".to_string(), DataType::BigInt { length: None }),
("b".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"y",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COALESCE(x.b, y.b) AS b"),
"USING column should be COALESCE in star expansion: {sql}"
);
assert!(sql.contains("x.a"), "non-USING column a from x: {sql}");
assert!(sql.contains("y.c"), "non-USING column c from y: {sql}");
let coalesce_count = sql.matches("COALESCE").count();
assert_eq!(
coalesce_count, 1,
"b should appear only once as COALESCE: {sql}"
);
}
#[test]
fn test_expand_using_table_star() {
let expr = parse("SELECT x.* FROM x JOIN y USING(b)");
let mut schema = MappingSchema::new();
schema
.add_table(
"x",
&[
("a".to_string(), DataType::BigInt { length: None }),
("b".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"y",
&[
("b".to_string(), DataType::BigInt { length: None }),
("c".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COALESCE(x.b, y.b)"),
"USING column from x.* should become COALESCE: {sql}"
);
assert!(sql.contains("x.a"), "non-USING column a: {sql}");
}
#[test]
fn test_qualify_columns_qualified_table_name() {
let expr = parse("SELECT a FROM raw.t1");
let mut schema = MappingSchema::new();
schema
.add_table(
"raw.t1",
&[("a".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.a"),
"column should be qualified with table name: {sql}"
);
let expr = parse("SELECT MAX(a) FROM raw.t1");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.a"),
"column in function should be qualified with table name: {sql}"
);
let expr = parse("SELECT ABS(a) FROM raw.t1");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.a"),
"column in function should be qualified with table name: {sql}"
);
}
#[test]
fn test_qualify_columns_count_star() {
let expr = parse("SELECT COUNT(*) FROM t1");
let mut schema = MappingSchema::new();
schema
.add_table(
"t1",
&[("id".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("COUNT(*)"),
"COUNT(*) should be preserved: {sql}"
);
}
#[test]
fn test_qualify_columns_correlated_scalar_subquery() {
let expr =
parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1");
let mut schema = MappingSchema::new();
schema
.add_table(
"t1",
&[("id".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
schema
.add_table(
"t2",
&[
("id".to_string(), DataType::BigInt { length: None }),
("val".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.id"),
"outer column should be qualified: {sql}"
);
assert!(
sql.contains("t2.id"),
"inner column should be qualified: {sql}"
);
}
#[test]
fn test_qualify_columns_correlated_scalar_subquery_unqualified() {
let expr =
parse("SELECT t1_id, (SELECT AVG(val) FROM t2 WHERE t2_id = t1_id) AS avg_val FROM t1");
let mut schema = MappingSchema::new();
schema
.add_table(
"t1",
&[("t1_id".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
schema
.add_table(
"t2",
&[
("t2_id".to_string(), DataType::BigInt { length: None }),
("val".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("t1.t1_id"),
"outer column should be qualified: {sql}"
);
assert!(
sql.contains("t2.t2_id"),
"inner column should be qualified: {sql}"
);
assert!(
sql.contains("= t1.t1_id"),
"correlated column should be qualified: {sql}"
);
}
#[test]
fn test_qualify_columns_correlated_exists_subquery() {
let expr = parse(
"SELECT o_orderpriority FROM orders \
WHERE EXISTS (SELECT * FROM lineitem WHERE l_orderkey = o_orderkey)",
);
let mut schema = MappingSchema::new();
schema
.add_table(
"orders",
&[
("o_orderpriority".to_string(), DataType::Text),
("o_orderkey".to_string(), DataType::BigInt { length: None }),
],
None,
)
.expect("schema setup");
schema
.add_table(
"lineitem",
&[("l_orderkey".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
let result =
qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify");
let sql = gen(&result);
assert!(
sql.contains("orders.o_orderpriority"),
"outer column should be qualified: {sql}"
);
assert!(
sql.contains("lineitem.l_orderkey"),
"inner column should be qualified: {sql}"
);
assert!(
sql.contains("orders.o_orderkey"),
"correlated outer column should be qualified: {sql}"
);
}
#[test]
fn test_qualify_columns_rejects_unknown_table() {
let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1");
let mut schema = MappingSchema::new();
schema
.add_table(
"t1",
&[("id".to_string(), DataType::BigInt { length: None })],
None,
)
.expect("schema setup");
let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new());
assert!(
result.is_err(),
"should reject reference to table not in scope or schema"
);
}
#[test]
fn test_needs_quoting_reserved_word() {
let reserved = get_reserved_words(None);
assert!(needs_quoting("select", &reserved));
assert!(needs_quoting("SELECT", &reserved));
assert!(needs_quoting("from", &reserved));
assert!(needs_quoting("WHERE", &reserved));
assert!(needs_quoting("join", &reserved));
assert!(needs_quoting("table", &reserved));
}
#[test]
fn test_needs_quoting_normal_identifiers() {
let reserved = get_reserved_words(None);
assert!(!needs_quoting("foo", &reserved));
assert!(!needs_quoting("my_column", &reserved));
assert!(!needs_quoting("col1", &reserved));
assert!(!needs_quoting("A", &reserved));
assert!(!needs_quoting("_hidden", &reserved));
}
#[test]
fn test_needs_quoting_special_characters() {
let reserved = get_reserved_words(None);
assert!(needs_quoting("my column", &reserved)); assert!(needs_quoting("my-column", &reserved)); assert!(needs_quoting("my.column", &reserved)); assert!(needs_quoting("col@name", &reserved)); assert!(needs_quoting("col#name", &reserved)); }
#[test]
fn test_needs_quoting_starts_with_digit() {
let reserved = get_reserved_words(None);
assert!(needs_quoting("1col", &reserved));
assert!(needs_quoting("123", &reserved));
assert!(needs_quoting("0_start", &reserved));
}
#[test]
fn test_needs_quoting_empty() {
let reserved = get_reserved_words(None);
assert!(!needs_quoting("", &reserved));
}
#[test]
fn test_maybe_quote_sets_quoted_flag() {
let reserved = get_reserved_words(None);
let mut id = Identifier::new("select");
assert!(!id.quoted);
maybe_quote(&mut id, &reserved);
assert!(id.quoted);
}
#[test]
fn test_maybe_quote_skips_already_quoted() {
let reserved = get_reserved_words(None);
let mut id = Identifier::quoted("myname");
assert!(id.quoted);
maybe_quote(&mut id, &reserved);
assert!(id.quoted); assert_eq!(id.name, "myname"); }
#[test]
fn test_maybe_quote_skips_star() {
let reserved = get_reserved_words(None);
let mut id = Identifier::new("*");
maybe_quote(&mut id, &reserved);
assert!(!id.quoted); }
#[test]
fn test_maybe_quote_skips_normal() {
let reserved = get_reserved_words(None);
let mut id = Identifier::new("normal_col");
maybe_quote(&mut id, &reserved);
assert!(!id.quoted);
}
#[test]
fn test_quote_identifiers_column_with_reserved_name() {
let expr = Expression::boxed_column(Column {
name: Identifier::new("select"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = quote_identifiers(expr, None);
if let Expression::Column(col) = &result {
assert!(col.name.quoted, "Column named 'select' should be quoted");
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_quote_identifiers_column_with_special_chars() {
let expr = Expression::boxed_column(Column {
name: Identifier::new("my column"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = quote_identifiers(expr, None);
if let Expression::Column(col) = &result {
assert!(col.name.quoted, "Column with space should be quoted");
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_quote_identifiers_preserves_normal_column() {
let expr = Expression::boxed_column(Column {
name: Identifier::new("normal_col"),
table: Some(Identifier::new("my_table")),
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = quote_identifiers(expr, None);
if let Expression::Column(col) = &result {
assert!(!col.name.quoted, "Normal column should not be quoted");
assert!(
!col.table.as_ref().unwrap().quoted,
"Normal table should not be quoted"
);
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_quote_identifiers_table_ref_reserved() {
let expr = Expression::Table(Box::new(TableRef::new("select")));
let result = quote_identifiers(expr, None);
if let Expression::Table(tr) = &result {
assert!(tr.name.quoted, "Table named 'select' should be quoted");
} else {
panic!("Expected Table expression");
}
}
#[test]
fn test_quote_identifiers_table_ref_schema_and_alias() {
let mut tr = TableRef::new("my_table");
tr.schema = Some(Identifier::new("from"));
tr.alias = Some(Identifier::new("t"));
let expr = Expression::Table(Box::new(tr));
let result = quote_identifiers(expr, None);
if let Expression::Table(tr) = &result {
assert!(!tr.name.quoted, "Normal table name should not be quoted");
assert!(
tr.schema.as_ref().unwrap().quoted,
"Schema named 'from' should be quoted"
);
assert!(
!tr.alias.as_ref().unwrap().quoted,
"Normal alias should not be quoted"
);
} else {
panic!("Expected Table expression");
}
}
#[test]
fn test_quote_identifiers_identifier_node() {
let expr = Expression::Identifier(Identifier::new("order"));
let result = quote_identifiers(expr, None);
if let Expression::Identifier(id) = &result {
assert!(id.quoted, "Identifier named 'order' should be quoted");
} else {
panic!("Expected Identifier expression");
}
}
#[test]
fn test_quote_identifiers_alias() {
let inner = Expression::boxed_column(Column {
name: Identifier::new("val"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let expr = Expression::Alias(Box::new(Alias {
this: inner,
alias: Identifier::new("select"),
column_aliases: vec![Identifier::new("from")],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}));
let result = quote_identifiers(expr, None);
if let Expression::Alias(alias) = &result {
assert!(alias.alias.quoted, "Alias named 'select' should be quoted");
assert!(
alias.column_aliases[0].quoted,
"Column alias named 'from' should be quoted"
);
if let Expression::Column(col) = &alias.this {
assert!(!col.name.quoted);
}
} else {
panic!("Expected Alias expression");
}
}
#[test]
fn test_quote_identifiers_select_recursive() {
let expr = parse("SELECT a, b FROM t WHERE c = 1");
let result = quote_identifiers(expr, None);
let sql = gen(&result);
assert!(sql.contains("a"));
assert!(sql.contains("b"));
assert!(sql.contains("t"));
}
#[test]
fn test_quote_identifiers_digit_start() {
let expr = Expression::boxed_column(Column {
name: Identifier::new("1col"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = quote_identifiers(expr, None);
if let Expression::Column(col) = &result {
assert!(
col.name.quoted,
"Column starting with digit should be quoted"
);
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_quote_identifiers_with_mysql_dialect() {
let reserved = get_reserved_words(Some(DialectType::MySQL));
assert!(needs_quoting("KILL", &reserved));
assert!(needs_quoting("FORCE", &reserved));
}
#[test]
fn test_quote_identifiers_with_postgresql_dialect() {
let reserved = get_reserved_words(Some(DialectType::PostgreSQL));
assert!(needs_quoting("ILIKE", &reserved));
assert!(needs_quoting("VERBOSE", &reserved));
}
#[test]
fn test_quote_identifiers_with_bigquery_dialect() {
let reserved = get_reserved_words(Some(DialectType::BigQuery));
assert!(needs_quoting("STRUCT", &reserved));
assert!(needs_quoting("PROTO", &reserved));
}
#[test]
fn test_quote_identifiers_case_insensitive_reserved() {
let reserved = get_reserved_words(None);
assert!(needs_quoting("Select", &reserved));
assert!(needs_quoting("sElEcT", &reserved));
assert!(needs_quoting("FROM", &reserved));
assert!(needs_quoting("from", &reserved));
}
#[test]
fn test_quote_identifiers_join_using() {
let mut join = crate::expressions::Join {
this: Expression::Table(Box::new(TableRef::new("other"))),
on: None,
using: vec![Identifier::new("key"), Identifier::new("value")],
kind: crate::expressions::JoinKind::Inner,
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: vec![],
comments: vec![],
nesting_group: 0,
directed: false,
};
let reserved = get_reserved_words(None);
quote_join(&mut join, &reserved);
assert!(
join.using[0].quoted,
"USING identifier 'key' should be quoted"
);
assert!(
!join.using[1].quoted,
"USING identifier 'value' should not be quoted"
);
}
#[test]
fn test_quote_identifiers_cte() {
let mut cte = crate::expressions::Cte {
alias: Identifier::new("select"),
this: Expression::boxed_column(Column {
name: Identifier::new("x"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
columns: vec![Identifier::new("from"), Identifier::new("normal")],
materialized: None,
key_expressions: vec![],
alias_first: false,
comments: Vec::new(),
};
let reserved = get_reserved_words(None);
maybe_quote(&mut cte.alias, &reserved);
for c in &mut cte.columns {
maybe_quote(c, &reserved);
}
assert!(cte.alias.quoted, "CTE alias 'select' should be quoted");
assert!(cte.columns[0].quoted, "CTE column 'from' should be quoted");
assert!(
!cte.columns[1].quoted,
"CTE column 'normal' should not be quoted"
);
}
#[test]
fn test_quote_identifiers_binary_ops_recurse() {
let expr = Expression::Add(Box::new(crate::expressions::BinaryOp::new(
Expression::boxed_column(Column {
name: Identifier::new("select"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
Expression::boxed_column(Column {
name: Identifier::new("normal"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
)));
let result = quote_identifiers(expr, None);
if let Expression::Add(bin) = &result {
if let Expression::Column(left) = &bin.left {
assert!(
left.name.quoted,
"'select' column should be quoted in binary op"
);
}
if let Expression::Column(right) = &bin.right {
assert!(!right.name.quoted, "'normal' column should not be quoted");
}
} else {
panic!("Expected Add expression");
}
}
#[test]
fn test_quote_identifiers_already_quoted_preserved() {
let expr = Expression::boxed_column(Column {
name: Identifier::quoted("normal_name"),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = quote_identifiers(expr, None);
if let Expression::Column(col) = &result {
assert!(
col.name.quoted,
"Already-quoted identifier should remain quoted"
);
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_quote_identifiers_full_parsed_query() {
let mut select = crate::expressions::Select::new();
select.expressions.push(Expression::boxed_column(Column {
name: Identifier::new("order"),
table: Some(Identifier::new("t")),
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}));
select.from = Some(crate::expressions::From {
expressions: vec![Expression::Table(Box::new(TableRef::new("t")))],
});
let expr = Expression::Select(Box::new(select));
let result = quote_identifiers(expr, None);
if let Expression::Select(sel) = &result {
if let Expression::Column(col) = &sel.expressions[0] {
assert!(col.name.quoted, "Column named 'order' should be quoted");
assert!(
!col.table.as_ref().unwrap().quoted,
"Table 't' should not be quoted"
);
} else {
panic!("Expected Column in SELECT list");
}
} else {
panic!("Expected Select expression");
}
}
#[test]
fn test_get_reserved_words_all_dialects() {
let dialects = [
None,
Some(DialectType::Generic),
Some(DialectType::MySQL),
Some(DialectType::PostgreSQL),
Some(DialectType::BigQuery),
Some(DialectType::Snowflake),
Some(DialectType::TSQL),
Some(DialectType::ClickHouse),
Some(DialectType::DuckDB),
Some(DialectType::Hive),
Some(DialectType::Spark),
Some(DialectType::Trino),
Some(DialectType::Oracle),
Some(DialectType::Redshift),
];
for dialect in &dialects {
let words = get_reserved_words(*dialect);
assert!(
words.contains("SELECT"),
"All dialects should have SELECT as reserved"
);
assert!(
words.contains("FROM"),
"All dialects should have FROM as reserved"
);
}
}
}