datafusion_expr/expr_rewriter/
guarantees.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrite expressions based on external expression value range guarantees.
19
20use crate::{Between, BinaryExpr, Expr, expr::InList, lit};
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue};
23use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval};
24use std::borrow::Cow;
25
26/// Rewrite expressions to incorporate guarantees.
27///
28/// See [`rewrite_with_guarantees`] for more information
29pub struct GuaranteeRewriter<'a> {
30    guarantees: HashMap<&'a Expr, &'a NullableInterval>,
31}
32
33impl<'a> GuaranteeRewriter<'a> {
34    pub fn new(
35        guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
36    ) -> Self {
37        Self {
38            guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
39        }
40    }
41}
42
43/// Rewrite expressions to incorporate guarantees.
44///
45/// Guarantees are a mapping from an expression (which currently is always a
46/// column reference) to a [NullableInterval] that represents the known possible
47/// values of the expression.
48///
49/// Rewriting expressions using this type of guarantee can make the work of other expression
50/// simplifications, like const evaluation, easier.
51///
52/// For example, if we know that a column is not null and has values in the
53/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
54///
55/// If the set of guarantees will be used to rewrite more than one expression, consider using
56/// [rewrite_with_guarantees_map] instead.
57///
58/// A full example of using this rewrite rule can be found in
59/// [`ExprSimplifier::with_guarantees()`](https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html#method.with_guarantees).
60pub fn rewrite_with_guarantees<'a>(
61    expr: Expr,
62    guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
63) -> Result<Transformed<Expr>> {
64    let guarantees_map: HashMap<&Expr, &NullableInterval> =
65        guarantees.into_iter().map(|(k, v)| (k, v)).collect();
66    rewrite_with_guarantees_map(expr, &guarantees_map)
67}
68
69/// Rewrite expressions to incorporate guarantees.
70///
71/// Guarantees are a mapping from an expression (which currently is always a
72/// column reference) to a [NullableInterval]. The interval represents the known
73/// possible values of the column.
74///
75/// For example, if we know that a column is not null and has values in the
76/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
77pub fn rewrite_with_guarantees_map<'a>(
78    expr: Expr,
79    guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>,
80) -> Result<Transformed<Expr>> {
81    if guarantees.is_empty() {
82        return Ok(Transformed::no(expr));
83    }
84
85    expr.transform_up(|e| rewrite_expr(e, guarantees))
86}
87
88impl TreeNodeRewriter for GuaranteeRewriter<'_> {
89    type Node = Expr;
90
91    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
92        if self.guarantees.is_empty() {
93            return Ok(Transformed::no(expr));
94        }
95
96        rewrite_expr(expr, &self.guarantees)
97    }
98}
99
100fn rewrite_expr(
101    expr: Expr,
102    guarantees: &HashMap<&Expr, &NullableInterval>,
103) -> Result<Transformed<Expr>> {
104    // If an expression collapses to a single value, replace it with a literal
105    if let Some(interval) = guarantees.get(&expr)
106        && let Some(value) = interval.single_value()
107    {
108        return Ok(Transformed::yes(lit(value)));
109    }
110
111    let result = match expr {
112        Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
113            Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)),
114            Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)),
115            _ => Transformed::no(Expr::IsNull(inner)),
116        },
117        Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) {
118            Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)),
119            Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)),
120            _ => Transformed::no(Expr::IsNotNull(inner)),
121        },
122        Expr::Between(b) => rewrite_between(b, guarantees)?,
123        Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?,
124        Expr::InList(i) => rewrite_inlist(i, guarantees)?,
125        expr => Transformed::no(expr),
126    };
127    Ok(result)
128}
129
130fn rewrite_between(
131    between: Between,
132    guarantees: &HashMap<&Expr, &NullableInterval>,
133) -> Result<Transformed<Expr>> {
134    let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
135        guarantees.get(between.expr.as_ref()),
136        between.low.as_ref(),
137        between.high.as_ref(),
138    ) else {
139        return Ok(Transformed::no(Expr::Between(between)));
140    };
141
142    // Ensure that, if low or high are null, their type matches the other bound
143    let low = ensure_typed_null(low, high)?;
144    let high = ensure_typed_null(high, &low)?;
145
146    let Ok(between_interval) = Interval::try_new(low, high) else {
147        // If we can't create an interval from the literals, be conservative and simply leave
148        // the expression unmodified.
149        return Ok(Transformed::no(Expr::Between(between)));
150    };
151
152    if between_interval.lower().is_null() && between_interval.upper().is_null() {
153        return Ok(Transformed::yes(lit(between_interval.lower().clone())));
154    }
155
156    let expr_interval = match expr_interval {
157        NullableInterval::Null { datatype } => {
158            // Value is guaranteed to be null, so we can simplify to null.
159            return Ok(Transformed::yes(lit(
160                ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null)
161            )));
162        }
163        NullableInterval::MaybeNull { .. } => {
164            // Value may or may not be null, so we can't simplify the expression.
165            return Ok(Transformed::no(Expr::Between(between)));
166        }
167        NullableInterval::NotNull { values } => values,
168    };
169
170    let result = if between_interval.lower().is_null() {
171        // <expr> (NOT) BETWEEN NULL AND <high>
172        let upper_bound = Interval::from(between_interval.upper().clone());
173        if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
174            // if <expr> > high, then certainly false
175            Transformed::yes(lit(between.negated))
176        } else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
177            // if <expr> <= high, then certainly null
178            Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
179                .unwrap_or(ScalarValue::Null)))
180        } else {
181            // otherwise unknown
182            Transformed::no(Expr::Between(between))
183        }
184    } else if between_interval.upper().is_null() {
185        // <expr> (NOT) BETWEEN <low> AND NULL
186        let lower_bound = Interval::from(between_interval.lower().clone());
187        if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
188            // if <expr> < low, then certainly false
189            Transformed::yes(lit(between.negated))
190        } else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
191            // if <expr> >= low, then certainly null
192            Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
193                .unwrap_or(ScalarValue::Null)))
194        } else {
195            // otherwise unknown
196            Transformed::no(Expr::Between(between))
197        }
198    } else {
199        let contains = between_interval.contains(expr_interval)?;
200        if contains.eq(&Interval::TRUE) {
201            Transformed::yes(lit(!between.negated))
202        } else if contains.eq(&Interval::FALSE) {
203            Transformed::yes(lit(between.negated))
204        } else {
205            Transformed::no(Expr::Between(between))
206        }
207    };
208    Ok(result)
209}
210
211fn ensure_typed_null(
212    value: &ScalarValue,
213    other: &ScalarValue,
214) -> Result<ScalarValue, DataFusionError> {
215    Ok(
216        if value.data_type().is_null() && !other.data_type().is_null() {
217            ScalarValue::try_new_null(&other.data_type())?
218        } else {
219            value.clone()
220        },
221    )
222}
223
224fn rewrite_binary_expr(
225    binary: BinaryExpr,
226    guarantees: &HashMap<&Expr, &NullableInterval>,
227) -> Result<Transformed<Expr>, DataFusionError> {
228    // The left or right side of expression might either have a guarantee
229    // or be a literal. Either way, we can resolve them to a NullableInterval.
230    let left_interval = guarantees
231        .get(binary.left.as_ref())
232        .map(|interval| Cow::Borrowed(*interval))
233        .or_else(|| {
234            if let Expr::Literal(value, _) = binary.left.as_ref() {
235                Some(Cow::Owned(value.clone().into()))
236            } else {
237                None
238            }
239        });
240    let right_interval = guarantees
241        .get(binary.right.as_ref())
242        .map(|interval| Cow::Borrowed(*interval))
243        .or_else(|| {
244            if let Expr::Literal(value, _) = binary.right.as_ref() {
245                Some(Cow::Owned(value.clone().into()))
246            } else {
247                None
248            }
249        });
250
251    if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) {
252        let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
253        if result.is_certainly_true() {
254            return Ok(Transformed::yes(lit(true)));
255        } else if result.is_certainly_false() {
256            return Ok(Transformed::yes(lit(false)));
257        }
258    }
259    Ok(Transformed::no(Expr::BinaryExpr(binary)))
260}
261
262fn rewrite_inlist(
263    inlist: InList,
264    guarantees: &HashMap<&Expr, &NullableInterval>,
265) -> Result<Transformed<Expr>, DataFusionError> {
266    let Some(interval) = guarantees.get(inlist.expr.as_ref()) else {
267        return Ok(Transformed::no(Expr::InList(inlist)));
268    };
269
270    let InList {
271        expr,
272        list,
273        negated,
274    } = inlist;
275
276    // Can remove items from the list that don't match the guarantee
277    let list: Vec<Expr> = list
278        .into_iter()
279        .filter_map(|expr| {
280            if let Expr::Literal(item, _) = &expr {
281                match interval.contains(NullableInterval::from(item.clone())) {
282                    // If we know for certain the value isn't in the column's interval,
283                    // we can skip checking it.
284                    Ok(interval) if interval.is_certainly_false() => None,
285                    Ok(_) => Some(Ok(expr)),
286                    Err(e) => Some(Err(e)),
287                }
288            } else {
289                Some(Ok(expr))
290            }
291        })
292        .collect::<Result<_, DataFusionError>>()?;
293
294    Ok(Transformed::yes(Expr::InList(InList {
295        expr,
296        list,
297        negated,
298    })))
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    use crate::{Operator, col};
306    use datafusion_common::ScalarValue;
307    use datafusion_common::tree_node::TransformedResult;
308
309    #[test]
310    fn test_not_null_guarantee() {
311        // IsNull / IsNotNull can be rewritten to true / false
312        let guarantees = [
313            // Note: AlwaysNull case handled by test_column_single_value test,
314            // since it's a special case of a column with a single value.
315            (
316                col("x"),
317                NullableInterval::NotNull {
318                    values: Interval::make(Some(1), Some(3)).unwrap(),
319                },
320            ),
321        ];
322
323        let is_null_cases = vec![
324            // x IS NULL => guaranteed false
325            (col("x").is_null(), Some(lit(false))),
326            // x IS NOT NULL => guaranteed true
327            (col("x").is_not_null(), Some(lit(true))),
328            // [1, 3] BETWEEN 0 AND 10 => guaranteed true
329            (col("x").between(lit(0), lit(10)), Some(lit(true))),
330            // x BETWEEN 1 AND -2 => unknown (actually guaranteed false)
331            (col("x").between(lit(1), lit(-2)), None),
332            // [1, 3] BETWEEN NULL AND 0 => guaranteed false
333            (
334                col("x").between(lit(ScalarValue::Null), lit(0)),
335                Some(lit(false)),
336            ),
337            // [1, 3] BETWEEN NULL AND 1 => unknown
338            (col("x").between(lit(ScalarValue::Null), lit(1)), None),
339            // [1, 3] BETWEEN NULL AND 2 => unknown
340            (col("x").between(lit(ScalarValue::Null), lit(2)), None),
341            // [1, 3] BETWEEN NULL AND 3 => guaranteed NULL
342            (
343                col("x").between(lit(ScalarValue::Null), lit(3)),
344                Some(lit(ScalarValue::Int32(None))),
345            ),
346            // [1, 3] BETWEEN NULL AND 4 => guaranteed NULL
347            (
348                col("x").between(lit(ScalarValue::Null), lit(4)),
349                Some(lit(ScalarValue::Int32(None))),
350            ),
351            // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
352            (
353                col("x").between(lit(0), lit(ScalarValue::Null)),
354                Some(lit(ScalarValue::Int32(None))),
355            ),
356            // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
357            (
358                col("x").between(lit(1), lit(ScalarValue::Null)),
359                Some(lit(ScalarValue::Int32(None))),
360            ),
361            // [1, 3] BETWEEN 2 AND NULL => unknown
362            (col("x").between(lit(2), lit(ScalarValue::Null)), None),
363            // [1, 3] BETWEEN 3 AND NULL => unknown
364            (col("x").between(lit(3), lit(ScalarValue::Null)), None),
365            // [1, 3] BETWEEN 4 AND NULL => guaranteed false
366            (
367                col("x").between(lit(4), lit(ScalarValue::Null)),
368                Some(lit(false)),
369            ),
370            // [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false
371            (
372                col("x").not_between(lit(ScalarValue::Null), lit(0)),
373                Some(lit(true)),
374            ),
375            // [1, 3] NOT BETWEEN NULL AND 1 => unknown
376            (col("x").not_between(lit(ScalarValue::Null), lit(1)), None),
377            // [1, 3] NOT BETWEEN NULL AND 2 => unknown
378            (col("x").not_between(lit(ScalarValue::Null), lit(2)), None),
379            // [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL
380            (
381                col("x").not_between(lit(ScalarValue::Null), lit(3)),
382                Some(lit(ScalarValue::Int32(None))),
383            ),
384            // [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL
385            (
386                col("x").not_between(lit(ScalarValue::Null), lit(4)),
387                Some(lit(ScalarValue::Int32(None))),
388            ),
389            // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
390            (
391                col("x").not_between(lit(0), lit(ScalarValue::Null)),
392                Some(lit(ScalarValue::Int32(None))),
393            ),
394            // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
395            (
396                col("x").not_between(lit(1), lit(ScalarValue::Null)),
397                Some(lit(ScalarValue::Int32(None))),
398            ),
399            // [1, 3] NOT BETWEEN 2 AND NULL => unknown
400            (col("x").not_between(lit(2), lit(ScalarValue::Null)), None),
401            // [1, 3] NOT BETWEEN 3 AND NULL => unknown
402            (col("x").not_between(lit(3), lit(ScalarValue::Null)), None),
403            // [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false
404            (
405                col("x").not_between(lit(4), lit(ScalarValue::Null)),
406                Some(lit(true)),
407            ),
408        ];
409
410        for case in is_null_cases {
411            let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter())
412                .data()
413                .unwrap();
414            let expected = match case.1 {
415                None => case.0.clone(),
416                Some(expected) => expected,
417            };
418
419            assert_eq!(output, expected, "Failed for {}", case.0);
420        }
421    }
422
423    fn validate_simplified_cases<T>(
424        guarantees: &[(Expr, NullableInterval)],
425        cases: &[(Expr, T)],
426    ) where
427        ScalarValue: From<T>,
428        T: Clone,
429    {
430        for (expr, expected_value) in cases {
431            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
432                .data()
433                .unwrap();
434            let expected = lit(ScalarValue::from(expected_value.clone()));
435            assert_eq!(
436                output, expected,
437                "{expr} simplified to {output}, but expected {expected}"
438            );
439        }
440    }
441
442    fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) {
443        for expr in cases {
444            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
445                .data()
446                .unwrap();
447            assert_eq!(
448                &output, expr,
449                "{expr} was simplified to {output}, but expected it to be unchanged"
450            );
451        }
452    }
453
454    #[test]
455    fn test_inequalities_non_null_unbounded() {
456        let guarantees = [
457            // y ∈ [2021-01-01, ∞) (not null)
458            (
459                col("x"),
460                NullableInterval::NotNull {
461                    values: Interval::try_new(
462                        ScalarValue::Date32(Some(18628)),
463                        ScalarValue::Date32(None),
464                    )
465                    .unwrap(),
466                },
467            ),
468        ];
469
470        // (original_expr, expected_simplification)
471        let simplified_cases = &[
472            (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
473            (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
474            (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
475            (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
476            (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
477            (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
478            (
479                col("x").between(
480                    lit(ScalarValue::Date32(Some(16000))),
481                    lit(ScalarValue::Date32(Some(17000))),
482                ),
483                false,
484            ),
485            (
486                col("x").not_between(
487                    lit(ScalarValue::Date32(Some(16000))),
488                    lit(ScalarValue::Date32(Some(17000))),
489                ),
490                true,
491            ),
492            (
493                Expr::BinaryExpr(BinaryExpr {
494                    left: Box::new(col("x")),
495                    op: Operator::IsDistinctFrom,
496                    right: Box::new(lit(ScalarValue::Null)),
497                }),
498                true,
499            ),
500            (
501                Expr::BinaryExpr(BinaryExpr {
502                    left: Box::new(col("x")),
503                    op: Operator::IsDistinctFrom,
504                    right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
505                }),
506                true,
507            ),
508        ];
509
510        validate_simplified_cases(&guarantees, simplified_cases);
511
512        let unchanged_cases = &[
513            col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
514            col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
515            col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
516            col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
517            col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
518            col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
519            col("x").between(
520                lit(ScalarValue::Date32(Some(18000))),
521                lit(ScalarValue::Date32(Some(19000))),
522            ),
523            col("x").not_between(
524                lit(ScalarValue::Date32(Some(18000))),
525                lit(ScalarValue::Date32(Some(19000))),
526            ),
527        ];
528
529        validate_unchanged_cases(&guarantees, unchanged_cases);
530    }
531
532    #[test]
533    fn test_inequalities_maybe_null() {
534        let guarantees = [
535            // x ∈ ("abc", "def"]? (maybe null)
536            (
537                col("x"),
538                NullableInterval::MaybeNull {
539                    values: Interval::try_new(
540                        ScalarValue::from("abc"),
541                        ScalarValue::from("def"),
542                    )
543                    .unwrap(),
544                },
545            ),
546        ];
547
548        // (original_expr, expected_simplification)
549        let simplified_cases = &[
550            (
551                Expr::BinaryExpr(BinaryExpr {
552                    left: Box::new(col("x")),
553                    op: Operator::IsDistinctFrom,
554                    right: Box::new(lit("z")),
555                }),
556                true,
557            ),
558            (
559                Expr::BinaryExpr(BinaryExpr {
560                    left: Box::new(col("x")),
561                    op: Operator::IsNotDistinctFrom,
562                    right: Box::new(lit("z")),
563                }),
564                false,
565            ),
566        ];
567
568        validate_simplified_cases(&guarantees, simplified_cases);
569
570        let unchanged_cases = &[
571            col("x").lt(lit("z")),
572            col("x").lt_eq(lit("z")),
573            col("x").gt(lit("a")),
574            col("x").gt_eq(lit("a")),
575            col("x").eq(lit("abc")),
576            col("x").not_eq(lit("a")),
577            col("x").between(lit("a"), lit("z")),
578            col("x").not_between(lit("a"), lit("z")),
579            Expr::BinaryExpr(BinaryExpr {
580                left: Box::new(col("x")),
581                op: Operator::IsDistinctFrom,
582                right: Box::new(lit(ScalarValue::Null)),
583            }),
584        ];
585
586        validate_unchanged_cases(&guarantees, unchanged_cases);
587    }
588
589    #[test]
590    fn test_column_single_value() {
591        let scalars = [
592            ScalarValue::Null,
593            ScalarValue::Int32(Some(1)),
594            ScalarValue::Boolean(Some(true)),
595            ScalarValue::Boolean(None),
596            ScalarValue::from("abc"),
597            ScalarValue::LargeUtf8(Some("def".to_string())),
598            ScalarValue::Date32(Some(18628)),
599            ScalarValue::Date32(None),
600            ScalarValue::Decimal128(Some(1000), 19, 2),
601        ];
602
603        for scalar in scalars {
604            let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))];
605
606            let output = rewrite_with_guarantees(col("x"), guarantees.iter())
607                .data()
608                .unwrap();
609            assert_eq!(output, Expr::Literal(scalar.clone(), None));
610        }
611    }
612
613    #[test]
614    fn test_in_list() {
615        let guarantees = [
616            // x ∈ [1, 10] (not null)
617            (
618                col("x"),
619                NullableInterval::NotNull {
620                    values: Interval::try_new(
621                        ScalarValue::Int32(Some(1)),
622                        ScalarValue::Int32(Some(10)),
623                    )
624                    .unwrap(),
625                },
626            ),
627        ];
628
629        // These cases should be simplified so the list doesn't contain any
630        // values the guarantee says are outside the range.
631        // (column_name, starting_list, negated, expected_list)
632        let cases = &[
633            // x IN (9, 11) => x IN (9)
634            ("x", vec![9, 11], false, vec![9]),
635            // x IN (10, 2) => x IN (10, 2)
636            ("x", vec![10, 2], false, vec![10, 2]),
637            // x NOT IN (9, 11) => x NOT IN (9)
638            ("x", vec![9, 11], true, vec![9]),
639            // x NOT IN (0, 22) => x NOT IN ()
640            ("x", vec![0, 22], true, vec![]),
641        ];
642
643        for (column_name, starting_list, negated, expected_list) in cases {
644            let expr = col(*column_name).in_list(
645                starting_list
646                    .iter()
647                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
648                    .collect(),
649                *negated,
650            );
651            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
652                .data()
653                .unwrap();
654            let expected_list = expected_list
655                .iter()
656                .map(|v| lit(ScalarValue::Int32(Some(*v))))
657                .collect();
658            assert_eq!(
659                output,
660                Expr::InList(InList {
661                    expr: Box::new(col(*column_name)),
662                    list: expected_list,
663                    negated: *negated,
664                })
665            );
666        }
667    }
668}