datafusion_physical_expr/simplifier/
mod.rs1use arrow::datatypes::Schema;
21use datafusion_common::{
22 tree_node::{Transformed, TreeNode, TreeNodeRewriter},
23 Result,
24};
25use std::sync::Arc;
26
27use crate::PhysicalExpr;
28
29pub mod unwrap_cast;
30
31pub struct PhysicalExprSimplifier<'a> {
37 schema: &'a Schema,
38}
39
40impl<'a> PhysicalExprSimplifier<'a> {
41 pub fn new(schema: &'a Schema) -> Self {
43 Self { schema }
44 }
45
46 pub fn simplify(
48 &mut self,
49 expr: Arc<dyn PhysicalExpr>,
50 ) -> Result<Arc<dyn PhysicalExpr>> {
51 Ok(expr.rewrite(self)?.data)
52 }
53}
54
55impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
56 type Node = Arc<dyn PhysicalExpr>;
57
58 fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
59 #[cfg(test)]
61 let original_type = node.data_type(self.schema).unwrap();
62 let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?;
63 #[cfg(test)]
64 assert_eq!(
65 unwrapped.data.data_type(self.schema).unwrap(),
66 original_type,
67 "Simplified expression should have the same data type as the original"
68 );
69 Ok(unwrapped)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr};
77 use arrow::datatypes::{DataType, Field, Schema};
78 use datafusion_common::ScalarValue;
79 use datafusion_expr::Operator;
80
81 fn test_schema() -> Schema {
82 Schema::new(vec![
83 Field::new("c1", DataType::Int32, false),
84 Field::new("c2", DataType::Int64, false),
85 Field::new("c3", DataType::Utf8, false),
86 ])
87 }
88
89 #[test]
90 fn test_simplify() {
91 let schema = test_schema();
92 let mut simplifier = PhysicalExprSimplifier::new(&schema);
93
94 let column_expr = col("c2", &schema).unwrap();
96 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
97 let literal_expr = lit(ScalarValue::Int32(Some(99)));
98 let binary_expr =
99 Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
100
101 let optimized = simplifier.simplify(binary_expr).unwrap();
103
104 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
105
106 let left_expr = optimized_binary.left();
108 assert!(
109 left_expr.as_any().downcast_ref::<CastExpr>().is_none()
110 && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
111 );
112 let right_literal = optimized_binary
113 .right()
114 .as_any()
115 .downcast_ref::<Literal>()
116 .unwrap();
117 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
118 }
119
120 #[test]
121 fn test_nested_expression_simplification() {
122 let schema = test_schema();
123 let mut simplifier = PhysicalExprSimplifier::new(&schema);
124
125 let c1_expr = col("c1", &schema).unwrap();
127 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
128 let c1_literal = lit(ScalarValue::Int64(Some(5)));
129 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
130
131 let c2_expr = col("c2", &schema).unwrap();
132 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
133 let c2_literal = lit(ScalarValue::Int32(Some(10)));
134 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
135
136 let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
137
138 let optimized = simplifier.simplify(or_expr).unwrap();
140
141 let or_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
142
143 let left_binary = or_binary
145 .left()
146 .as_any()
147 .downcast_ref::<BinaryExpr>()
148 .unwrap();
149 let left_left_expr = left_binary.left();
150 assert!(
151 left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
152 && left_left_expr
153 .as_any()
154 .downcast_ref::<TryCastExpr>()
155 .is_none()
156 );
157 let left_literal = left_binary
158 .right()
159 .as_any()
160 .downcast_ref::<Literal>()
161 .unwrap();
162 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
163
164 let right_binary = or_binary
166 .right()
167 .as_any()
168 .downcast_ref::<BinaryExpr>()
169 .unwrap();
170 let right_left_expr = right_binary.left();
171 assert!(
172 right_left_expr
173 .as_any()
174 .downcast_ref::<CastExpr>()
175 .is_none()
176 && right_left_expr
177 .as_any()
178 .downcast_ref::<TryCastExpr>()
179 .is_none()
180 );
181 let right_literal = right_binary
182 .right()
183 .as_any()
184 .downcast_ref::<Literal>()
185 .unwrap();
186 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
187 }
188}