use crate::dialects::transform_recursive;
use crate::dialects::{Dialect, DialectType};
use crate::error::Result;
use crate::expressions::{
Alias, BinaryOp, BooleanLiteral, Cast, DataType, Exists, Expression, From, Function,
Identifier, Join, JoinKind, Lateral, LateralView, Literal, NamedArgSeparator, NamedArgument,
Over, Select, StructField, Subquery, UnaryFunc, UnnestFunc, Where,
};
use std::cell::RefCell;
pub fn preprocess<F>(expr: Expression, transforms: &[F]) -> Result<Expression>
where
F: Fn(Expression) -> Result<Expression>,
{
let mut result = expr;
for transform in transforms {
result = transform(result)?;
}
Ok(result)
}
pub fn unnest_to_explode(expr: Expression) -> Result<Expression> {
match expr {
Expression::Unnest(unnest) => {
Ok(Expression::Explode(Box::new(UnaryFunc::new(unnest.this))))
}
_ => Ok(expr),
}
}
pub fn unnest_to_explode_select(expr: Expression) -> Result<Expression> {
transform_recursive(expr, &unnest_to_explode_select_inner)
}
fn make_udtf_expr(unnest: &UnnestFunc) -> Expression {
let has_multi_expr = !unnest.expressions.is_empty();
if has_multi_expr {
let mut all_args = vec![unnest.this.clone()];
all_args.extend(unnest.expressions.iter().cloned());
let arrays_zip =
Expression::Function(Box::new(Function::new("ARRAYS_ZIP".to_string(), all_args)));
Expression::Function(Box::new(Function::new(
"INLINE".to_string(),
vec![arrays_zip],
)))
} else {
Expression::Explode(Box::new(UnaryFunc::new(unnest.this.clone())))
}
}
fn unnest_to_explode_select_inner(expr: Expression) -> Result<Expression> {
let Expression::Select(mut select) = expr else {
return Ok(expr);
};
if let Some(ref mut from) = select.from {
if from.expressions.len() >= 1 {
let mut new_from_exprs = Vec::new();
let mut new_lateral_views = Vec::new();
let first_is_unnest = is_unnest_expr(&from.expressions[0]);
for (idx, from_item) in from.expressions.drain(..).enumerate() {
if idx == 0 && first_is_unnest {
let replaced = replace_from_unnest(from_item);
new_from_exprs.push(replaced);
} else if idx > 0 && is_unnest_expr(&from_item) {
let (alias_name, column_aliases, unnest_func) = extract_unnest_info(from_item);
let udtf = make_udtf_expr(&unnest_func);
new_lateral_views.push(LateralView {
this: udtf,
table_alias: alias_name,
column_aliases,
outer: false,
});
} else {
new_from_exprs.push(from_item);
}
}
from.expressions = new_from_exprs;
select.lateral_views.extend(new_lateral_views);
}
}
let mut remaining_joins = Vec::new();
for join in select.joins.drain(..) {
if matches!(join.kind, JoinKind::Cross | JoinKind::Inner) {
let (is_unnest, is_lateral) = check_join_unnest(&join.this);
if is_unnest {
let (lateral_alias, lateral_col_aliases, join_expr) = if is_lateral {
if let Expression::Lateral(lat) = join.this {
let alias = lat.alias.map(|s| Identifier::new(&s));
let col_aliases: Vec<Identifier> = lat
.column_aliases
.iter()
.map(|s| Identifier::new(s))
.collect();
(alias, col_aliases, *lat.this)
} else {
(None, Vec::new(), join.this)
}
} else {
(None, Vec::new(), join.this)
};
let (alias_name, column_aliases, unnest_func) = extract_unnest_info(join_expr);
let final_alias = lateral_alias.or(alias_name);
let final_col_aliases = if !lateral_col_aliases.is_empty() {
lateral_col_aliases
} else {
column_aliases
};
let table_alias = final_alias.or_else(|| Some(Identifier::new("unnest")));
let col_aliases = if final_col_aliases.is_empty() {
vec![Identifier::new("unnest")]
} else {
final_col_aliases
};
let udtf = make_udtf_expr(&unnest_func);
select.lateral_views.push(LateralView {
this: udtf,
table_alias,
column_aliases: col_aliases,
outer: false,
});
} else {
remaining_joins.push(join);
}
} else {
remaining_joins.push(join);
}
}
select.joins = remaining_joins;
Ok(Expression::Select(select))
}
fn is_unnest_expr(expr: &Expression) -> bool {
match expr {
Expression::Unnest(_) => true,
Expression::Alias(a) => matches!(a.this, Expression::Unnest(_)),
_ => false,
}
}
fn check_join_unnest(expr: &Expression) -> (bool, bool) {
match expr {
Expression::Unnest(_) => (true, false),
Expression::Alias(a) => {
if matches!(a.this, Expression::Unnest(_)) {
(true, false)
} else {
(false, false)
}
}
Expression::Lateral(lat) => match &*lat.this {
Expression::Unnest(_) => (true, true),
Expression::Alias(a) => {
if matches!(a.this, Expression::Unnest(_)) {
(true, true)
} else {
(false, true)
}
}
_ => (false, true),
},
_ => (false, false),
}
}
fn replace_from_unnest(from_item: Expression) -> Expression {
match from_item {
Expression::Alias(mut a) => {
if let Expression::Unnest(unnest) = a.this {
a.this = make_udtf_expr(&unnest);
}
Expression::Alias(a)
}
Expression::Unnest(unnest) => make_udtf_expr(&unnest),
other => other,
}
}
fn extract_unnest_info(expr: Expression) -> (Option<Identifier>, Vec<Identifier>, UnnestFunc) {
match expr {
Expression::Alias(a) => {
if let Expression::Unnest(unnest) = a.this {
(Some(a.alias), a.column_aliases, *unnest)
} else {
(
Some(a.alias),
a.column_aliases,
UnnestFunc {
this: a.this,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
},
)
}
}
Expression::Unnest(unnest) => {
let alias = unnest.alias.clone();
(alias, Vec::new(), *unnest)
}
_ => (
None,
Vec::new(),
UnnestFunc {
this: expr,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
},
),
}
}
pub fn explode_to_unnest(expr: Expression) -> Result<Expression> {
match expr {
Expression::Explode(explode) => Ok(Expression::Unnest(Box::new(UnnestFunc {
this: explode.this,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}))),
_ => Ok(expr),
}
}
pub fn replace_bool_with_int(expr: Expression) -> Result<Expression> {
match expr {
Expression::Boolean(b) => {
let value = if b.value { "1" } else { "0" };
Ok(Expression::Literal(Box::new(Literal::Number(
value.to_string(),
))))
}
_ => Ok(expr),
}
}
pub fn replace_int_with_bool(expr: Expression) -> Result<Expression> {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(n) if n == "1" || n == "0") =>
{
let Literal::Number(n) = lit.as_ref() else {
unreachable!()
};
Ok(Expression::Boolean(BooleanLiteral { value: n == "1" }))
}
_ => Ok(expr),
}
}
pub fn remove_precision_parameterized_types(expr: Expression) -> Result<Expression> {
Ok(strip_type_params_recursive(expr))
}
fn strip_type_params_recursive(expr: Expression) -> Expression {
match expr {
Expression::Cast(mut cast) => {
cast.to = strip_data_type_params(cast.to);
cast.this = strip_type_params_recursive(cast.this);
Expression::Cast(cast)
}
Expression::TryCast(mut try_cast) => {
try_cast.to = strip_data_type_params(try_cast.to);
try_cast.this = strip_type_params_recursive(try_cast.this);
Expression::TryCast(try_cast)
}
Expression::SafeCast(mut safe_cast) => {
safe_cast.to = strip_data_type_params(safe_cast.to);
safe_cast.this = strip_type_params_recursive(safe_cast.this);
Expression::SafeCast(safe_cast)
}
_ => expr,
}
}
fn strip_data_type_params(dt: DataType) -> DataType {
match dt {
DataType::Decimal { .. } => DataType::Decimal {
precision: None,
scale: None,
},
DataType::TinyInt { .. } => DataType::TinyInt { length: None },
DataType::SmallInt { .. } => DataType::SmallInt { length: None },
DataType::Int { .. } => DataType::Int {
length: None,
integer_spelling: false,
},
DataType::BigInt { .. } => DataType::BigInt { length: None },
DataType::Char { .. } => DataType::Char { length: None },
DataType::VarChar { .. } => DataType::VarChar {
length: None,
parenthesized_length: false,
},
DataType::Binary { .. } => DataType::Binary { length: None },
DataType::VarBinary { .. } => DataType::VarBinary { length: None },
DataType::Bit { .. } => DataType::Bit { length: None },
DataType::VarBit { .. } => DataType::VarBit { length: None },
DataType::Time { .. } => DataType::Time {
precision: None,
timezone: false,
},
DataType::Timestamp { timezone, .. } => DataType::Timestamp {
precision: None,
timezone,
},
DataType::Array {
element_type,
dimension,
} => DataType::Array {
element_type: Box::new(strip_data_type_params(*element_type)),
dimension,
},
DataType::Map {
key_type,
value_type,
} => DataType::Map {
key_type: Box::new(strip_data_type_params(*key_type)),
value_type: Box::new(strip_data_type_params(*value_type)),
},
DataType::Struct { fields, nested } => DataType::Struct {
fields: fields
.into_iter()
.map(|f| {
StructField::with_options(
f.name,
strip_data_type_params(f.data_type),
f.options,
)
})
.collect(),
nested,
},
DataType::Vector { element_type, .. } => DataType::Vector {
element_type: element_type.map(|et| Box::new(strip_data_type_params(*et))),
dimension: None,
},
DataType::Object { fields, modifier } => DataType::Object {
fields: fields
.into_iter()
.map(|(name, ty, not_null)| (name, strip_data_type_params(ty), not_null))
.collect(),
modifier,
},
other => other,
}
}
pub fn eliminate_qualify(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(qualify) = select.qualify.take() {
let qualify_filter = qualify.this;
let window_alias_name = "_w".to_string();
let window_alias_ident = Identifier::new(window_alias_name.clone());
let (window_expr, outer_where) =
extract_window_from_condition(qualify_filter.clone(), &window_alias_ident);
if let Some(win_expr) = window_expr {
let window_alias_expr =
Expression::Alias(Box::new(crate::expressions::Alias {
this: win_expr,
alias: window_alias_ident.clone(),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}));
let outer_exprs: Vec<Expression> = select
.expressions
.iter()
.map(|expr| {
if let Expression::Alias(a) = expr {
Expression::Column(Box::new(crate::expressions::Column {
name: a.alias.clone(),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}))
} else {
expr.clone()
}
})
.collect();
select.expressions.push(window_alias_expr);
let inner_select = Expression::Select(select);
let subquery = Subquery {
this: inner_select,
alias: Some(Identifier::new("_t".to_string())),
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
};
let outer_select = Select {
expressions: outer_exprs,
from: Some(From {
expressions: vec![Expression::Subquery(Box::new(subquery))],
}),
where_clause: Some(Where { this: outer_where }),
..Select::new()
};
return Ok(Expression::Select(Box::new(outer_select)));
} else {
let qualify_alias = Expression::Alias(Box::new(crate::expressions::Alias {
this: qualify_filter.clone(),
alias: window_alias_ident.clone(),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}));
let original_exprs = select.expressions.clone();
select.expressions.push(qualify_alias);
let inner_select = Expression::Select(select);
let subquery = Subquery {
this: inner_select,
alias: Some(Identifier::new("_t".to_string())),
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
};
let outer_select = Select {
expressions: original_exprs,
from: Some(From {
expressions: vec![Expression::Subquery(Box::new(subquery))],
}),
where_clause: Some(Where {
this: Expression::Column(Box::new(crate::expressions::Column {
name: window_alias_ident,
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})),
}),
..Select::new()
};
return Ok(Expression::Select(Box::new(outer_select)));
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn extract_window_from_condition(
condition: Expression,
alias: &Identifier,
) -> (Option<Expression>, Expression) {
let alias_col = Expression::Column(Box::new(crate::expressions::Column {
name: alias.clone(),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}));
match condition {
Expression::Eq(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Eq(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Eq(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
Expression::Neq(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Neq(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Neq(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
Expression::Lt(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Lt(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Lt(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
Expression::Lte(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Lte(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Lte(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
Expression::Gt(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Gt(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Gt(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
Expression::Gte(ref op) => {
if is_window_expr(&op.left) {
(
Some(op.left.clone()),
Expression::Gte(Box::new(BinaryOp {
left: alias_col,
right: op.right.clone(),
..(**op).clone()
})),
)
} else if is_window_expr(&op.right) {
(
Some(op.right.clone()),
Expression::Gte(Box::new(BinaryOp {
left: op.left.clone(),
right: alias_col,
..(**op).clone()
})),
)
} else {
(None, condition)
}
}
_ if is_window_expr(&condition) => (Some(condition), alias_col),
_ => (None, condition),
}
}
fn is_window_expr(expr: &Expression) -> bool {
matches!(expr, Expression::Window(_) | Expression::WindowFunction(_))
}
pub fn eliminate_distinct_on(expr: Expression) -> Result<Expression> {
eliminate_distinct_on_for_dialect(expr, None, None)
}
pub fn eliminate_distinct_on_for_dialect(
expr: Expression,
target: Option<DialectType>,
source: Option<DialectType>,
) -> Result<Expression> {
use crate::expressions::Case;
if matches!(
target,
Some(DialectType::PostgreSQL) | Some(DialectType::DuckDB)
) {
return Ok(expr);
}
enum NullsMode {
None, NullsFirst, CaseExpr, }
let nulls_mode = match target {
Some(DialectType::MySQL)
| Some(DialectType::SingleStore)
| Some(DialectType::TSQL)
| Some(DialectType::Fabric) => NullsMode::CaseExpr,
Some(DialectType::Oracle) | Some(DialectType::Redshift) | Some(DialectType::Snowflake) => {
NullsMode::None
}
Some(DialectType::StarRocks) => {
if matches!(source, Some(DialectType::Redshift)) {
NullsMode::CaseExpr
} else {
NullsMode::None
}
}
_ => NullsMode::NullsFirst,
};
match expr {
Expression::Select(mut select) => {
if let Some(distinct_cols) = select.distinct_on.take() {
if !distinct_cols.is_empty() {
let row_number_alias = Identifier::new("_row_number".to_string());
let order_exprs = if let Some(ref order_by) = select.order_by {
let mut exprs = order_by.expressions.clone();
match nulls_mode {
NullsMode::NullsFirst => {
for ord in &mut exprs {
if ord.desc && ord.nulls_first.is_none() {
ord.nulls_first = Some(true);
}
}
}
NullsMode::CaseExpr => {
let mut new_exprs = Vec::new();
for ord in &exprs {
if ord.desc && ord.nulls_first.is_none() {
let null_check = Expression::Case(Box::new(Case {
operand: None,
whens: vec![(
Expression::IsNull(Box::new(
crate::expressions::IsNull {
this: ord.this.clone(),
not: false,
postfix_form: false,
},
)),
Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
))),
)],
else_: Some(Expression::Literal(Box::new(
Literal::Number("0".to_string()),
))),
comments: Vec::new(),
inferred_type: None,
}));
new_exprs.push(crate::expressions::Ordered {
this: null_check,
desc: true,
nulls_first: None,
explicit_asc: false,
with_fill: None,
});
}
new_exprs.push(ord.clone());
}
exprs = new_exprs;
}
NullsMode::None => {}
}
exprs
} else {
distinct_cols
.iter()
.map(|e| crate::expressions::Ordered {
this: e.clone(),
desc: false,
nulls_first: None,
explicit_asc: false,
with_fill: None,
})
.collect()
};
let row_number_func =
Expression::WindowFunction(Box::new(crate::expressions::WindowFunction {
this: Expression::RowNumber(crate::expressions::RowNumber),
over: Over {
partition_by: distinct_cols,
order_by: order_exprs,
frame: None,
window_name: None,
alias: None,
},
keep: None,
inferred_type: None,
}));
let mut inner_aliased_exprs = Vec::new();
let mut outer_select_exprs = Vec::new();
for orig_expr in &select.expressions {
match orig_expr {
Expression::Alias(alias) => {
inner_aliased_exprs.push(orig_expr.clone());
outer_select_exprs.push(Expression::Column(Box::new(
crate::expressions::Column {
name: alias.alias.clone(),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
},
)));
}
Expression::Column(col) => {
inner_aliased_exprs.push(Expression::Alias(Box::new(
crate::expressions::Alias {
this: orig_expr.clone(),
alias: col.name.clone(),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
},
)));
outer_select_exprs.push(Expression::Column(Box::new(
crate::expressions::Column {
name: col.name.clone(),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
},
)));
}
_ => {
inner_aliased_exprs.push(orig_expr.clone());
outer_select_exprs.push(orig_expr.clone());
}
}
}
let row_number_alias_expr =
Expression::Alias(Box::new(crate::expressions::Alias {
this: row_number_func,
alias: row_number_alias.clone(),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}));
inner_aliased_exprs.push(row_number_alias_expr);
select.expressions = inner_aliased_exprs;
let _inner_order_by = select.order_by.take();
select.distinct = false;
let inner_select = Expression::Select(select);
let subquery = Subquery {
this: inner_select,
alias: Some(Identifier::new("_t".to_string())),
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
};
let outer_select = Select {
expressions: outer_select_exprs,
from: Some(From {
expressions: vec![Expression::Subquery(Box::new(subquery))],
}),
where_clause: Some(Where {
this: Expression::Eq(Box::new(BinaryOp {
left: Expression::Column(Box::new(crate::expressions::Column {
name: row_number_alias,
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})),
right: Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
))),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
}),
..Select::new()
};
return Ok(Expression::Select(Box::new(outer_select)));
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
pub fn eliminate_semi_and_anti_joins(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
let mut new_joins = Vec::new();
let mut extra_where_conditions = Vec::new();
for join in select.joins.drain(..) {
match join.kind {
JoinKind::Semi | JoinKind::LeftSemi => {
if let Some(on_condition) = join.on {
let subquery_select = Select {
expressions: vec![Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
)))],
from: Some(From {
expressions: vec![join.this],
}),
where_clause: Some(Where { this: on_condition }),
..Select::new()
};
let exists = Expression::Exists(Box::new(Exists {
this: Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(subquery_select)),
alias: None,
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
})),
not: false,
}));
extra_where_conditions.push(exists);
}
}
JoinKind::Anti | JoinKind::LeftAnti => {
if let Some(on_condition) = join.on {
let subquery_select = Select {
expressions: vec![Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
)))],
from: Some(From {
expressions: vec![join.this],
}),
where_clause: Some(Where { this: on_condition }),
..Select::new()
};
let not_exists = Expression::Exists(Box::new(Exists {
this: Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(subquery_select)),
alias: None,
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
})),
not: true,
}));
extra_where_conditions.push(not_exists);
}
}
_ => {
new_joins.push(join);
}
}
}
select.joins = new_joins;
if !extra_where_conditions.is_empty() {
let combined = extra_where_conditions
.into_iter()
.reduce(|acc, cond| {
Expression::And(Box::new(BinaryOp {
left: acc,
right: cond,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
})
.unwrap();
select.where_clause = match select.where_clause {
Some(Where { this: existing }) => Some(Where {
this: Expression::And(Box::new(BinaryOp {
left: existing,
right: combined,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
}),
None => Some(Where { this: combined }),
};
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
pub fn eliminate_full_outer_join(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
let full_outer_join_idx = select.joins.iter().position(|j| j.kind == JoinKind::Full);
if let Some(idx) = full_outer_join_idx {
let full_join_count = select
.joins
.iter()
.filter(|j| j.kind == JoinKind::Full)
.count();
if full_join_count != 1 {
return Ok(Expression::Select(select));
}
let mut right_select = select.clone();
let full_join = &select.joins[idx];
let join_condition = full_join.on.clone();
select.joins[idx].kind = JoinKind::Left;
right_select.joins[idx].kind = JoinKind::Right;
if let (Some(ref from), Some(ref join_cond)) = (&select.from, &join_condition) {
if !from.expressions.is_empty() {
let anti_subquery = Expression::Select(Box::new(Select {
expressions: vec![Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
)))],
from: Some(from.clone()),
where_clause: Some(Where {
this: join_cond.clone(),
}),
..Select::new()
}));
let not_exists = Expression::Not(Box::new(crate::expressions::UnaryOp {
inferred_type: None,
this: Expression::Exists(Box::new(Exists {
this: Expression::Subquery(Box::new(Subquery {
this: anti_subquery,
alias: None,
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
})),
not: false,
})),
}));
right_select.where_clause = Some(Where {
this: match right_select.where_clause {
Some(w) => Expression::And(Box::new(BinaryOp {
left: w.this,
right: not_exists,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
None => not_exists,
},
});
}
}
right_select.with = None;
let order_by = select.order_by.take();
let union = crate::expressions::Union {
left: Expression::Select(select),
right: Expression::Select(right_select),
all: true, distinct: false,
with: None,
order_by,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
by_name: false,
side: None,
kind: None,
corresponding: false,
strict: false,
on_columns: Vec::new(),
};
return Ok(Expression::Union(Box::new(union)));
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
pub fn move_ctes_to_top_level(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
let mut collected_ctes: Vec<crate::expressions::Cte> = Vec::new();
let mut has_recursive = false;
collect_nested_ctes(
&Expression::Select(select.clone()),
&mut collected_ctes,
&mut has_recursive,
true,
);
let mut cte_body_collected: Vec<(String, Vec<crate::expressions::Cte>)> = Vec::new();
if let Some(ref with) = select.with {
for cte in &with.ctes {
let mut body_ctes: Vec<crate::expressions::Cte> = Vec::new();
collect_ctes_from_cte_body(&cte.this, &mut body_ctes, &mut has_recursive);
if !body_ctes.is_empty() {
cte_body_collected.push((cte.alias.name.clone(), body_ctes));
}
}
}
let has_subquery_ctes = !collected_ctes.is_empty();
let has_body_ctes = !cte_body_collected.is_empty();
if has_subquery_ctes || has_body_ctes {
strip_nested_with_clauses(&mut select, true);
if has_body_ctes {
if let Some(ref mut with) = select.with {
for cte in with.ctes.iter_mut() {
strip_with_from_cte_body(&mut cte.this);
}
}
}
let top_with = select.with.get_or_insert_with(|| crate::expressions::With {
ctes: Vec::new(),
recursive: false,
leading_comments: vec![],
search: None,
});
if has_recursive {
top_with.recursive = true;
}
if has_body_ctes {
let mut new_ctes: Vec<crate::expressions::Cte> = Vec::new();
for mut cte in top_with.ctes.drain(..) {
if let Some(pos) = cte_body_collected
.iter()
.position(|(name, _)| *name == cte.alias.name)
{
let (_, mut nested) = cte_body_collected.remove(pos);
for nested_cte in nested.iter_mut() {
strip_with_from_cte_body(&mut nested_cte.this);
}
new_ctes.extend(nested);
}
strip_with_from_cte_body(&mut cte.this);
new_ctes.push(cte);
}
top_with.ctes = new_ctes;
}
top_with.ctes.extend(collected_ctes);
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn collect_ctes_from_cte_body(
expr: &Expression,
collected: &mut Vec<crate::expressions::Cte>,
has_recursive: &mut bool,
) {
if let Expression::Select(select) = expr {
if let Some(ref with) = select.with {
if with.recursive {
*has_recursive = true;
}
for cte in &with.ctes {
collect_ctes_from_cte_body(&cte.this, collected, has_recursive);
collected.push(cte.clone());
}
}
}
}
fn strip_with_from_cte_body(expr: &mut Expression) {
if let Expression::Select(ref mut select) = expr {
select.with = None;
}
}
fn strip_nested_with_clauses(select: &mut Select, _is_top_level: bool) {
if let Some(ref mut from) = select.from {
for expr in from.expressions.iter_mut() {
strip_with_from_expr(expr);
}
}
for join in select.joins.iter_mut() {
strip_with_from_expr(&mut join.this);
}
for expr in select.expressions.iter_mut() {
strip_with_from_expr(expr);
}
if let Some(ref mut w) = select.where_clause {
strip_with_from_expr(&mut w.this);
}
}
fn strip_with_from_expr(expr: &mut Expression) {
match expr {
Expression::Subquery(ref mut subquery) => {
strip_with_from_inner_query(&mut subquery.this);
}
Expression::Alias(ref mut alias) => {
strip_with_from_expr(&mut alias.this);
}
Expression::Select(ref mut select) => {
select.with = None;
strip_nested_with_clauses(select, false);
}
_ => {}
}
}
fn strip_with_from_inner_query(expr: &mut Expression) {
if let Expression::Select(ref mut select) = expr {
select.with = None;
strip_nested_with_clauses(select, false);
}
}
fn collect_nested_ctes(
expr: &Expression,
collected: &mut Vec<crate::expressions::Cte>,
has_recursive: &mut bool,
is_top_level: bool,
) {
match expr {
Expression::Select(select) => {
if !is_top_level {
if let Some(ref with) = select.with {
if with.recursive {
*has_recursive = true;
}
collected.extend(with.ctes.clone());
}
}
if let Some(ref from) = select.from {
for expr in &from.expressions {
collect_nested_ctes(expr, collected, has_recursive, false);
}
}
for join in &select.joins {
collect_nested_ctes(&join.this, collected, has_recursive, false);
}
for sel_expr in &select.expressions {
collect_nested_ctes(sel_expr, collected, has_recursive, false);
}
if let Some(ref where_clause) = select.where_clause {
collect_nested_ctes(&where_clause.this, collected, has_recursive, false);
}
}
Expression::Subquery(subquery) => {
collect_nested_ctes(&subquery.this, collected, has_recursive, false);
}
Expression::Alias(alias) => {
collect_nested_ctes(&alias.this, collected, has_recursive, false);
}
_ => {}
}
}
pub fn eliminate_window_clause(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(named_windows) = select.windows.take() {
let window_map: std::collections::HashMap<String, &Over> = named_windows
.iter()
.map(|nw| (nw.name.name.to_lowercase(), &nw.spec))
.collect();
select.expressions = select
.expressions
.into_iter()
.map(|e| inline_window_refs(e, &window_map))
.collect();
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn inline_window_refs(
expr: Expression,
window_map: &std::collections::HashMap<String, &Over>,
) -> Expression {
match expr {
Expression::WindowFunction(mut wf) => {
if let Some(ref name) = wf.over.window_name {
let key = name.name.to_lowercase();
if let Some(named_spec) = window_map.get(&key) {
if wf.over.partition_by.is_empty() && !named_spec.partition_by.is_empty() {
wf.over.partition_by = named_spec.partition_by.clone();
}
if wf.over.order_by.is_empty() && !named_spec.order_by.is_empty() {
wf.over.order_by = named_spec.order_by.clone();
}
if wf.over.frame.is_none() && named_spec.frame.is_some() {
wf.over.frame = named_spec.frame.clone();
}
wf.over.window_name = None;
}
}
Expression::WindowFunction(wf)
}
Expression::Alias(mut alias) => {
alias.this = inline_window_refs(alias.this, window_map);
Expression::Alias(alias)
}
other => other,
}
}
pub fn eliminate_join_marks(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
let has_join_marks = select
.where_clause
.as_ref()
.map_or(false, |w| contains_join_mark(&w.this));
if !has_join_marks {
return Ok(Expression::Select(select));
}
let from_tables: Vec<String> = select
.from
.as_ref()
.map(|f| {
f.expressions
.iter()
.filter_map(|e| get_table_name(e))
.collect()
})
.unwrap_or_default();
let mut join_conditions: std::collections::HashMap<String, Vec<Expression>> =
std::collections::HashMap::new();
let mut remaining_conditions: Vec<Expression> = Vec::new();
if let Some(ref where_clause) = select.where_clause {
extract_join_mark_conditions(
&where_clause.this,
&mut join_conditions,
&mut remaining_conditions,
);
}
let mut new_joins = select.joins.clone();
for (table_name, conditions) in join_conditions {
let table_in_from = from_tables.contains(&table_name);
if table_in_from && !conditions.is_empty() {
let combined_condition = conditions.into_iter().reduce(|a, b| {
Expression::And(Box::new(BinaryOp {
left: a,
right: b,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
});
if let Some(ref mut from) = select.from {
if let Some(pos) = from
.expressions
.iter()
.position(|e| get_table_name(e).map_or(false, |n| n == table_name))
{
if from.expressions.len() > 1 {
let join_table = from.expressions.remove(pos);
new_joins.push(crate::expressions::Join {
this: join_table,
kind: JoinKind::Left,
on: combined_condition,
using: vec![],
use_inner_keyword: false,
use_outer_keyword: true,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
});
}
}
}
}
}
select.joins = new_joins;
if remaining_conditions.is_empty() {
select.where_clause = None;
} else {
let combined = remaining_conditions.into_iter().reduce(|a, b| {
Expression::And(Box::new(BinaryOp {
left: a,
right: b,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
});
select.where_clause = combined.map(|c| Where { this: c });
}
clear_join_marks(&mut Expression::Select(select.clone()));
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn contains_join_mark(expr: &Expression) -> bool {
match expr {
Expression::Column(col) => col.join_mark,
Expression::And(op) | Expression::Or(op) => {
contains_join_mark(&op.left) || contains_join_mark(&op.right)
}
Expression::Eq(op)
| Expression::Neq(op)
| Expression::Lt(op)
| Expression::Lte(op)
| Expression::Gt(op)
| Expression::Gte(op) => contains_join_mark(&op.left) || contains_join_mark(&op.right),
Expression::Not(op) => contains_join_mark(&op.this),
_ => false,
}
}
fn get_table_name(expr: &Expression) -> Option<String> {
match expr {
Expression::Table(t) => Some(t.name.name.clone()),
Expression::Alias(a) => Some(a.alias.name.clone()),
_ => None,
}
}
fn extract_join_mark_conditions(
expr: &Expression,
join_conditions: &mut std::collections::HashMap<String, Vec<Expression>>,
remaining: &mut Vec<Expression>,
) {
match expr {
Expression::And(op) => {
extract_join_mark_conditions(&op.left, join_conditions, remaining);
extract_join_mark_conditions(&op.right, join_conditions, remaining);
}
_ => {
if let Some(table) = get_join_mark_table(expr) {
join_conditions
.entry(table)
.or_insert_with(Vec::new)
.push(expr.clone());
} else {
remaining.push(expr.clone());
}
}
}
}
fn get_join_mark_table(expr: &Expression) -> Option<String> {
match expr {
Expression::Eq(op)
| Expression::Neq(op)
| Expression::Lt(op)
| Expression::Lte(op)
| Expression::Gt(op)
| Expression::Gte(op) => {
if let Expression::Column(col) = &op.left {
if col.join_mark {
return col.table.as_ref().map(|t| t.name.clone());
}
}
if let Expression::Column(col) = &op.right {
if col.join_mark {
return col.table.as_ref().map(|t| t.name.clone());
}
}
None
}
_ => None,
}
}
fn clear_join_marks(expr: &mut Expression) {
match expr {
Expression::Column(col) => col.join_mark = false,
Expression::Select(select) => {
if let Some(ref mut w) = select.where_clause {
clear_join_marks(&mut w.this);
}
for sel_expr in &mut select.expressions {
clear_join_marks(sel_expr);
}
}
Expression::And(op) | Expression::Or(op) => {
clear_join_marks(&mut op.left);
clear_join_marks(&mut op.right);
}
Expression::Eq(op)
| Expression::Neq(op)
| Expression::Lt(op)
| Expression::Lte(op)
| Expression::Gt(op)
| Expression::Gte(op) => {
clear_join_marks(&mut op.left);
clear_join_marks(&mut op.right);
}
_ => {}
}
}
pub fn add_recursive_cte_column_names(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(ref mut with) = select.with {
if with.recursive {
let mut counter = 0;
for cte in &mut with.ctes {
if cte.columns.is_empty() {
if let Expression::Select(ref cte_select) = cte.this {
let names: Vec<Identifier> = cte_select
.expressions
.iter()
.map(|e| match e {
Expression::Alias(a) => a.alias.clone(),
Expression::Column(c) => c.name.clone(),
_ => {
counter += 1;
Identifier::new(format!("_c_{}", counter))
}
})
.collect();
cte.columns = names;
}
}
}
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
pub fn epoch_cast_to_ts(expr: Expression) -> Result<Expression> {
match expr {
Expression::Cast(mut cast) => {
if let Expression::Literal(ref lit) = cast.this {
if let Literal::String(ref s) = lit.as_ref() {
if s.to_lowercase() == "epoch" {
if is_temporal_type(&cast.to) {
cast.this = Expression::Literal(Box::new(Literal::String(
"1970-01-01 00:00:00".to_string(),
)));
}
}
}
}
Ok(Expression::Cast(cast))
}
Expression::TryCast(mut try_cast) => {
if let Expression::Literal(ref lit) = try_cast.this {
if let Literal::String(ref s) = lit.as_ref() {
if s.to_lowercase() == "epoch" {
if is_temporal_type(&try_cast.to) {
try_cast.this = Expression::Literal(Box::new(Literal::String(
"1970-01-01 00:00:00".to_string(),
)));
}
}
}
}
Ok(Expression::TryCast(try_cast))
}
other => Ok(other),
}
}
fn is_temporal_type(dt: &DataType) -> bool {
matches!(
dt,
DataType::Date | DataType::Timestamp { .. } | DataType::Time { .. }
)
}
pub fn ensure_bools(expr: Expression) -> Result<Expression> {
let expr = ensure_bools_in_case(expr);
match expr {
Expression::Select(mut select) => {
if let Some(ref mut where_clause) = select.where_clause {
where_clause.this = ensure_bool_condition(where_clause.this.clone());
}
if let Some(ref mut having) = select.having {
having.this = ensure_bool_condition(having.this.clone());
}
Ok(Expression::Select(select))
}
Expression::And(_) | Expression::Or(_) | Expression::Not(_) => {
Ok(ensure_bool_condition(expr))
}
other => Ok(other),
}
}
fn ensure_bools_in_case(expr: Expression) -> Expression {
match expr {
Expression::Case(mut case) => {
case.whens = case
.whens
.into_iter()
.map(|(condition, result)| {
let new_condition = ensure_bool_condition(ensure_bools_in_case(condition));
let new_result = ensure_bools_in_case(result);
(new_condition, new_result)
})
.collect();
if let Some(else_expr) = case.else_ {
case.else_ = Some(ensure_bools_in_case(else_expr));
}
Expression::Case(Box::new(*case))
}
Expression::Select(mut select) => {
select.expressions = select
.expressions
.into_iter()
.map(ensure_bools_in_case)
.collect();
Expression::Select(select)
}
Expression::Alias(mut alias) => {
alias.this = ensure_bools_in_case(alias.this);
Expression::Alias(alias)
}
Expression::Paren(mut paren) => {
paren.this = ensure_bools_in_case(paren.this);
Expression::Paren(paren)
}
other => other,
}
}
fn is_boolean_expression(expr: &Expression) -> bool {
matches!(
expr,
Expression::Eq(_)
| Expression::Neq(_)
| Expression::Lt(_)
| Expression::Lte(_)
| Expression::Gt(_)
| Expression::Gte(_)
| Expression::Is(_)
| Expression::IsNull(_)
| Expression::IsTrue(_)
| Expression::IsFalse(_)
| Expression::Like(_)
| Expression::ILike(_)
| Expression::SimilarTo(_)
| Expression::Glob(_)
| Expression::RegexpLike(_)
| Expression::In(_)
| Expression::Between(_)
| Expression::Exists(_)
| Expression::And(_)
| Expression::Or(_)
| Expression::Not(_)
| Expression::Any(_)
| Expression::All(_)
| Expression::EqualNull(_)
)
}
fn wrap_neq_zero(expr: Expression) -> Expression {
Expression::Neq(Box::new(BinaryOp {
left: expr,
right: Expression::Literal(Box::new(Literal::Number("0".to_string()))),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
}
fn ensure_bool_condition(expr: Expression) -> Expression {
match expr {
Expression::And(op) => {
let new_op = BinaryOp {
left: ensure_bool_condition(op.left.clone()),
right: ensure_bool_condition(op.right.clone()),
left_comments: op.left_comments.clone(),
operator_comments: op.operator_comments.clone(),
trailing_comments: op.trailing_comments.clone(),
inferred_type: None,
};
Expression::And(Box::new(new_op))
}
Expression::Or(op) => {
let new_op = BinaryOp {
left: ensure_bool_condition(op.left.clone()),
right: ensure_bool_condition(op.right.clone()),
left_comments: op.left_comments.clone(),
operator_comments: op.operator_comments.clone(),
trailing_comments: op.trailing_comments.clone(),
inferred_type: None,
};
Expression::Or(Box::new(new_op))
}
Expression::Not(op) => Expression::Not(Box::new(crate::expressions::UnaryOp {
this: ensure_bool_condition(op.this.clone()),
inferred_type: None,
})),
Expression::Paren(paren) => Expression::Paren(Box::new(crate::expressions::Paren {
this: ensure_bool_condition(paren.this.clone()),
trailing_comments: paren.trailing_comments.clone(),
})),
Expression::Boolean(BooleanLiteral { value: true }) => {
Expression::Paren(Box::new(crate::expressions::Paren {
this: Expression::Eq(Box::new(BinaryOp {
left: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
right: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
trailing_comments: vec![],
}))
}
Expression::Boolean(BooleanLiteral { value: false }) => {
Expression::Paren(Box::new(crate::expressions::Paren {
this: Expression::Eq(Box::new(BinaryOp {
left: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
right: Expression::Literal(Box::new(Literal::Number("0".to_string()))),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
trailing_comments: vec![],
}))
}
ref e if is_boolean_expression(e) => expr,
_ => wrap_neq_zero(expr),
}
}
pub fn unqualify_columns(expr: Expression) -> Result<Expression> {
Ok(unqualify_columns_recursive(expr))
}
fn unqualify_columns_recursive(expr: Expression) -> Expression {
match expr {
Expression::Column(mut col) => {
col.table = None;
Expression::Column(col)
}
Expression::Select(mut select) => {
select.expressions = select
.expressions
.into_iter()
.map(unqualify_columns_recursive)
.collect();
if let Some(ref mut where_clause) = select.where_clause {
where_clause.this = unqualify_columns_recursive(where_clause.this.clone());
}
if let Some(ref mut having) = select.having {
having.this = unqualify_columns_recursive(having.this.clone());
}
if let Some(ref mut group_by) = select.group_by {
group_by.expressions = group_by
.expressions
.iter()
.cloned()
.map(unqualify_columns_recursive)
.collect();
}
if let Some(ref mut order_by) = select.order_by {
order_by.expressions = order_by
.expressions
.iter()
.map(|o| crate::expressions::Ordered {
this: unqualify_columns_recursive(o.this.clone()),
desc: o.desc,
nulls_first: o.nulls_first,
explicit_asc: o.explicit_asc,
with_fill: o.with_fill.clone(),
})
.collect();
}
for join in &mut select.joins {
if let Some(ref mut on) = join.on {
*on = unqualify_columns_recursive(on.clone());
}
}
Expression::Select(select)
}
Expression::Alias(mut alias) => {
alias.this = unqualify_columns_recursive(alias.this);
Expression::Alias(alias)
}
Expression::And(op) => Expression::And(Box::new(unqualify_binary_op(*op))),
Expression::Or(op) => Expression::Or(Box::new(unqualify_binary_op(*op))),
Expression::Eq(op) => Expression::Eq(Box::new(unqualify_binary_op(*op))),
Expression::Neq(op) => Expression::Neq(Box::new(unqualify_binary_op(*op))),
Expression::Lt(op) => Expression::Lt(Box::new(unqualify_binary_op(*op))),
Expression::Lte(op) => Expression::Lte(Box::new(unqualify_binary_op(*op))),
Expression::Gt(op) => Expression::Gt(Box::new(unqualify_binary_op(*op))),
Expression::Gte(op) => Expression::Gte(Box::new(unqualify_binary_op(*op))),
Expression::Add(op) => Expression::Add(Box::new(unqualify_binary_op(*op))),
Expression::Sub(op) => Expression::Sub(Box::new(unqualify_binary_op(*op))),
Expression::Mul(op) => Expression::Mul(Box::new(unqualify_binary_op(*op))),
Expression::Div(op) => Expression::Div(Box::new(unqualify_binary_op(*op))),
Expression::Function(mut func) => {
func.args = func
.args
.into_iter()
.map(unqualify_columns_recursive)
.collect();
Expression::Function(func)
}
Expression::AggregateFunction(mut func) => {
func.args = func
.args
.into_iter()
.map(unqualify_columns_recursive)
.collect();
Expression::AggregateFunction(func)
}
Expression::Case(mut case) => {
case.whens = case
.whens
.into_iter()
.map(|(cond, result)| {
(
unqualify_columns_recursive(cond),
unqualify_columns_recursive(result),
)
})
.collect();
if let Some(else_expr) = case.else_ {
case.else_ = Some(unqualify_columns_recursive(else_expr));
}
Expression::Case(case)
}
other => other,
}
}
fn unqualify_binary_op(mut op: BinaryOp) -> BinaryOp {
op.left = unqualify_columns_recursive(op.left);
op.right = unqualify_columns_recursive(op.right);
op
}
pub fn unnest_generate_date_array_using_recursive_cte(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
let mut cte_count = 0;
let mut new_ctes: Vec<crate::expressions::Cte> = Vec::new();
if let Some(ref mut with) = select.with {
for cte in &mut with.ctes {
process_expression_for_gda(&mut cte.this, &mut cte_count, &mut new_ctes);
}
}
if let Some(ref mut from) = select.from {
for table_expr in &mut from.expressions {
if let Some((cte, replacement)) =
try_convert_generate_date_array(table_expr, &mut cte_count)
{
new_ctes.push(cte);
*table_expr = replacement;
}
}
}
for join in &mut select.joins {
if let Some((cte, replacement)) =
try_convert_generate_date_array(&join.this, &mut cte_count)
{
new_ctes.push(cte);
join.this = replacement;
}
}
if !new_ctes.is_empty() {
let with_clause = select.with.get_or_insert_with(|| crate::expressions::With {
ctes: Vec::new(),
recursive: true, leading_comments: vec![],
search: None,
});
with_clause.recursive = true;
let mut all_ctes = new_ctes;
all_ctes.append(&mut with_clause.ctes);
with_clause.ctes = all_ctes;
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn process_expression_for_gda(
expr: &mut Expression,
cte_count: &mut usize,
new_ctes: &mut Vec<crate::expressions::Cte>,
) {
match expr {
Expression::Select(ref mut select) => {
if let Some(ref mut from) = select.from {
for table_expr in &mut from.expressions {
if let Some((cte, replacement)) =
try_convert_generate_date_array(table_expr, cte_count)
{
new_ctes.push(cte);
*table_expr = replacement;
}
}
}
for join in &mut select.joins {
if let Some((cte, replacement)) =
try_convert_generate_date_array(&join.this, cte_count)
{
new_ctes.push(cte);
join.this = replacement;
}
}
}
Expression::Union(ref mut u) => {
process_expression_for_gda(&mut u.left, cte_count, new_ctes);
process_expression_for_gda(&mut u.right, cte_count, new_ctes);
}
Expression::Subquery(ref mut sq) => {
process_expression_for_gda(&mut sq.this, cte_count, new_ctes);
}
_ => {}
}
}
fn try_convert_generate_date_array(
expr: &Expression,
cte_count: &mut usize,
) -> Option<(crate::expressions::Cte, Expression)> {
try_convert_generate_date_array_with_name(expr, cte_count, None)
}
fn try_convert_generate_date_array_with_name(
expr: &Expression,
cte_count: &mut usize,
column_name_override: Option<&str>,
) -> Option<(crate::expressions::Cte, Expression)> {
fn extract_gda_args(
inner: &Expression,
) -> Option<(&Expression, &Expression, Option<&Expression>)> {
match inner {
Expression::GenerateDateArray(gda) => {
let start = gda.start.as_ref()?;
let end = gda.end.as_ref()?;
let step = gda.step.as_deref();
Some((start, end, step))
}
Expression::GenerateSeries(gs) => {
let start = gs.start.as_deref()?;
let end = gs.end.as_deref()?;
let step = gs.step.as_deref();
Some((start, end, step))
}
Expression::Function(f) if f.name.eq_ignore_ascii_case("GENERATE_DATE_ARRAY") => {
if f.args.len() >= 2 {
let start = &f.args[0];
let end = &f.args[1];
let step = f.args.get(2);
Some((start, end, step))
} else {
None
}
}
_ => None,
}
}
if let Expression::Unnest(unnest) = expr {
if let Some((start, end, step_opt)) = extract_gda_args(&unnest.this) {
let start = start;
let end = end;
let step: Option<&Expression> = step_opt;
let cte_name = if *cte_count == 0 {
"_generated_dates".to_string()
} else {
format!("_generated_dates_{}", cte_count)
};
*cte_count += 1;
let column_name =
Identifier::new(column_name_override.unwrap_or("date_value").to_string());
let cast_to_date = |expr: &Expression| -> Expression {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
if let Expression::Literal(lit) = expr {
if let Literal::Date(d) = lit.as_ref() {
Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String(d.clone()))),
to: DataType::Date,
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}))
} else {
expr.clone()
}
} else {
unreachable!()
}
}
Expression::Cast(c) if matches!(c.to, DataType::Date) => expr.clone(),
_ => Expression::Cast(Box::new(Cast {
this: expr.clone(),
to: DataType::Date,
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
})),
}
};
let base_select = Select {
expressions: vec![Expression::Alias(Box::new(crate::expressions::Alias {
this: cast_to_date(start),
alias: column_name.clone(),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))],
..Select::new()
};
let normalize_interval = |expr: &Expression| -> Expression {
if let Expression::Interval(ref iv) = expr {
let mut iv_clone = iv.as_ref().clone();
if let Some(Expression::Literal(ref lit)) = iv_clone.this {
if let Literal::String(ref s) = lit.as_ref() {
if s.parse::<f64>().is_ok() {
iv_clone.this =
Some(Expression::Literal(Box::new(Literal::Number(s.clone()))));
}
}
}
Expression::Interval(Box::new(iv_clone))
} else {
expr.clone()
}
};
let normalized_step = step.map(|s| normalize_interval(s)).unwrap_or_else(|| {
Expression::Interval(Box::new(crate::expressions::Interval {
this: Some(Expression::Literal(Box::new(Literal::Number(
"1".to_string(),
)))),
unit: Some(crate::expressions::IntervalUnitSpec::Simple {
unit: crate::expressions::IntervalUnit::Day,
use_plural: false,
}),
}))
});
let (add_unit, add_count) = extract_interval_unit_and_count(&normalized_step);
let date_add_expr = Expression::DateAdd(Box::new(crate::expressions::DateAddFunc {
this: Expression::Column(Box::new(crate::expressions::Column {
name: column_name.clone(),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})),
interval: add_count,
unit: add_unit,
}));
let cast_date_add = Expression::Cast(Box::new(Cast {
this: date_add_expr.clone(),
to: DataType::Date,
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let recursive_select = Select {
expressions: vec![cast_date_add.clone()],
from: Some(From {
expressions: vec![Expression::Table(Box::new(
crate::expressions::TableRef::new(&cte_name),
))],
}),
where_clause: Some(Where {
this: Expression::Lte(Box::new(BinaryOp {
left: cast_date_add,
right: cast_to_date(end),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})),
}),
..Select::new()
};
let union = crate::expressions::Union {
left: Expression::Select(Box::new(base_select)),
right: Expression::Select(Box::new(recursive_select)),
all: true, distinct: false,
with: None,
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
by_name: false,
side: None,
kind: None,
corresponding: false,
strict: false,
on_columns: Vec::new(),
};
let cte = crate::expressions::Cte {
this: Expression::Union(Box::new(union)),
alias: Identifier::new(cte_name.clone()),
columns: vec![column_name.clone()],
materialized: None,
key_expressions: Vec::new(),
alias_first: true,
comments: Vec::new(),
};
let replacement_select = Select {
expressions: vec![Expression::Column(Box::new(crate::expressions::Column {
name: column_name,
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}))],
from: Some(From {
expressions: vec![Expression::Table(Box::new(
crate::expressions::TableRef::new(&cte_name),
))],
}),
..Select::new()
};
let replacement = Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(replacement_select)),
alias: Some(Identifier::new(cte_name)),
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
}));
return Some((cte, replacement));
}
}
if let Expression::Alias(alias) = expr {
let col_name = alias.column_aliases.first().map(|id| id.name.as_str());
if let Some((cte, replacement)) =
try_convert_generate_date_array_with_name(&alias.this, cte_count, col_name)
{
if col_name.is_some() {
return Some((cte, replacement));
}
let new_alias = Expression::Alias(Box::new(crate::expressions::Alias {
this: replacement,
alias: alias.alias.clone(),
column_aliases: alias.column_aliases.clone(),
pre_alias_comments: alias.pre_alias_comments.clone(),
trailing_comments: alias.trailing_comments.clone(),
inferred_type: None,
}));
return Some((cte, new_alias));
}
}
None
}
fn extract_interval_unit_and_count(
expr: &Expression,
) -> (crate::expressions::IntervalUnit, Expression) {
use crate::expressions::{IntervalUnit, IntervalUnitSpec, Literal};
if let Expression::Interval(ref iv) = expr {
if let Some(ref unit_spec) = iv.unit {
if let IntervalUnitSpec::Simple { unit, .. } = unit_spec {
let count = match &iv.this {
Some(e) => e.clone(),
None => Expression::Literal(Box::new(Literal::Number("1".to_string()))),
};
return (unit.clone(), count);
}
}
if let Some(ref val_expr) = iv.this {
match val_expr {
Expression::Literal(lit)
if matches!(lit.as_ref(), Literal::String(_) | Literal::Number(_)) =>
{
let s = match lit.as_ref() {
Literal::String(s) | Literal::Number(s) => s,
_ => unreachable!(),
};
let parts: Vec<&str> = s.trim().splitn(2, char::is_whitespace).collect();
if parts.len() == 2 {
let count_str = parts[0].trim();
let unit_str = parts[1].trim().to_uppercase();
let unit = match unit_str.as_str() {
"YEAR" | "YEARS" => IntervalUnit::Year,
"QUARTER" | "QUARTERS" => IntervalUnit::Quarter,
"MONTH" | "MONTHS" => IntervalUnit::Month,
"WEEK" | "WEEKS" => IntervalUnit::Week,
"DAY" | "DAYS" => IntervalUnit::Day,
"HOUR" | "HOURS" => IntervalUnit::Hour,
"MINUTE" | "MINUTES" => IntervalUnit::Minute,
"SECOND" | "SECONDS" => IntervalUnit::Second,
"MILLISECOND" | "MILLISECONDS" => IntervalUnit::Millisecond,
"MICROSECOND" | "MICROSECONDS" => IntervalUnit::Microsecond,
_ => IntervalUnit::Day,
};
return (
unit,
Expression::Literal(Box::new(Literal::Number(count_str.to_string()))),
);
}
if s.parse::<f64>().is_ok() {
return (
IntervalUnit::Day,
Expression::Literal(Box::new(Literal::Number(s.clone()))),
);
}
}
_ => {}
}
}
(
IntervalUnit::Day,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)
} else {
(
IntervalUnit::Day,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)
}
}
pub fn no_ilike_sql(expr: Expression) -> Result<Expression> {
match expr {
Expression::ILike(ilike) => {
let lower_left = Expression::Function(Box::new(crate::expressions::Function {
name: "LOWER".to_string(),
args: vec![ilike.left],
distinct: false,
trailing_comments: vec![],
use_bracket_syntax: false,
no_parens: false,
quoted: false,
span: None,
inferred_type: None,
}));
let lower_right = Expression::Function(Box::new(crate::expressions::Function {
name: "LOWER".to_string(),
args: vec![ilike.right],
distinct: false,
trailing_comments: vec![],
use_bracket_syntax: false,
no_parens: false,
quoted: false,
span: None,
inferred_type: None,
}));
Ok(Expression::Like(Box::new(crate::expressions::LikeOp {
left: lower_left,
right: lower_right,
escape: ilike.escape,
quantifier: ilike.quantifier,
inferred_type: None,
})))
}
other => Ok(other),
}
}
pub fn no_trycast_sql(expr: Expression) -> Result<Expression> {
match expr {
Expression::TryCast(try_cast) => Ok(Expression::Cast(try_cast)),
other => Ok(other),
}
}
pub fn no_safe_cast_sql(expr: Expression) -> Result<Expression> {
match expr {
Expression::SafeCast(safe_cast) => Ok(Expression::Cast(safe_cast)),
other => Ok(other),
}
}
pub fn no_comment_column_constraint(expr: Expression) -> Result<Expression> {
Ok(expr)
}
pub fn unnest_generate_series(expr: Expression) -> Result<Expression> {
match expr {
Expression::Table(ref table) => {
if table.name.name.to_uppercase() == "GENERATE_SERIES" {
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: expr.clone(),
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
return Ok(Expression::Alias(Box::new(crate::expressions::Alias {
this: unnest,
alias: Identifier::new("_u".to_string()),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})));
}
Ok(expr)
}
Expression::GenerateSeries(gs) => {
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: Expression::GenerateSeries(gs),
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
Ok(unnest)
}
other => Ok(other),
}
}
pub fn unwrap_unnest_generate_series_for_postgres(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &unwrap_unnest_generate_series_single)
}
fn unwrap_unnest_generate_series_single(expr: Expression) -> Result<Expression> {
use crate::expressions::*;
match expr {
Expression::Select(mut select) => {
if let Some(ref mut from) = select.from {
for table_expr in &mut from.expressions {
if let Some(replacement) = try_unwrap_unnest_gen_series(table_expr) {
*table_expr = replacement;
}
}
}
for join in &mut select.joins {
if let Some(replacement) = try_unwrap_unnest_gen_series(&join.this) {
join.this = replacement;
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn try_unwrap_unnest_gen_series(expr: &Expression) -> Option<Expression> {
use crate::expressions::*;
let gen_series = match expr {
Expression::Unnest(unnest) => {
if let Expression::GenerateSeries(ref gs) = unnest.this {
Some(gs.as_ref().clone())
} else {
None
}
}
Expression::Alias(alias) => {
if let Expression::Unnest(ref unnest) = alias.this {
if let Expression::GenerateSeries(ref gs) = unnest.this {
Some(gs.as_ref().clone())
} else {
None
}
} else {
None
}
}
_ => None,
};
let gs = gen_series?;
let value_col = Expression::boxed_column(Column {
name: Identifier::new("value".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let cast_value = Expression::Cast(Box::new(Cast {
this: value_col,
to: DataType::Date,
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let gen_series_expr = Expression::GenerateSeries(Box::new(gs));
let gen_series_aliased = Expression::Alias(Box::new(Alias {
this: gen_series_expr,
alias: Identifier::new("_t".to_string()),
column_aliases: vec![Identifier::new("value".to_string())],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}));
let mut inner_select = Select::new();
inner_select.expressions = vec![cast_value];
inner_select.from = Some(From {
expressions: vec![gen_series_aliased],
});
let inner_select_expr = Expression::Select(Box::new(inner_select));
let subquery = Expression::Subquery(Box::new(Subquery {
this: inner_select_expr,
alias: None,
column_aliases: vec![],
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: vec![],
inferred_type: None,
}));
Some(Expression::Alias(Box::new(Alias {
this: subquery,
alias: Identifier::new("_unnested_generate_series".to_string()),
column_aliases: vec![],
pre_alias_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
})))
}
pub fn expand_between_in_delete(expr: Expression) -> Result<Expression> {
match expr {
Expression::Delete(mut delete) => {
if let Some(ref mut where_clause) = delete.where_clause {
where_clause.this = expand_between_recursive(where_clause.this.clone());
}
Ok(Expression::Delete(delete))
}
other => Ok(other),
}
}
fn expand_between_recursive(expr: Expression) -> Expression {
match expr {
Expression::Between(between) => {
let this = expand_between_recursive(between.this.clone());
let low = expand_between_recursive(between.low);
let high = expand_between_recursive(between.high);
if between.not {
Expression::Or(Box::new(BinaryOp::new(
Expression::Lt(Box::new(BinaryOp::new(this.clone(), low))),
Expression::Gt(Box::new(BinaryOp::new(this, high))),
)))
} else {
Expression::And(Box::new(BinaryOp::new(
Expression::Gte(Box::new(BinaryOp::new(this.clone(), low))),
Expression::Lte(Box::new(BinaryOp::new(this, high))),
)))
}
}
Expression::And(mut op) => {
op.left = expand_between_recursive(op.left);
op.right = expand_between_recursive(op.right);
Expression::And(op)
}
Expression::Or(mut op) => {
op.left = expand_between_recursive(op.left);
op.right = expand_between_recursive(op.right);
Expression::Or(op)
}
Expression::Not(mut op) => {
op.this = expand_between_recursive(op.this);
Expression::Not(op)
}
Expression::Paren(mut paren) => {
paren.this = expand_between_recursive(paren.this);
Expression::Paren(paren)
}
other => other,
}
}
pub fn pushdown_cte_column_names(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(ref mut with) = select.with {
for cte in &mut with.ctes {
if !cte.columns.is_empty() {
let is_star = matches!(&cte.this, Expression::Select(s) if
s.expressions.len() == 1 && matches!(&s.expressions[0], Expression::Star(_)));
if is_star {
cte.columns.clear();
continue;
}
let column_names: Vec<Identifier> = cte.columns.drain(..).collect();
if let Expression::Select(ref mut inner_select) = cte.this {
let new_exprs: Vec<Expression> = inner_select
.expressions
.drain(..)
.zip(
column_names
.into_iter()
.chain(std::iter::repeat_with(|| Identifier::new(""))),
)
.map(|(expr, col_name)| {
if col_name.name.is_empty() {
return expr;
}
match expr {
Expression::Alias(mut a) => {
a.alias = col_name;
Expression::Alias(a)
}
other => {
Expression::Alias(Box::new(crate::expressions::Alias {
this: other,
alias: col_name,
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}))
}
}
})
.collect();
inner_select.expressions = new_exprs;
}
}
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
pub fn simplify_nested_paren_values(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(ref mut from) = select.from {
for from_item in from.expressions.iter_mut() {
simplify_paren_values_in_from(from_item);
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn simplify_paren_values_in_from(expr: &mut Expression) {
let replacement = match expr {
Expression::Subquery(ref subquery) => {
if let Expression::Paren(ref paren) = subquery.this {
if matches!(&paren.this, Expression::Values(_)) {
let mut new_sub = subquery.as_ref().clone();
new_sub.this = paren.this.clone();
Some(Expression::Subquery(Box::new(new_sub)))
} else {
None
}
} else {
None
}
}
Expression::Paren(ref outer_paren) => {
if let Expression::Subquery(ref subquery) = outer_paren.this {
if matches!(&subquery.this, Expression::Values(_)) {
Some(outer_paren.this.clone())
}
else if let Expression::Paren(ref paren) = subquery.this {
if matches!(&paren.this, Expression::Values(_)) {
let mut new_sub = subquery.as_ref().clone();
new_sub.this = paren.this.clone();
Some(Expression::Subquery(Box::new(new_sub)))
} else {
None
}
} else {
None
}
} else if let Expression::Paren(ref inner_paren) = outer_paren.this {
if matches!(&inner_paren.this, Expression::Values(_)) {
Some(outer_paren.this.clone())
} else {
None
}
} else {
None
}
}
_ => None,
};
if let Some(new_expr) = replacement {
*expr = new_expr;
}
}
pub fn add_auto_table_alias(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(mut select) => {
if let Some(ref mut from) = select.from {
let mut counter = 0usize;
for from_item in from.expressions.iter_mut() {
add_auto_alias_to_from_item(from_item, &mut counter);
}
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn add_auto_alias_to_from_item(expr: &mut Expression, counter: &mut usize) {
use crate::expressions::Identifier;
match expr {
Expression::Alias(ref mut alias) => {
if alias.alias.name.is_empty() && !alias.column_aliases.is_empty() {
alias.alias = Identifier::new(format!("_t{}", counter));
*counter += 1;
}
}
_ => {}
}
}
pub fn propagate_struct_field_names(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &propagate_struct_names_in_expr)
}
fn propagate_struct_names_in_expr(expr: Expression) -> Result<Expression> {
use crate::expressions::{Alias, ArrayConstructor, Function, Identifier};
fn propagate_in_elements(elements: &[Expression]) -> Option<Vec<Expression>> {
if elements.len() <= 1 {
return None;
}
if let Some(Expression::Function(ref first_struct)) = elements.first() {
if first_struct.name.eq_ignore_ascii_case("STRUCT") {
let field_names: Vec<Option<String>> = first_struct
.args
.iter()
.map(|arg| {
if let Expression::Alias(a) = arg {
Some(a.alias.name.clone())
} else {
None
}
})
.collect();
if field_names.iter().any(|n| n.is_some()) {
let mut new_elements = Vec::with_capacity(elements.len());
new_elements.push(elements[0].clone());
for elem in &elements[1..] {
if let Expression::Function(ref s) = elem {
if s.name.eq_ignore_ascii_case("STRUCT")
&& s.args.len() == field_names.len()
{
let all_unnamed =
s.args.iter().all(|a| !matches!(a, Expression::Alias(_)));
if all_unnamed {
let new_args: Vec<Expression> = s
.args
.iter()
.zip(field_names.iter())
.map(|(val, name)| {
if let Some(n) = name {
Expression::Alias(Box::new(Alias::new(
val.clone(),
Identifier::new(n.clone()),
)))
} else {
val.clone()
}
})
.collect();
new_elements.push(Expression::Function(Box::new(
Function::new("STRUCT".to_string(), new_args),
)));
continue;
}
}
}
new_elements.push(elem.clone());
}
return Some(new_elements);
}
}
}
None
}
if let Expression::Array(ref arr) = expr {
if let Some(new_elements) = propagate_in_elements(&arr.expressions) {
return Ok(Expression::Array(Box::new(crate::expressions::Array {
expressions: new_elements,
})));
}
}
if let Expression::ArrayFunc(ref arr) = expr {
if let Some(new_elements) = propagate_in_elements(&arr.expressions) {
return Ok(Expression::ArrayFunc(Box::new(ArrayConstructor {
expressions: new_elements,
bracket_notation: arr.bracket_notation,
use_list_keyword: arr.use_list_keyword,
})));
}
}
Ok(expr)
}
pub fn unnest_alias_to_column_alias(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &unnest_alias_transform_single_select)
}
pub fn unnest_from_to_cross_join(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &unnest_from_to_cross_join_single_select)
}
fn unnest_from_to_cross_join_single_select(expr: Expression) -> Result<Expression> {
if let Expression::Select(mut select) = expr {
if let Some(ref mut from) = select.from {
if from.expressions.len() > 1 {
let mut new_from_exprs = Vec::new();
let mut new_cross_joins = Vec::new();
for (idx, from_item) in from.expressions.drain(..).enumerate() {
if idx == 0 {
new_from_exprs.push(from_item);
} else {
let is_unnest = match &from_item {
Expression::Unnest(_) => true,
Expression::Alias(a) => matches!(a.this, Expression::Unnest(_)),
_ => false,
};
if is_unnest {
new_cross_joins.push(crate::expressions::Join {
this: from_item,
on: None,
using: Vec::new(),
kind: JoinKind::Cross,
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
});
} else {
new_from_exprs.push(from_item);
}
}
}
from.expressions = new_from_exprs;
new_cross_joins.append(&mut select.joins);
select.joins = new_cross_joins;
}
}
Ok(Expression::Select(select))
} else {
Ok(expr)
}
}
pub fn wrap_unnest_join_aliases(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &wrap_unnest_join_aliases_single)
}
fn wrap_unnest_join_aliases_single(expr: Expression) -> Result<Expression> {
if let Expression::Select(mut select) = expr {
for join in &mut select.joins {
wrap_unnest_alias_in_join_item(&mut join.this);
}
Ok(Expression::Select(select))
} else {
Ok(expr)
}
}
fn wrap_unnest_alias_in_join_item(expr: &mut Expression) {
use crate::expressions::Identifier;
if let Expression::Alias(alias) = expr {
let is_unnest = match &alias.this {
Expression::Function(f) => f.name.eq_ignore_ascii_case("UNNEST"),
_ => false,
};
if is_unnest && alias.column_aliases.is_empty() {
let original_alias_name = alias.alias.name.clone();
alias.alias = Identifier {
name: "_u".to_string(),
quoted: false,
trailing_comments: Vec::new(),
span: None,
};
alias.column_aliases = vec![Identifier {
name: original_alias_name,
quoted: false,
trailing_comments: Vec::new(),
span: None,
}];
}
}
}
fn unnest_alias_transform_single_select(expr: Expression) -> Result<Expression> {
if let Expression::Select(mut select) = expr {
let mut counter = 0usize;
if let Some(ref mut from) = select.from {
for from_item in from.expressions.iter_mut() {
convert_unnest_alias_in_from(from_item, &mut counter);
}
if from.expressions.len() > 1 {
let mut new_from_exprs = Vec::new();
let mut new_cross_joins = Vec::new();
for (idx, from_item) in from.expressions.drain(..).enumerate() {
if idx == 0 {
new_from_exprs.push(from_item);
} else {
let is_unnest = match &from_item {
Expression::Unnest(_) => true,
Expression::Alias(a) => matches!(a.this, Expression::Unnest(_)),
_ => false,
};
if is_unnest {
new_cross_joins.push(crate::expressions::Join {
this: from_item,
on: None,
using: Vec::new(),
kind: JoinKind::Cross,
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
});
} else {
new_from_exprs.push(from_item);
}
}
}
from.expressions = new_from_exprs;
new_cross_joins.append(&mut select.joins);
select.joins = new_cross_joins;
}
}
for join in select.joins.iter_mut() {
convert_unnest_alias_in_from(&mut join.this, &mut counter);
}
Ok(Expression::Select(select))
} else {
Ok(expr)
}
}
fn convert_unnest_alias_in_from(expr: &mut Expression, counter: &mut usize) {
use crate::expressions::Identifier;
if let Expression::Alias(ref mut alias) = expr {
let is_unnest = matches!(&alias.this, Expression::Unnest(_))
|| matches!(&alias.this, Expression::Function(f) if f.name.eq_ignore_ascii_case("EXPLODE"));
if is_unnest && alias.column_aliases.is_empty() {
let col_alias = alias.alias.clone();
alias.column_aliases = vec![col_alias];
alias.alias = Identifier::new(format!("_t{}", counter));
*counter += 1;
}
}
}
pub fn expand_posexplode_duckdb(expr: Expression) -> Result<Expression> {
use crate::expressions::{Alias, Function};
match expr {
Expression::Select(mut select) => {
let mut new_expressions = Vec::new();
let mut changed = false;
for sel_expr in select.expressions.drain(..) {
if let Expression::Alias(ref alias_box) = sel_expr {
if let Expression::Function(ref func) = alias_box.this {
if func.name.eq_ignore_ascii_case("POSEXPLODE") && func.args.len() == 1 {
let arg = func.args[0].clone();
let (pos_name, col_name) = if alias_box.column_aliases.len() == 2 {
(
alias_box.column_aliases[0].name.clone(),
alias_box.column_aliases[1].name.clone(),
)
} else if !alias_box.alias.is_empty() {
("pos".to_string(), alias_box.alias.name.clone())
} else {
("pos".to_string(), "col".to_string())
};
let gen_subscripts = Expression::Function(Box::new(Function::new(
"GENERATE_SUBSCRIPTS".to_string(),
vec![
arg.clone(),
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
],
)));
let sub_one = Expression::Sub(Box::new(BinaryOp::new(
gen_subscripts,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)));
let pos_alias = Expression::Alias(Box::new(Alias {
this: sub_one,
alias: Identifier::new(pos_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: arg,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
let col_alias = Expression::Alias(Box::new(Alias {
this: unnest,
alias: Identifier::new(col_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
new_expressions.push(pos_alias);
new_expressions.push(col_alias);
changed = true;
continue;
}
}
}
if let Expression::Function(ref func) = sel_expr {
if func.name.eq_ignore_ascii_case("POSEXPLODE") && func.args.len() == 1 {
let arg = func.args[0].clone();
let pos_name = "pos";
let col_name = "col";
let gen_subscripts = Expression::Function(Box::new(Function::new(
"GENERATE_SUBSCRIPTS".to_string(),
vec![
arg.clone(),
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
],
)));
let sub_one = Expression::Sub(Box::new(BinaryOp::new(
gen_subscripts,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)));
let pos_alias = Expression::Alias(Box::new(Alias {
this: sub_one,
alias: Identifier::new(pos_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: arg,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
let col_alias = Expression::Alias(Box::new(Alias {
this: unnest,
alias: Identifier::new(col_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
new_expressions.push(pos_alias);
new_expressions.push(col_alias);
changed = true;
continue;
}
}
new_expressions.push(sel_expr);
}
if changed {
select.expressions = new_expressions;
} else {
select.expressions = new_expressions;
}
if let Some(ref mut from) = select.from {
expand_posexplode_in_from_duckdb(from)?;
}
Ok(Expression::Select(select))
}
other => Ok(other),
}
}
fn expand_posexplode_in_from_duckdb(from: &mut From) -> Result<()> {
use crate::expressions::{Alias, Function};
let mut new_expressions = Vec::new();
let mut _changed = false;
for table_expr in from.expressions.drain(..) {
if let Expression::Alias(ref alias_box) = table_expr {
if let Expression::Function(ref func) = alias_box.this {
if func.name.eq_ignore_ascii_case("POSEXPLODE") && func.args.len() == 1 {
let arg = func.args[0].clone();
let (pos_name, col_name) = if alias_box.column_aliases.len() == 2 {
(
alias_box.column_aliases[0].name.clone(),
alias_box.column_aliases[1].name.clone(),
)
} else {
("pos".to_string(), "col".to_string())
};
let gen_subscripts = Expression::Function(Box::new(Function::new(
"GENERATE_SUBSCRIPTS".to_string(),
vec![
arg.clone(),
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
],
)));
let sub_one = Expression::Sub(Box::new(BinaryOp::new(
gen_subscripts,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)));
let pos_alias = Expression::Alias(Box::new(Alias {
this: sub_one,
alias: Identifier::new(&pos_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: arg,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
let col_alias = Expression::Alias(Box::new(Alias {
this: unnest,
alias: Identifier::new(&col_name),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let mut inner_select = Select::new();
inner_select.expressions = vec![pos_alias, col_alias];
let subquery = Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(inner_select)),
alias: None,
column_aliases: Vec::new(),
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: Vec::new(),
inferred_type: None,
}));
new_expressions.push(subquery);
_changed = true;
continue;
}
}
}
if let Expression::Function(ref func) = table_expr {
if func.name.eq_ignore_ascii_case("POSEXPLODE") && func.args.len() == 1 {
let arg = func.args[0].clone();
let gen_subscripts = Expression::Function(Box::new(Function::new(
"GENERATE_SUBSCRIPTS".to_string(),
vec![
arg.clone(),
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
],
)));
let sub_one = Expression::Sub(Box::new(BinaryOp::new(
gen_subscripts,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)));
let pos_alias = Expression::Alias(Box::new(Alias {
this: sub_one,
alias: Identifier::new("pos"),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: arg,
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
let col_alias = Expression::Alias(Box::new(Alias {
this: unnest,
alias: Identifier::new("col"),
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
let mut inner_select = Select::new();
inner_select.expressions = vec![pos_alias, col_alias];
let subquery = Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(inner_select)),
alias: None,
column_aliases: Vec::new(),
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: Vec::new(),
inferred_type: None,
}));
new_expressions.push(subquery);
_changed = true;
continue;
}
}
new_expressions.push(table_expr);
}
from.expressions = new_expressions;
Ok(())
}
pub fn explode_projection_to_unnest(expr: Expression, target: DialectType) -> Result<Expression> {
match expr {
Expression::Select(select) => explode_projection_to_unnest_impl(*select, target),
other => Ok(other),
}
}
pub fn snowflake_flatten_projection_to_unnest(expr: Expression) -> Result<Expression> {
match expr {
Expression::Select(select) => snowflake_flatten_projection_to_unnest_impl(*select),
other => Ok(other),
}
}
fn snowflake_flatten_projection_to_unnest_impl(mut select: Select) -> Result<Expression> {
let mut flattened_inputs: Vec<Expression> = Vec::new();
let mut new_selects: Vec<Expression> = Vec::with_capacity(select.expressions.len());
for sel_expr in select.expressions.into_iter() {
let found_input: RefCell<Option<Expression>> = RefCell::new(None);
let rewritten = transform_recursive(sel_expr, &|e| {
if let Expression::Lateral(lat) = e {
if let Some(input_expr) = extract_flatten_input(&lat) {
if found_input.borrow().is_none() {
*found_input.borrow_mut() = Some(input_expr);
}
return Ok(Expression::Lateral(Box::new(rewrite_flatten_lateral(*lat))));
}
return Ok(Expression::Lateral(lat));
}
Ok(e)
})?;
if let Some(input) = found_input.into_inner() {
flattened_inputs.push(input);
}
new_selects.push(rewritten);
}
if flattened_inputs.is_empty() {
select.expressions = new_selects;
return Ok(Expression::Select(Box::new(select)));
}
select.expressions = new_selects;
for (idx, input_expr) in flattened_inputs.into_iter().enumerate() {
let is_first = idx == 0;
let series_alias = if is_first {
"pos".to_string()
} else {
format!("pos_{}", idx + 1)
};
let series_source_alias = if is_first {
"_u".to_string()
} else {
format!("_u_{}", idx * 2 + 1)
};
let unnest_source_alias = if is_first {
"_u_2".to_string()
} else {
format!("_u_{}", idx * 2 + 2)
};
let pos2_alias = if is_first {
"pos_2".to_string()
} else {
format!("{}_2", series_alias)
};
let entity_alias = if is_first {
"entity".to_string()
} else {
format!("entity_{}", idx + 1)
};
let array_size_call = Expression::Function(Box::new(Function::new(
"ARRAY_SIZE".to_string(),
vec![Expression::NamedArgument(Box::new(NamedArgument {
name: Identifier::new("INPUT"),
value: input_expr.clone(),
separator: NamedArgSeparator::DArrow,
}))],
)));
let greatest = Expression::Function(Box::new(Function::new(
"GREATEST".to_string(),
vec![array_size_call.clone()],
)));
let series_end = Expression::Add(Box::new(BinaryOp::new(
Expression::Paren(Box::new(crate::expressions::Paren {
this: Expression::Sub(Box::new(BinaryOp::new(
greatest,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
))),
trailing_comments: Vec::new(),
})),
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
)));
let series_range = Expression::Function(Box::new(Function::new(
"ARRAY_GENERATE_RANGE".to_string(),
vec![
Expression::Literal(Box::new(Literal::Number("0".to_string()))),
series_end,
],
)));
let series_flatten = Expression::Function(Box::new(Function::new(
"FLATTEN".to_string(),
vec![Expression::NamedArgument(Box::new(NamedArgument {
name: Identifier::new("INPUT"),
value: series_range,
separator: NamedArgSeparator::DArrow,
}))],
)));
let series_table = Expression::Function(Box::new(Function::new(
"TABLE".to_string(),
vec![series_flatten],
)));
let series_alias_expr = Expression::Alias(Box::new(Alias {
this: series_table,
alias: Identifier::new(series_source_alias.clone()),
column_aliases: vec![
Identifier::new("seq"),
Identifier::new("key"),
Identifier::new("path"),
Identifier::new("index"),
Identifier::new(series_alias.clone()),
Identifier::new("this"),
],
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
select.joins.push(Join {
this: series_alias_expr,
on: None,
using: Vec::new(),
kind: JoinKind::Cross,
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
});
let entity_flatten = Expression::Function(Box::new(Function::new(
"FLATTEN".to_string(),
vec![Expression::NamedArgument(Box::new(NamedArgument {
name: Identifier::new("INPUT"),
value: input_expr.clone(),
separator: NamedArgSeparator::DArrow,
}))],
)));
let entity_table = Expression::Function(Box::new(Function::new(
"TABLE".to_string(),
vec![entity_flatten],
)));
let entity_alias_expr = Expression::Alias(Box::new(Alias {
this: entity_table,
alias: Identifier::new(unnest_source_alias.clone()),
column_aliases: vec![
Identifier::new("seq"),
Identifier::new("key"),
Identifier::new("path"),
Identifier::new(pos2_alias.clone()),
Identifier::new(entity_alias.clone()),
Identifier::new("this"),
],
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}));
select.joins.push(Join {
this: entity_alias_expr,
on: None,
using: Vec::new(),
kind: JoinKind::Cross,
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
});
let pos_col =
Expression::qualified_column(series_source_alias.clone(), series_alias.clone());
let pos2_col =
Expression::qualified_column(unnest_source_alias.clone(), pos2_alias.clone());
let eq = Expression::Eq(Box::new(BinaryOp::new(pos_col.clone(), pos2_col.clone())));
let size_minus_1 = Expression::Paren(Box::new(crate::expressions::Paren {
this: Expression::Sub(Box::new(BinaryOp::new(
array_size_call,
Expression::Literal(Box::new(Literal::Number("1".to_string()))),
))),
trailing_comments: Vec::new(),
}));
let gt = Expression::Gt(Box::new(BinaryOp::new(pos_col, size_minus_1.clone())));
let pos2_eq_size = Expression::Eq(Box::new(BinaryOp::new(pos2_col, size_minus_1)));
let and_cond = Expression::And(Box::new(BinaryOp::new(gt, pos2_eq_size)));
let or_cond = Expression::Or(Box::new(BinaryOp::new(
eq,
Expression::Paren(Box::new(crate::expressions::Paren {
this: and_cond,
trailing_comments: Vec::new(),
})),
)));
select.where_clause = Some(match select.where_clause.take() {
Some(existing) => Where {
this: Expression::And(Box::new(BinaryOp::new(existing.this, or_cond))),
},
None => Where { this: or_cond },
});
}
Ok(Expression::Select(Box::new(select)))
}
fn extract_flatten_input(lat: &Lateral) -> Option<Expression> {
let Expression::Function(f) = lat.this.as_ref() else {
return None;
};
if !f.name.eq_ignore_ascii_case("FLATTEN") {
return None;
}
for arg in &f.args {
if let Expression::NamedArgument(na) = arg {
if na.name.name.eq_ignore_ascii_case("INPUT") {
return Some(na.value.clone());
}
}
}
f.args.first().cloned()
}
fn rewrite_flatten_lateral(mut lat: Lateral) -> Lateral {
let cond = Expression::Eq(Box::new(BinaryOp::new(
Expression::qualified_column("_u", "pos"),
Expression::qualified_column("_u_2", "pos_2"),
)));
let true_expr = Expression::qualified_column("_u_2", "entity");
let iff_expr = Expression::Function(Box::new(Function::new(
"IFF".to_string(),
vec![cond, true_expr, Expression::Null(crate::expressions::Null)],
)));
lat.this = Box::new(iff_expr);
if lat.column_aliases.is_empty() {
lat.column_aliases = vec![
"SEQ".to_string(),
"KEY".to_string(),
"PATH".to_string(),
"INDEX".to_string(),
"VALUE".to_string(),
"THIS".to_string(),
];
}
lat
}
struct ExplodeInfo {
arg_sql: String,
explode_alias: String,
pos_alias: String,
unnest_source_alias: String,
}
fn explode_projection_to_unnest_impl(select: Select, target: DialectType) -> Result<Expression> {
let is_presto = matches!(
target,
DialectType::Presto | DialectType::Trino | DialectType::Athena
);
let is_bigquery = matches!(target, DialectType::BigQuery);
if !is_presto && !is_bigquery {
return Ok(Expression::Select(Box::new(select)));
}
let has_explode = select.expressions.iter().any(|e| expr_contains_explode(e));
if !has_explode {
return Ok(Expression::Select(Box::new(select)));
}
let mut taken_select_names = std::collections::HashSet::new();
let mut taken_source_names = std::collections::HashSet::new();
for sel in &select.expressions {
if let Some(name) = get_output_name(sel) {
taken_select_names.insert(name);
}
}
for sel in &select.expressions {
let explode_expr = find_explode_in_expr(sel);
if let Some(arg) = explode_expr {
if let Some(name) = get_output_name(&arg) {
taken_select_names.insert(name);
}
}
}
if let Some(ref from) = select.from {
for from_expr in &from.expressions {
collect_source_names(from_expr, &mut taken_source_names);
}
}
for join in &select.joins {
collect_source_names(&join.this, &mut taken_source_names);
}
let series_alias = new_name(&mut taken_select_names, "pos");
let series_source_alias = new_name(&mut taken_source_names, "_u");
let target_dialect = Dialect::get(target);
let mut explode_infos: Vec<ExplodeInfo> = Vec::new();
let mut new_projections: Vec<String> = Vec::new();
for (_idx, sel_expr) in select.expressions.iter().enumerate() {
let explode_data = extract_explode_data(sel_expr);
if let Some((is_posexplode, arg_expr, explicit_alias, explicit_pos_alias)) = explode_data {
let arg_sql = target_dialect
.generate(&arg_expr)
.unwrap_or_else(|_| "NULL".to_string());
let unnest_source_alias = new_name(&mut taken_source_names, "_u");
let explode_alias = if let Some(ref ea) = explicit_alias {
taken_select_names.remove(ea.as_str());
let name = new_name(&mut taken_select_names, ea);
name
} else {
new_name(&mut taken_select_names, "col")
};
let pos_alias = if let Some(ref pa) = explicit_pos_alias {
taken_select_names.remove(pa.as_str());
let name = new_name(&mut taken_select_names, pa);
name
} else {
new_name(&mut taken_select_names, "pos")
};
if is_presto {
let if_col = format!(
"IF({}.{} = {}.{}, {}.{}) AS {}",
series_source_alias,
series_alias,
unnest_source_alias,
pos_alias,
unnest_source_alias,
explode_alias,
explode_alias
);
new_projections.push(if_col);
if is_posexplode {
let if_pos = format!(
"IF({}.{} = {}.{}, {}.{}) AS {}",
series_source_alias,
series_alias,
unnest_source_alias,
pos_alias,
unnest_source_alias,
pos_alias,
pos_alias
);
new_projections.push(if_pos);
}
} else {
let if_col = format!(
"IF({} = {}, {}, NULL) AS {}",
series_alias, pos_alias, explode_alias, explode_alias
);
new_projections.push(if_col);
if is_posexplode {
let if_pos = format!(
"IF({} = {}, {}, NULL) AS {}",
series_alias, pos_alias, pos_alias, pos_alias
);
new_projections.push(if_pos);
}
}
explode_infos.push(ExplodeInfo {
arg_sql,
explode_alias,
pos_alias,
unnest_source_alias,
});
} else {
let sel_sql = target_dialect
.generate(sel_expr)
.unwrap_or_else(|_| "*".to_string());
new_projections.push(sel_sql);
}
}
if explode_infos.is_empty() {
return Ok(Expression::Select(Box::new(select)));
}
let mut from_parts: Vec<String> = Vec::new();
if let Some(ref from) = select.from {
for from_expr in &from.expressions {
let from_sql = target_dialect.generate(from_expr).unwrap_or_default();
from_parts.push(from_sql);
}
}
let size_exprs: Vec<String> = explode_infos
.iter()
.map(|info| {
if is_presto {
format!("CARDINALITY({})", info.arg_sql)
} else {
format!("ARRAY_LENGTH({})", info.arg_sql)
}
})
.collect();
let greatest_arg = if size_exprs.len() == 1 {
size_exprs[0].clone()
} else {
format!("GREATEST({})", size_exprs.join(", "))
};
let series_sql = if is_presto {
if size_exprs.len() == 1 {
format!(
"UNNEST(SEQUENCE(1, GREATEST({}))) AS {}({})",
greatest_arg, series_source_alias, series_alias
)
} else {
format!(
"UNNEST(SEQUENCE(1, {})) AS {}({})",
greatest_arg, series_source_alias, series_alias
)
}
} else {
if size_exprs.len() == 1 {
format!(
"UNNEST(GENERATE_ARRAY(0, GREATEST({}) - 1)) AS {}",
greatest_arg, series_alias
)
} else {
format!(
"UNNEST(GENERATE_ARRAY(0, {} - 1)) AS {}",
greatest_arg, series_alias
)
}
};
let mut cross_joins: Vec<String> = Vec::new();
for info in &explode_infos {
cross_joins.push(format!(
"CROSS JOIN UNNEST({}) WITH ORDINALITY AS {}({}, {})",
info.arg_sql, info.unnest_source_alias, info.explode_alias, info.pos_alias
));
}
let mut where_conditions: Vec<String> = Vec::new();
for info in &explode_infos {
let size_expr = if is_presto {
format!("CARDINALITY({})", info.arg_sql)
} else {
format!("ARRAY_LENGTH({})", info.arg_sql)
};
let cond = if is_presto {
format!(
"{series_src}.{series_al} = {unnest_src}.{pos_al} OR ({series_src}.{series_al} > {size} AND {unnest_src}.{pos_al} = {size})",
series_src = series_source_alias,
series_al = series_alias,
unnest_src = info.unnest_source_alias,
pos_al = info.pos_alias,
size = size_expr
)
} else {
format!(
"{series_al} = {pos_al} OR ({series_al} > ({size} - 1) AND {pos_al} = ({size} - 1))",
series_al = series_alias,
pos_al = info.pos_alias,
size = size_expr
)
};
where_conditions.push(cond);
}
let where_sql = if where_conditions.len() == 1 {
where_conditions[0].clone()
} else {
where_conditions
.iter()
.map(|c| format!("({})", c))
.collect::<Vec<_>>()
.join(" AND ")
};
let select_part = new_projections.join(", ");
let from_and_joins = if from_parts.is_empty() {
format!("FROM {} {}", series_sql, cross_joins.join(" "))
} else {
format!(
"FROM {} {} {}",
from_parts.join(", "),
format!("CROSS JOIN {}", series_sql),
cross_joins.join(" ")
)
};
let full_sql = format!(
"SELECT {} {} WHERE {}",
select_part, from_and_joins, where_sql
);
let generic_dialect = Dialect::get(DialectType::Generic);
let parsed = generic_dialect.parse(&full_sql);
match parsed {
Ok(mut stmts) if !stmts.is_empty() => {
let mut result = stmts.remove(0);
if is_bigquery {
convert_unnest_presto_to_bigquery(&mut result);
}
Ok(result)
}
_ => {
Ok(Expression::Select(Box::new(select)))
}
}
}
fn convert_unnest_presto_to_bigquery(expr: &mut Expression) {
match expr {
Expression::Select(ref mut select) => {
if let Some(ref mut from) = select.from {
for from_item in from.expressions.iter_mut() {
convert_unnest_presto_to_bigquery(from_item);
}
}
for join in select.joins.iter_mut() {
convert_unnest_presto_to_bigquery(&mut join.this);
}
}
Expression::Alias(ref alias) => {
if let Expression::Unnest(ref unnest) = alias.this {
if unnest.with_ordinality && alias.column_aliases.len() >= 2 {
let col_alias = alias.column_aliases[0].clone();
let pos_alias = alias.column_aliases[1].clone();
let mut new_unnest = unnest.as_ref().clone();
new_unnest.alias = Some(col_alias);
new_unnest.offset_alias = Some(pos_alias);
*expr = Expression::Unnest(Box::new(new_unnest));
}
}
}
_ => {}
}
}
fn new_name(names: &mut std::collections::HashSet<String>, base: &str) -> String {
if !names.contains(base) {
names.insert(base.to_string());
return base.to_string();
}
let mut i = 2;
loop {
let candidate = format!("{}_{}", base, i);
if !names.contains(&candidate) {
names.insert(candidate.clone());
return candidate;
}
i += 1;
}
}
fn expr_contains_explode(expr: &Expression) -> bool {
match expr {
Expression::Explode(_) => true,
Expression::ExplodeOuter(_) => true,
Expression::Function(f) => {
let name = f.name.to_uppercase();
name == "POSEXPLODE" || name == "POSEXPLODE_OUTER"
}
Expression::Alias(a) => expr_contains_explode(&a.this),
_ => false,
}
}
fn find_explode_in_expr(expr: &Expression) -> Option<Expression> {
match expr {
Expression::Explode(uf) => Some(uf.this.clone()),
Expression::ExplodeOuter(uf) => Some(uf.this.clone()),
Expression::Function(f) => {
let name = f.name.to_uppercase();
if (name == "POSEXPLODE" || name == "POSEXPLODE_OUTER") && !f.args.is_empty() {
Some(f.args[0].clone())
} else {
None
}
}
Expression::Alias(a) => find_explode_in_expr(&a.this),
_ => None,
}
}
fn extract_explode_data(
expr: &Expression,
) -> Option<(bool, Expression, Option<String>, Option<String>)> {
match expr {
Expression::Explode(uf) => Some((false, uf.this.clone(), None, None)),
Expression::ExplodeOuter(uf) => Some((false, uf.this.clone(), None, None)),
Expression::Function(f) => {
let name = f.name.to_uppercase();
if (name == "POSEXPLODE" || name == "POSEXPLODE_OUTER") && !f.args.is_empty() {
Some((true, f.args[0].clone(), None, None))
} else {
None
}
}
Expression::Alias(a) => {
match &a.this {
Expression::Explode(uf) => {
let alias = if !a.alias.is_empty() {
Some(a.alias.name.clone())
} else {
None
};
Some((false, uf.this.clone(), alias, None))
}
Expression::ExplodeOuter(uf) => {
let alias = if !a.alias.is_empty() {
Some(a.alias.name.clone())
} else {
None
};
Some((false, uf.this.clone(), alias, None))
}
Expression::Function(f) => {
let name = f.name.to_uppercase();
if (name == "POSEXPLODE" || name == "POSEXPLODE_OUTER") && !f.args.is_empty() {
if a.column_aliases.len() == 2 {
let pos_alias = a.column_aliases[0].name.clone();
let col_alias = a.column_aliases[1].name.clone();
Some((true, f.args[0].clone(), Some(col_alias), Some(pos_alias)))
} else if !a.alias.is_empty() {
Some((true, f.args[0].clone(), Some(a.alias.name.clone()), None))
} else {
Some((true, f.args[0].clone(), None, None))
}
} else {
None
}
}
_ => None,
}
}
_ => None,
}
}
fn get_output_name(expr: &Expression) -> Option<String> {
match expr {
Expression::Alias(a) => {
if !a.alias.is_empty() {
Some(a.alias.name.clone())
} else {
None
}
}
Expression::Column(c) => Some(c.name.name.clone()),
Expression::Identifier(id) => Some(id.name.clone()),
_ => None,
}
}
fn collect_source_names(expr: &Expression, names: &mut std::collections::HashSet<String>) {
match expr {
Expression::Alias(a) => {
if !a.alias.is_empty() {
names.insert(a.alias.name.clone());
}
}
Expression::Subquery(s) => {
if let Some(ref alias) = s.alias {
names.insert(alias.name.clone());
}
}
Expression::Table(t) => {
if let Some(ref alias) = t.alias {
names.insert(alias.name.clone());
} else {
names.insert(t.name.name.clone());
}
}
Expression::Column(c) => {
names.insert(c.name.name.clone());
}
Expression::Identifier(id) => {
names.insert(id.name.clone());
}
_ => {}
}
}
pub fn strip_unnest_column_refs(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &strip_unnest_column_refs_single)
}
fn strip_unnest_column_refs_single(expr: Expression) -> Result<Expression> {
if let Expression::Select(mut select) = expr {
for join in select.joins.iter_mut() {
strip_unnest_from_expr(&mut join.this);
}
if let Some(ref mut from) = select.from {
for from_item in from.expressions.iter_mut() {
strip_unnest_from_expr(from_item);
}
}
Ok(Expression::Select(select))
} else {
Ok(expr)
}
}
fn strip_unnest_from_expr(expr: &mut Expression) {
if let Expression::Alias(ref mut alias) = expr {
if let Expression::Unnest(ref unnest) = alias.this {
let is_column_ref = matches!(&unnest.this, Expression::Column(_) | Expression::Dot(_));
if is_column_ref {
let inner = unnest.this.clone();
alias.this = inner;
}
}
}
}
pub fn wrap_duckdb_unnest_struct(expr: Expression) -> Result<Expression> {
use crate::dialects::transform_recursive;
transform_recursive(expr, &wrap_duckdb_unnest_struct_single)
}
fn wrap_duckdb_unnest_struct_single(expr: Expression) -> Result<Expression> {
if let Expression::Select(mut select) = expr {
if let Some(ref mut from) = select.from {
for from_item in from.expressions.iter_mut() {
try_wrap_unnest_in_subquery(from_item);
}
}
for join in select.joins.iter_mut() {
try_wrap_unnest_in_subquery(&mut join.this);
}
Ok(Expression::Select(select))
} else {
Ok(expr)
}
}
fn is_struct_array_unnest_arg(expr: &Expression) -> bool {
match expr {
Expression::Array(arr) => arr
.expressions
.iter()
.any(|e| matches!(e, Expression::Struct(_))),
Expression::ArrayFunc(arr) => arr
.expressions
.iter()
.any(|e| matches!(e, Expression::Struct(_))),
Expression::Cast(c) => {
matches!(&c.to, DataType::Array { element_type, .. } if matches!(**element_type, DataType::Struct { .. }))
}
_ => false,
}
}
fn try_wrap_unnest_in_subquery(expr: &mut Expression) {
if let Expression::Alias(ref alias) = expr {
if let Expression::Unnest(ref unnest) = alias.this {
if is_struct_array_unnest_arg(&unnest.this) {
let unnest_clone = (**unnest).clone();
let alias_name = alias.alias.clone();
let new_expr = make_unnest_subquery(unnest_clone, Some(alias_name));
*expr = new_expr;
return;
}
}
}
if let Expression::Unnest(ref unnest) = expr {
if is_struct_array_unnest_arg(&unnest.this) {
let unnest_clone = (**unnest).clone();
let new_expr = make_unnest_subquery(unnest_clone, None);
*expr = new_expr;
}
}
}
fn make_unnest_subquery(unnest: UnnestFunc, alias: Option<Identifier>) -> Expression {
let max_depth_arg = Expression::NamedArgument(Box::new(NamedArgument {
name: Identifier::new("max_depth".to_string()),
value: Expression::Literal(Box::new(Literal::Number("2".to_string()))),
separator: NamedArgSeparator::DArrow,
}));
let mut unnest_args = vec![unnest.this];
unnest_args.extend(unnest.expressions);
unnest_args.push(max_depth_arg);
let unnest_func =
Expression::Function(Box::new(Function::new("UNNEST".to_string(), unnest_args)));
let mut inner_select = Select::new();
inner_select.expressions = vec![unnest_func];
let inner_select = Expression::Select(Box::new(inner_select));
let subquery = Subquery {
this: inner_select,
alias,
column_aliases: Vec::new(),
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: Vec::new(),
inferred_type: None,
};
Expression::Subquery(Box::new(subquery))
}
pub fn no_limit_order_by_union(expr: Expression) -> Result<Expression> {
use crate::expressions::{Limit as LimitClause, Offset as OffsetClause, OrderBy, Star};
match expr {
Expression::Union(mut u) => {
if u.order_by.is_none() && u.limit.is_none() && u.offset.is_none() {
if let Expression::Select(ref mut right_select) = u.right {
if right_select.order_by.is_some()
|| right_select.limit.is_some()
|| right_select.offset.is_some()
{
u.order_by = right_select.order_by.take();
u.limit = right_select.limit.take().map(|l| Box::new(l.this));
u.offset = right_select.offset.take().map(|o| Box::new(o.this));
}
}
}
let has_order_or_limit =
u.order_by.is_some() || u.limit.is_some() || u.offset.is_some();
if has_order_or_limit {
let order_by: Option<OrderBy> = u.order_by.take();
let union_limit: Option<Box<Expression>> = u.limit.take();
let union_offset: Option<Box<Expression>> = u.offset.take();
let select_limit: Option<LimitClause> = union_limit.map(|l| LimitClause {
this: *l,
percent: false,
comments: Vec::new(),
});
let select_offset: Option<OffsetClause> = union_offset.map(|o| OffsetClause {
this: *o,
rows: None,
});
let subquery = Subquery {
this: Expression::Union(u),
alias: Some(Identifier::new("_l_0")),
column_aliases: Vec::new(),
lateral: false,
modifiers_inside: false,
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
trailing_comments: Vec::new(),
inferred_type: None,
};
let mut select = Select::default();
select.expressions = vec![Expression::Star(Star {
table: None,
except: None,
replace: None,
rename: None,
trailing_comments: Vec::new(),
span: None,
})];
select.from = Some(From {
expressions: vec![Expression::Subquery(Box::new(subquery))],
});
select.order_by = order_by;
select.limit = select_limit;
select.offset = select_offset;
Ok(Expression::Select(Box::new(select)))
} else {
Ok(Expression::Union(u))
}
}
_ => Ok(expr),
}
}
pub fn expand_like_any(expr: Expression) -> Result<Expression> {
use crate::expressions::{BinaryOp, LikeOp, Paren};
const LIKE_ALL_MARKER: &str = "__LIKE_ALL_EXPANSION__";
fn unwrap_parens(e: &Expression) -> &Expression {
match e {
Expression::Paren(p) => unwrap_parens(&p.this),
_ => e,
}
}
fn extract_tuple_values(e: &Expression) -> Option<Vec<Expression>> {
let inner = unwrap_parens(e);
match inner {
Expression::Tuple(t) => Some(t.expressions.clone()),
_ if !matches!(e, Expression::Tuple(_)) => Some(vec![inner.clone()]),
_ => None,
}
}
fn expand_like_quantifier(
op: &LikeOp,
values: Vec<Expression>,
is_ilike: bool,
combiner: fn(Expression, Expression) -> Expression,
wrap_marker: bool,
) -> Expression {
let num_values = values.len();
let mut result: Option<Expression> = None;
for val in values {
let like = if is_ilike {
Expression::ILike(Box::new(LikeOp {
left: op.left.clone(),
right: val,
escape: op.escape.clone(),
quantifier: None,
inferred_type: None,
}))
} else {
Expression::Like(Box::new(LikeOp {
left: op.left.clone(),
right: val,
escape: op.escape.clone(),
quantifier: None,
inferred_type: None,
}))
};
result = Some(match result {
None => like,
Some(prev) => combiner(prev, like),
});
}
let expanded = result.unwrap_or_else(|| unreachable!("values is non-empty"));
if wrap_marker && num_values > 1 {
Expression::Paren(Box::new(Paren {
this: expanded,
trailing_comments: vec![LIKE_ALL_MARKER.to_string()],
}))
} else {
expanded
}
}
fn or_combiner(a: Expression, b: Expression) -> Expression {
Expression::Or(Box::new(BinaryOp::new(a, b)))
}
fn and_combiner(a: Expression, b: Expression) -> Expression {
Expression::And(Box::new(BinaryOp::new(a, b)))
}
fn is_like_all_marker(p: &Paren) -> bool {
p.trailing_comments.len() == 1 && p.trailing_comments[0] == LIKE_ALL_MARKER
}
let result = transform_recursive(expr, &|e| {
match e {
Expression::Like(ref op) if op.quantifier.as_deref() == Some("ANY") => {
if let Some(values) = extract_tuple_values(&op.right) {
if values.is_empty() {
return Ok(e);
}
Ok(expand_like_quantifier(
op,
values,
false,
or_combiner,
false,
))
} else {
Ok(e)
}
}
Expression::Like(ref op) if op.quantifier.as_deref() == Some("ALL") => {
if let Some(values) = extract_tuple_values(&op.right) {
if values.is_empty() {
return Ok(e);
}
Ok(expand_like_quantifier(
op,
values,
false,
and_combiner,
true,
))
} else {
Ok(e)
}
}
Expression::ILike(ref op) if op.quantifier.as_deref() == Some("ANY") => {
if let Some(values) = extract_tuple_values(&op.right) {
if values.is_empty() {
return Ok(e);
}
Ok(expand_like_quantifier(op, values, true, or_combiner, false))
} else {
Ok(e)
}
}
Expression::ILike(ref op) if op.quantifier.as_deref() == Some("ALL") => {
if let Some(values) = extract_tuple_values(&op.right) {
if values.is_empty() {
return Ok(e);
}
Ok(expand_like_quantifier(op, values, true, and_combiner, true))
} else {
Ok(e)
}
}
Expression::And(mut op) => {
if matches!(&op.left, Expression::Or(_)) {
op.left = Expression::Paren(Box::new(Paren {
this: op.left,
trailing_comments: vec![],
}));
}
if matches!(&op.right, Expression::Or(_)) {
op.right = Expression::Paren(Box::new(Paren {
this: op.right,
trailing_comments: vec![],
}));
}
Ok(Expression::And(op))
}
Expression::Or(mut op) => {
if let Expression::Paren(ref mut p) = op.left {
if is_like_all_marker(p) {
p.trailing_comments.clear();
}
}
if let Expression::Paren(ref mut p) = op.right {
if is_like_all_marker(p) {
p.trailing_comments.clear();
}
}
Ok(Expression::Or(op))
}
_ => Ok(e),
}
})?;
transform_recursive(result, &|e| {
if let Expression::Paren(p) = &e {
if is_like_all_marker(p) {
let Expression::Paren(p) = e else {
unreachable!()
};
return Ok(p.this);
}
}
Ok(e)
})
}
pub fn qualify_derived_table_outputs(expr: Expression) -> Result<Expression> {
use crate::expressions::Alias;
fn add_self_aliases_to_select(select: &mut Select) {
let new_expressions: Vec<Expression> = select
.expressions
.iter()
.map(|e| {
match e {
Expression::Column(col) => {
let alias_name = col.name.clone();
Expression::Alias(Box::new(Alias {
this: e.clone(),
alias: alias_name,
column_aliases: Vec::new(),
pre_alias_comments: Vec::new(),
trailing_comments: Vec::new(),
inferred_type: None,
}))
}
_ => e.clone(),
}
})
.collect();
select.expressions = new_expressions;
}
fn walk_and_qualify(expr: &mut Expression) {
match expr {
Expression::Select(ref mut select) => {
if let Some(ref mut from) = select.from {
for e in from.expressions.iter_mut() {
qualify_subquery_expr(e);
walk_and_qualify(e);
}
}
for join in select.joins.iter_mut() {
qualify_subquery_expr(&mut join.this);
walk_and_qualify(&mut join.this);
}
for e in select.expressions.iter_mut() {
walk_and_qualify(e);
}
if let Some(ref mut w) = select.where_clause {
walk_and_qualify(&mut w.this);
}
}
Expression::Subquery(ref mut subquery) => {
walk_and_qualify(&mut subquery.this);
}
Expression::Union(ref mut u) => {
walk_and_qualify(&mut u.left);
walk_and_qualify(&mut u.right);
}
Expression::Intersect(ref mut i) => {
walk_and_qualify(&mut i.left);
walk_and_qualify(&mut i.right);
}
Expression::Except(ref mut e) => {
walk_and_qualify(&mut e.left);
walk_and_qualify(&mut e.right);
}
Expression::Cte(ref mut cte) => {
walk_and_qualify(&mut cte.this);
}
_ => {}
}
}
fn qualify_subquery_expr(expr: &mut Expression) {
match expr {
Expression::Subquery(ref mut subquery) => {
if subquery.alias.is_some() && subquery.column_aliases.is_empty() {
if let Expression::Select(ref mut inner_select) = subquery.this {
let has_star = inner_select
.expressions
.iter()
.any(|e| matches!(e, Expression::Star(_)));
if !has_star {
add_self_aliases_to_select(inner_select);
}
}
}
walk_and_qualify(&mut subquery.this);
}
Expression::Alias(ref mut alias) => {
qualify_subquery_expr(&mut alias.this);
}
_ => {}
}
}
let mut result = expr;
walk_and_qualify(&mut result);
if let Expression::Select(ref mut select) = result {
if let Some(ref mut with) = select.with {
for cte in with.ctes.iter_mut() {
if cte.columns.is_empty() {
walk_and_qualify(&mut cte.this);
}
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialects::{Dialect, DialectType};
use crate::expressions::Column;
fn gen(expr: &Expression) -> String {
let dialect = Dialect::get(DialectType::Generic);
dialect.generate(expr).unwrap()
}
#[test]
fn test_preprocess() {
let expr = Expression::Boolean(BooleanLiteral { value: true });
let result = preprocess(expr, &[replace_bool_with_int]).unwrap();
assert!(
matches!(result, Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)))
);
}
#[test]
fn test_preprocess_chain() {
let expr = Expression::Boolean(BooleanLiteral { value: true });
let transforms: Vec<fn(Expression) -> Result<Expression>> =
vec![replace_bool_with_int, replace_int_with_bool];
let result = preprocess(expr, &transforms).unwrap();
if let Expression::Boolean(b) = result {
assert!(b.value);
} else {
panic!("Expected boolean literal");
}
}
#[test]
fn test_unnest_to_explode() {
let unnest = Expression::Unnest(Box::new(UnnestFunc {
this: Expression::boxed_column(Column {
name: Identifier::new("arr".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
expressions: Vec::new(),
with_ordinality: false,
alias: None,
offset_alias: None,
}));
let result = unnest_to_explode(unnest).unwrap();
assert!(matches!(result, Expression::Explode(_)));
}
#[test]
fn test_explode_to_unnest() {
let explode = Expression::Explode(Box::new(UnaryFunc {
this: Expression::boxed_column(Column {
name: Identifier::new("arr".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
original_name: None,
inferred_type: None,
}));
let result = explode_to_unnest(explode).unwrap();
assert!(matches!(result, Expression::Unnest(_)));
}
#[test]
fn test_replace_bool_with_int() {
let true_expr = Expression::Boolean(BooleanLiteral { value: true });
let result = replace_bool_with_int(true_expr).unwrap();
if let Expression::Literal(lit) = result {
if let Literal::Number(n) = lit.as_ref() {
assert_eq!(n, "1");
}
} else {
panic!("Expected number literal");
}
let false_expr = Expression::Boolean(BooleanLiteral { value: false });
let result = replace_bool_with_int(false_expr).unwrap();
if let Expression::Literal(lit) = result {
if let Literal::Number(n) = lit.as_ref() {
assert_eq!(n, "0");
}
} else {
panic!("Expected number literal");
}
}
#[test]
fn test_replace_int_with_bool() {
let one_expr = Expression::Literal(Box::new(Literal::Number("1".to_string())));
let result = replace_int_with_bool(one_expr).unwrap();
if let Expression::Boolean(b) = result {
assert!(b.value);
} else {
panic!("Expected boolean true");
}
let zero_expr = Expression::Literal(Box::new(Literal::Number("0".to_string())));
let result = replace_int_with_bool(zero_expr).unwrap();
if let Expression::Boolean(b) = result {
assert!(!b.value);
} else {
panic!("Expected boolean false");
}
let two_expr = Expression::Literal(Box::new(Literal::Number("2".to_string())));
let result = replace_int_with_bool(two_expr).unwrap();
assert!(
matches!(result, Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)))
);
}
#[test]
fn test_strip_data_type_params() {
let decimal = DataType::Decimal {
precision: Some(10),
scale: Some(2),
};
let stripped = strip_data_type_params(decimal);
assert_eq!(
stripped,
DataType::Decimal {
precision: None,
scale: None
}
);
let varchar = DataType::VarChar {
length: Some(255),
parenthesized_length: false,
};
let stripped = strip_data_type_params(varchar);
assert_eq!(
stripped,
DataType::VarChar {
length: None,
parenthesized_length: false
}
);
let char_type = DataType::Char { length: Some(10) };
let stripped = strip_data_type_params(char_type);
assert_eq!(stripped, DataType::Char { length: None });
let timestamp = DataType::Timestamp {
precision: Some(6),
timezone: true,
};
let stripped = strip_data_type_params(timestamp);
assert_eq!(
stripped,
DataType::Timestamp {
precision: None,
timezone: true
}
);
let array = DataType::Array {
element_type: Box::new(DataType::VarChar {
length: Some(100),
parenthesized_length: false,
}),
dimension: None,
};
let stripped = strip_data_type_params(array);
assert_eq!(
stripped,
DataType::Array {
element_type: Box::new(DataType::VarChar {
length: None,
parenthesized_length: false
}),
dimension: None,
}
);
let text = DataType::Text;
let stripped = strip_data_type_params(text);
assert_eq!(stripped, DataType::Text);
}
#[test]
fn test_remove_precision_parameterized_types_cast() {
let cast_expr = Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
to: DataType::Decimal {
precision: Some(10),
scale: Some(2),
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = remove_precision_parameterized_types(cast_expr).unwrap();
if let Expression::Cast(cast) = result {
assert_eq!(
cast.to,
DataType::Decimal {
precision: None,
scale: None
}
);
} else {
panic!("Expected Cast expression");
}
}
#[test]
fn test_remove_precision_parameterized_types_varchar() {
let cast_expr = Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String("hello".to_string()))),
to: DataType::VarChar {
length: Some(10),
parenthesized_length: false,
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = remove_precision_parameterized_types(cast_expr).unwrap();
if let Expression::Cast(cast) = result {
assert_eq!(
cast.to,
DataType::VarChar {
length: None,
parenthesized_length: false
}
);
} else {
panic!("Expected Cast expression");
}
}
#[test]
fn test_remove_precision_direct_cast() {
let cast = Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::Number("1".to_string()))),
to: DataType::Decimal {
precision: Some(10),
scale: Some(2),
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let transformed = remove_precision_parameterized_types(cast).unwrap();
let generated = gen(&transformed);
assert!(generated.contains("DECIMAL"));
assert!(!generated.contains("(10"));
}
#[test]
fn test_epoch_cast_to_ts() {
let cast_expr = Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String("epoch".to_string()))),
to: DataType::Timestamp {
precision: None,
timezone: false,
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = epoch_cast_to_ts(cast_expr).unwrap();
if let Expression::Cast(cast) = result {
if let Expression::Literal(lit) = cast.this {
if let Literal::String(s) = lit.as_ref() {
assert_eq!(s, "1970-01-01 00:00:00");
}
} else {
panic!("Expected string literal");
}
} else {
panic!("Expected Cast expression");
}
}
#[test]
fn test_epoch_cast_to_ts_preserves_non_epoch() {
let cast_expr = Expression::Cast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String("2024-01-15".to_string()))),
to: DataType::Timestamp {
precision: None,
timezone: false,
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = epoch_cast_to_ts(cast_expr).unwrap();
if let Expression::Cast(cast) = result {
if let Expression::Literal(lit) = cast.this {
if let Literal::String(s) = lit.as_ref() {
assert_eq!(s, "2024-01-15");
}
} else {
panic!("Expected string literal");
}
} else {
panic!("Expected Cast expression");
}
}
#[test]
fn test_unqualify_columns() {
let col = Expression::boxed_column(Column {
name: Identifier::new("id".to_string()),
table: Some(Identifier::new("users".to_string())),
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
});
let result = unqualify_columns(col).unwrap();
if let Expression::Column(c) = result {
assert!(c.table.is_none());
assert_eq!(c.name.name, "id");
} else {
panic!("Expected Column expression");
}
}
#[test]
fn test_is_temporal_type() {
assert!(is_temporal_type(&DataType::Date));
assert!(is_temporal_type(&DataType::Timestamp {
precision: None,
timezone: false
}));
assert!(is_temporal_type(&DataType::Time {
precision: None,
timezone: false
}));
assert!(!is_temporal_type(&DataType::Int {
length: None,
integer_spelling: false
}));
assert!(!is_temporal_type(&DataType::VarChar {
length: None,
parenthesized_length: false
}));
}
#[test]
fn test_eliminate_semi_join_basic() {
use crate::expressions::{Join, TableRef};
let select = Expression::Select(Box::new(Select {
expressions: vec![Expression::boxed_column(Column {
name: Identifier::new("a".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})],
from: Some(From {
expressions: vec![Expression::Table(Box::new(TableRef::new("t1")))],
}),
joins: vec![Join {
this: Expression::Table(Box::new(TableRef::new("t2"))),
kind: JoinKind::Semi,
on: Some(Expression::Eq(Box::new(BinaryOp {
left: Expression::boxed_column(Column {
name: Identifier::new("x".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
right: Expression::boxed_column(Column {
name: Identifier::new("y".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))),
using: vec![],
use_inner_keyword: false,
use_outer_keyword: false,
deferred_condition: false,
join_hint: None,
match_condition: None,
pivots: Vec::new(),
comments: Vec::new(),
nesting_group: 0,
directed: false,
}],
..Select::new()
}));
let result = eliminate_semi_and_anti_joins(select).unwrap();
if let Expression::Select(s) = result {
assert!(s.joins.is_empty());
assert!(s.where_clause.is_some());
} else {
panic!("Expected Select expression");
}
}
#[test]
fn test_no_ilike_sql() {
use crate::expressions::LikeOp;
let ilike_expr = Expression::ILike(Box::new(LikeOp {
left: Expression::boxed_column(Column {
name: Identifier::new("name".to_string()),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
}),
right: Expression::Literal(Box::new(Literal::String("%test%".to_string()))),
escape: None,
quantifier: None,
inferred_type: None,
}));
let result = no_ilike_sql(ilike_expr).unwrap();
if let Expression::Like(like) = result {
if let Expression::Function(f) = &like.left {
assert_eq!(f.name, "LOWER");
} else {
panic!("Expected LOWER function on left");
}
if let Expression::Function(f) = &like.right {
assert_eq!(f.name, "LOWER");
} else {
panic!("Expected LOWER function on right");
}
} else {
panic!("Expected Like expression");
}
}
#[test]
fn test_no_trycast_sql() {
let trycast_expr = Expression::TryCast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String("123".to_string()))),
to: DataType::Int {
length: None,
integer_spelling: false,
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = no_trycast_sql(trycast_expr).unwrap();
assert!(matches!(result, Expression::Cast(_)));
}
#[test]
fn test_no_safe_cast_sql() {
let safe_cast_expr = Expression::SafeCast(Box::new(Cast {
this: Expression::Literal(Box::new(Literal::String("123".to_string()))),
to: DataType::Int {
length: None,
integer_spelling: false,
},
trailing_comments: vec![],
double_colon_syntax: false,
format: None,
default: None,
inferred_type: None,
}));
let result = no_safe_cast_sql(safe_cast_expr).unwrap();
assert!(matches!(result, Expression::Cast(_)));
}
#[test]
fn test_explode_to_unnest_presto() {
let spark = Dialect::get(DialectType::Spark);
let result = spark
.transpile("SELECT EXPLODE(x) FROM tbl", DialectType::Presto)
.unwrap();
assert_eq!(
result[0],
"SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM tbl CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(x) AND _u_2.pos_2 = CARDINALITY(x))"
);
}
#[test]
fn test_explode_to_unnest_bigquery() {
let spark = Dialect::get(DialectType::Spark);
let result = spark
.transpile("SELECT EXPLODE(x) FROM tbl", DialectType::BigQuery)
.unwrap();
assert_eq!(
result[0],
"SELECT IF(pos = pos_2, col, NULL) AS col FROM tbl CROSS JOIN UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))"
);
}
}