datafusion_optimizer/simplify_expressions/
guarantees.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`]
19//!
20//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
21
22use std::{borrow::Cow, collections::HashMap};
23
24use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
25use datafusion_common::{DataFusionError, Result};
26use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
27use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
28
29/// Rewrite expressions to incorporate guarantees.
30///
31/// Guarantees are a mapping from an expression (which currently is always a
32/// column reference) to a [NullableInterval]. The interval represents the known
33/// possible values of the column. Using these known values, expressions are
34/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`.
35///
36/// For example, if we know that a column is not null and has values in the
37/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
38///
39/// See a full example in [`ExprSimplifier::with_guarantees()`].
40///
41/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
42pub struct GuaranteeRewriter<'a> {
43    guarantees: HashMap<&'a Expr, &'a NullableInterval>,
44}
45
46impl<'a> GuaranteeRewriter<'a> {
47    pub fn new(
48        guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
49    ) -> Self {
50        Self {
51            // TODO: Clippy wants the "map" call removed, but doing so generates
52            //       a compilation error. Remove the clippy directive once this
53            //       issue is fixed.
54            #[allow(clippy::map_identity)]
55            guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
56        }
57    }
58}
59
60impl TreeNodeRewriter for GuaranteeRewriter<'_> {
61    type Node = Expr;
62
63    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
64        if self.guarantees.is_empty() {
65            return Ok(Transformed::no(expr));
66        }
67
68        match &expr {
69            Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
70                Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))),
71                Some(NullableInterval::NotNull { .. }) => {
72                    Ok(Transformed::yes(lit(false)))
73                }
74                _ => Ok(Transformed::no(expr)),
75            },
76            Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) {
77                Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))),
78                Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))),
79                _ => Ok(Transformed::no(expr)),
80            },
81            Expr::Between(Between {
82                expr: inner,
83                negated,
84                low,
85                high,
86            }) => {
87                if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
88                    self.guarantees.get(inner.as_ref()),
89                    low.as_ref(),
90                    high.as_ref(),
91                ) {
92                    let expr_interval = NullableInterval::NotNull {
93                        values: Interval::try_new(low.clone(), high.clone())?,
94                    };
95
96                    let contains = expr_interval.contains(*interval)?;
97
98                    if contains.is_certainly_true() {
99                        Ok(Transformed::yes(lit(!negated)))
100                    } else if contains.is_certainly_false() {
101                        Ok(Transformed::yes(lit(*negated)))
102                    } else {
103                        Ok(Transformed::no(expr))
104                    }
105                } else {
106                    Ok(Transformed::no(expr))
107                }
108            }
109
110            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
111                // The left or right side of expression might either have a guarantee
112                // or be a literal. Either way, we can resolve them to a NullableInterval.
113                let left_interval = self
114                    .guarantees
115                    .get(left.as_ref())
116                    .map(|interval| Cow::Borrowed(*interval))
117                    .or_else(|| {
118                        if let Expr::Literal(value, _) = left.as_ref() {
119                            Some(Cow::Owned(value.clone().into()))
120                        } else {
121                            None
122                        }
123                    });
124                let right_interval = self
125                    .guarantees
126                    .get(right.as_ref())
127                    .map(|interval| Cow::Borrowed(*interval))
128                    .or_else(|| {
129                        if let Expr::Literal(value, _) = right.as_ref() {
130                            Some(Cow::Owned(value.clone().into()))
131                        } else {
132                            None
133                        }
134                    });
135
136                match (left_interval, right_interval) {
137                    (Some(left_interval), Some(right_interval)) => {
138                        let result =
139                            left_interval.apply_operator(op, right_interval.as_ref())?;
140                        if result.is_certainly_true() {
141                            Ok(Transformed::yes(lit(true)))
142                        } else if result.is_certainly_false() {
143                            Ok(Transformed::yes(lit(false)))
144                        } else {
145                            Ok(Transformed::no(expr))
146                        }
147                    }
148                    _ => Ok(Transformed::no(expr)),
149                }
150            }
151
152            // Columns (if interval is collapsed to a single value)
153            Expr::Column(_) => {
154                if let Some(interval) = self.guarantees.get(&expr) {
155                    Ok(Transformed::yes(interval.single_value().map_or(expr, lit)))
156                } else {
157                    Ok(Transformed::no(expr))
158                }
159            }
160
161            Expr::InList(InList {
162                expr: inner,
163                list,
164                negated,
165            }) => {
166                if let Some(interval) = self.guarantees.get(inner.as_ref()) {
167                    // Can remove items from the list that don't match the guarantee
168                    let new_list: Vec<Expr> = list
169                        .iter()
170                        .filter_map(|expr| {
171                            if let Expr::Literal(item, _) = expr {
172                                match interval
173                                    .contains(NullableInterval::from(item.clone()))
174                                {
175                                    // If we know for certain the value isn't in the column's interval,
176                                    // we can skip checking it.
177                                    Ok(interval) if interval.is_certainly_false() => None,
178                                    Ok(_) => Some(Ok(expr.clone())),
179                                    Err(e) => Some(Err(e)),
180                                }
181                            } else {
182                                Some(Ok(expr.clone()))
183                            }
184                        })
185                        .collect::<Result<_, DataFusionError>>()?;
186
187                    Ok(Transformed::yes(Expr::InList(InList {
188                        expr: inner.clone(),
189                        list: new_list,
190                        negated: *negated,
191                    })))
192                } else {
193                    Ok(Transformed::no(expr))
194                }
195            }
196
197            _ => Ok(Transformed::no(expr)),
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    use arrow::datatypes::DataType;
207    use datafusion_common::tree_node::{TransformedResult, TreeNode};
208    use datafusion_common::ScalarValue;
209    use datafusion_expr::{col, Operator};
210
211    #[test]
212    fn test_null_handling() {
213        // IsNull / IsNotNull can be rewritten to true / false
214        let guarantees = vec![
215            // Note: AlwaysNull case handled by test_column_single_value test,
216            // since it's a special case of a column with a single value.
217            (
218                col("x"),
219                NullableInterval::NotNull {
220                    values: Interval::make_unbounded(&DataType::Boolean).unwrap(),
221                },
222            ),
223        ];
224        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
225
226        // x IS NULL => guaranteed false
227        let expr = col("x").is_null();
228        let output = expr.rewrite(&mut rewriter).data().unwrap();
229        assert_eq!(output, lit(false));
230
231        // x IS NOT NULL => guaranteed true
232        let expr = col("x").is_not_null();
233        let output = expr.rewrite(&mut rewriter).data().unwrap();
234        assert_eq!(output, lit(true));
235    }
236
237    fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)])
238    where
239        ScalarValue: From<T>,
240        T: Clone,
241    {
242        for (expr, expected_value) in cases {
243            let output = expr.clone().rewrite(rewriter).data().unwrap();
244            let expected = lit(ScalarValue::from(expected_value.clone()));
245            assert_eq!(
246                output, expected,
247                "{expr} simplified to {output}, but expected {expected}"
248            );
249        }
250    }
251
252    fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) {
253        for expr in cases {
254            let output = expr.clone().rewrite(rewriter).data().unwrap();
255            assert_eq!(
256                &output, expr,
257                "{expr} was simplified to {output}, but expected it to be unchanged"
258            );
259        }
260    }
261
262    #[test]
263    fn test_inequalities_non_null_unbounded() {
264        let guarantees = vec![
265            // y ∈ [2021-01-01, ∞) (not null)
266            (
267                col("x"),
268                NullableInterval::NotNull {
269                    values: Interval::try_new(
270                        ScalarValue::Date32(Some(18628)),
271                        ScalarValue::Date32(None),
272                    )
273                    .unwrap(),
274                },
275            ),
276        ];
277        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
278
279        // (original_expr, expected_simplification)
280        let simplified_cases = &[
281            (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
282            (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
283            (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
284            (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
285            (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
286            (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
287            (
288                col("x").between(
289                    lit(ScalarValue::Date32(Some(16000))),
290                    lit(ScalarValue::Date32(Some(17000))),
291                ),
292                false,
293            ),
294            (
295                col("x").not_between(
296                    lit(ScalarValue::Date32(Some(16000))),
297                    lit(ScalarValue::Date32(Some(17000))),
298                ),
299                true,
300            ),
301            (
302                Expr::BinaryExpr(BinaryExpr {
303                    left: Box::new(col("x")),
304                    op: Operator::IsDistinctFrom,
305                    right: Box::new(lit(ScalarValue::Null)),
306                }),
307                true,
308            ),
309            (
310                Expr::BinaryExpr(BinaryExpr {
311                    left: Box::new(col("x")),
312                    op: Operator::IsDistinctFrom,
313                    right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
314                }),
315                true,
316            ),
317        ];
318
319        validate_simplified_cases(&mut rewriter, simplified_cases);
320
321        let unchanged_cases = &[
322            col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
323            col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
324            col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
325            col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
326            col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
327            col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
328            col("x").between(
329                lit(ScalarValue::Date32(Some(18000))),
330                lit(ScalarValue::Date32(Some(19000))),
331            ),
332            col("x").not_between(
333                lit(ScalarValue::Date32(Some(18000))),
334                lit(ScalarValue::Date32(Some(19000))),
335            ),
336        ];
337
338        validate_unchanged_cases(&mut rewriter, unchanged_cases);
339    }
340
341    #[test]
342    fn test_inequalities_maybe_null() {
343        let guarantees = vec![
344            // x ∈ ("abc", "def"]? (maybe null)
345            (
346                col("x"),
347                NullableInterval::MaybeNull {
348                    values: Interval::try_new(
349                        ScalarValue::from("abc"),
350                        ScalarValue::from("def"),
351                    )
352                    .unwrap(),
353                },
354            ),
355        ];
356        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
357
358        // (original_expr, expected_simplification)
359        let simplified_cases = &[
360            (
361                Expr::BinaryExpr(BinaryExpr {
362                    left: Box::new(col("x")),
363                    op: Operator::IsDistinctFrom,
364                    right: Box::new(lit("z")),
365                }),
366                true,
367            ),
368            (
369                Expr::BinaryExpr(BinaryExpr {
370                    left: Box::new(col("x")),
371                    op: Operator::IsNotDistinctFrom,
372                    right: Box::new(lit("z")),
373                }),
374                false,
375            ),
376        ];
377
378        validate_simplified_cases(&mut rewriter, simplified_cases);
379
380        let unchanged_cases = &[
381            col("x").lt(lit("z")),
382            col("x").lt_eq(lit("z")),
383            col("x").gt(lit("a")),
384            col("x").gt_eq(lit("a")),
385            col("x").eq(lit("abc")),
386            col("x").not_eq(lit("a")),
387            col("x").between(lit("a"), lit("z")),
388            col("x").not_between(lit("a"), lit("z")),
389            Expr::BinaryExpr(BinaryExpr {
390                left: Box::new(col("x")),
391                op: Operator::IsDistinctFrom,
392                right: Box::new(lit(ScalarValue::Null)),
393            }),
394        ];
395
396        validate_unchanged_cases(&mut rewriter, unchanged_cases);
397    }
398
399    #[test]
400    fn test_column_single_value() {
401        let scalars = [
402            ScalarValue::Null,
403            ScalarValue::Int32(Some(1)),
404            ScalarValue::Boolean(Some(true)),
405            ScalarValue::Boolean(None),
406            ScalarValue::from("abc"),
407            ScalarValue::LargeUtf8(Some("def".to_string())),
408            ScalarValue::Date32(Some(18628)),
409            ScalarValue::Date32(None),
410            ScalarValue::Decimal128(Some(1000), 19, 2),
411        ];
412
413        for scalar in scalars {
414            let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))];
415            let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
416
417            let output = col("x").rewrite(&mut rewriter).data().unwrap();
418            assert_eq!(output, Expr::Literal(scalar.clone(), None));
419        }
420    }
421
422    #[test]
423    fn test_in_list() {
424        let guarantees = vec![
425            // x ∈ [1, 10] (not null)
426            (
427                col("x"),
428                NullableInterval::NotNull {
429                    values: Interval::try_new(
430                        ScalarValue::Int32(Some(1)),
431                        ScalarValue::Int32(Some(10)),
432                    )
433                    .unwrap(),
434                },
435            ),
436        ];
437        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
438
439        // These cases should be simplified so the list doesn't contain any
440        // values the guarantee says are outside the range.
441        // (column_name, starting_list, negated, expected_list)
442        let cases = &[
443            // x IN (9, 11) => x IN (9)
444            ("x", vec![9, 11], false, vec![9]),
445            // x IN (10, 2) => x IN (10, 2)
446            ("x", vec![10, 2], false, vec![10, 2]),
447            // x NOT IN (9, 11) => x NOT IN (9)
448            ("x", vec![9, 11], true, vec![9]),
449            // x NOT IN (0, 22) => x NOT IN ()
450            ("x", vec![0, 22], true, vec![]),
451        ];
452
453        for (column_name, starting_list, negated, expected_list) in cases {
454            let expr = col(*column_name).in_list(
455                starting_list
456                    .iter()
457                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
458                    .collect(),
459                *negated,
460            );
461            let output = expr.clone().rewrite(&mut rewriter).data().unwrap();
462            let expected_list = expected_list
463                .iter()
464                .map(|v| lit(ScalarValue::Int32(Some(*v))))
465                .collect();
466            assert_eq!(
467                output,
468                Expr::InList(InList {
469                    expr: Box::new(col(*column_name)),
470                    list: expected_list,
471                    negated: *negated,
472                })
473            );
474        }
475    }
476}