datafusion_optimizer/simplify_expressions/
guarantees.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`]
19//!
20//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
21
22use std::{borrow::Cow, collections::HashMap};
23
24use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
25use datafusion_common::{DataFusionError, Result};
26use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
27use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
28
29/// Rewrite expressions to incorporate guarantees.
30///
31/// Guarantees are a mapping from an expression (which currently is always a
32/// column reference) to a [NullableInterval]. The interval represents the known
33/// possible values of the column. Using these known values, expressions are
34/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`.
35///
36/// For example, if we know that a column is not null and has values in the
37/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
38///
39/// See a full example in [`ExprSimplifier::with_guarantees()`].
40///
41/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
42pub struct GuaranteeRewriter<'a> {
43    guarantees: HashMap<&'a Expr, &'a NullableInterval>,
44}
45
46impl<'a> GuaranteeRewriter<'a> {
47    pub fn new(
48        guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
49    ) -> Self {
50        Self {
51            // TODO: Clippy wants the "map" call removed, but doing so generates
52            //       a compilation error. Remove the clippy directive once this
53            //       issue is fixed.
54            #[allow(clippy::map_identity)]
55            guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
56        }
57    }
58}
59
60impl TreeNodeRewriter for GuaranteeRewriter<'_> {
61    type Node = Expr;
62
63    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
64        if self.guarantees.is_empty() {
65            return Ok(Transformed::no(expr));
66        }
67
68        match &expr {
69            Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
70                Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))),
71                Some(NullableInterval::NotNull { .. }) => {
72                    Ok(Transformed::yes(lit(false)))
73                }
74                _ => Ok(Transformed::no(expr)),
75            },
76            Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) {
77                Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))),
78                Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))),
79                _ => Ok(Transformed::no(expr)),
80            },
81            Expr::Between(Between {
82                expr: inner,
83                negated,
84                low,
85                high,
86            }) => {
87                if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = (
88                    self.guarantees.get(inner.as_ref()),
89                    low.as_ref(),
90                    high.as_ref(),
91                ) {
92                    let expr_interval = NullableInterval::NotNull {
93                        values: Interval::try_new(low.clone(), high.clone())?,
94                    };
95
96                    let contains = expr_interval.contains(*interval)?;
97
98                    if contains.is_certainly_true() {
99                        Ok(Transformed::yes(lit(!negated)))
100                    } else if contains.is_certainly_false() {
101                        Ok(Transformed::yes(lit(*negated)))
102                    } else {
103                        Ok(Transformed::no(expr))
104                    }
105                } else {
106                    Ok(Transformed::no(expr))
107                }
108            }
109
110            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
111                // The left or right side of expression might either have a guarantee
112                // or be a literal. Either way, we can resolve them to a NullableInterval.
113                let left_interval = self
114                    .guarantees
115                    .get(left.as_ref())
116                    .map(|interval| Cow::Borrowed(*interval))
117                    .or_else(|| {
118                        if let Expr::Literal(value) = left.as_ref() {
119                            Some(Cow::Owned(value.clone().into()))
120                        } else {
121                            None
122                        }
123                    });
124                let right_interval = self
125                    .guarantees
126                    .get(right.as_ref())
127                    .map(|interval| Cow::Borrowed(*interval))
128                    .or_else(|| {
129                        if let Expr::Literal(value) = right.as_ref() {
130                            Some(Cow::Owned(value.clone().into()))
131                        } else {
132                            None
133                        }
134                    });
135
136                match (left_interval, right_interval) {
137                    (Some(left_interval), Some(right_interval)) => {
138                        let result =
139                            left_interval.apply_operator(op, right_interval.as_ref())?;
140                        if result.is_certainly_true() {
141                            Ok(Transformed::yes(lit(true)))
142                        } else if result.is_certainly_false() {
143                            Ok(Transformed::yes(lit(false)))
144                        } else {
145                            Ok(Transformed::no(expr))
146                        }
147                    }
148                    _ => Ok(Transformed::no(expr)),
149                }
150            }
151
152            // Columns (if interval is collapsed to a single value)
153            Expr::Column(_) => {
154                if let Some(interval) = self.guarantees.get(&expr) {
155                    Ok(Transformed::yes(interval.single_value().map_or(expr, lit)))
156                } else {
157                    Ok(Transformed::no(expr))
158                }
159            }
160
161            Expr::InList(InList {
162                expr: inner,
163                list,
164                negated,
165            }) => {
166                if let Some(interval) = self.guarantees.get(inner.as_ref()) {
167                    // Can remove items from the list that don't match the guarantee
168                    let new_list: Vec<Expr> = list
169                        .iter()
170                        .filter_map(|expr| {
171                            if let Expr::Literal(item) = expr {
172                                match interval
173                                    .contains(NullableInterval::from(item.clone()))
174                                {
175                                    // If we know for certain the value isn't in the column's interval,
176                                    // we can skip checking it.
177                                    Ok(interval) if interval.is_certainly_false() => None,
178                                    Ok(_) => Some(Ok(expr.clone())),
179                                    Err(e) => Some(Err(e)),
180                                }
181                            } else {
182                                Some(Ok(expr.clone()))
183                            }
184                        })
185                        .collect::<Result<_, DataFusionError>>()?;
186
187                    Ok(Transformed::yes(Expr::InList(InList {
188                        expr: inner.clone(),
189                        list: new_list,
190                        negated: *negated,
191                    })))
192                } else {
193                    Ok(Transformed::no(expr))
194                }
195            }
196
197            _ => Ok(Transformed::no(expr)),
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    use arrow::datatypes::DataType;
207    use datafusion_common::tree_node::{TransformedResult, TreeNode};
208    use datafusion_common::ScalarValue;
209    use datafusion_expr::{col, Operator};
210
211    #[test]
212    fn test_null_handling() {
213        // IsNull / IsNotNull can be rewritten to true / false
214        let guarantees = vec![
215            // Note: AlwaysNull case handled by test_column_single_value test,
216            // since it's a special case of a column with a single value.
217            (
218                col("x"),
219                NullableInterval::NotNull {
220                    values: Interval::make_unbounded(&DataType::Boolean).unwrap(),
221                },
222            ),
223        ];
224        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
225
226        // x IS NULL => guaranteed false
227        let expr = col("x").is_null();
228        let output = expr.rewrite(&mut rewriter).data().unwrap();
229        assert_eq!(output, lit(false));
230
231        // x IS NOT NULL => guaranteed true
232        let expr = col("x").is_not_null();
233        let output = expr.rewrite(&mut rewriter).data().unwrap();
234        assert_eq!(output, lit(true));
235    }
236
237    fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)])
238    where
239        ScalarValue: From<T>,
240        T: Clone,
241    {
242        for (expr, expected_value) in cases {
243            let output = expr.clone().rewrite(rewriter).data().unwrap();
244            let expected = lit(ScalarValue::from(expected_value.clone()));
245            assert_eq!(
246                output, expected,
247                "{} simplified to {}, but expected {}",
248                expr, output, expected
249            );
250        }
251    }
252
253    fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) {
254        for expr in cases {
255            let output = expr.clone().rewrite(rewriter).data().unwrap();
256            assert_eq!(
257                &output, expr,
258                "{} was simplified to {}, but expected it to be unchanged",
259                expr, output
260            );
261        }
262    }
263
264    #[test]
265    fn test_inequalities_non_null_unbounded() {
266        let guarantees = vec![
267            // y ∈ [2021-01-01, ∞) (not null)
268            (
269                col("x"),
270                NullableInterval::NotNull {
271                    values: Interval::try_new(
272                        ScalarValue::Date32(Some(18628)),
273                        ScalarValue::Date32(None),
274                    )
275                    .unwrap(),
276                },
277            ),
278        ];
279        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
280
281        // (original_expr, expected_simplification)
282        let simplified_cases = &[
283            (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
284            (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
285            (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
286            (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
287            (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
288            (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
289            (
290                col("x").between(
291                    lit(ScalarValue::Date32(Some(16000))),
292                    lit(ScalarValue::Date32(Some(17000))),
293                ),
294                false,
295            ),
296            (
297                col("x").not_between(
298                    lit(ScalarValue::Date32(Some(16000))),
299                    lit(ScalarValue::Date32(Some(17000))),
300                ),
301                true,
302            ),
303            (
304                Expr::BinaryExpr(BinaryExpr {
305                    left: Box::new(col("x")),
306                    op: Operator::IsDistinctFrom,
307                    right: Box::new(lit(ScalarValue::Null)),
308                }),
309                true,
310            ),
311            (
312                Expr::BinaryExpr(BinaryExpr {
313                    left: Box::new(col("x")),
314                    op: Operator::IsDistinctFrom,
315                    right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
316                }),
317                true,
318            ),
319        ];
320
321        validate_simplified_cases(&mut rewriter, simplified_cases);
322
323        let unchanged_cases = &[
324            col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
325            col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
326            col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
327            col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
328            col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
329            col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
330            col("x").between(
331                lit(ScalarValue::Date32(Some(18000))),
332                lit(ScalarValue::Date32(Some(19000))),
333            ),
334            col("x").not_between(
335                lit(ScalarValue::Date32(Some(18000))),
336                lit(ScalarValue::Date32(Some(19000))),
337            ),
338        ];
339
340        validate_unchanged_cases(&mut rewriter, unchanged_cases);
341    }
342
343    #[test]
344    fn test_inequalities_maybe_null() {
345        let guarantees = vec![
346            // x ∈ ("abc", "def"]? (maybe null)
347            (
348                col("x"),
349                NullableInterval::MaybeNull {
350                    values: Interval::try_new(
351                        ScalarValue::from("abc"),
352                        ScalarValue::from("def"),
353                    )
354                    .unwrap(),
355                },
356            ),
357        ];
358        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
359
360        // (original_expr, expected_simplification)
361        let simplified_cases = &[
362            (
363                Expr::BinaryExpr(BinaryExpr {
364                    left: Box::new(col("x")),
365                    op: Operator::IsDistinctFrom,
366                    right: Box::new(lit("z")),
367                }),
368                true,
369            ),
370            (
371                Expr::BinaryExpr(BinaryExpr {
372                    left: Box::new(col("x")),
373                    op: Operator::IsNotDistinctFrom,
374                    right: Box::new(lit("z")),
375                }),
376                false,
377            ),
378        ];
379
380        validate_simplified_cases(&mut rewriter, simplified_cases);
381
382        let unchanged_cases = &[
383            col("x").lt(lit("z")),
384            col("x").lt_eq(lit("z")),
385            col("x").gt(lit("a")),
386            col("x").gt_eq(lit("a")),
387            col("x").eq(lit("abc")),
388            col("x").not_eq(lit("a")),
389            col("x").between(lit("a"), lit("z")),
390            col("x").not_between(lit("a"), lit("z")),
391            Expr::BinaryExpr(BinaryExpr {
392                left: Box::new(col("x")),
393                op: Operator::IsDistinctFrom,
394                right: Box::new(lit(ScalarValue::Null)),
395            }),
396        ];
397
398        validate_unchanged_cases(&mut rewriter, unchanged_cases);
399    }
400
401    #[test]
402    fn test_column_single_value() {
403        let scalars = [
404            ScalarValue::Null,
405            ScalarValue::Int32(Some(1)),
406            ScalarValue::Boolean(Some(true)),
407            ScalarValue::Boolean(None),
408            ScalarValue::from("abc"),
409            ScalarValue::LargeUtf8(Some("def".to_string())),
410            ScalarValue::Date32(Some(18628)),
411            ScalarValue::Date32(None),
412            ScalarValue::Decimal128(Some(1000), 19, 2),
413        ];
414
415        for scalar in scalars {
416            let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))];
417            let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
418
419            let output = col("x").rewrite(&mut rewriter).data().unwrap();
420            assert_eq!(output, Expr::Literal(scalar.clone()));
421        }
422    }
423
424    #[test]
425    fn test_in_list() {
426        let guarantees = vec![
427            // x ∈ [1, 10] (not null)
428            (
429                col("x"),
430                NullableInterval::NotNull {
431                    values: Interval::try_new(
432                        ScalarValue::Int32(Some(1)),
433                        ScalarValue::Int32(Some(10)),
434                    )
435                    .unwrap(),
436                },
437            ),
438        ];
439        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
440
441        // These cases should be simplified so the list doesn't contain any
442        // values the guarantee says are outside the range.
443        // (column_name, starting_list, negated, expected_list)
444        let cases = &[
445            // x IN (9, 11) => x IN (9)
446            ("x", vec![9, 11], false, vec![9]),
447            // x IN (10, 2) => x IN (10, 2)
448            ("x", vec![10, 2], false, vec![10, 2]),
449            // x NOT IN (9, 11) => x NOT IN (9)
450            ("x", vec![9, 11], true, vec![9]),
451            // x NOT IN (0, 22) => x NOT IN ()
452            ("x", vec![0, 22], true, vec![]),
453        ];
454
455        for (column_name, starting_list, negated, expected_list) in cases {
456            let expr = col(*column_name).in_list(
457                starting_list
458                    .iter()
459                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
460                    .collect(),
461                *negated,
462            );
463            let output = expr.clone().rewrite(&mut rewriter).data().unwrap();
464            let expected_list = expected_list
465                .iter()
466                .map(|v| lit(ScalarValue::Int32(Some(*v))))
467                .collect();
468            assert_eq!(
469                output,
470                Expr::InList(InList {
471                    expr: Box::new(col(*column_name)),
472                    list: expected_list,
473                    negated: *negated,
474                })
475            );
476        }
477    }
478}