use arrow::datatypes::Schema;
use datafusion_common::{
tree_node::{Transformed, TreeNode, TreeNodeRewriter},
Result,
};
use std::sync::Arc;
use crate::PhysicalExpr;
pub mod unwrap_cast;
pub struct PhysicalExprSimplifier<'a> {
schema: &'a Schema,
}
impl<'a> PhysicalExprSimplifier<'a> {
pub fn new(schema: &'a Schema) -> Self {
Self { schema }
}
pub fn simplify(
&mut self,
expr: Arc<dyn PhysicalExpr>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(expr.rewrite(self)?.data)
}
}
impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
type Node = Arc<dyn PhysicalExpr>;
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
#[cfg(test)]
let original_type = node.data_type(self.schema).unwrap();
let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?;
#[cfg(test)]
assert_eq!(
unwrapped.data.data_type(self.schema).unwrap(),
original_type,
"Simplified expression should have the same data type as the original"
);
Ok(unwrapped)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
fn test_schema() -> Schema {
Schema::new(vec![
Field::new("c1", DataType::Int32, false),
Field::new("c2", DataType::Int64, false),
Field::new("c3", DataType::Utf8, false),
])
}
#[test]
fn test_simplify() {
let schema = test_schema();
let mut simplifier = PhysicalExprSimplifier::new(&schema);
let column_expr = col("c2", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
let literal_expr = lit(ScalarValue::Int32(Some(99)));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
let optimized = simplifier.simplify(binary_expr).unwrap();
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
let left_expr = optimized_binary.left();
assert!(
left_expr.as_any().downcast_ref::<CastExpr>().is_none()
&& left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
);
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
}
#[test]
fn test_nested_expression_simplification() {
let schema = test_schema();
let mut simplifier = PhysicalExprSimplifier::new(&schema);
let c1_expr = col("c1", &schema).unwrap();
let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
let c1_literal = lit(ScalarValue::Int64(Some(5)));
let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
let c2_expr = col("c2", &schema).unwrap();
let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
let c2_literal = lit(ScalarValue::Int32(Some(10)));
let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
let optimized = simplifier.simplify(or_expr).unwrap();
let or_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
let left_binary = or_binary
.left()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
let left_left_expr = left_binary.left();
assert!(
left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
&& left_left_expr
.as_any()
.downcast_ref::<TryCastExpr>()
.is_none()
);
let left_literal = left_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
let right_binary = or_binary
.right()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
let right_left_expr = right_binary.left();
assert!(
right_left_expr
.as_any()
.downcast_ref::<CastExpr>()
.is_none()
&& right_left_expr
.as_any()
.downcast_ref::<TryCastExpr>()
.is_none()
);
let right_literal = right_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
}
}