use crate::dialects::DialectType;
use crate::expressions::{DataType, Expression, Literal, Null};
use crate::helper::{is_iso_date, is_iso_datetime};
pub fn canonicalize(expression: Expression, dialect: Option<DialectType>) -> Expression {
canonicalize_recursive(expression, dialect)
}
fn canonicalize_recursive(expression: Expression, dialect: Option<DialectType>) -> Expression {
let expr = match expression {
Expression::Select(mut select) => {
select.expressions = select
.expressions
.into_iter()
.map(|e| canonicalize_recursive(e, dialect))
.collect();
if let Some(mut from) = select.from {
from.expressions = from
.expressions
.into_iter()
.map(|e| canonicalize_recursive(e, dialect))
.collect();
select.from = Some(from);
}
if let Some(mut where_clause) = select.where_clause {
where_clause.this = canonicalize_recursive(where_clause.this, dialect);
where_clause.this = ensure_bools(where_clause.this);
select.where_clause = Some(where_clause);
}
if let Some(mut having) = select.having {
having.this = canonicalize_recursive(having.this, dialect);
having.this = ensure_bools(having.this);
select.having = Some(having);
}
if let Some(mut order_by) = select.order_by {
order_by.expressions = order_by
.expressions
.into_iter()
.map(|mut o| {
o.this = canonicalize_recursive(o.this, dialect);
o = remove_ascending_order(o);
o
})
.collect();
select.order_by = Some(order_by);
}
select.joins = select
.joins
.into_iter()
.map(|mut j| {
j.this = canonicalize_recursive(j.this, dialect);
if let Some(on) = j.on {
j.on = Some(canonicalize_recursive(on, dialect));
}
j
})
.collect();
Expression::Select(select)
}
Expression::Add(bin) => {
let left = canonicalize_recursive(bin.left, dialect);
let right = canonicalize_recursive(bin.right, dialect);
let result = Expression::Add(Box::new(crate::expressions::BinaryOp {
left,
right,
left_comments: bin.left_comments,
operator_comments: bin.operator_comments,
trailing_comments: bin.trailing_comments,
inferred_type: None,
}));
add_text_to_concat(result)
}
Expression::And(bin) => {
let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
Expression::And(Box::new(crate::expressions::BinaryOp {
left,
right,
left_comments: bin.left_comments,
operator_comments: bin.operator_comments,
trailing_comments: bin.trailing_comments,
inferred_type: None,
}))
}
Expression::Or(bin) => {
let left = ensure_bools(canonicalize_recursive(bin.left, dialect));
let right = ensure_bools(canonicalize_recursive(bin.right, dialect));
Expression::Or(Box::new(crate::expressions::BinaryOp {
left,
right,
left_comments: bin.left_comments,
operator_comments: bin.operator_comments,
trailing_comments: bin.trailing_comments,
inferred_type: None,
}))
}
Expression::Not(un) => {
let inner = ensure_bools(canonicalize_recursive(un.this, dialect));
Expression::Not(Box::new(crate::expressions::UnaryOp {
this: inner,
inferred_type: None,
}))
}
Expression::Eq(bin) => canonicalize_comparison(Expression::Eq, *bin, dialect),
Expression::Neq(bin) => canonicalize_comparison(Expression::Neq, *bin, dialect),
Expression::Lt(bin) => canonicalize_comparison(Expression::Lt, *bin, dialect),
Expression::Lte(bin) => canonicalize_comparison(Expression::Lte, *bin, dialect),
Expression::Gt(bin) => canonicalize_comparison(Expression::Gt, *bin, dialect),
Expression::Gte(bin) => canonicalize_comparison(Expression::Gte, *bin, dialect),
Expression::Sub(bin) => canonicalize_comparison(Expression::Sub, *bin, dialect),
Expression::Mul(bin) => canonicalize_binary(Expression::Mul, *bin, dialect),
Expression::Div(bin) => canonicalize_binary(Expression::Div, *bin, dialect),
Expression::Cast(cast) => {
let inner = canonicalize_recursive(cast.this, dialect);
let result = Expression::Cast(Box::new(crate::expressions::Cast {
this: inner,
to: cast.to,
trailing_comments: cast.trailing_comments,
double_colon_syntax: cast.double_colon_syntax,
format: cast.format,
default: cast.default,
inferred_type: None,
}));
remove_redundant_casts(result)
}
Expression::Function(func) => {
let args = func
.args
.into_iter()
.map(|e| canonicalize_recursive(e, dialect))
.collect();
Expression::Function(Box::new(crate::expressions::Function {
name: func.name,
args,
distinct: func.distinct,
trailing_comments: func.trailing_comments,
use_bracket_syntax: func.use_bracket_syntax,
no_parens: func.no_parens,
quoted: func.quoted,
span: None,
inferred_type: None,
}))
}
Expression::AggregateFunction(agg) => {
let args = agg
.args
.into_iter()
.map(|e| canonicalize_recursive(e, dialect))
.collect();
Expression::AggregateFunction(Box::new(crate::expressions::AggregateFunction {
name: agg.name,
args,
distinct: agg.distinct,
filter: agg.filter.map(|f| canonicalize_recursive(f, dialect)),
order_by: agg.order_by,
limit: agg.limit,
ignore_nulls: agg.ignore_nulls,
inferred_type: None,
}))
}
Expression::Alias(alias) => {
let inner = canonicalize_recursive(alias.this, dialect);
Expression::Alias(Box::new(crate::expressions::Alias {
this: inner,
alias: alias.alias,
column_aliases: alias.column_aliases,
pre_alias_comments: alias.pre_alias_comments,
trailing_comments: alias.trailing_comments,
inferred_type: None,
}))
}
Expression::Paren(paren) => {
let inner = canonicalize_recursive(paren.this, dialect);
Expression::Paren(Box::new(crate::expressions::Paren {
this: inner,
trailing_comments: paren.trailing_comments,
}))
}
Expression::Case(case) => {
let operand = case.operand.map(|e| canonicalize_recursive(e, dialect));
let whens = case
.whens
.into_iter()
.map(|(w, t)| {
(
canonicalize_recursive(w, dialect),
canonicalize_recursive(t, dialect),
)
})
.collect();
let else_ = case.else_.map(|e| canonicalize_recursive(e, dialect));
Expression::Case(Box::new(crate::expressions::Case {
operand,
whens,
else_,
comments: Vec::new(),
inferred_type: None,
}))
}
Expression::Between(between) => {
let this = canonicalize_recursive(between.this, dialect);
let low = canonicalize_recursive(between.low, dialect);
let high = canonicalize_recursive(between.high, dialect);
Expression::Between(Box::new(crate::expressions::Between {
this,
low,
high,
not: between.not,
symmetric: between.symmetric,
}))
}
Expression::In(in_expr) => {
let this = canonicalize_recursive(in_expr.this, dialect);
let expressions = in_expr
.expressions
.into_iter()
.map(|e| canonicalize_recursive(e, dialect))
.collect();
let query = in_expr.query.map(|q| canonicalize_recursive(q, dialect));
Expression::In(Box::new(crate::expressions::In {
this,
expressions,
query,
not: in_expr.not,
global: in_expr.global,
unnest: in_expr.unnest,
is_field: in_expr.is_field,
}))
}
Expression::Subquery(subquery) => {
let this = canonicalize_recursive(subquery.this, dialect);
Expression::Subquery(Box::new(crate::expressions::Subquery {
this,
alias: subquery.alias,
column_aliases: subquery.column_aliases,
order_by: subquery.order_by,
limit: subquery.limit,
offset: subquery.offset,
distribute_by: subquery.distribute_by,
sort_by: subquery.sort_by,
cluster_by: subquery.cluster_by,
lateral: subquery.lateral,
modifiers_inside: subquery.modifiers_inside,
trailing_comments: subquery.trailing_comments,
inferred_type: None,
}))
}
Expression::Union(union) => {
let mut u = *union;
let left = std::mem::replace(&mut u.left, Expression::Null(Null));
u.left = canonicalize_recursive(left, dialect);
let right = std::mem::replace(&mut u.right, Expression::Null(Null));
u.right = canonicalize_recursive(right, dialect);
Expression::Union(Box::new(u))
}
Expression::Intersect(intersect) => {
let mut i = *intersect;
let left = std::mem::replace(&mut i.left, Expression::Null(Null));
i.left = canonicalize_recursive(left, dialect);
let right = std::mem::replace(&mut i.right, Expression::Null(Null));
i.right = canonicalize_recursive(right, dialect);
Expression::Intersect(Box::new(i))
}
Expression::Except(except) => {
let mut e = *except;
let left = std::mem::replace(&mut e.left, Expression::Null(Null));
e.left = canonicalize_recursive(left, dialect);
let right = std::mem::replace(&mut e.right, Expression::Null(Null));
e.right = canonicalize_recursive(right, dialect);
Expression::Except(Box::new(e))
}
other => other,
};
expr
}
fn add_text_to_concat(expression: Expression) -> Expression {
expression
}
fn remove_redundant_casts(expression: Expression) -> Expression {
if let Expression::Cast(cast) = &expression {
if let Expression::Literal(lit) = &cast.this {
if let Literal::String(_) = lit.as_ref() {
if matches!(&cast.to, DataType::VarChar { .. } | DataType::Text) {
return cast.this.clone();
}
}
}
if let Expression::Literal(lit) = &cast.this {
if let Literal::Number(_) = lit.as_ref() {
if matches!(
&cast.to,
DataType::Int { .. }
| DataType::BigInt { .. }
| DataType::Decimal { .. }
| DataType::Float { .. }
) {
}
}
}
}
expression
}
fn ensure_bools(expression: Expression) -> Expression {
expression
}
fn remove_ascending_order(mut ordered: crate::expressions::Ordered) -> crate::expressions::Ordered {
if !ordered.desc && ordered.explicit_asc {
ordered.explicit_asc = false;
}
ordered
}
fn canonicalize_comparison<F>(
constructor: F,
bin: crate::expressions::BinaryOp,
dialect: Option<DialectType>,
) -> Expression
where
F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
{
let left = canonicalize_recursive(bin.left, dialect);
let right = canonicalize_recursive(bin.right, dialect);
let (left, right) = coerce_date_operands(left, right);
constructor(Box::new(crate::expressions::BinaryOp {
left,
right,
left_comments: bin.left_comments,
operator_comments: bin.operator_comments,
trailing_comments: bin.trailing_comments,
inferred_type: None,
}))
}
fn canonicalize_binary<F>(
constructor: F,
bin: crate::expressions::BinaryOp,
dialect: Option<DialectType>,
) -> Expression
where
F: FnOnce(Box<crate::expressions::BinaryOp>) -> Expression,
{
let left = canonicalize_recursive(bin.left, dialect);
let right = canonicalize_recursive(bin.right, dialect);
constructor(Box::new(crate::expressions::BinaryOp {
left,
right,
left_comments: bin.left_comments,
operator_comments: bin.operator_comments,
trailing_comments: bin.trailing_comments,
inferred_type: None,
}))
}
fn coerce_date_operands(left: Expression, right: Expression) -> (Expression, Expression) {
let left = coerce_date_string(left, &right);
let right = coerce_date_string(right, &left);
(left, right)
}
fn coerce_date_string(expr: Expression, _other: &Expression) -> Expression {
if let Expression::Literal(ref lit) = expr {
if let Literal::String(ref s) = lit.as_ref() {
if is_iso_date(s) {
} else if is_iso_datetime(s) {
}
}
}
expr
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::Generator;
use crate::parser::Parser;
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
#[test]
fn test_canonicalize_simple() {
let expr = parse("SELECT a FROM t");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("SELECT"));
}
#[test]
fn test_canonicalize_preserves_structure() {
let expr = parse("SELECT a, b FROM t WHERE c = 1");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("WHERE"));
}
#[test]
fn test_canonicalize_and_or() {
let expr = parse("SELECT 1 WHERE a AND b OR c");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("AND") || sql.contains("OR"));
}
#[test]
fn test_canonicalize_comparison() {
let expr = parse("SELECT 1 WHERE a = 1 AND b > 2");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("=") && sql.contains(">"));
}
#[test]
fn test_canonicalize_case() {
let expr = parse("SELECT CASE WHEN a = 1 THEN 'yes' ELSE 'no' END FROM t");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("CASE") && sql.contains("WHEN"));
}
#[test]
fn test_canonicalize_subquery() {
let expr = parse("SELECT a FROM (SELECT b FROM t) AS sub");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("SELECT") && sql.contains("sub"));
}
#[test]
fn test_canonicalize_order_by() {
let expr = parse("SELECT a FROM t ORDER BY a");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("ORDER BY"));
}
#[test]
fn test_canonicalize_union() {
let expr = parse("SELECT a FROM t UNION SELECT b FROM s");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("UNION"));
}
#[test]
fn test_add_text_to_concat_passthrough() {
let expr = parse("SELECT 1 + 2");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("+"));
}
#[test]
fn test_canonicalize_function() {
let expr = parse("SELECT MAX(a) FROM t");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("MAX"));
}
#[test]
fn test_canonicalize_between() {
let expr = parse("SELECT 1 WHERE a BETWEEN 1 AND 10");
let result = canonicalize(expr, None);
let sql = gen(&result);
assert!(sql.contains("BETWEEN"));
}
}