use crate::dialects::DialectType;
use crate::expressions::{
BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
UnaryOp,
};
pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
let mut simplifier = Simplifier::new(dialect);
simplifier.simplify(expression)
}
pub fn always_true(expr: &Expression) -> bool {
match expr {
Expression::Boolean(b) => b.value,
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
let Literal::Number(n) = lit.as_ref() else {
unreachable!()
};
if let Ok(num) = n.parse::<f64>() {
num != 0.0
} else {
false
}
}
_ => false,
}
}
pub fn is_boolean_true(expr: &Expression) -> bool {
matches!(expr, Expression::Boolean(b) if b.value)
}
pub fn is_boolean_false(expr: &Expression) -> bool {
matches!(expr, Expression::Boolean(b) if !b.value)
}
pub fn always_false(expr: &Expression) -> bool {
is_false(expr) || is_null(expr) || is_zero(expr)
}
pub fn is_false(expr: &Expression) -> bool {
matches!(expr, Expression::Boolean(b) if !b.value)
}
pub fn is_null(expr: &Expression) -> bool {
matches!(expr, Expression::Null(_))
}
pub fn is_zero(expr: &Expression) -> bool {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
let Literal::Number(n) = lit.as_ref() else {
unreachable!()
};
if let Ok(num) = n.parse::<f64>() {
num == 0.0
} else {
false
}
}
_ => false,
}
}
pub fn is_complement(a: &Expression, b: &Expression) -> bool {
if let Expression::Not(not_op) = b {
¬_op.this == a
} else {
false
}
}
pub fn bool_true() -> Expression {
Expression::Boolean(BooleanLiteral { value: true })
}
pub fn bool_false() -> Expression {
Expression::Boolean(BooleanLiteral { value: false })
}
pub fn null() -> Expression {
Expression::Null(Null)
}
pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
let result = match op {
"=" | "==" => a == b,
"!=" | "<>" => a != b,
">" => a > b,
">=" => a >= b,
"<" => a < b,
"<=" => a <= b,
_ => return None,
};
Some(if result { bool_true() } else { bool_false() })
}
pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
let result = match op {
"=" | "==" => a == b,
"!=" | "<>" => a != b,
">" => a > b,
">=" => a >= b,
"<" => a < b,
"<=" => a <= b,
_ => return None,
};
Some(if result { bool_true() } else { bool_false() })
}
pub struct Simplifier {
_dialect: Option<DialectType>,
max_iterations: usize,
}
impl Simplifier {
pub fn new(dialect: Option<DialectType>) -> Self {
Self {
_dialect: dialect,
max_iterations: 100,
}
}
pub fn simplify(&mut self, expression: Expression) -> Expression {
let mut current = expression;
for _ in 0..self.max_iterations {
let simplified = self.simplify_once(current.clone());
if expressions_equal(&simplified, ¤t) {
return simplified;
}
current = simplified;
}
current
}
fn simplify_once(&mut self, expression: Expression) -> Expression {
match expression {
Expression::And(op) => self.simplify_and(*op),
Expression::Or(op) => self.simplify_or(*op),
Expression::Not(op) => self.simplify_not(*op),
Expression::Add(op) => self.simplify_add(*op),
Expression::Sub(op) => self.simplify_sub(*op),
Expression::Mul(op) => self.simplify_mul(*op),
Expression::Div(op) => self.simplify_div(*op),
Expression::Eq(op) => self.simplify_comparison(*op, "="),
Expression::Neq(op) => self.simplify_comparison(*op, "!="),
Expression::Gt(op) => self.simplify_comparison(*op, ">"),
Expression::Gte(op) => self.simplify_comparison(*op, ">="),
Expression::Lt(op) => self.simplify_comparison(*op, "<"),
Expression::Lte(op) => self.simplify_comparison(*op, "<="),
Expression::Neg(op) => self.simplify_neg(*op),
Expression::Case(case) => self.simplify_case(*case),
Expression::Concat(op) => self.simplify_concat(*op),
Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
Expression::Paren(paren) => self.simplify_paren(*paren),
Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
other => self.simplify_children(other),
}
}
fn simplify_and(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if is_boolean_false(&left) || is_boolean_false(&right) {
return bool_false();
}
if is_zero(&left) || is_zero(&right) {
return bool_false();
}
if (is_null(&left) && is_null(&right))
|| (is_null(&left) && is_boolean_true(&right))
|| (is_boolean_true(&left) && is_null(&right))
{
return null();
}
if is_boolean_true(&left) {
return right;
}
if is_boolean_true(&right) {
return left;
}
if is_complement(&left, &right) || is_complement(&right, &left) {
return bool_false();
}
if expressions_equal(&left, &right) {
return left;
}
absorb_and_eliminate_and(left, right)
}
fn simplify_or(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if is_boolean_true(&left) {
return bool_true();
}
if is_boolean_true(&right) {
return bool_true();
}
if (is_null(&left) && is_null(&right))
|| (is_null(&left) && is_boolean_false(&right))
|| (is_boolean_false(&left) && is_null(&right))
{
return null();
}
if is_boolean_false(&left) {
return right;
}
if is_boolean_false(&right) {
return left;
}
if expressions_equal(&left, &right) {
return left;
}
absorb_and_eliminate_or(left, right)
}
fn simplify_not(&mut self, op: UnaryOp) -> Expression {
match &op.this {
Expression::Eq(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Neq(Box::new(BinaryOp::new(left, right)));
}
Expression::Neq(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Eq(Box::new(BinaryOp::new(left, right)));
}
Expression::Gt(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Lte(Box::new(BinaryOp::new(left, right)));
}
Expression::Gte(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Lt(Box::new(BinaryOp::new(left, right)));
}
Expression::Lt(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Gte(Box::new(BinaryOp::new(left, right)));
}
Expression::Lte(inner_op) => {
let left = self.simplify_once(inner_op.left.clone());
let right = self.simplify_once(inner_op.right.clone());
return Expression::Gt(Box::new(BinaryOp::new(left, right)));
}
_ => {}
}
let inner = self.simplify_once(op.this);
if is_null(&inner) {
return null();
}
if is_boolean_true(&inner) {
return bool_false();
}
if is_boolean_false(&inner) {
return bool_true();
}
if let Expression::Not(inner_not) = &inner {
return inner_not.this.clone();
}
Expression::Not(Box::new(UnaryOp {
this: inner,
inferred_type: None,
}))
}
fn simplify_add(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
return Expression::Literal(Box::new(Literal::Number((a + b).to_string())));
}
if is_zero(&right) {
return left;
}
if is_zero(&left) {
return right;
}
Expression::Add(Box::new(BinaryOp::new(left, right)))
}
fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
return Expression::Literal(Box::new(Literal::Number((a - b).to_string())));
}
if is_zero(&right) {
return left;
}
if expressions_equal(&left, &right) {
if let Expression::Literal(lit) = &left {
if let Literal::Number(_) = lit.as_ref() {
return Expression::Literal(Box::new(Literal::Number("0".to_string())));
}
}
}
Expression::Sub(Box::new(BinaryOp::new(left, right)))
}
fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
return Expression::Literal(Box::new(Literal::Number((a * b).to_string())));
}
if is_zero(&right) {
return Expression::Literal(Box::new(Literal::Number("0".to_string())));
}
if is_zero(&left) {
return Expression::Literal(Box::new(Literal::Number("0".to_string())));
}
if is_one(&right) {
return left;
}
if is_one(&left) {
return right;
}
Expression::Mul(Box::new(BinaryOp::new(left, right)))
}
fn simplify_div(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
return Expression::Literal(Box::new(Literal::Number((a / b).to_string())));
}
}
if is_zero(&left) && !is_zero(&right) {
return Expression::Literal(Box::new(Literal::Number("0".to_string())));
}
if is_one(&right) {
return left;
}
Expression::Div(Box::new(BinaryOp::new(left, right)))
}
fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
let inner = self.simplify_once(op.this);
if let Expression::Neg(inner_neg) = inner {
return inner_neg.this;
}
if let Some(n) = get_number(&inner) {
return Expression::Literal(Box::new(Literal::Number((-n).to_string())));
}
Expression::Neg(Box::new(UnaryOp {
this: inner,
inferred_type: None,
}))
}
fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
if let Some(result) = eval_boolean_nums(operator, a, b) {
return result;
}
}
if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
if let Some(result) = eval_boolean_strings(operator, &a, &b) {
return result;
}
}
if operator == "=" {
if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
return simplified;
}
}
let new_op = BinaryOp::new(left, right);
match operator {
"=" => Expression::Eq(Box::new(new_op)),
"!=" | "<>" => Expression::Neq(Box::new(new_op)),
">" => Expression::Gt(Box::new(new_op)),
">=" => Expression::Gte(Box::new(new_op)),
"<" => Expression::Lt(Box::new(new_op)),
"<=" => Expression::Lte(Box::new(new_op)),
_ => Expression::Eq(Box::new(new_op)),
}
}
fn simplify_case(&mut self, case: Case) -> Expression {
let mut new_whens = Vec::new();
for (cond, then_expr) in case.whens {
let simplified_cond = self.simplify_once(cond);
if always_true(&simplified_cond) {
return self.simplify_once(then_expr);
}
if always_false(&simplified_cond) {
continue;
}
new_whens.push((simplified_cond, self.simplify_once(then_expr)));
}
if new_whens.is_empty() {
return case
.else_
.map(|e| self.simplify_once(e))
.unwrap_or_else(null);
}
Expression::Case(Box::new(Case {
operand: case.operand.map(|e| self.simplify_once(e)),
whens: new_whens,
else_: case.else_.map(|e| self.simplify_once(e)),
comments: Vec::new(),
inferred_type: None,
}))
}
fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
let left = self.simplify_once(op.left);
let right = self.simplify_once(op.right);
if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
return Expression::Literal(Box::new(Literal::String(format!("{}{}", a, b))));
}
if let Some(s) = get_string(&left) {
if s.is_empty() {
return right;
}
}
if let Some(s) = get_string(&right) {
if s.is_empty() {
return left;
}
}
if is_null(&left) || is_null(&right) {
return null();
}
Expression::Concat(Box::new(BinaryOp::new(left, right)))
}
fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
let separator = self.simplify_once(concat_ws.separator);
if is_null(&separator) {
return null();
}
let expressions: Vec<Expression> = concat_ws
.expressions
.into_iter()
.map(|e| self.simplify_once(e))
.filter(|e| !is_null(e)) .collect();
if expressions.is_empty() {
return Expression::Literal(Box::new(Literal::String(String::new())));
}
if let Some(sep) = get_string(&separator) {
let all_strings: Option<Vec<String>> =
expressions.iter().map(|e| get_string(e)).collect();
if let Some(strings) = all_strings {
return Expression::Literal(Box::new(Literal::String(strings.join(&sep))));
}
}
Expression::ConcatWs(Box::new(ConcatWs {
separator,
expressions,
}))
}
fn simplify_paren(&mut self, paren: Paren) -> Expression {
let inner = self.simplify_once(paren.this);
match &inner {
Expression::Literal(_)
| Expression::Boolean(_)
| Expression::Null(_)
| Expression::Column(_)
| Expression::Paren(_) => inner,
_ => Expression::Paren(Box::new(Paren {
this: inner,
trailing_comments: paren.trailing_comments,
})),
}
}
fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
let inner = self.simplify_once(dt.this);
Expression::DateTrunc(Box::new(DateTruncFunc {
this: inner,
unit: dt.unit,
}))
}
fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
let right_val = get_number(&right)?;
match left {
Expression::Add(ref op) => {
if let Some(c) = get_number(&op.right) {
let new_right =
Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.left.clone(),
new_right,
))));
}
if let Some(c) = get_number(&op.left) {
let new_right =
Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.right.clone(),
new_right,
))));
}
}
Expression::Sub(ref op) => {
if let Some(c) = get_number(&op.right) {
let new_right =
Expression::Literal(Box::new(Literal::Number((right_val + c).to_string())));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.left.clone(),
new_right,
))));
}
if let Some(c) = get_number(&op.left) {
let new_right =
Expression::Literal(Box::new(Literal::Number((c - right_val).to_string())));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.right.clone(),
new_right,
))));
}
}
Expression::Mul(ref op) => {
if let Some(c) = get_number(&op.right) {
if c != 0.0 && right_val % c == 0.0 {
let new_right = Expression::Literal(Box::new(Literal::Number(
(right_val / c).to_string(),
)));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.left.clone(),
new_right,
))));
}
}
if let Some(c) = get_number(&op.left) {
if c != 0.0 && right_val % c == 0.0 {
let new_right = Expression::Literal(Box::new(Literal::Number(
(right_val / c).to_string(),
)));
return Some(Expression::Eq(Box::new(BinaryOp::new(
op.right.clone(),
new_right,
))));
}
}
}
_ => {}
}
None
}
fn simplify_children(&mut self, expr: Expression) -> Expression {
match expr {
Expression::Alias(mut alias) => {
alias.this = self.simplify_once(alias.this);
Expression::Alias(alias)
}
Expression::Between(mut between) => {
between.this = self.simplify_once(between.this);
between.low = self.simplify_once(between.low);
between.high = self.simplify_once(between.high);
Expression::Between(between)
}
Expression::In(mut in_expr) => {
in_expr.this = self.simplify_once(in_expr.this);
in_expr.expressions = in_expr
.expressions
.into_iter()
.map(|e| self.simplify_once(e))
.collect();
Expression::In(in_expr)
}
Expression::Function(mut func) => {
func.args = func
.args
.into_iter()
.map(|e| self.simplify_once(e))
.collect();
Expression::Function(func)
}
other => other,
}
}
}
fn is_one(expr: &Expression) -> bool {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
let Literal::Number(n) = lit.as_ref() else {
unreachable!()
};
if let Ok(num) = n.parse::<f64>() {
num == 1.0
} else {
false
}
}
_ => false,
}
}
fn get_number(expr: &Expression) -> Option<f64> {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
let Literal::Number(n) = lit.as_ref() else {
unreachable!()
};
n.parse().ok()
}
_ => None,
}
}
fn get_string(expr: &Expression) -> Option<String> {
match expr {
Expression::Literal(lit) if matches!(lit.as_ref(), Literal::String(_)) => {
let Literal::String(s) = lit.as_ref() else {
unreachable!()
};
Some(s.clone())
}
_ => None,
}
}
fn expressions_equal(a: &Expression, b: &Expression) -> bool {
format!("{:?}", a) == format!("{:?}", b)
}
fn flatten_and(expr: &Expression) -> Vec<Expression> {
match expr {
Expression::And(op) => {
let mut result = flatten_and(&op.left);
result.extend(flatten_and(&op.right));
result
}
other => vec![other.clone()],
}
}
fn flatten_or(expr: &Expression) -> Vec<Expression> {
match expr {
Expression::Or(op) => {
let mut result = flatten_or(&op.left);
result.extend(flatten_or(&op.right));
result
}
other => vec![other.clone()],
}
}
fn rebuild_and(operands: Vec<Expression>) -> Expression {
if operands.is_empty() {
return bool_true(); }
let mut result = operands.into_iter();
let first = result.next().unwrap();
result.fold(first, |acc, op| {
Expression::And(Box::new(BinaryOp::new(acc, op)))
})
}
fn rebuild_or(operands: Vec<Expression>) -> Expression {
if operands.is_empty() {
return bool_false(); }
let mut result = operands.into_iter();
let first = result.next().unwrap();
result.fold(first, |acc, op| {
Expression::Or(Box::new(BinaryOp::new(acc, op)))
})
}
fn get_not_inner(expr: &Expression) -> Option<&Expression> {
match expr {
Expression::Not(op) => Some(&op.this),
_ => None,
}
}
pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
let left_ops = flatten_and(&left);
let right_ops = flatten_and(&right);
let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
let mut result_ops: Vec<Expression> = Vec::new();
let mut absorbed = std::collections::HashSet::new();
for (i, op) in all_ops.iter().enumerate() {
let op_str = gen(op);
if absorbed.contains(&op_str) {
continue;
}
if let Expression::Or(_) = op {
let or_operands = flatten_or(op);
let absorbed_by_existing = or_operands.iter().any(|or_op| {
let or_op_str = gen(or_op);
all_ops
.iter()
.enumerate()
.any(|(j, other)| i != j && gen(other) == or_op_str)
});
if absorbed_by_existing {
absorbed.insert(op_str);
continue;
}
let mut remaining_or_ops: Vec<Expression> = Vec::new();
let mut had_complement_absorption = false;
for or_op in or_operands {
let complement_str = if let Some(inner) = get_not_inner(&or_op) {
gen(inner)
} else {
format!("NOT {}", gen(&or_op))
};
let has_complement = all_ops
.iter()
.enumerate()
.any(|(j, other)| i != j && gen(other) == complement_str)
|| op_strings.contains(&complement_str);
if has_complement {
had_complement_absorption = true;
} else {
remaining_or_ops.push(or_op);
}
}
if had_complement_absorption {
if remaining_or_ops.is_empty() {
absorbed.insert(op_str);
continue;
} else if remaining_or_ops.len() == 1 {
result_ops.push(remaining_or_ops.into_iter().next().unwrap());
absorbed.insert(op_str);
continue;
} else {
result_ops.push(rebuild_or(remaining_or_ops));
absorbed.insert(op_str);
continue;
}
}
}
result_ops.push(op.clone());
}
let mut seen = std::collections::HashSet::new();
result_ops.retain(|op| seen.insert(gen(op)));
if result_ops.is_empty() {
bool_true()
} else {
rebuild_and(result_ops)
}
}
pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
let left_ops = flatten_or(&left);
let right_ops = flatten_or(&right);
let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
let mut result_ops: Vec<Expression> = Vec::new();
let mut absorbed = std::collections::HashSet::new();
for (i, op) in all_ops.iter().enumerate() {
let op_str = gen(op);
if absorbed.contains(&op_str) {
continue;
}
if let Expression::And(_) = op {
let and_operands = flatten_and(op);
let absorbed_by_existing = and_operands.iter().any(|and_op| {
let and_op_str = gen(and_op);
all_ops
.iter()
.enumerate()
.any(|(j, other)| i != j && gen(other) == and_op_str)
});
if absorbed_by_existing {
absorbed.insert(op_str);
continue;
}
let mut remaining_and_ops: Vec<Expression> = Vec::new();
let mut had_complement_absorption = false;
for and_op in and_operands {
let complement_str = if let Some(inner) = get_not_inner(&and_op) {
gen(inner)
} else {
format!("NOT {}", gen(&and_op))
};
let has_complement = all_ops
.iter()
.enumerate()
.any(|(j, other)| i != j && gen(other) == complement_str)
|| op_strings.contains(&complement_str);
if has_complement {
had_complement_absorption = true;
} else {
remaining_and_ops.push(and_op);
}
}
if had_complement_absorption {
if remaining_and_ops.is_empty() {
absorbed.insert(op_str);
continue;
} else if remaining_and_ops.len() == 1 {
result_ops.push(remaining_and_ops.into_iter().next().unwrap());
absorbed.insert(op_str);
continue;
} else {
result_ops.push(rebuild_and(remaining_and_ops));
absorbed.insert(op_str);
continue;
}
}
}
result_ops.push(op.clone());
}
let mut seen = std::collections::HashSet::new();
result_ops.retain(|op| seen.insert(gen(op)));
if result_ops.is_empty() {
bool_false()
} else {
rebuild_or(result_ops)
}
}
pub fn gen(expr: &Expression) -> String {
match expr {
Expression::Literal(lit) => match lit.as_ref() {
Literal::String(s) => format!("'{}'", s),
Literal::Number(n) => n.clone(),
_ => format!("{:?}", lit),
},
Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
Expression::Null(_) => "NULL".to_string(),
Expression::Column(col) => {
if let Some(ref table) = col.table {
format!("{}.{}", table.name, col.name.name)
} else {
col.name.name.clone()
}
}
Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
Expression::Not(op) => format!("NOT {}", gen(&op.this)),
Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
Expression::Function(f) => {
let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
format!("{}({})", f.name.to_uppercase(), args.join(", "))
}
_ => format!("{:?}", expr),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_int(val: i64) -> Expression {
Expression::Literal(Box::new(Literal::Number(val.to_string())))
}
fn make_string(val: &str) -> Expression {
Expression::Literal(Box::new(Literal::String(val.to_string())))
}
fn make_bool(val: bool) -> Expression {
Expression::Boolean(BooleanLiteral { value: val })
}
fn make_column(name: &str) -> Expression {
use crate::expressions::{Column, Identifier};
Expression::boxed_column(Column {
name: Identifier::new(name),
table: None,
join_mark: false,
trailing_comments: vec![],
span: None,
inferred_type: None,
})
}
#[test]
fn test_always_true_false() {
assert!(always_true(&make_bool(true)));
assert!(!always_true(&make_bool(false)));
assert!(always_true(&make_int(1)));
assert!(!always_true(&make_int(0)));
assert!(always_false(&make_bool(false)));
assert!(!always_false(&make_bool(true)));
assert!(always_false(&null()));
assert!(always_false(&make_int(0)));
}
#[test]
fn test_simplify_and_with_true() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
let result = simplifier.simplify(expr);
assert!(always_false(&result));
let x = make_int(42);
let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_simplify_or_with_false() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
let result = simplifier.simplify(expr);
assert!(always_false(&result));
let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
let x = make_int(42);
let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_simplify_not() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
let result = simplifier.simplify(expr);
assert!(is_false(&result));
let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
let x = make_int(42);
let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_simplify_demorgan_comparison() {
let mut simplifier = Simplifier::new(None);
let a = make_column("a");
let b = make_column("b");
let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
let result = simplifier.simplify(expr);
assert!(matches!(result, Expression::Neq(_)));
let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
let result = simplifier.simplify(expr);
assert!(matches!(result, Expression::Lte(_)));
}
#[test]
fn test_constant_folding_add() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(3.0));
let x = make_int(42);
let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_constant_folding_mul() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(12.0));
let x = make_int(42);
let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(0.0));
let x = make_int(42);
let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_constant_folding_comparison() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
let result = simplifier.simplify(expr);
assert!(is_false(&result));
let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
let expr = Expression::Eq(Box::new(BinaryOp::new(
make_string("abc"),
make_string("abc"),
)));
let result = simplifier.simplify(expr);
assert!(always_true(&result));
}
#[test]
fn test_simplify_negation() {
let mut simplifier = Simplifier::new(None);
let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(5.0));
let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(-3.0));
}
#[test]
fn test_gen_simple() {
assert_eq!(gen(&make_int(42)), "42");
assert_eq!(gen(&make_string("hello")), "'hello'");
assert_eq!(gen(&make_bool(true)), "TRUE");
assert_eq!(gen(&make_bool(false)), "FALSE");
assert_eq!(gen(&null()), "NULL");
}
#[test]
fn test_gen_operations() {
let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
assert_eq!(gen(&add), "1 + 2");
let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
assert_eq!(gen(&and), "(TRUE AND FALSE)");
}
#[test]
fn test_complement_elimination() {
let mut simplifier = Simplifier::new(None);
let x = make_int(42);
let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
let result = simplifier.simplify(expr);
assert!(is_false(&result));
}
#[test]
fn test_idempotent() {
let mut simplifier = Simplifier::new(None);
let x = make_int(42);
let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
let x = make_int(42);
let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
let result = simplifier.simplify(expr);
assert_eq!(format!("{:?}", result), format!("{:?}", x));
}
#[test]
fn test_absorption_and() {
let mut simplifier = Simplifier::new(None);
let a = make_column("a");
let b = make_column("b");
let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
let result = simplifier.simplify(expr);
assert_eq!(gen(&result), gen(&a));
}
#[test]
fn test_absorption_or() {
let mut simplifier = Simplifier::new(None);
let a = make_column("a");
let b = make_column("b");
let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
let result = simplifier.simplify(expr);
assert_eq!(gen(&result), gen(&a));
}
#[test]
fn test_absorption_with_complement_and() {
let mut simplifier = Simplifier::new(None);
let a = make_column("a");
let b = make_column("b");
let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
let result = simplifier.simplify(expr);
let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
assert_eq!(gen(&result), gen(&expected));
}
#[test]
fn test_absorption_with_complement_or() {
let mut simplifier = Simplifier::new(None);
let a = make_column("a");
let b = make_column("b");
let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
let result = simplifier.simplify(expr);
let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
assert_eq!(gen(&result), gen(&expected));
}
#[test]
fn test_flatten_and() {
let a = make_column("a");
let b = make_column("b");
let c = make_column("c");
let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
let flattened = flatten_and(&expr);
assert_eq!(flattened.len(), 3);
assert_eq!(gen(&flattened[0]), "a");
assert_eq!(gen(&flattened[1]), "b");
assert_eq!(gen(&flattened[2]), "c");
}
#[test]
fn test_flatten_or() {
let a = make_column("a");
let b = make_column("b");
let c = make_column("c");
let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
let flattened = flatten_or(&expr);
assert_eq!(flattened.len(), 3);
assert_eq!(gen(&flattened[0]), "a");
assert_eq!(gen(&flattened[1]), "b");
assert_eq!(gen(&flattened[2]), "c");
}
#[test]
fn test_simplify_concat() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Concat(Box::new(BinaryOp::new(
make_string("hello"),
make_string("world"),
)));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("helloworld".to_string()));
let x = make_string("test");
let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("test".to_string()));
let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("test".to_string()));
let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
let result = simplifier.simplify(expr);
assert!(is_null(&result));
}
#[test]
fn test_simplify_concat_ws() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::ConcatWs(Box::new(ConcatWs {
separator: make_string(","),
expressions: vec![make_string("a"), make_string("b"), make_string("c")],
}));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("a,b,c".to_string()));
let expr = Expression::ConcatWs(Box::new(ConcatWs {
separator: null(),
expressions: vec![make_string("a"), make_string("b")],
}));
let result = simplifier.simplify(expr);
assert!(is_null(&result));
let expr = Expression::ConcatWs(Box::new(ConcatWs {
separator: make_string(","),
expressions: vec![],
}));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("".to_string()));
let expr = Expression::ConcatWs(Box::new(ConcatWs {
separator: make_string("-"),
expressions: vec![make_string("a"), null(), make_string("b")],
}));
let result = simplifier.simplify(expr);
assert_eq!(get_string(&result), Some("a-b".to_string()));
}
#[test]
fn test_simplify_paren() {
let mut simplifier = Simplifier::new(None);
let expr = Expression::Paren(Box::new(Paren {
this: make_int(42),
trailing_comments: vec![],
}));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(42.0));
let expr = Expression::Paren(Box::new(Paren {
this: make_bool(true),
trailing_comments: vec![],
}));
let result = simplifier.simplify(expr);
assert!(is_boolean_true(&result));
let expr = Expression::Paren(Box::new(Paren {
this: null(),
trailing_comments: vec![],
}));
let result = simplifier.simplify(expr);
assert!(is_null(&result));
let inner_paren = Expression::Paren(Box::new(Paren {
this: make_int(10),
trailing_comments: vec![],
}));
let expr = Expression::Paren(Box::new(Paren {
this: inner_paren,
trailing_comments: vec![],
}));
let result = simplifier.simplify(expr);
assert_eq!(get_number(&result), Some(10.0));
}
#[test]
fn test_simplify_equality_solve() {
let mut simplifier = Simplifier::new(None);
let x = make_column("x");
let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
let result = simplifier.simplify(expr);
if let Expression::Eq(op) = &result {
assert_eq!(gen(&op.left), "x");
assert_eq!(get_number(&op.right), Some(2.0));
} else {
panic!("Expected Eq expression");
}
let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
let result = simplifier.simplify(expr);
if let Expression::Eq(op) = &result {
assert_eq!(gen(&op.left), "x");
assert_eq!(get_number(&op.right), Some(4.0));
} else {
panic!("Expected Eq expression");
}
let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
let result = simplifier.simplify(expr);
if let Expression::Eq(op) = &result {
assert_eq!(gen(&op.left), "x");
assert_eq!(get_number(&op.right), Some(3.0));
} else {
panic!("Expected Eq expression");
}
let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
let result = simplifier.simplify(expr);
if let Expression::Eq(op) = &result {
assert_eq!(gen(&op.left), "x");
assert_eq!(get_number(&op.right), Some(2.0));
} else {
panic!("Expected Eq expression");
}
}
#[test]
fn test_simplify_datetrunc() {
use crate::expressions::DateTimeField;
let mut simplifier = Simplifier::new(None);
let x = make_column("x");
let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
this: x.clone(),
unit: DateTimeField::Day,
}));
let result = simplifier.simplify(expr);
if let Expression::DateTrunc(dt) = &result {
assert_eq!(gen(&dt.this), "x");
assert_eq!(dt.unit, DateTimeField::Day);
} else {
panic!("Expected DateTrunc expression");
}
}
}