datafusion_physical_expr/simplifier/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifier for Physical Expressions
19
20use arrow::datatypes::Schema;
21use datafusion_common::{
22    tree_node::{Transformed, TreeNode, TreeNodeRewriter},
23    Result,
24};
25use std::sync::Arc;
26
27use crate::PhysicalExpr;
28
29pub mod unwrap_cast;
30
31/// Simplifies physical expressions by applying various optimizations
32///
33/// This can be useful after adapting expressions from a table schema
34/// to a file schema. For example, casts added to match the types may
35/// potentially be unwrapped.
36pub struct PhysicalExprSimplifier<'a> {
37    schema: &'a Schema,
38}
39
40impl<'a> PhysicalExprSimplifier<'a> {
41    /// Create a new physical expression simplifier
42    pub fn new(schema: &'a Schema) -> Self {
43        Self { schema }
44    }
45
46    /// Simplify a physical expression
47    pub fn simplify(
48        &mut self,
49        expr: Arc<dyn PhysicalExpr>,
50    ) -> Result<Arc<dyn PhysicalExpr>> {
51        Ok(expr.rewrite(self)?.data)
52    }
53}
54
55impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
56    type Node = Arc<dyn PhysicalExpr>;
57
58    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
59        // Apply unwrap cast optimization
60        #[cfg(test)]
61        let original_type = node.data_type(self.schema).unwrap();
62        let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?;
63        #[cfg(test)]
64        assert_eq!(
65            unwrapped.data.data_type(self.schema).unwrap(),
66            original_type,
67            "Simplified expression should have the same data type as the original"
68        );
69        Ok(unwrapped)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr};
77    use arrow::datatypes::{DataType, Field, Schema};
78    use datafusion_common::ScalarValue;
79    use datafusion_expr::Operator;
80
81    fn test_schema() -> Schema {
82        Schema::new(vec![
83            Field::new("c1", DataType::Int32, false),
84            Field::new("c2", DataType::Int64, false),
85            Field::new("c3", DataType::Utf8, false),
86        ])
87    }
88
89    #[test]
90    fn test_simplify() {
91        let schema = test_schema();
92        let mut simplifier = PhysicalExprSimplifier::new(&schema);
93
94        // Create: cast(c2 as INT32) != INT32(99)
95        let column_expr = col("c2", &schema).unwrap();
96        let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
97        let literal_expr = lit(ScalarValue::Int32(Some(99)));
98        let binary_expr =
99            Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
100
101        // Apply full simplification (uses TreeNodeRewriter)
102        let optimized = simplifier.simplify(binary_expr).unwrap();
103
104        let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
105
106        // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match)
107        let left_expr = optimized_binary.left();
108        assert!(
109            left_expr.as_any().downcast_ref::<CastExpr>().is_none()
110                && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
111        );
112        let right_literal = optimized_binary
113            .right()
114            .as_any()
115            .downcast_ref::<Literal>()
116            .unwrap();
117        assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
118    }
119
120    #[test]
121    fn test_nested_expression_simplification() {
122        let schema = test_schema();
123        let mut simplifier = PhysicalExprSimplifier::new(&schema);
124
125        // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10))
126        let c1_expr = col("c1", &schema).unwrap();
127        let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
128        let c1_literal = lit(ScalarValue::Int64(Some(5)));
129        let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
130
131        let c2_expr = col("c2", &schema).unwrap();
132        let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
133        let c2_literal = lit(ScalarValue::Int32(Some(10)));
134        let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
135
136        let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
137
138        // Apply simplification
139        let optimized = simplifier.simplify(or_expr).unwrap();
140
141        let or_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
142
143        // Verify left side: c1 > INT32(5)
144        let left_binary = or_binary
145            .left()
146            .as_any()
147            .downcast_ref::<BinaryExpr>()
148            .unwrap();
149        let left_left_expr = left_binary.left();
150        assert!(
151            left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
152                && left_left_expr
153                    .as_any()
154                    .downcast_ref::<TryCastExpr>()
155                    .is_none()
156        );
157        let left_literal = left_binary
158            .right()
159            .as_any()
160            .downcast_ref::<Literal>()
161            .unwrap();
162        assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
163
164        // Verify right side: c2 <= INT64(10)
165        let right_binary = or_binary
166            .right()
167            .as_any()
168            .downcast_ref::<BinaryExpr>()
169            .unwrap();
170        let right_left_expr = right_binary.left();
171        assert!(
172            right_left_expr
173                .as_any()
174                .downcast_ref::<CastExpr>()
175                .is_none()
176                && right_left_expr
177                    .as_any()
178                    .downcast_ref::<TryCastExpr>()
179                    .is_none()
180        );
181        let right_literal = right_binary
182            .right()
183            .as_any()
184            .downcast_ref::<Literal>()
185            .unwrap();
186        assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
187    }
188}