use arrow::datatypes::Schema;
use datafusion_common::{Result, tree_node::TreeNode};
use std::sync::Arc;
use crate::{
PhysicalExpr,
simplifier::{
const_evaluator::create_dummy_batch, unwrap_cast::unwrap_cast_in_comparison,
},
};
pub mod const_evaluator;
pub mod not;
pub mod unwrap_cast;
const MAX_LOOP_COUNT: usize = 5;
pub struct PhysicalExprSimplifier<'a> {
schema: &'a Schema,
}
impl<'a> PhysicalExprSimplifier<'a> {
pub fn new(schema: &'a Schema) -> Self {
Self { schema }
}
pub fn simplify(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
let mut current_expr = expr;
let mut count = 0;
let schema = self.schema;
let batch = create_dummy_batch()?;
while count < MAX_LOOP_COUNT {
count += 1;
let result = current_expr.transform(|node| {
#[cfg(debug_assertions)]
let original_type = node.data_type(schema).unwrap();
#[expect(deprecated, reason = "`simplify_not_expr` is marked as deprecated until it's made private.")]
let rewritten = not::simplify_not_expr(node, schema)?
.transform_data(|node| unwrap_cast_in_comparison(node, schema))?
.transform_data(|node| {
const_evaluator::simplify_const_expr_immediate(node, &batch)
})?;
#[cfg(debug_assertions)]
assert_eq!(
rewritten.data.data_type(schema).unwrap(),
original_type,
"Simplified expression should have the same data type as the original"
);
Ok(rewritten)
})?;
if !result.transformed {
return Ok(result.data);
}
current_expr = result.data;
}
Ok(current_expr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{
BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, col, in_list, lit,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
fn test_schema() -> Schema {
Schema::new(vec![
Field::new("c1", DataType::Int32, false),
Field::new("c2", DataType::Int64, false),
Field::new("c3", DataType::Utf8, false),
])
}
fn not_test_schema() -> Schema {
Schema::new(vec![
Field::new("a", DataType::Boolean, false),
Field::new("b", DataType::Boolean, false),
Field::new("c", DataType::Int32, false),
])
}
fn as_literal(expr: &Arc<dyn PhysicalExpr>) -> &Literal {
expr.as_any()
.downcast_ref::<Literal>()
.unwrap_or_else(|| panic!("Expected Literal, got: {expr}"))
}
fn as_binary(expr: &Arc<dyn PhysicalExpr>) -> &BinaryExpr {
expr.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}"))
}
fn assert_not_simplify(
simplifier: &PhysicalExprSimplifier,
input: Arc<dyn PhysicalExpr>,
expected: Arc<dyn PhysicalExpr>,
) {
let result = simplifier.simplify(Arc::clone(&input)).unwrap();
assert_eq!(
&result, &expected,
"Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}"
);
}
#[test]
fn test_simplify() {
let schema = test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let column_expr = col("c2", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
let literal_expr = lit(ScalarValue::Int32(Some(99)));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
let optimized = simplifier.simplify(binary_expr).unwrap();
let optimized_binary = as_binary(&optimized);
let left_expr = optimized_binary.left();
assert!(
left_expr.as_any().downcast_ref::<CastExpr>().is_none()
&& left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
);
let right_literal = as_literal(optimized_binary.right());
assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
}
#[test]
fn test_nested_expression_simplification() {
let schema = test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let c1_expr = col("c1", &schema).unwrap();
let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
let c1_literal = lit(ScalarValue::Int64(Some(5)));
let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
let c2_expr = col("c2", &schema).unwrap();
let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
let c2_literal = lit(ScalarValue::Int32(Some(10)));
let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
let optimized = simplifier.simplify(or_expr).unwrap();
let or_binary = as_binary(&optimized);
let left_binary = as_binary(or_binary.left());
let left_left_expr = left_binary.left();
assert!(
left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
&& left_left_expr
.as_any()
.downcast_ref::<TryCastExpr>()
.is_none()
);
let left_literal = as_literal(left_binary.right());
assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
let right_binary = as_binary(or_binary.right());
let right_left_expr = right_binary.left();
assert!(
right_left_expr
.as_any()
.downcast_ref::<CastExpr>()
.is_none()
&& right_left_expr
.as_any()
.downcast_ref::<TryCastExpr>()
.is_none()
);
let right_literal = as_literal(right_binary.right());
assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
}
#[test]
fn test_double_negation_elimination() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::Gt,
lit(ScalarValue::Int32(Some(5))),
));
let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr)));
let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(inner_not));
let expected = inner_expr;
assert_not_simplify(&simplifier, double_not, expected);
Ok(())
}
#[test]
fn test_not_literal() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true)))));
let expected = lit(ScalarValue::Boolean(Some(false)));
assert_not_simplify(&simplifier, not_true, expected);
let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false)))));
let expected = lit(ScalarValue::Boolean(Some(true)));
assert_not_simplify(&simplifier, not_false, expected);
Ok(())
}
#[test]
fn test_negate_comparison() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::Eq,
lit(ScalarValue::Int32(Some(5))),
))));
let expected = Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::NotEq,
lit(ScalarValue::Int32(Some(5))),
));
assert_not_simplify(&simplifier, not_eq, expected);
Ok(())
}
#[test]
fn test_demorgans_law_and() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let and_expr = Arc::new(BinaryExpr::new(
col("a", &schema)?,
Operator::And,
col("b", &schema)?,
));
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(NotExpr::new(col("a", &schema)?)),
Operator::Or,
Arc::new(NotExpr::new(col("b", &schema)?)),
));
assert_not_simplify(&simplifier, not_and, expected);
Ok(())
}
#[test]
fn test_demorgans_law_or() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let or_expr = Arc::new(BinaryExpr::new(
col("a", &schema)?,
Operator::Or,
col("b", &schema)?,
));
let not_or: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(or_expr));
let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(NotExpr::new(col("a", &schema)?)),
Operator::And,
Arc::new(NotExpr::new(col("b", &schema)?)),
));
assert_not_simplify(&simplifier, not_or, expected);
Ok(())
}
#[test]
fn test_demorgans_with_comparison_simplification() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let eq1 = Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::Eq,
lit(ScalarValue::Int32(Some(1))),
));
let eq2 = Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::Eq,
lit(ScalarValue::Int32(Some(2))),
));
let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2));
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::NotEq,
lit(ScalarValue::Int32(Some(1))),
)),
Operator::Or,
Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::NotEq,
lit(ScalarValue::Int32(Some(2))),
)),
));
assert_not_simplify(&simplifier, not_and, expected);
Ok(())
}
#[test]
fn test_not_of_not_and_not() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let not_a = Arc::new(NotExpr::new(col("a", &schema)?));
let not_b = Arc::new(NotExpr::new(col("b", &schema)?));
let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b));
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
col("a", &schema)?,
Operator::Or,
col("b", &schema)?,
));
assert_not_simplify(&simplifier, not_and, expected);
Ok(())
}
#[test]
fn test_not_in_list() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let list = vec![
lit(ScalarValue::Int32(Some(1))),
lit(ScalarValue::Int32(Some(2))),
lit(ScalarValue::Int32(Some(3))),
];
let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
let expected = in_list(col("c", &schema)?, list, &true, &schema)?;
assert_not_simplify(&simplifier, not_in, expected);
Ok(())
}
#[test]
fn test_not_not_in_list() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let list = vec![
lit(ScalarValue::Int32(Some(1))),
lit(ScalarValue::Int32(Some(2))),
lit(ScalarValue::Int32(Some(3))),
];
let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?;
let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
assert_not_simplify(&simplifier, not_not_in, expected);
Ok(())
}
#[test]
fn test_double_not_in_list() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let list = vec![
lit(ScalarValue::Int32(Some(1))),
lit(ScalarValue::Int32(Some(2))),
lit(ScalarValue::Int32(Some(3))),
];
let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
let not_in = Arc::new(NotExpr::new(in_list_expr));
let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
assert_not_simplify(&simplifier, double_not, expected);
Ok(())
}
#[test]
fn test_deeply_nested_not() -> Result<()> {
let schema = not_test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
col("c", &schema)?,
Operator::Gt,
lit(ScalarValue::Int32(Some(5))),
));
let mut expr = Arc::clone(&inner_expr);
for _ in 0..200 {
expr = Arc::new(NotExpr::new(expr));
}
let expected = inner_expr;
assert_not_simplify(&simplifier, Arc::clone(&expr), expected);
while let Some(not_expr) = expr.as_any().downcast_ref::<NotExpr>() {
let child = Arc::clone(not_expr.arg());
expr = child;
}
Ok(())
}
#[test]
fn test_simplify_literal_binary_expr() {
let schema = Schema::empty();
let simplifier = PhysicalExprSimplifier::new(&schema);
let expr: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(literal.value(), &ScalarValue::Int32(Some(3)));
}
#[test]
fn test_simplify_literal_comparison() {
let schema = Schema::empty();
let simplifier = PhysicalExprSimplifier::new(&schema);
let expr: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(lit(5i32), Operator::Gt, lit(3i32)));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true)));
let expr: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(lit(2i32), Operator::Gt, lit(3i32)));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false)));
}
#[test]
fn test_simplify_nested_literal_expr() {
let schema = Schema::empty();
let simplifier = PhysicalExprSimplifier::new(&schema);
let inner: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
let expr: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(inner, Operator::Multiply, lit(3i32)));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(literal.value(), &ScalarValue::Int32(Some(9)));
}
#[test]
fn test_simplify_deeply_nested_literals() {
let schema = Schema::empty();
let simplifier = PhysicalExprSimplifier::new(&schema);
let left: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))),
Operator::Multiply,
lit(3i32),
));
let right: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(lit(4i32), Operator::Minus, lit(1i32))),
Operator::Multiply,
lit(2i32),
));
let expr: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(left, Operator::Plus, right));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(literal.value(), &ScalarValue::Int32(Some(15)));
}
#[test]
fn test_no_simplify_with_column() {
let schema = test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
col("c1", &schema).unwrap(),
Operator::Plus,
lit(2i32),
));
let result = simplifier.simplify(expr).unwrap();
assert!(result.as_any().downcast_ref::<BinaryExpr>().is_some());
}
#[test]
fn test_partial_simplify_with_column() {
let schema = test_schema();
let simplifier = PhysicalExprSimplifier::new(&schema);
let literal_part: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
literal_part,
Operator::Plus,
col("c1", &schema).unwrap(),
));
let result = simplifier.simplify(expr).unwrap();
let binary = as_binary(&result);
let left_literal = as_literal(binary.left());
assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(3)));
}
#[test]
fn test_simplify_literal_string_concat() {
let schema = Schema::empty();
let simplifier = PhysicalExprSimplifier::new(&schema);
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
lit("hello"),
Operator::StringConcat,
lit(" world"),
));
let result = simplifier.simplify(expr).unwrap();
let literal = as_literal(&result);
assert_eq!(
literal.value(),
&ScalarValue::Utf8(Some("hello world".to_string()))
);
}
}