datafusion_optimizer/simplify_expressions/simplify_predicates.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifies predicates by reducing redundant or overlapping conditions.
19//!
20//! This module provides functionality to optimize logical predicates used in query planning
21//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate.
22//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on
23//! general expression simplification (e.g., constant folding and algebraic simplifications),
24//! this module specifically targets predicate optimization by handling containment relationships.
25//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition
26//! encompasses the former, resulting in fewer checks during query execution.
27
28use datafusion_common::{Column, Result, ScalarValue};
29use datafusion_expr::{BinaryExpr, Expr, Operator};
30use std::collections::BTreeMap;
31
32/// Simplifies a list of predicates by removing redundancies.
33///
34/// This function takes a vector of predicate expressions and groups them by the column they reference.
35/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =)
36/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to
37/// `x > 6`. Other predicates that do not fit this pattern are retained as-is.
38///
39/// # Arguments
40/// * `predicates` - A vector of `Expr` representing the predicates to simplify.
41///
42/// # Returns
43/// A `Result` containing a vector of simplified `Expr` predicates.
44pub fn simplify_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
45 // Early return for simple cases
46 if predicates.len() <= 1 {
47 return Ok(predicates);
48 }
49
50 // Group predicates by their column reference
51 let mut column_predicates: BTreeMap<Column, Vec<Expr>> = BTreeMap::new();
52 let mut other_predicates = Vec::new();
53
54 for pred in predicates {
55 match &pred {
56 Expr::BinaryExpr(BinaryExpr {
57 left,
58 op:
59 Operator::Gt
60 | Operator::GtEq
61 | Operator::Lt
62 | Operator::LtEq
63 | Operator::Eq,
64 right,
65 }) => {
66 let left_col = extract_column_from_expr(left);
67 let right_col = extract_column_from_expr(right);
68 if let (Some(col), Some(_)) = (&left_col, right.as_literal()) {
69 column_predicates.entry(col.clone()).or_default().push(pred);
70 } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) {
71 column_predicates.entry(col.clone()).or_default().push(pred);
72 } else {
73 other_predicates.push(pred);
74 }
75 }
76 _ => other_predicates.push(pred),
77 }
78 }
79
80 // Process each column's predicates to remove redundancies
81 let mut result = other_predicates;
82 for (_, preds) in column_predicates {
83 let simplified = simplify_column_predicates(preds)?;
84 result.extend(simplified);
85 }
86
87 Ok(result)
88}
89
90/// Simplifies predicates related to a single column.
91///
92/// This function processes a list of predicates that all reference the same column and
93/// simplifies them based on their operators. It groups predicates into greater-than (>, >=),
94/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition
95/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6`
96/// is retained as it is more restrictive.
97///
98/// # Arguments
99/// * `predicates` - A vector of `Expr` representing predicates for a single column.
100///
101/// # Returns
102/// A `Result` containing a vector of simplified `Expr` predicates for the column.
103fn simplify_column_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
104 if predicates.len() <= 1 {
105 return Ok(predicates);
106 }
107
108 // Group by operator type, but combining similar operators
109 let mut greater_predicates = Vec::new(); // Combines > and >=
110 let mut less_predicates = Vec::new(); // Combines < and <=
111 let mut eq_predicates = Vec::new();
112
113 for pred in predicates {
114 match &pred {
115 Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => {
116 match (op, right.as_literal().is_some()) {
117 (Operator::Gt, true)
118 | (Operator::Lt, false)
119 | (Operator::GtEq, true)
120 | (Operator::LtEq, false) => greater_predicates.push(pred),
121 (Operator::Lt, true)
122 | (Operator::Gt, false)
123 | (Operator::LtEq, true)
124 | (Operator::GtEq, false) => less_predicates.push(pred),
125 (Operator::Eq, _) => eq_predicates.push(pred),
126 _ => unreachable!("Unexpected operator: {}", op),
127 }
128 }
129 _ => unreachable!("Unexpected predicate {}", pred.to_string()),
130 }
131 }
132
133 let mut result = Vec::new();
134
135 if !eq_predicates.is_empty() {
136 // If there are many equality predicates, we can only keep one if they are all the same
137 if eq_predicates.len() == 1
138 || eq_predicates.iter().all(|e| e == &eq_predicates[0])
139 {
140 result.push(eq_predicates.pop().unwrap());
141 } else {
142 // If they are not the same, add a false predicate
143 result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None));
144 }
145 }
146
147 // Handle all greater-than-style predicates (keep the most restrictive - highest value)
148 if !greater_predicates.is_empty() {
149 if let Some(most_restrictive) =
150 find_most_restrictive_predicate(&greater_predicates, true)?
151 {
152 result.push(most_restrictive);
153 } else {
154 result.extend(greater_predicates);
155 }
156 }
157
158 // Handle all less-than-style predicates (keep the most restrictive - lowest value)
159 if !less_predicates.is_empty() {
160 if let Some(most_restrictive) =
161 find_most_restrictive_predicate(&less_predicates, false)?
162 {
163 result.push(most_restrictive);
164 } else {
165 result.extend(less_predicates);
166 }
167 }
168
169 Ok(result)
170}
171
172/// Finds the most restrictive predicate from a list based on literal values.
173///
174/// This function iterates through a list of predicates to identify the most restrictive one
175/// by comparing their literal values. For greater-than predicates, the highest value is most
176/// restrictive, while for less-than predicates, the lowest value is most restrictive.
177///
178/// # Arguments
179/// * `predicates` - A slice of `Expr` representing predicates to compare.
180/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=)
181/// or the lowest value (false for <, <=).
182///
183/// # Returns
184/// A `Result` containing an `Option<Expr>` with the most restrictive predicate, if any.
185fn find_most_restrictive_predicate(
186 predicates: &[Expr],
187 find_greater: bool,
188) -> Result<Option<Expr>> {
189 if predicates.is_empty() {
190 return Ok(None);
191 }
192
193 let mut most_restrictive_idx = 0;
194 let mut best_value: Option<&ScalarValue> = None;
195
196 for (idx, pred) in predicates.iter().enumerate() {
197 if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = pred {
198 // Extract the literal value based on which side has it
199 let scalar_value = match (right.as_literal(), left.as_literal()) {
200 (Some(scalar), _) => Some(scalar),
201 (_, Some(scalar)) => Some(scalar),
202 _ => None,
203 };
204
205 if let Some(scalar) = scalar_value {
206 if let Some(current_best) = best_value {
207 let comparison = scalar.try_cmp(current_best)?;
208 let is_better = if find_greater {
209 comparison == std::cmp::Ordering::Greater
210 || (comparison == std::cmp::Ordering::Equal
211 && op == &Operator::Gt)
212 } else {
213 comparison == std::cmp::Ordering::Less
214 || (comparison == std::cmp::Ordering::Equal
215 && op == &Operator::Lt)
216 };
217
218 if is_better {
219 best_value = Some(scalar);
220 most_restrictive_idx = idx;
221 }
222 } else {
223 best_value = Some(scalar);
224 most_restrictive_idx = idx;
225 }
226 }
227 }
228 }
229
230 Ok(Some(predicates[most_restrictive_idx].clone()))
231}
232
233/// Extracts a column reference from an expression, if present.
234///
235/// This function checks if the given expression is a column reference or contains one,
236/// such as within a cast operation. It returns the `Column` if found.
237///
238/// # Arguments
239/// * `expr` - A reference to an `Expr` to inspect for a column reference.
240///
241/// # Returns
242/// An `Option<Column>` containing the column reference if found, otherwise `None`.
243fn extract_column_from_expr(expr: &Expr) -> Option<Column> {
244 match expr {
245 Expr::Column(col) => Some(col.clone()),
246 _ => None,
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use arrow::datatypes::DataType;
254 use datafusion_expr::{cast, col, lit};
255
256 #[test]
257 fn test_simplify_predicates_with_cast() {
258 // Test that predicates on cast expressions are not grouped with predicates on the raw column
259 // a < 5 AND CAST(a AS varchar) < 'abc' AND a < 6
260 // Should simplify to:
261 // a < 5 AND CAST(a AS varchar) < 'abc'
262
263 let predicates = vec![
264 col("a").lt(lit(5i32)),
265 cast(col("a"), DataType::Utf8).lt(lit("abc")),
266 col("a").lt(lit(6i32)),
267 ];
268
269 let result = simplify_predicates(predicates).unwrap();
270
271 // Should have 2 predicates: a < 5 and CAST(a AS varchar) < 'abc'
272 assert_eq!(result.len(), 2);
273
274 // Check that the cast predicate is preserved
275 let has_cast_predicate = result.iter().any(|p| {
276 matches!(p, Expr::BinaryExpr(BinaryExpr {
277 left,
278 op: Operator::Lt,
279 right
280 }) if matches!(left.as_ref(), Expr::Cast(_)) && right == &Box::new(lit("abc")))
281 });
282 assert!(has_cast_predicate, "Cast predicate should be preserved");
283
284 // Check that we have the more restrictive column predicate (a < 5)
285 let has_column_predicate = result.iter().any(|p| {
286 matches!(p, Expr::BinaryExpr(BinaryExpr {
287 left,
288 op: Operator::Lt,
289 right
290 }) if left == &Box::new(col("a")) && right == &Box::new(lit(5i32)))
291 });
292 assert!(has_column_predicate, "Should have a < 5 predicate");
293 }
294
295 #[test]
296 fn test_extract_column_ignores_cast() {
297 // Test that extract_column_from_expr does not extract columns from cast expressions
298 let cast_expr = cast(col("a"), DataType::Utf8);
299 assert_eq!(extract_column_from_expr(&cast_expr), None);
300
301 // Test that it still extracts from direct column references
302 let col_expr = col("a");
303 assert_eq!(extract_column_from_expr(&col_expr), Some(Column::from("a")));
304 }
305
306 #[test]
307 fn test_simplify_predicates_direct_columns_only() {
308 // Test that only predicates on direct columns are simplified together
309 let predicates = vec![
310 col("a").lt(lit(5i32)),
311 col("a").lt(lit(3i32)),
312 col("b").gt(lit(10i32)),
313 col("b").gt(lit(20i32)),
314 ];
315
316 let result = simplify_predicates(predicates).unwrap();
317
318 // Should have 2 predicates: a < 3 and b > 20 (most restrictive for each column)
319 assert_eq!(result.len(), 2);
320
321 // Check for a < 3
322 let has_a_predicate = result.iter().any(|p| {
323 matches!(p, Expr::BinaryExpr(BinaryExpr {
324 left,
325 op: Operator::Lt,
326 right
327 }) if left == &Box::new(col("a")) && right == &Box::new(lit(3i32)))
328 });
329 assert!(has_a_predicate, "Should have a < 3 predicate");
330
331 // Check for b > 20
332 let has_b_predicate = result.iter().any(|p| {
333 matches!(p, Expr::BinaryExpr(BinaryExpr {
334 left,
335 op: Operator::Gt,
336 right
337 }) if left == &Box::new(col("b")) && right == &Box::new(lit(20i32)))
338 });
339 assert!(has_b_predicate, "Should have b > 20 predicate");
340 }
341}