Skip to main content

datafusion_expr/expr_rewriter/
guarantees.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrite expressions based on external expression value range guarantees.
19
20use crate::{Between, BinaryExpr, Expr, expr::InList, lit};
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue};
23use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval};
24use std::borrow::Cow;
25
26/// Rewrite expressions to incorporate guarantees.
27///
28/// See [`rewrite_with_guarantees`] for more information
29pub struct GuaranteeRewriter<'a> {
30    guarantees: HashMap<&'a Expr, &'a NullableInterval>,
31}
32
33impl<'a> GuaranteeRewriter<'a> {
34    pub fn new(
35        guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
36    ) -> Self {
37        Self {
38            guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
39        }
40    }
41}
42
43/// Rewrite expressions to incorporate guarantees.
44///
45/// Guarantees are a mapping from an expression (which currently is always a
46/// column reference) to a [NullableInterval] that represents the known possible
47/// values of the expression.
48///
49/// Rewriting expressions using this type of guarantee can make the work of other expression
50/// simplifications, like const evaluation, easier.
51///
52/// For example, if we know that a column is not null and has values in the
53/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
54///
55/// If the set of guarantees will be used to rewrite more than one expression, consider using
56/// [rewrite_with_guarantees_map] instead.
57///
58/// A full example of using this rewrite rule can be found in
59/// [`ExprSimplifier::with_guarantees()`](https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html#method.with_guarantees).
60pub fn rewrite_with_guarantees<'a>(
61    expr: Expr,
62    guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
63) -> Result<Transformed<Expr>> {
64    let guarantees_map: HashMap<&Expr, &NullableInterval> =
65        guarantees.into_iter().map(|(k, v)| (k, v)).collect();
66    rewrite_with_guarantees_map(expr, &guarantees_map)
67}
68
69/// Rewrite expressions to incorporate guarantees.
70///
71/// Guarantees are a mapping from an expression (which currently is always a
72/// column reference) to a [NullableInterval]. The interval represents the known
73/// possible values of the column.
74///
75/// For example, if we know that a column is not null and has values in the
76/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
77pub fn rewrite_with_guarantees_map<'a>(
78    expr: Expr,
79    guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>,
80) -> Result<Transformed<Expr>> {
81    if guarantees.is_empty() {
82        return Ok(Transformed::no(expr));
83    }
84
85    expr.transform_up(|e| rewrite_expr(e, guarantees))
86}
87
88impl TreeNodeRewriter for GuaranteeRewriter<'_> {
89    type Node = Expr;
90
91    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
92        if self.guarantees.is_empty() {
93            return Ok(Transformed::no(expr));
94        }
95
96        rewrite_expr(expr, &self.guarantees)
97    }
98}
99
100fn rewrite_expr(
101    expr: Expr,
102    guarantees: &HashMap<&Expr, &NullableInterval>,
103) -> Result<Transformed<Expr>> {
104    // If an expression collapses to a single value, replace it with a literal
105    if let Some(interval) = guarantees.get(&expr)
106        && let Some(value) = interval.single_value()
107    {
108        return Ok(Transformed::yes(lit(value)));
109    }
110
111    let result = match expr {
112        Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
113            Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)),
114            Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)),
115            _ => Transformed::no(Expr::IsNull(inner)),
116        },
117        Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) {
118            Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)),
119            Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)),
120            _ => Transformed::no(Expr::IsNotNull(inner)),
121        },
122        Expr::Between(b) => rewrite_between(b, guarantees)?,
123        Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?,
124        Expr::InList(i) => rewrite_inlist(i, guarantees)?,
125        expr => Transformed::no(expr),
126    };
127    Ok(result)
128}
129
130fn rewrite_between(
131    between: Between,
132    guarantees: &HashMap<&Expr, &NullableInterval>,
133) -> Result<Transformed<Expr>> {
134    let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
135        guarantees.get(between.expr.as_ref()),
136        between.low.as_ref(),
137        between.high.as_ref(),
138    ) else {
139        return Ok(Transformed::no(Expr::Between(between)));
140    };
141
142    // Ensure that, if low or high are null, their type matches the other bound
143    let low = ensure_typed_null(low, high)?;
144    let high = ensure_typed_null(high, &low)?;
145
146    let Ok(between_interval) = Interval::try_new(low, high) else {
147        // If we can't create an interval from the literals, be conservative and simply leave
148        // the expression unmodified.
149        return Ok(Transformed::no(Expr::Between(between)));
150    };
151
152    if between_interval.lower().is_null() && between_interval.upper().is_null() {
153        return Ok(Transformed::yes(lit(between_interval.lower().clone())));
154    }
155
156    let expr_interval = match expr_interval {
157        NullableInterval::Null { datatype } => {
158            // Value is guaranteed to be null, so we can simplify to null.
159            return Ok(Transformed::yes(lit(
160                ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null)
161            )));
162        }
163        NullableInterval::MaybeNull { .. } => {
164            // Value may or may not be null, so we can't simplify the expression.
165            return Ok(Transformed::no(Expr::Between(between)));
166        }
167        NullableInterval::NotNull { values } => values,
168    };
169
170    let result = if between_interval.lower().is_null() {
171        // <expr> (NOT) BETWEEN NULL AND <high>
172        let upper_bound = Interval::from(between_interval.upper().clone());
173        if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
174            // if <expr> > high, then certainly false
175            Transformed::yes(lit(between.negated))
176        } else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
177            // if <expr> <= high, then certainly null
178            Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
179                .unwrap_or(ScalarValue::Null)))
180        } else {
181            // otherwise unknown
182            Transformed::no(Expr::Between(between))
183        }
184    } else if between_interval.upper().is_null() {
185        // <expr> (NOT) BETWEEN <low> AND NULL
186        let lower_bound = Interval::from(between_interval.lower().clone());
187        if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
188            // if <expr> < low, then certainly false
189            Transformed::yes(lit(between.negated))
190        } else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
191            // if <expr> >= low, then certainly null
192            Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
193                .unwrap_or(ScalarValue::Null)))
194        } else {
195            // otherwise unknown
196            Transformed::no(Expr::Between(between))
197        }
198    } else {
199        let contains = between_interval.contains(expr_interval)?;
200        if contains.eq(&Interval::TRUE) {
201            Transformed::yes(lit(!between.negated))
202        } else if contains.eq(&Interval::FALSE) {
203            Transformed::yes(lit(between.negated))
204        } else {
205            Transformed::no(Expr::Between(between))
206        }
207    };
208    Ok(result)
209}
210
211fn ensure_typed_null(
212    value: &ScalarValue,
213    other: &ScalarValue,
214) -> Result<ScalarValue, DataFusionError> {
215    Ok(
216        if value.data_type().is_null() && !other.data_type().is_null() {
217            ScalarValue::try_new_null(&other.data_type())?
218        } else {
219            value.clone()
220        },
221    )
222}
223
224fn rewrite_binary_expr(
225    binary: BinaryExpr,
226    guarantees: &HashMap<&Expr, &NullableInterval>,
227) -> Result<Transformed<Expr>, DataFusionError> {
228    // The left or right side of expression might either have a guarantee
229    // or be a literal. Either way, we can resolve them to a NullableInterval.
230    let left_interval = guarantees
231        .get(binary.left.as_ref())
232        .map(|interval| Cow::Borrowed(*interval))
233        .or_else(|| {
234            if let Expr::Literal(value, _) = binary.left.as_ref() {
235                Some(Cow::Owned(value.clone().into()))
236            } else {
237                None
238            }
239        });
240    let right_interval = guarantees
241        .get(binary.right.as_ref())
242        .map(|interval| Cow::Borrowed(*interval))
243        .or_else(|| {
244            if let Expr::Literal(value, _) = binary.right.as_ref() {
245                Some(Cow::Owned(value.clone().into()))
246            } else {
247                None
248            }
249        });
250
251    if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) {
252        let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
253        if result.is_certainly_true() {
254            return Ok(Transformed::yes(lit(true)));
255        } else if result.is_certainly_false() {
256            return Ok(Transformed::yes(lit(false)));
257        }
258    }
259    Ok(Transformed::no(Expr::BinaryExpr(binary)))
260}
261
262fn rewrite_inlist(
263    inlist: InList,
264    guarantees: &HashMap<&Expr, &NullableInterval>,
265) -> Result<Transformed<Expr>, DataFusionError> {
266    let Some(interval) = guarantees.get(inlist.expr.as_ref()) else {
267        return Ok(Transformed::no(Expr::InList(inlist)));
268    };
269
270    let InList {
271        expr,
272        list,
273        negated,
274    } = inlist;
275
276    // Can remove items from the list that don't match the guarantee
277    let list: Vec<Expr> = list
278        .into_iter()
279        .filter_map(|expr| {
280            if let Expr::Literal(item, _) = &expr {
281                match interval.contains(NullableInterval::from(item.clone())) {
282                    // If we know for certain the value isn't in the column's interval,
283                    // we can skip checking it.
284                    Ok(interval) if interval.is_certainly_false() => None,
285                    Ok(_) => Some(Ok(expr)),
286                    Err(e) => Some(Err(e)),
287                }
288            } else {
289                Some(Ok(expr))
290            }
291        })
292        .collect::<Result<_, DataFusionError>>()?;
293
294    Ok(Transformed::yes(Expr::InList(InList {
295        expr,
296        list,
297        negated,
298    })))
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    use crate::{Operator, col};
306    use datafusion_common::tree_node::TransformedResult;
307
308    #[test]
309    fn test_not_null_guarantee() {
310        // IsNull / IsNotNull can be rewritten to true / false
311        let guarantees = [
312            // Note: AlwaysNull case handled by test_column_single_value test,
313            // since it's a special case of a column with a single value.
314            (
315                col("x"),
316                NullableInterval::NotNull {
317                    values: Interval::make(Some(1), Some(3)).unwrap(),
318                },
319            ),
320        ];
321
322        let is_null_cases = vec![
323            // x IS NULL => guaranteed false
324            (col("x").is_null(), Some(lit(false))),
325            // x IS NOT NULL => guaranteed true
326            (col("x").is_not_null(), Some(lit(true))),
327            // [1, 3] BETWEEN 0 AND 10 => guaranteed true
328            (col("x").between(lit(0), lit(10)), Some(lit(true))),
329            // x BETWEEN 1 AND -2 => unknown (actually guaranteed false)
330            (col("x").between(lit(1), lit(-2)), None),
331            // [1, 3] BETWEEN NULL AND 0 => guaranteed false
332            (
333                col("x").between(lit(ScalarValue::Null), lit(0)),
334                Some(lit(false)),
335            ),
336            // [1, 3] BETWEEN NULL AND 1 => unknown
337            (col("x").between(lit(ScalarValue::Null), lit(1)), None),
338            // [1, 3] BETWEEN NULL AND 2 => unknown
339            (col("x").between(lit(ScalarValue::Null), lit(2)), None),
340            // [1, 3] BETWEEN NULL AND 3 => guaranteed NULL
341            (
342                col("x").between(lit(ScalarValue::Null), lit(3)),
343                Some(lit(ScalarValue::Int32(None))),
344            ),
345            // [1, 3] BETWEEN NULL AND 4 => guaranteed NULL
346            (
347                col("x").between(lit(ScalarValue::Null), lit(4)),
348                Some(lit(ScalarValue::Int32(None))),
349            ),
350            // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
351            (
352                col("x").between(lit(0), lit(ScalarValue::Null)),
353                Some(lit(ScalarValue::Int32(None))),
354            ),
355            // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
356            (
357                col("x").between(lit(1), lit(ScalarValue::Null)),
358                Some(lit(ScalarValue::Int32(None))),
359            ),
360            // [1, 3] BETWEEN 2 AND NULL => unknown
361            (col("x").between(lit(2), lit(ScalarValue::Null)), None),
362            // [1, 3] BETWEEN 3 AND NULL => unknown
363            (col("x").between(lit(3), lit(ScalarValue::Null)), None),
364            // [1, 3] BETWEEN 4 AND NULL => guaranteed false
365            (
366                col("x").between(lit(4), lit(ScalarValue::Null)),
367                Some(lit(false)),
368            ),
369            // [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false
370            (
371                col("x").not_between(lit(ScalarValue::Null), lit(0)),
372                Some(lit(true)),
373            ),
374            // [1, 3] NOT BETWEEN NULL AND 1 => unknown
375            (col("x").not_between(lit(ScalarValue::Null), lit(1)), None),
376            // [1, 3] NOT BETWEEN NULL AND 2 => unknown
377            (col("x").not_between(lit(ScalarValue::Null), lit(2)), None),
378            // [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL
379            (
380                col("x").not_between(lit(ScalarValue::Null), lit(3)),
381                Some(lit(ScalarValue::Int32(None))),
382            ),
383            // [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL
384            (
385                col("x").not_between(lit(ScalarValue::Null), lit(4)),
386                Some(lit(ScalarValue::Int32(None))),
387            ),
388            // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
389            (
390                col("x").not_between(lit(0), lit(ScalarValue::Null)),
391                Some(lit(ScalarValue::Int32(None))),
392            ),
393            // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
394            (
395                col("x").not_between(lit(1), lit(ScalarValue::Null)),
396                Some(lit(ScalarValue::Int32(None))),
397            ),
398            // [1, 3] NOT BETWEEN 2 AND NULL => unknown
399            (col("x").not_between(lit(2), lit(ScalarValue::Null)), None),
400            // [1, 3] NOT BETWEEN 3 AND NULL => unknown
401            (col("x").not_between(lit(3), lit(ScalarValue::Null)), None),
402            // [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false
403            (
404                col("x").not_between(lit(4), lit(ScalarValue::Null)),
405                Some(lit(true)),
406            ),
407        ];
408
409        for case in is_null_cases {
410            let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter())
411                .data()
412                .unwrap();
413            let expected = match case.1 {
414                None => case.0.clone(),
415                Some(expected) => expected,
416            };
417
418            assert_eq!(output, expected, "Failed for {}", case.0);
419        }
420    }
421
422    fn validate_simplified_cases<T>(
423        guarantees: &[(Expr, NullableInterval)],
424        cases: &[(Expr, T)],
425    ) where
426        ScalarValue: From<T>,
427        T: Clone,
428    {
429        for (expr, expected_value) in cases {
430            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
431                .data()
432                .unwrap();
433            let expected = lit(ScalarValue::from(expected_value.clone()));
434            assert_eq!(
435                output, expected,
436                "{expr} simplified to {output}, but expected {expected}"
437            );
438        }
439    }
440
441    fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) {
442        for expr in cases {
443            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
444                .data()
445                .unwrap();
446            assert_eq!(
447                &output, expr,
448                "{expr} was simplified to {output}, but expected it to be unchanged"
449            );
450        }
451    }
452
453    #[test]
454    fn test_inequalities_non_null_unbounded() {
455        let guarantees = [
456            // y ∈ [2021-01-01, ∞) (not null)
457            (
458                col("x"),
459                NullableInterval::NotNull {
460                    values: Interval::try_new(
461                        ScalarValue::Date32(Some(18628)),
462                        ScalarValue::Date32(None),
463                    )
464                    .unwrap(),
465                },
466            ),
467        ];
468
469        // (original_expr, expected_simplification)
470        let simplified_cases = &[
471            (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
472            (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
473            (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
474            (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
475            (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
476            (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
477            (
478                col("x").between(
479                    lit(ScalarValue::Date32(Some(16000))),
480                    lit(ScalarValue::Date32(Some(17000))),
481                ),
482                false,
483            ),
484            (
485                col("x").not_between(
486                    lit(ScalarValue::Date32(Some(16000))),
487                    lit(ScalarValue::Date32(Some(17000))),
488                ),
489                true,
490            ),
491            (
492                Expr::BinaryExpr(BinaryExpr {
493                    left: Box::new(col("x")),
494                    op: Operator::IsDistinctFrom,
495                    right: Box::new(lit(ScalarValue::Null)),
496                }),
497                true,
498            ),
499            (
500                Expr::BinaryExpr(BinaryExpr {
501                    left: Box::new(col("x")),
502                    op: Operator::IsDistinctFrom,
503                    right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
504                }),
505                true,
506            ),
507        ];
508
509        validate_simplified_cases(&guarantees, simplified_cases);
510
511        let unchanged_cases = &[
512            col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
513            col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
514            col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
515            col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
516            col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
517            col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
518            col("x").between(
519                lit(ScalarValue::Date32(Some(18000))),
520                lit(ScalarValue::Date32(Some(19000))),
521            ),
522            col("x").not_between(
523                lit(ScalarValue::Date32(Some(18000))),
524                lit(ScalarValue::Date32(Some(19000))),
525            ),
526        ];
527
528        validate_unchanged_cases(&guarantees, unchanged_cases);
529    }
530
531    #[test]
532    fn test_inequalities_maybe_null() {
533        let guarantees = [
534            // x ∈ ("abc", "def"]? (maybe null)
535            (
536                col("x"),
537                NullableInterval::MaybeNull {
538                    values: Interval::try_new(
539                        ScalarValue::from("abc"),
540                        ScalarValue::from("def"),
541                    )
542                    .unwrap(),
543                },
544            ),
545        ];
546
547        // (original_expr, expected_simplification)
548        let simplified_cases = &[
549            (
550                Expr::BinaryExpr(BinaryExpr {
551                    left: Box::new(col("x")),
552                    op: Operator::IsDistinctFrom,
553                    right: Box::new(lit("z")),
554                }),
555                true,
556            ),
557            (
558                Expr::BinaryExpr(BinaryExpr {
559                    left: Box::new(col("x")),
560                    op: Operator::IsNotDistinctFrom,
561                    right: Box::new(lit("z")),
562                }),
563                false,
564            ),
565        ];
566
567        validate_simplified_cases(&guarantees, simplified_cases);
568
569        let unchanged_cases = &[
570            col("x").lt(lit("z")),
571            col("x").lt_eq(lit("z")),
572            col("x").gt(lit("a")),
573            col("x").gt_eq(lit("a")),
574            col("x").eq(lit("abc")),
575            col("x").not_eq(lit("a")),
576            col("x").between(lit("a"), lit("z")),
577            col("x").not_between(lit("a"), lit("z")),
578            Expr::BinaryExpr(BinaryExpr {
579                left: Box::new(col("x")),
580                op: Operator::IsDistinctFrom,
581                right: Box::new(lit(ScalarValue::Null)),
582            }),
583        ];
584
585        validate_unchanged_cases(&guarantees, unchanged_cases);
586    }
587
588    #[test]
589    fn test_column_single_value() {
590        let scalars = [
591            ScalarValue::Null,
592            ScalarValue::Int32(Some(1)),
593            ScalarValue::Boolean(Some(true)),
594            ScalarValue::Boolean(None),
595            ScalarValue::from("abc"),
596            ScalarValue::LargeUtf8(Some("def".to_string())),
597            ScalarValue::Date32(Some(18628)),
598            ScalarValue::Date32(None),
599            ScalarValue::Decimal128(Some(1000), 19, 2),
600        ];
601
602        for scalar in scalars {
603            let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))];
604
605            let output = rewrite_with_guarantees(col("x"), guarantees.iter())
606                .data()
607                .unwrap();
608            assert_eq!(output, Expr::Literal(scalar.clone(), None));
609        }
610    }
611
612    #[test]
613    fn test_in_list() {
614        let guarantees = [
615            // x ∈ [1, 10] (not null)
616            (
617                col("x"),
618                NullableInterval::NotNull {
619                    values: Interval::try_new(
620                        ScalarValue::Int32(Some(1)),
621                        ScalarValue::Int32(Some(10)),
622                    )
623                    .unwrap(),
624                },
625            ),
626        ];
627
628        // These cases should be simplified so the list doesn't contain any
629        // values the guarantee says are outside the range.
630        // (column_name, starting_list, negated, expected_list)
631        let cases = &[
632            // x IN (9, 11) => x IN (9)
633            ("x", vec![9, 11], false, vec![9]),
634            // x IN (10, 2) => x IN (10, 2)
635            ("x", vec![10, 2], false, vec![10, 2]),
636            // x NOT IN (9, 11) => x NOT IN (9)
637            ("x", vec![9, 11], true, vec![9]),
638            // x NOT IN (0, 22) => x NOT IN ()
639            ("x", vec![0, 22], true, vec![]),
640        ];
641
642        for (column_name, starting_list, negated, expected_list) in cases {
643            let expr = col(*column_name).in_list(
644                starting_list
645                    .iter()
646                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
647                    .collect(),
648                *negated,
649            );
650            let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
651                .data()
652                .unwrap();
653            let expected_list = expected_list
654                .iter()
655                .map(|v| lit(ScalarValue::Int32(Some(*v))))
656                .collect();
657            assert_eq!(
658                output,
659                Expr::InList(InList {
660                    expr: Box::new(col(*column_name)),
661                    list: expected_list,
662                    negated: *negated,
663                })
664            );
665        }
666    }
667}