datafusion_optimizer/simplify_expressions/
simplify_predicates.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifies predicates by reducing redundant or overlapping conditions.
19//!
20//! This module provides functionality to optimize logical predicates used in query planning
21//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate.
22//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on
23//! general expression simplification (e.g., constant folding and algebraic simplifications),
24//! this module specifically targets predicate optimization by handling containment relationships.
25//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition
26//! encompasses the former, resulting in fewer checks during query execution.
27
28use datafusion_common::{Column, Result, ScalarValue};
29use datafusion_expr::{BinaryExpr, Expr, Operator};
30use std::collections::BTreeMap;
31
32/// Simplifies a list of predicates by removing redundancies.
33///
34/// This function takes a vector of predicate expressions and groups them by the column they reference.
35/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =)
36/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to
37/// `x > 6`. Other predicates that do not fit this pattern are retained as-is.
38///
39/// # Arguments
40/// * `predicates` - A vector of `Expr` representing the predicates to simplify.
41///
42/// # Returns
43/// A `Result` containing a vector of simplified `Expr` predicates.
44pub fn simplify_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
45    // Early return for simple cases
46    if predicates.len() <= 1 {
47        return Ok(predicates);
48    }
49
50    // Group predicates by their column reference
51    let mut column_predicates: BTreeMap<Column, Vec<Expr>> = BTreeMap::new();
52    let mut other_predicates = Vec::new();
53
54    for pred in predicates {
55        match &pred {
56            Expr::BinaryExpr(BinaryExpr {
57                left,
58                op:
59                    Operator::Gt
60                    | Operator::GtEq
61                    | Operator::Lt
62                    | Operator::LtEq
63                    | Operator::Eq,
64                right,
65            }) => {
66                let left_col = extract_column_from_expr(left);
67                let right_col = extract_column_from_expr(right);
68                if let (Some(col), Some(_)) = (&left_col, right.as_literal()) {
69                    column_predicates.entry(col.clone()).or_default().push(pred);
70                } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) {
71                    column_predicates.entry(col.clone()).or_default().push(pred);
72                } else {
73                    other_predicates.push(pred);
74                }
75            }
76            _ => other_predicates.push(pred),
77        }
78    }
79
80    // Process each column's predicates to remove redundancies
81    let mut result = other_predicates;
82    for (_, preds) in column_predicates {
83        let simplified = simplify_column_predicates(preds)?;
84        result.extend(simplified);
85    }
86
87    Ok(result)
88}
89
90/// Simplifies predicates related to a single column.
91///
92/// This function processes a list of predicates that all reference the same column and
93/// simplifies them based on their operators. It groups predicates into greater-than (>, >=),
94/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition
95/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6`
96/// is retained as it is more restrictive.
97///
98/// # Arguments
99/// * `predicates` - A vector of `Expr` representing predicates for a single column.
100///
101/// # Returns
102/// A `Result` containing a vector of simplified `Expr` predicates for the column.
103fn simplify_column_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
104    if predicates.len() <= 1 {
105        return Ok(predicates);
106    }
107
108    // Group by operator type, but combining similar operators
109    let mut greater_predicates = Vec::new(); // Combines > and >=
110    let mut less_predicates = Vec::new(); // Combines < and <=
111    let mut eq_predicates = Vec::new();
112
113    for pred in predicates {
114        match &pred {
115            Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => {
116                match (op, right.as_literal().is_some()) {
117                    (Operator::Gt, true)
118                    | (Operator::Lt, false)
119                    | (Operator::GtEq, true)
120                    | (Operator::LtEq, false) => greater_predicates.push(pred),
121                    (Operator::Lt, true)
122                    | (Operator::Gt, false)
123                    | (Operator::LtEq, true)
124                    | (Operator::GtEq, false) => less_predicates.push(pred),
125                    (Operator::Eq, _) => eq_predicates.push(pred),
126                    _ => unreachable!("Unexpected operator: {}", op),
127                }
128            }
129            _ => unreachable!("Unexpected predicate {}", pred.to_string()),
130        }
131    }
132
133    let mut result = Vec::new();
134
135    if !eq_predicates.is_empty() {
136        // If there are many equality predicates, we can only keep one if they are all the same
137        if eq_predicates.len() == 1
138            || eq_predicates.iter().all(|e| e == &eq_predicates[0])
139        {
140            result.push(eq_predicates.pop().unwrap());
141        } else {
142            // If they are not the same, add a false predicate
143            result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None));
144        }
145    }
146
147    // Handle all greater-than-style predicates (keep the most restrictive - highest value)
148    if !greater_predicates.is_empty() {
149        if let Some(most_restrictive) =
150            find_most_restrictive_predicate(&greater_predicates, true)?
151        {
152            result.push(most_restrictive);
153        } else {
154            result.extend(greater_predicates);
155        }
156    }
157
158    // Handle all less-than-style predicates (keep the most restrictive - lowest value)
159    if !less_predicates.is_empty() {
160        if let Some(most_restrictive) =
161            find_most_restrictive_predicate(&less_predicates, false)?
162        {
163            result.push(most_restrictive);
164        } else {
165            result.extend(less_predicates);
166        }
167    }
168
169    Ok(result)
170}
171
172/// Finds the most restrictive predicate from a list based on literal values.
173///
174/// This function iterates through a list of predicates to identify the most restrictive one
175/// by comparing their literal values. For greater-than predicates, the highest value is most
176/// restrictive, while for less-than predicates, the lowest value is most restrictive.
177///
178/// # Arguments
179/// * `predicates` - A slice of `Expr` representing predicates to compare.
180/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=)
181///   or the lowest value (false for <, <=).
182///
183/// # Returns
184/// A `Result` containing an `Option<Expr>` with the most restrictive predicate, if any.
185fn find_most_restrictive_predicate(
186    predicates: &[Expr],
187    find_greater: bool,
188) -> Result<Option<Expr>> {
189    if predicates.is_empty() {
190        return Ok(None);
191    }
192
193    let mut most_restrictive_idx = 0;
194    let mut best_value: Option<&ScalarValue> = None;
195
196    for (idx, pred) in predicates.iter().enumerate() {
197        if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = pred {
198            // Extract the literal value based on which side has it
199            let scalar_value = match (right.as_literal(), left.as_literal()) {
200                (Some(scalar), _) => Some(scalar),
201                (_, Some(scalar)) => Some(scalar),
202                _ => None,
203            };
204
205            if let Some(scalar) = scalar_value {
206                if let Some(current_best) = best_value {
207                    let comparison = scalar.try_cmp(current_best)?;
208                    let is_better = if find_greater {
209                        comparison == std::cmp::Ordering::Greater
210                            || (comparison == std::cmp::Ordering::Equal
211                                && op == &Operator::Gt)
212                    } else {
213                        comparison == std::cmp::Ordering::Less
214                            || (comparison == std::cmp::Ordering::Equal
215                                && op == &Operator::Lt)
216                    };
217
218                    if is_better {
219                        best_value = Some(scalar);
220                        most_restrictive_idx = idx;
221                    }
222                } else {
223                    best_value = Some(scalar);
224                    most_restrictive_idx = idx;
225                }
226            }
227        }
228    }
229
230    Ok(Some(predicates[most_restrictive_idx].clone()))
231}
232
233/// Extracts a column reference from an expression, if present.
234///
235/// This function checks if the given expression is a column reference or contains one,
236/// such as within a cast operation. It returns the `Column` if found.
237///
238/// # Arguments
239/// * `expr` - A reference to an `Expr` to inspect for a column reference.
240///
241/// # Returns
242/// An `Option<Column>` containing the column reference if found, otherwise `None`.
243fn extract_column_from_expr(expr: &Expr) -> Option<Column> {
244    match expr {
245        Expr::Column(col) => Some(col.clone()),
246        _ => None,
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use arrow::datatypes::DataType;
254    use datafusion_expr::{cast, col, lit};
255
256    #[test]
257    fn test_simplify_predicates_with_cast() {
258        // Test that predicates on cast expressions are not grouped with predicates on the raw column
259        // a < 5 AND CAST(a AS varchar) < 'abc' AND a < 6
260        // Should simplify to:
261        // a < 5 AND CAST(a AS varchar) < 'abc'
262
263        let predicates = vec![
264            col("a").lt(lit(5i32)),
265            cast(col("a"), DataType::Utf8).lt(lit("abc")),
266            col("a").lt(lit(6i32)),
267        ];
268
269        let result = simplify_predicates(predicates).unwrap();
270
271        // Should have 2 predicates: a < 5 and CAST(a AS varchar) < 'abc'
272        assert_eq!(result.len(), 2);
273
274        // Check that the cast predicate is preserved
275        let has_cast_predicate = result.iter().any(|p| {
276            matches!(p, Expr::BinaryExpr(BinaryExpr { 
277                left, 
278                op: Operator::Lt, 
279                right 
280            }) if matches!(left.as_ref(), Expr::Cast(_)) && right == &Box::new(lit("abc")))
281        });
282        assert!(has_cast_predicate, "Cast predicate should be preserved");
283
284        // Check that we have the more restrictive column predicate (a < 5)
285        let has_column_predicate = result.iter().any(|p| {
286            matches!(p, Expr::BinaryExpr(BinaryExpr { 
287                left, 
288                op: Operator::Lt, 
289                right 
290            }) if left == &Box::new(col("a")) && right == &Box::new(lit(5i32)))
291        });
292        assert!(has_column_predicate, "Should have a < 5 predicate");
293    }
294
295    #[test]
296    fn test_extract_column_ignores_cast() {
297        // Test that extract_column_from_expr does not extract columns from cast expressions
298        let cast_expr = cast(col("a"), DataType::Utf8);
299        assert_eq!(extract_column_from_expr(&cast_expr), None);
300
301        // Test that it still extracts from direct column references
302        let col_expr = col("a");
303        assert_eq!(extract_column_from_expr(&col_expr), Some(Column::from("a")));
304    }
305
306    #[test]
307    fn test_simplify_predicates_direct_columns_only() {
308        // Test that only predicates on direct columns are simplified together
309        let predicates = vec![
310            col("a").lt(lit(5i32)),
311            col("a").lt(lit(3i32)),
312            col("b").gt(lit(10i32)),
313            col("b").gt(lit(20i32)),
314        ];
315
316        let result = simplify_predicates(predicates).unwrap();
317
318        // Should have 2 predicates: a < 3 and b > 20 (most restrictive for each column)
319        assert_eq!(result.len(), 2);
320
321        // Check for a < 3
322        let has_a_predicate = result.iter().any(|p| {
323            matches!(p, Expr::BinaryExpr(BinaryExpr { 
324                left, 
325                op: Operator::Lt, 
326                right 
327            }) if left == &Box::new(col("a")) && right == &Box::new(lit(3i32)))
328        });
329        assert!(has_a_predicate, "Should have a < 3 predicate");
330
331        // Check for b > 20
332        let has_b_predicate = result.iter().any(|p| {
333            matches!(p, Expr::BinaryExpr(BinaryExpr { 
334                left, 
335                op: Operator::Gt, 
336                right 
337            }) if left == &Box::new(col("b")) && right == &Box::new(lit(20i32)))
338        });
339        assert!(has_b_predicate, "Should have b > 20 predicate");
340    }
341}