datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use arrow::{
21    array::{new_null_array, AsArray},
22    datatypes::{DataType, Field, Schema},
23    record_batch::RecordBatch,
24};
25use std::borrow::Cow;
26use std::collections::HashSet;
27use std::ops::Not;
28use std::sync::Arc;
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
33};
34use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35use datafusion_expr::{
36    and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
37    Operator, Volatility,
38};
39use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
40use datafusion_expr::{
41    expr::{InList, InSubquery},
42    utils::{iter_conjunction, iter_conjunction_owned},
43};
44use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
45use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
46
47use super::inlist_simplifier::ShortenInListSimplifier;
48use super::utils::*;
49use crate::analyzer::type_coercion::TypeCoercionRewriter;
50use crate::simplify_expressions::guarantees::GuaranteeRewriter;
51use crate::simplify_expressions::regex::simplify_regex_expr;
52use crate::simplify_expressions::unwrap_cast::{
53    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
54    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
55    unwrap_cast_in_comparison_for_binary,
56};
57use crate::simplify_expressions::SimplifyInfo;
58use datafusion_expr::expr::FieldMetadata;
59use datafusion_expr_common::casts::try_cast_literal_to_type;
60use indexmap::IndexSet;
61use regex::Regex;
62
63/// This structure handles API for expression simplification
64///
65/// Provides simplification information based on DFSchema and
66/// [`ExecutionProps`]. This is the default implementation used by DataFusion
67///
68/// For example:
69/// ```
70/// use arrow::datatypes::{Schema, Field, DataType};
71/// use datafusion_expr::{col, lit};
72/// use datafusion_common::{DataFusionError, ToDFSchema};
73/// use datafusion_expr::execution_props::ExecutionProps;
74/// use datafusion_expr::simplify::SimplifyContext;
75/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
76///
77/// // Create the schema
78/// let schema = Schema::new(vec![
79///     Field::new("i", DataType::Int64, false),
80///   ])
81///   .to_dfschema_ref().unwrap();
82///
83/// // Create the simplifier
84/// let props = ExecutionProps::new();
85/// let context = SimplifyContext::new(&props)
86///    .with_schema(schema);
87/// let simplifier = ExprSimplifier::new(context);
88///
89/// // Use the simplifier
90///
91/// // b < 2 or (1 > 3)
92/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
93///
94/// // b < 2
95/// let simplified = simplifier.simplify(expr).unwrap();
96/// assert_eq!(simplified, col("b").lt(lit(2)));
97/// ```
98pub struct ExprSimplifier<S> {
99    info: S,
100    /// Guarantees about the values of columns. This is provided by the user
101    /// in [ExprSimplifier::with_guarantees()].
102    guarantees: Vec<(Expr, NullableInterval)>,
103    /// Should expressions be canonicalized before simplification? Defaults to
104    /// true
105    canonicalize: bool,
106    /// Maximum number of simplifier cycles
107    max_simplifier_cycles: u32,
108}
109
110pub const THRESHOLD_INLINE_INLIST: usize = 3;
111pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
112
113impl<S: SimplifyInfo> ExprSimplifier<S> {
114    /// Create a new `ExprSimplifier` with the given `info` such as an
115    /// instance of [`SimplifyContext`]. See
116    /// [`simplify`](Self::simplify) for an example.
117    ///
118    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
119    pub fn new(info: S) -> Self {
120        Self {
121            info,
122            guarantees: vec![],
123            canonicalize: true,
124            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
125        }
126    }
127
128    /// Simplifies this [`Expr`] as much as possible, evaluating
129    /// constants and applying algebraic simplifications.
130    ///
131    /// The types of the expression must match what operators expect,
132    /// or else an error may occur trying to evaluate. See
133    /// [`coerce`](Self::coerce) for a function to help.
134    ///
135    /// # Example:
136    ///
137    /// `b > 2 AND b > 2`
138    ///
139    /// can be written to
140    ///
141    /// `b > 2`
142    ///
143    /// ```
144    /// use arrow::datatypes::DataType;
145    /// use datafusion_expr::{col, lit, Expr};
146    /// use datafusion_common::Result;
147    /// use datafusion_expr::execution_props::ExecutionProps;
148    /// use datafusion_expr::simplify::SimplifyContext;
149    /// use datafusion_expr::simplify::SimplifyInfo;
150    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
151    /// use datafusion_common::DFSchema;
152    /// use std::sync::Arc;
153    ///
154    /// /// Simple implementation that provides `Simplifier` the information it needs
155    /// /// See SimplifyContext for a structure that does this.
156    /// #[derive(Default)]
157    /// struct Info {
158    ///   execution_props: ExecutionProps,
159    /// };
160    ///
161    /// impl SimplifyInfo for Info {
162    ///   fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
163    ///     Ok(false)
164    ///   }
165    ///   fn nullable(&self, expr: &Expr) -> Result<bool> {
166    ///     Ok(true)
167    ///   }
168    ///   fn execution_props(&self) -> &ExecutionProps {
169    ///     &self.execution_props
170    ///   }
171    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
172    ///     Ok(DataType::Int32)
173    ///   }
174    /// }
175    ///
176    /// // Create the simplifier
177    /// let simplifier = ExprSimplifier::new(Info::default());
178    ///
179    /// // b < 2
180    /// let b_lt_2 = col("b").gt(lit(2));
181    ///
182    /// // (b < 2) OR (b < 2)
183    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
184    ///
185    /// // (b < 2) OR (b < 2) --> (b < 2)
186    /// let expr = simplifier.simplify(expr).unwrap();
187    /// assert_eq!(expr, b_lt_2);
188    /// ```
189    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
190        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
191    }
192
193    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
194    /// constants and applying algebraic simplifications. Additionally returns a `u32`
195    /// representing the number of simplification cycles performed, which can be useful for testing
196    /// optimizations.
197    ///
198    /// See [Self::simplify] for details and usage examples.
199    ///
200    #[deprecated(
201        since = "48.0.0",
202        note = "Use `simplify_with_cycle_count_transformed` instead"
203    )]
204    #[allow(unused_mut)]
205    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
206        let (transformed, cycle_count) =
207            self.simplify_with_cycle_count_transformed(expr)?;
208        Ok((transformed.data, cycle_count))
209    }
210
211    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
212    /// constants and applying algebraic simplifications. Additionally returns a `u32`
213    /// representing the number of simplification cycles performed, which can be useful for testing
214    /// optimizations.
215    ///
216    /// # Returns
217    ///
218    /// A tuple containing:
219    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
220    /// - The number of simplification cycles that were performed
221    ///
222    /// See [Self::simplify] for details and usage examples.
223    ///
224    pub fn simplify_with_cycle_count_transformed(
225        &self,
226        mut expr: Expr,
227    ) -> Result<(Transformed<Expr>, u32)> {
228        let mut simplifier = Simplifier::new(&self.info);
229        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
230        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
231        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
232
233        if self.canonicalize {
234            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
235        }
236
237        // Evaluating constants can enable new simplifications and
238        // simplifications can enable new constant evaluation
239        // see `Self::with_max_cycles`
240        let mut num_cycles = 0;
241        let mut has_transformed = false;
242        loop {
243            let Transformed {
244                data, transformed, ..
245            } = expr
246                .rewrite(&mut const_evaluator)?
247                .transform_data(|expr| expr.rewrite(&mut simplifier))?
248                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
249            expr = data;
250            num_cycles += 1;
251            // Track if any transformation occurred
252            has_transformed = has_transformed || transformed;
253            if !transformed || num_cycles >= self.max_simplifier_cycles {
254                break;
255            }
256        }
257        // shorten inlist should be started after other inlist rules are applied
258        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
259        Ok((
260            Transformed::new_transformed(expr, has_transformed),
261            num_cycles,
262        ))
263    }
264
265    /// Apply type coercion to an [`Expr`] so that it can be
266    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
267    ///
268    /// See the [type coercion module](datafusion_expr::type_coercion)
269    /// documentation for more details on type coercion
270    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
271        let mut expr_rewrite = TypeCoercionRewriter { schema };
272        expr.rewrite(&mut expr_rewrite).data()
273    }
274
275    /// Input guarantees about the values of columns.
276    ///
277    /// The guarantees can simplify expressions. For example, if a column `x` is
278    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
279    /// literal `true`.
280    ///
281    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
282    /// where the [Expr] is a column reference and the [NullableInterval]
283    /// is an interval representing the known possible values of that column.
284    ///
285    /// ```rust
286    /// use arrow::datatypes::{DataType, Field, Schema};
287    /// use datafusion_expr::{col, lit, Expr};
288    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
289    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
290    /// use datafusion_expr::execution_props::ExecutionProps;
291    /// use datafusion_expr::simplify::SimplifyContext;
292    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
293    ///
294    /// let schema = Schema::new(vec![
295    ///   Field::new("x", DataType::Int64, false),
296    ///   Field::new("y", DataType::UInt32, false),
297    ///   Field::new("z", DataType::Int64, false),
298    ///   ])
299    ///   .to_dfschema_ref().unwrap();
300    ///
301    /// // Create the simplifier
302    /// let props = ExecutionProps::new();
303    /// let context = SimplifyContext::new(&props)
304    ///    .with_schema(schema);
305    ///
306    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
307    /// let expr_x = col("x").gt_eq(lit(3_i64));
308    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
309    /// let expr_z = col("z").gt(lit(5_i64));
310    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
311    ///
312    /// let guarantees = vec![
313    ///    // x ∈ [3, 5]
314    ///    (
315    ///        col("x"),
316    ///        NullableInterval::NotNull {
317    ///            values: Interval::make(Some(3_i64), Some(5_i64)).unwrap()
318    ///        }
319    ///    ),
320    ///    // y = 3
321    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
322    /// ];
323    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
324    /// let output = simplifier.simplify(expr).unwrap();
325    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
326    /// // z > 5.
327    /// assert_eq!(output, expr_z);
328    /// ```
329    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
330        self.guarantees = guarantees;
331        self
332    }
333
334    /// Should `Canonicalizer` be applied before simplification?
335    ///
336    /// If true (the default), the expression will be rewritten to canonical
337    /// form before simplification. This is useful to ensure that the simplifier
338    /// can apply all possible simplifications.
339    ///
340    /// Some expressions, such as those in some Joins, can not be canonicalized
341    /// without changing their meaning. In these cases, canonicalization should
342    /// be disabled.
343    ///
344    /// ```rust
345    /// use arrow::datatypes::{DataType, Field, Schema};
346    /// use datafusion_expr::{col, lit, Expr};
347    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
348    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
349    /// use datafusion_expr::execution_props::ExecutionProps;
350    /// use datafusion_expr::simplify::SimplifyContext;
351    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
352    ///
353    /// let schema = Schema::new(vec![
354    ///   Field::new("a", DataType::Int64, false),
355    ///   Field::new("b", DataType::Int64, false),
356    ///   Field::new("c", DataType::Int64, false),
357    ///   ])
358    ///   .to_dfschema_ref().unwrap();
359    ///
360    /// // Create the simplifier
361    /// let props = ExecutionProps::new();
362    /// let context = SimplifyContext::new(&props)
363    ///    .with_schema(schema);
364    /// let simplifier = ExprSimplifier::new(context);
365    ///
366    /// // Expression: a = c AND 1 = b
367    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
368    ///
369    /// // With canonicalization, the expression is rewritten to canonical form
370    /// // (though it is no simpler in this case):
371    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
372    /// // Expression has been rewritten to: (c = a AND b = 1)
373    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
374    ///
375    /// // If canonicalization is disabled, the expression is not changed
376    /// let non_canonicalized = simplifier
377    ///   .with_canonicalize(false)
378    ///   .simplify(expr.clone())
379    ///   .unwrap();
380    ///
381    /// assert_eq!(non_canonicalized, expr);
382    /// ```
383    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
384        self.canonicalize = canonicalize;
385        self
386    }
387
388    /// Specifies the maximum number of simplification cycles to run.
389    ///
390    /// The simplifier can perform multiple passes of simplification. This is
391    /// because the output of one simplification step can allow more optimizations
392    /// in another simplification step. For example, constant evaluation can allow more
393    /// expression simplifications, and expression simplifications can allow more constant
394    /// evaluations.
395    ///
396    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
397    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
398    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
399    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
400    ///
401    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
402    /// instead.
403    ///
404    /// ```rust
405    /// use arrow::datatypes::{DataType, Field, Schema};
406    /// use datafusion_expr::{col, lit, Expr};
407    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
408    /// use datafusion_expr::execution_props::ExecutionProps;
409    /// use datafusion_expr::simplify::SimplifyContext;
410    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
411    ///
412    /// let schema = Schema::new(vec![
413    ///   Field::new("a", DataType::Int64, false),
414    ///   ])
415    ///   .to_dfschema_ref().unwrap();
416    ///
417    /// // Create the simplifier
418    /// let props = ExecutionProps::new();
419    /// let context = SimplifyContext::new(&props)
420    ///    .with_schema(schema);
421    /// let simplifier = ExprSimplifier::new(context);
422    ///
423    /// // Expression: a IS NOT NULL
424    /// let expr = col("a").is_not_null();
425    ///
426    /// // When using default maximum cycles, 2 cycles will be performed.
427    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
428    /// assert_eq!(simplified_expr.data, lit(true));
429    /// // 2 cycles were executed, but only 1 was needed
430    /// assert_eq!(count, 2);
431    ///
432    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
433    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
434    /// // Expression has been rewritten to: (c = a AND b = 1)
435    /// assert_eq!(simplified_expr.data, lit(true));
436    /// // Only 1 cycle was executed
437    /// assert_eq!(count, 1);
438    ///
439    /// ```
440    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
441        self.max_simplifier_cycles = max_simplifier_cycles;
442        self
443    }
444}
445
446/// Canonicalize any BinaryExprs that are not in canonical form
447///
448/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
449///
450/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
451/// than `col2` (`a > b` would be canonicalized to `b < a`)
452struct Canonicalizer {}
453
454impl Canonicalizer {
455    fn new() -> Self {
456        Self {}
457    }
458}
459
460impl TreeNodeRewriter for Canonicalizer {
461    type Node = Expr;
462
463    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
464        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
465            return Ok(Transformed::no(expr));
466        };
467        match (left.as_ref(), right.as_ref(), op.swap()) {
468            // <col1> <op> <col2>
469            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
470                if right_col > left_col =>
471            {
472                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
473                    left: right,
474                    op: swapped_op,
475                    right: left,
476                })))
477            }
478            // <literal> <op> <col>
479            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
480                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
481                    left: right,
482                    op: swapped_op,
483                    right: left,
484                })))
485            }
486            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
487                left,
488                op,
489                right,
490            }))),
491        }
492    }
493}
494
495#[allow(rustdoc::private_intra_doc_links)]
496/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
497///
498/// Note it does not handle algebraic rewrites such as `(a or false)`
499/// --> `a`, which is handled by [`Simplifier`]
500struct ConstEvaluator<'a> {
501    /// `can_evaluate` is used during the depth-first-search of the
502    /// `Expr` tree to track if any siblings (or their descendants) were
503    /// non evaluatable (e.g. had a column reference or volatile
504    /// function)
505    ///
506    /// Specifically, `can_evaluate[N]` represents the state of
507    /// traversal when we are N levels deep in the tree, one entry for
508    /// this Expr and each of its parents.
509    ///
510    /// After visiting all siblings if `can_evaluate.top()` is true, that
511    /// means there were no non evaluatable siblings (or their
512    /// descendants) so this `Expr` can be evaluated
513    can_evaluate: Vec<bool>,
514
515    execution_props: &'a ExecutionProps,
516    input_schema: DFSchema,
517    input_batch: RecordBatch,
518}
519
520#[allow(dead_code)]
521/// The simplify result of ConstEvaluator
522enum ConstSimplifyResult {
523    // Expr was simplified and contains the new expression
524    Simplified(ScalarValue, Option<FieldMetadata>),
525    // Expr was not simplified and original value is returned
526    NotSimplified(ScalarValue, Option<FieldMetadata>),
527    // Evaluation encountered an error, contains the original expression
528    SimplifyRuntimeError(DataFusionError, Expr),
529}
530
531impl TreeNodeRewriter for ConstEvaluator<'_> {
532    type Node = Expr;
533
534    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
535        // Default to being able to evaluate this node
536        self.can_evaluate.push(true);
537
538        // if this expr is not ok to evaluate, mark entire parent
539        // stack as not ok (as all parents have at least one child or
540        // descendant that can not be evaluated
541
542        if !Self::can_evaluate(&expr) {
543            // walk back up stack, marking first parent that is not mutable
544            let parent_iter = self.can_evaluate.iter_mut().rev();
545            for p in parent_iter {
546                if !*p {
547                    // optimization: if we find an element on the
548                    // stack already marked, know all elements above are also marked
549                    break;
550                }
551                *p = false;
552            }
553        }
554
555        // NB: do not short circuit recursion even if we find a non
556        // evaluatable node (so we can fold other children, args to
557        // functions, etc.)
558        Ok(Transformed::no(expr))
559    }
560
561    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
562        match self.can_evaluate.pop() {
563            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
564            // and may not evaluate all their sub expressions. Thus, if
565            // any error is countered during simplification, return the original
566            // so that normal evaluation can occur
567            Some(true) => match self.evaluate_to_scalar(expr) {
568                ConstSimplifyResult::Simplified(s, m) => {
569                    Ok(Transformed::yes(Expr::Literal(s, m)))
570                }
571                ConstSimplifyResult::NotSimplified(s, m) => {
572                    Ok(Transformed::no(Expr::Literal(s, m)))
573                }
574                ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
575                    Ok(Transformed::yes(expr))
576                }
577            },
578            Some(false) => Ok(Transformed::no(expr)),
579            _ => internal_err!("Failed to pop can_evaluate"),
580        }
581    }
582}
583
584impl<'a> ConstEvaluator<'a> {
585    /// Create a new `ConstantEvaluator`. Session constants (such as
586    /// the time for `now()` are taken from the passed
587    /// `execution_props`.
588    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
589        // The dummy column name is unused and doesn't matter as only
590        // expressions without column references can be evaluated
591        static DUMMY_COL_NAME: &str = ".";
592        let schema = Arc::new(Schema::new(vec![Field::new(
593            DUMMY_COL_NAME,
594            DataType::Null,
595            true,
596        )]));
597        let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
598        // Need a single "input" row to produce a single output row
599        let col = new_null_array(&DataType::Null, 1);
600        let input_batch = RecordBatch::try_new(schema, vec![col])?;
601
602        Ok(Self {
603            can_evaluate: vec![],
604            execution_props,
605            input_schema,
606            input_batch,
607        })
608    }
609
610    /// Can a function of the specified volatility be evaluated?
611    fn volatility_ok(volatility: Volatility) -> bool {
612        match volatility {
613            Volatility::Immutable => true,
614            // Values for functions such as now() are taken from ExecutionProps
615            Volatility::Stable => true,
616            Volatility::Volatile => false,
617        }
618    }
619
620    /// Can the expression be evaluated at plan time, (assuming all of
621    /// its children can also be evaluated)?
622    fn can_evaluate(expr: &Expr) -> bool {
623        // check for reasons we can't evaluate this node
624        //
625        // NOTE all expr types are listed here so when new ones are
626        // added they can be checked for their ability to be evaluated
627        // at plan time
628        match expr {
629            // TODO: remove the next line after `Expr::Wildcard` is removed
630            #[expect(deprecated)]
631            Expr::AggregateFunction { .. }
632            | Expr::ScalarVariable(_, _)
633            | Expr::Column(_)
634            | Expr::OuterReferenceColumn(_, _)
635            | Expr::Exists { .. }
636            | Expr::InSubquery(_)
637            | Expr::ScalarSubquery(_)
638            | Expr::WindowFunction { .. }
639            | Expr::GroupingSet(_)
640            | Expr::Wildcard { .. }
641            | Expr::Placeholder(_) => false,
642            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
643                Self::volatility_ok(func.signature().volatility)
644            }
645            Expr::Literal(_, _)
646            | Expr::Alias(..)
647            | Expr::Unnest(_)
648            | Expr::BinaryExpr { .. }
649            | Expr::Not(_)
650            | Expr::IsNotNull(_)
651            | Expr::IsNull(_)
652            | Expr::IsTrue(_)
653            | Expr::IsFalse(_)
654            | Expr::IsUnknown(_)
655            | Expr::IsNotTrue(_)
656            | Expr::IsNotFalse(_)
657            | Expr::IsNotUnknown(_)
658            | Expr::Negative(_)
659            | Expr::Between { .. }
660            | Expr::Like { .. }
661            | Expr::SimilarTo { .. }
662            | Expr::Case(_)
663            | Expr::Cast { .. }
664            | Expr::TryCast { .. }
665            | Expr::InList { .. } => true,
666        }
667    }
668
669    /// Internal helper to evaluates an Expr
670    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
671        if let Expr::Literal(s, m) = expr {
672            return ConstSimplifyResult::NotSimplified(s, m);
673        }
674
675        let phys_expr =
676            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
677                Ok(e) => e,
678                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
679            };
680        let metadata = phys_expr
681            .return_field(self.input_batch.schema_ref())
682            .ok()
683            .and_then(|f| {
684                let m = f.metadata();
685                match m.is_empty() {
686                    true => None,
687                    false => Some(FieldMetadata::from(m)),
688                }
689            });
690        let col_val = match phys_expr.evaluate(&self.input_batch) {
691            Ok(v) => v,
692            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
693        };
694        match col_val {
695            ColumnarValue::Array(a) => {
696                if a.len() != 1 {
697                    ConstSimplifyResult::SimplifyRuntimeError(
698                        DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
699                        expr,
700                    )
701                } else if as_list_array(&a).is_ok() {
702                    ConstSimplifyResult::Simplified(
703                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
704                        metadata,
705                    )
706                } else if as_large_list_array(&a).is_ok() {
707                    ConstSimplifyResult::Simplified(
708                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
709                        metadata,
710                    )
711                } else {
712                    // Non-ListArray
713                    match ScalarValue::try_from_array(&a, 0) {
714                        Ok(s) => {
715                            // TODO: support the optimization for `Map` type after support impl hash for it
716                            if matches!(&s, ScalarValue::Map(_)) {
717                                ConstSimplifyResult::SimplifyRuntimeError(
718                                    DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()),
719                                    expr,
720                                )
721                            } else {
722                                ConstSimplifyResult::Simplified(s, metadata)
723                            }
724                        }
725                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
726                    }
727                }
728            }
729            ColumnarValue::Scalar(s) => {
730                // TODO: support the optimization for `Map` type after support impl hash for it
731                if matches!(&s, ScalarValue::Map(_)) {
732                    ConstSimplifyResult::SimplifyRuntimeError(
733                        DataFusionError::NotImplemented(
734                            "Const evaluate for Map type is still not supported"
735                                .to_string(),
736                        ),
737                        expr,
738                    )
739                } else {
740                    ConstSimplifyResult::Simplified(s, metadata)
741                }
742            }
743        }
744    }
745}
746
747/// Simplifies [`Expr`]s by applying algebraic transformation rules
748///
749/// Example transformations that are applied:
750/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
751/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
752/// * `true = true` and `false = false` to `true`
753/// * `false = true` and `true = false` to `false`
754/// * `!!expr` to `expr`
755/// * `expr = null` and `expr != null` to `null`
756struct Simplifier<'a, S> {
757    info: &'a S,
758}
759
760impl<'a, S> Simplifier<'a, S> {
761    pub fn new(info: &'a S) -> Self {
762        Self { info }
763    }
764}
765
766impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
767    type Node = Expr;
768
769    /// rewrite the expression simplifying any constant expressions
770    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
771        use datafusion_expr::Operator::{
772            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
773            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
774            RegexNotIMatch, RegexNotMatch,
775        };
776
777        let info = self.info;
778        Ok(match expr {
779            // `value op NULL` -> `NULL`
780            // `NULL op value` -> `NULL`
781            // except for few operators that can return non-null value even when one of the operands is NULL
782            ref expr @ Expr::BinaryExpr(BinaryExpr {
783                ref left,
784                ref op,
785                ref right,
786            }) if op.returns_null_on_null()
787                && (is_null(left.as_ref()) || is_null(right.as_ref())) =>
788            {
789                Transformed::yes(Expr::Literal(
790                    ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
791                    None,
792                ))
793            }
794
795            // `NULL {AND, OR} NULL` -> `NULL`
796            Expr::BinaryExpr(BinaryExpr {
797                left,
798                op: And | Or,
799                right,
800            }) if is_null(&left) && is_null(&right) => Transformed::yes(lit_bool_null()),
801
802            //
803            // Rules for Eq
804            //
805
806            // true = A  --> A
807            // false = A --> !A
808            // null = A --> null
809            Expr::BinaryExpr(BinaryExpr {
810                left,
811                op: Eq,
812                right,
813            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
814                Transformed::yes(match as_bool_lit(&left)? {
815                    Some(true) => *right,
816                    Some(false) => Expr::Not(right),
817                    None => lit_bool_null(),
818                })
819            }
820            // A = true  --> A
821            // A = false --> !A
822            // A = null --> null
823            Expr::BinaryExpr(BinaryExpr {
824                left,
825                op: Eq,
826                right,
827            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
828                Transformed::yes(match as_bool_lit(&right)? {
829                    Some(true) => *left,
830                    Some(false) => Expr::Not(left),
831                    None => lit_bool_null(),
832                })
833            }
834            // According to SQL's null semantics, NULL = NULL evaluates to NULL
835            // Both sides are the same expression (A = A) and A is non-volatile expression
836            // A = A --> A IS NOT NULL OR NULL
837            // A = A --> true (if A not nullable)
838            Expr::BinaryExpr(BinaryExpr {
839                left,
840                op: Eq,
841                right,
842            }) if (left == right) & !left.is_volatile() => {
843                Transformed::yes(match !info.nullable(&left)? {
844                    true => lit(true),
845                    false => Expr::BinaryExpr(BinaryExpr {
846                        left: Box::new(Expr::IsNotNull(left)),
847                        op: Or,
848                        right: Box::new(lit_bool_null()),
849                    }),
850                })
851            }
852
853            // Rules for NotEq
854            //
855
856            // true != A  --> !A
857            // false != A --> A
858            // null != A --> null
859            Expr::BinaryExpr(BinaryExpr {
860                left,
861                op: NotEq,
862                right,
863            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
864                Transformed::yes(match as_bool_lit(&left)? {
865                    Some(true) => Expr::Not(right),
866                    Some(false) => *right,
867                    None => lit_bool_null(),
868                })
869            }
870            // A != true  --> !A
871            // A != false --> A
872            // A != null --> null,
873            Expr::BinaryExpr(BinaryExpr {
874                left,
875                op: NotEq,
876                right,
877            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
878                Transformed::yes(match as_bool_lit(&right)? {
879                    Some(true) => Expr::Not(left),
880                    Some(false) => *left,
881                    None => lit_bool_null(),
882                })
883            }
884
885            //
886            // Rules for OR
887            //
888
889            // true OR A --> true (even if A is null)
890            Expr::BinaryExpr(BinaryExpr {
891                left,
892                op: Or,
893                right: _,
894            }) if is_true(&left) => Transformed::yes(*left),
895            // false OR A --> A
896            Expr::BinaryExpr(BinaryExpr {
897                left,
898                op: Or,
899                right,
900            }) if is_false(&left) => Transformed::yes(*right),
901            // A OR true --> true (even if A is null)
902            Expr::BinaryExpr(BinaryExpr {
903                left: _,
904                op: Or,
905                right,
906            }) if is_true(&right) => Transformed::yes(*right),
907            // A OR false --> A
908            Expr::BinaryExpr(BinaryExpr {
909                left,
910                op: Or,
911                right,
912            }) if is_false(&right) => Transformed::yes(*left),
913            // A OR !A ---> true (if A not nullable)
914            Expr::BinaryExpr(BinaryExpr {
915                left,
916                op: Or,
917                right,
918            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
919                Transformed::yes(lit(true))
920            }
921            // !A OR A ---> true (if A not nullable)
922            Expr::BinaryExpr(BinaryExpr {
923                left,
924                op: Or,
925                right,
926            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
927                Transformed::yes(lit(true))
928            }
929            // (..A..) OR A --> (..A..)
930            Expr::BinaryExpr(BinaryExpr {
931                left,
932                op: Or,
933                right,
934            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
935            // A OR (..A..) --> (..A..)
936            Expr::BinaryExpr(BinaryExpr {
937                left,
938                op: Or,
939                right,
940            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
941            // A OR (A AND B) --> A
942            Expr::BinaryExpr(BinaryExpr {
943                left,
944                op: Or,
945                right,
946            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
947            // (A AND B) OR A --> A
948            Expr::BinaryExpr(BinaryExpr {
949                left,
950                op: Or,
951                right,
952            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
953            // Eliminate common factors in conjunctions e.g
954            // (A AND B) OR (A AND C) -> A AND (B OR C)
955            Expr::BinaryExpr(BinaryExpr {
956                left,
957                op: Or,
958                right,
959            }) if has_common_conjunction(&left, &right) => {
960                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
961                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
962                    .partition(|e| lhs.contains(e) && !e.is_volatile());
963
964                let new_rhs = rhs.into_iter().reduce(and);
965                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
966                let common_conjunction = common.into_iter().reduce(and).unwrap();
967
968                let new_expr = match (new_lhs, new_rhs) {
969                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
970                    (_, _) => common_conjunction,
971                };
972                Transformed::yes(new_expr)
973            }
974
975            //
976            // Rules for AND
977            //
978
979            // true AND A --> A
980            Expr::BinaryExpr(BinaryExpr {
981                left,
982                op: And,
983                right,
984            }) if is_true(&left) => Transformed::yes(*right),
985            // false AND A --> false (even if A is null)
986            Expr::BinaryExpr(BinaryExpr {
987                left,
988                op: And,
989                right: _,
990            }) if is_false(&left) => Transformed::yes(*left),
991            // A AND true --> A
992            Expr::BinaryExpr(BinaryExpr {
993                left,
994                op: And,
995                right,
996            }) if is_true(&right) => Transformed::yes(*left),
997            // A AND false --> false (even if A is null)
998            Expr::BinaryExpr(BinaryExpr {
999                left: _,
1000                op: And,
1001                right,
1002            }) if is_false(&right) => Transformed::yes(*right),
1003            // A AND !A ---> false (if A not nullable)
1004            Expr::BinaryExpr(BinaryExpr {
1005                left,
1006                op: And,
1007                right,
1008            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
1009                Transformed::yes(lit(false))
1010            }
1011            // !A AND A ---> false (if A not nullable)
1012            Expr::BinaryExpr(BinaryExpr {
1013                left,
1014                op: And,
1015                right,
1016            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
1017                Transformed::yes(lit(false))
1018            }
1019            // (..A..) AND A --> (..A..)
1020            Expr::BinaryExpr(BinaryExpr {
1021                left,
1022                op: And,
1023                right,
1024            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
1025            // A AND (..A..) --> (..A..)
1026            Expr::BinaryExpr(BinaryExpr {
1027                left,
1028                op: And,
1029                right,
1030            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1031            // A AND (A OR B) --> A
1032            Expr::BinaryExpr(BinaryExpr {
1033                left,
1034                op: And,
1035                right,
1036            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1037            // (A OR B) AND A --> A
1038            Expr::BinaryExpr(BinaryExpr {
1039                left,
1040                op: And,
1041                right,
1042            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1043            // A >= constant AND constant <= A --> A = constant
1044            Expr::BinaryExpr(BinaryExpr {
1045                left,
1046                op: And,
1047                right,
1048            }) if can_reduce_to_equal_statement(&left, &right) => {
1049                if let Expr::BinaryExpr(BinaryExpr {
1050                    left: left_left,
1051                    right: left_right,
1052                    ..
1053                }) = *left
1054                {
1055                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1056                        left: left_left,
1057                        op: Eq,
1058                        right: left_right,
1059                    }))
1060                } else {
1061                    return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr");
1062                }
1063            }
1064
1065            //
1066            // Rules for Multiply
1067            //
1068
1069            // A * 1 --> A (with type coercion if needed)
1070            Expr::BinaryExpr(BinaryExpr {
1071                left,
1072                op: Multiply,
1073                right,
1074            }) if is_one(&right) => {
1075                simplify_right_is_one_case(info, left, &Multiply, &right)?
1076            }
1077            // 1 * A --> A
1078            Expr::BinaryExpr(BinaryExpr {
1079                left,
1080                op: Multiply,
1081                right,
1082            }) if is_one(&left) => {
1083                // 1 * A is equivalent to A * 1
1084                simplify_right_is_one_case(info, right, &Multiply, &left)?
1085            }
1086
1087            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1088            Expr::BinaryExpr(BinaryExpr {
1089                left,
1090                op: Multiply,
1091                right,
1092            }) if !info.nullable(&left)?
1093                && !info.get_data_type(&left)?.is_floating()
1094                && is_zero(&right) =>
1095            {
1096                Transformed::yes(*right)
1097            }
1098            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1099            Expr::BinaryExpr(BinaryExpr {
1100                left,
1101                op: Multiply,
1102                right,
1103            }) if !info.nullable(&right)?
1104                && !info.get_data_type(&right)?.is_floating()
1105                && is_zero(&left) =>
1106            {
1107                Transformed::yes(*left)
1108            }
1109
1110            //
1111            // Rules for Divide
1112            //
1113
1114            // A / 1 --> A
1115            Expr::BinaryExpr(BinaryExpr {
1116                left,
1117                op: Divide,
1118                right,
1119            }) if is_one(&right) => {
1120                simplify_right_is_one_case(info, left, &Divide, &right)?
1121            }
1122
1123            //
1124            // Rules for Modulo
1125            //
1126
1127            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1128            Expr::BinaryExpr(BinaryExpr {
1129                left,
1130                op: Modulo,
1131                right,
1132            }) if !info.nullable(&left)?
1133                && !info.get_data_type(&left)?.is_floating()
1134                && is_one(&right) =>
1135            {
1136                Transformed::yes(Expr::Literal(
1137                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1138                    None,
1139                ))
1140            }
1141
1142            //
1143            // Rules for BitwiseAnd
1144            //
1145
1146            // A & 0 -> 0 (if A not nullable)
1147            Expr::BinaryExpr(BinaryExpr {
1148                left,
1149                op: BitwiseAnd,
1150                right,
1151            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1152
1153            // 0 & A -> 0 (if A not nullable)
1154            Expr::BinaryExpr(BinaryExpr {
1155                left,
1156                op: BitwiseAnd,
1157                right,
1158            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1159
1160            // !A & A -> 0 (if A not nullable)
1161            Expr::BinaryExpr(BinaryExpr {
1162                left,
1163                op: BitwiseAnd,
1164                right,
1165            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1166                Transformed::yes(Expr::Literal(
1167                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1168                    None,
1169                ))
1170            }
1171
1172            // A & !A -> 0 (if A not nullable)
1173            Expr::BinaryExpr(BinaryExpr {
1174                left,
1175                op: BitwiseAnd,
1176                right,
1177            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1178                Transformed::yes(Expr::Literal(
1179                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1180                    None,
1181                ))
1182            }
1183
1184            // (..A..) & A --> (..A..)
1185            Expr::BinaryExpr(BinaryExpr {
1186                left,
1187                op: BitwiseAnd,
1188                right,
1189            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1190
1191            // A & (..A..) --> (..A..)
1192            Expr::BinaryExpr(BinaryExpr {
1193                left,
1194                op: BitwiseAnd,
1195                right,
1196            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1197
1198            // A & (A | B) --> A (if B not null)
1199            Expr::BinaryExpr(BinaryExpr {
1200                left,
1201                op: BitwiseAnd,
1202                right,
1203            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1204                Transformed::yes(*left)
1205            }
1206
1207            // (A | B) & A --> A (if B not null)
1208            Expr::BinaryExpr(BinaryExpr {
1209                left,
1210                op: BitwiseAnd,
1211                right,
1212            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1213                Transformed::yes(*right)
1214            }
1215
1216            //
1217            // Rules for BitwiseOr
1218            //
1219
1220            // A | 0 -> A (even if A is null)
1221            Expr::BinaryExpr(BinaryExpr {
1222                left,
1223                op: BitwiseOr,
1224                right,
1225            }) if is_zero(&right) => Transformed::yes(*left),
1226
1227            // 0 | A -> A (even if A is null)
1228            Expr::BinaryExpr(BinaryExpr {
1229                left,
1230                op: BitwiseOr,
1231                right,
1232            }) if is_zero(&left) => Transformed::yes(*right),
1233
1234            // !A | A -> -1 (if A not nullable)
1235            Expr::BinaryExpr(BinaryExpr {
1236                left,
1237                op: BitwiseOr,
1238                right,
1239            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1240                Transformed::yes(Expr::Literal(
1241                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1242                    None,
1243                ))
1244            }
1245
1246            // A | !A -> -1 (if A not nullable)
1247            Expr::BinaryExpr(BinaryExpr {
1248                left,
1249                op: BitwiseOr,
1250                right,
1251            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1252                Transformed::yes(Expr::Literal(
1253                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1254                    None,
1255                ))
1256            }
1257
1258            // (..A..) | A --> (..A..)
1259            Expr::BinaryExpr(BinaryExpr {
1260                left,
1261                op: BitwiseOr,
1262                right,
1263            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1264
1265            // A | (..A..) --> (..A..)
1266            Expr::BinaryExpr(BinaryExpr {
1267                left,
1268                op: BitwiseOr,
1269                right,
1270            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1271
1272            // A | (A & B) --> A (if B not null)
1273            Expr::BinaryExpr(BinaryExpr {
1274                left,
1275                op: BitwiseOr,
1276                right,
1277            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1278                Transformed::yes(*left)
1279            }
1280
1281            // (A & B) | A --> A (if B not null)
1282            Expr::BinaryExpr(BinaryExpr {
1283                left,
1284                op: BitwiseOr,
1285                right,
1286            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1287                Transformed::yes(*right)
1288            }
1289
1290            //
1291            // Rules for BitwiseXor
1292            //
1293
1294            // A ^ 0 -> A (if A not nullable)
1295            Expr::BinaryExpr(BinaryExpr {
1296                left,
1297                op: BitwiseXor,
1298                right,
1299            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1300
1301            // 0 ^ A -> A (if A not nullable)
1302            Expr::BinaryExpr(BinaryExpr {
1303                left,
1304                op: BitwiseXor,
1305                right,
1306            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1307
1308            // !A ^ A -> -1 (if A not nullable)
1309            Expr::BinaryExpr(BinaryExpr {
1310                left,
1311                op: BitwiseXor,
1312                right,
1313            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1314                Transformed::yes(Expr::Literal(
1315                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1316                    None,
1317                ))
1318            }
1319
1320            // A ^ !A -> -1 (if A not nullable)
1321            Expr::BinaryExpr(BinaryExpr {
1322                left,
1323                op: BitwiseXor,
1324                right,
1325            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1326                Transformed::yes(Expr::Literal(
1327                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1328                    None,
1329                ))
1330            }
1331
1332            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1333            Expr::BinaryExpr(BinaryExpr {
1334                left,
1335                op: BitwiseXor,
1336                right,
1337            }) if expr_contains(&left, &right, BitwiseXor) => {
1338                let expr = delete_xor_in_complex_expr(&left, &right, false);
1339                Transformed::yes(if expr == *right {
1340                    Expr::Literal(
1341                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1342                        None,
1343                    )
1344                } else {
1345                    expr
1346                })
1347            }
1348
1349            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1350            Expr::BinaryExpr(BinaryExpr {
1351                left,
1352                op: BitwiseXor,
1353                right,
1354            }) if expr_contains(&right, &left, BitwiseXor) => {
1355                let expr = delete_xor_in_complex_expr(&right, &left, true);
1356                Transformed::yes(if expr == *left {
1357                    Expr::Literal(
1358                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1359                        None,
1360                    )
1361                } else {
1362                    expr
1363                })
1364            }
1365
1366            //
1367            // Rules for BitwiseShiftRight
1368            //
1369
1370            // A >> 0 -> A (even if A is null)
1371            Expr::BinaryExpr(BinaryExpr {
1372                left,
1373                op: BitwiseShiftRight,
1374                right,
1375            }) if is_zero(&right) => Transformed::yes(*left),
1376
1377            //
1378            // Rules for BitwiseShiftRight
1379            //
1380
1381            // A << 0 -> A (even if A is null)
1382            Expr::BinaryExpr(BinaryExpr {
1383                left,
1384                op: BitwiseShiftLeft,
1385                right,
1386            }) if is_zero(&right) => Transformed::yes(*left),
1387
1388            //
1389            // Rules for Not
1390            //
1391            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1392
1393            //
1394            // Rules for Negative
1395            //
1396            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1397
1398            //
1399            // Rules for Case
1400            //
1401
1402            // CASE
1403            //   WHEN X THEN A
1404            //   WHEN Y THEN B
1405            //   ...
1406            //   ELSE Q
1407            // END
1408            //
1409            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1410            //
1411            // Note: the rationale for this rewrite is that the expr can then be further
1412            // simplified using the existing rules for AND/OR
1413            Expr::Case(Case {
1414                expr: None,
1415                when_then_expr,
1416                else_expr,
1417            }) if !when_then_expr.is_empty()
1418                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1419                && info.is_boolean_type(&when_then_expr[0].1)? =>
1420            {
1421                // String disjunction of all the when predicates encountered so far. Not nullable.
1422                let mut filter_expr = lit(false);
1423                // The disjunction of all the cases
1424                let mut out_expr = lit(false);
1425
1426                for (when, then) in when_then_expr {
1427                    let when = is_exactly_true(*when, info)?;
1428                    let case_expr =
1429                        when.clone().and(filter_expr.clone().not()).and(*then);
1430
1431                    out_expr = out_expr.or(case_expr);
1432                    filter_expr = filter_expr.or(when);
1433                }
1434
1435                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1436                let case_expr = filter_expr.not().and(else_expr);
1437                out_expr = out_expr.or(case_expr);
1438
1439                // Do a first pass at simplification
1440                out_expr.rewrite(self)?
1441            }
1442            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1443                match udf.simplify(args, info)? {
1444                    ExprSimplifyResult::Original(args) => {
1445                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1446                            func: udf,
1447                            args,
1448                        }))
1449                    }
1450                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1451                }
1452            }
1453
1454            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1455                ref func,
1456                ..
1457            }) => match (func.simplify(), expr) {
1458                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1459                    Transformed::yes(simplify_function(af, info)?)
1460                }
1461                (_, expr) => Transformed::no(expr),
1462            },
1463
1464            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1465                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1466                    Transformed::yes(simplify_function(*wf, info)?)
1467                }
1468                (_, expr) => Transformed::no(expr),
1469            },
1470
1471            //
1472            // Rules for Between
1473            //
1474
1475            // a between 3 and 5  -->  a >= 3 AND a <=5
1476            // a not between 3 and 5  -->  a < 3 OR a > 5
1477            Expr::Between(between) => Transformed::yes(if between.negated {
1478                let l = *between.expr.clone();
1479                let r = *between.expr;
1480                or(l.lt(*between.low), r.gt(*between.high))
1481            } else {
1482                and(
1483                    between.expr.clone().gt_eq(*between.low),
1484                    between.expr.lt_eq(*between.high),
1485                )
1486            }),
1487
1488            //
1489            // Rules for regexes
1490            //
1491            Expr::BinaryExpr(BinaryExpr {
1492                left,
1493                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1494                right,
1495            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1496
1497            // Rules for Like
1498            Expr::Like(like) => {
1499                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1500                let escape_char = like.escape_char.unwrap_or('\\');
1501                match as_string_scalar(&like.pattern) {
1502                    Some((data_type, pattern_str)) => {
1503                        match pattern_str {
1504                            None => return Ok(Transformed::yes(lit_bool_null())),
1505                            Some(pattern_str) if pattern_str == "%" => {
1506                                // exp LIKE '%' is
1507                                //   - when exp is not NULL, it's true
1508                                //   - when exp is NULL, it's NULL
1509                                // exp NOT LIKE '%' is
1510                                //   - when exp is not NULL, it's false
1511                                //   - when exp is NULL, it's NULL
1512                                let result_for_non_null = lit(!like.negated);
1513                                Transformed::yes(if !info.nullable(&like.expr)? {
1514                                    result_for_non_null
1515                                } else {
1516                                    Expr::Case(Case {
1517                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1518                                        when_then_expr: vec![(
1519                                            Box::new(lit(true)),
1520                                            Box::new(result_for_non_null),
1521                                        )],
1522                                        else_expr: None,
1523                                    })
1524                                })
1525                            }
1526                            Some(pattern_str)
1527                                if pattern_str.contains("%%")
1528                                    && !pattern_str.contains(escape_char) =>
1529                            {
1530                                // Repeated occurrences of wildcard are redundant so remove them
1531                                // exp LIKE '%%'  --> exp LIKE '%'
1532                                let simplified_pattern = Regex::new("%%+")
1533                                    .unwrap()
1534                                    .replace_all(pattern_str, "%")
1535                                    .to_string();
1536                                Transformed::yes(Expr::Like(Like {
1537                                    pattern: Box::new(to_string_scalar(
1538                                        data_type,
1539                                        Some(simplified_pattern),
1540                                    )),
1541                                    ..like
1542                                }))
1543                            }
1544                            Some(pattern_str)
1545                                if !like.case_insensitive
1546                                    && !pattern_str
1547                                        .contains(['%', '_', escape_char].as_ref()) =>
1548                            {
1549                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1550                                // TODO: handle escape characters
1551                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1552                                    left: like.expr.clone(),
1553                                    op: if like.negated { NotEq } else { Eq },
1554                                    right: like.pattern.clone(),
1555                                }))
1556                            }
1557
1558                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1559                        }
1560                    }
1561                    None => Transformed::no(Expr::Like(like)),
1562                }
1563            }
1564
1565            // a is not null/unknown --> true (if a is not nullable)
1566            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1567                if !info.nullable(&expr)? =>
1568            {
1569                Transformed::yes(lit(true))
1570            }
1571
1572            // a is null/unknown --> false (if a is not nullable)
1573            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1574                Transformed::yes(lit(false))
1575            }
1576
1577            // expr IN () --> false
1578            // expr NOT IN () --> true
1579            Expr::InList(InList {
1580                expr: _,
1581                list,
1582                negated,
1583            }) if list.is_empty() => Transformed::yes(lit(negated)),
1584
1585            // null in (x, y, z) --> null
1586            // null not in (x, y, z) --> null
1587            Expr::InList(InList {
1588                expr,
1589                list,
1590                negated: _,
1591            }) if is_null(expr.as_ref()) && !list.is_empty() => {
1592                Transformed::yes(lit_bool_null())
1593            }
1594
1595            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1596            Expr::InList(InList {
1597                expr,
1598                mut list,
1599                negated,
1600            }) if list.len() == 1
1601                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1602            {
1603                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1604                    unreachable!()
1605                };
1606
1607                Transformed::yes(Expr::InSubquery(InSubquery::new(
1608                    expr, subquery, negated,
1609                )))
1610            }
1611
1612            // Combine multiple OR expressions into a single IN list expression if possible
1613            //
1614            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1615            Expr::BinaryExpr(BinaryExpr {
1616                left,
1617                op: Or,
1618                right,
1619            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1620                let lhs = to_inlist(*left).unwrap();
1621                let rhs = to_inlist(*right).unwrap();
1622                let mut seen: HashSet<Expr> = HashSet::new();
1623                let list = lhs
1624                    .list
1625                    .into_iter()
1626                    .chain(rhs.list)
1627                    .filter(|e| seen.insert(e.to_owned()))
1628                    .collect::<Vec<_>>();
1629
1630                let merged_inlist = InList {
1631                    expr: lhs.expr,
1632                    list,
1633                    negated: false,
1634                };
1635
1636                Transformed::yes(Expr::InList(merged_inlist))
1637            }
1638
1639            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1640            //
1641            // Rules:
1642            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1643            //   Intersection:
1644            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1645            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1646            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1647            //   Union:
1648            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1649            //     # This rule is handled by `or_in_list_simplifier.rs`
1650            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1651            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1652            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1653            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1654            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1655            Expr::BinaryExpr(BinaryExpr {
1656                left,
1657                op: And,
1658                right,
1659            }) if are_inlist_and_eq_and_match_neg(
1660                left.as_ref(),
1661                right.as_ref(),
1662                false,
1663                false,
1664            ) =>
1665            {
1666                match (*left, *right) {
1667                    (Expr::InList(l1), Expr::InList(l2)) => {
1668                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1669                    }
1670                    // Matched previously once
1671                    _ => unreachable!(),
1672                }
1673            }
1674
1675            Expr::BinaryExpr(BinaryExpr {
1676                left,
1677                op: And,
1678                right,
1679            }) if are_inlist_and_eq_and_match_neg(
1680                left.as_ref(),
1681                right.as_ref(),
1682                true,
1683                true,
1684            ) =>
1685            {
1686                match (*left, *right) {
1687                    (Expr::InList(l1), Expr::InList(l2)) => {
1688                        return inlist_union(l1, l2, true).map(Transformed::yes);
1689                    }
1690                    // Matched previously once
1691                    _ => unreachable!(),
1692                }
1693            }
1694
1695            Expr::BinaryExpr(BinaryExpr {
1696                left,
1697                op: And,
1698                right,
1699            }) if are_inlist_and_eq_and_match_neg(
1700                left.as_ref(),
1701                right.as_ref(),
1702                false,
1703                true,
1704            ) =>
1705            {
1706                match (*left, *right) {
1707                    (Expr::InList(l1), Expr::InList(l2)) => {
1708                        return inlist_except(l1, &l2).map(Transformed::yes);
1709                    }
1710                    // Matched previously once
1711                    _ => unreachable!(),
1712                }
1713            }
1714
1715            Expr::BinaryExpr(BinaryExpr {
1716                left,
1717                op: And,
1718                right,
1719            }) if are_inlist_and_eq_and_match_neg(
1720                left.as_ref(),
1721                right.as_ref(),
1722                true,
1723                false,
1724            ) =>
1725            {
1726                match (*left, *right) {
1727                    (Expr::InList(l1), Expr::InList(l2)) => {
1728                        return inlist_except(l2, &l1).map(Transformed::yes);
1729                    }
1730                    // Matched previously once
1731                    _ => unreachable!(),
1732                }
1733            }
1734
1735            Expr::BinaryExpr(BinaryExpr {
1736                left,
1737                op: Or,
1738                right,
1739            }) if are_inlist_and_eq_and_match_neg(
1740                left.as_ref(),
1741                right.as_ref(),
1742                true,
1743                true,
1744            ) =>
1745            {
1746                match (*left, *right) {
1747                    (Expr::InList(l1), Expr::InList(l2)) => {
1748                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1749                    }
1750                    // Matched previously once
1751                    _ => unreachable!(),
1752                }
1753            }
1754
1755            // =======================================
1756            // unwrap_cast_in_comparison
1757            // =======================================
1758            //
1759            // For case:
1760            // try_cast/cast(expr as data_type) op literal
1761            Expr::BinaryExpr(BinaryExpr { left, op, right })
1762                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1763                    info, &left, op, &right,
1764                ) && op.supports_propagation() =>
1765            {
1766                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1767            }
1768            // literal op try_cast/cast(expr as data_type)
1769            // -->
1770            // try_cast/cast(expr as data_type) op_swap literal
1771            Expr::BinaryExpr(BinaryExpr { left, op, right })
1772                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1773                    info, &right, op, &left,
1774                ) && op.supports_propagation()
1775                    && op.swap().is_some() =>
1776            {
1777                unwrap_cast_in_comparison_for_binary(
1778                    info,
1779                    *right,
1780                    *left,
1781                    op.swap().unwrap(),
1782                )?
1783            }
1784            // For case:
1785            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1786            Expr::InList(InList {
1787                expr: mut left,
1788                list,
1789                negated,
1790            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1791                info, &left, &list,
1792            ) =>
1793            {
1794                let (Expr::TryCast(TryCast {
1795                    expr: left_expr, ..
1796                })
1797                | Expr::Cast(Cast {
1798                    expr: left_expr, ..
1799                })) = left.as_mut()
1800                else {
1801                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1802                };
1803
1804                let expr_type = info.get_data_type(left_expr)?;
1805                let right_exprs = list
1806                    .into_iter()
1807                    .map(|right| {
1808                        match right {
1809                            Expr::Literal(right_lit_value, _) => {
1810                                // if the right_lit_value can be casted to the type of internal_left_expr
1811                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1812                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1813                                    internal_err!(
1814                                        "Can't cast the list expr {:?} to type {:?}",
1815                                        right_lit_value, &expr_type
1816                                    )?
1817                                };
1818                                Ok(lit(value))
1819                            }
1820                            other_expr => internal_err!(
1821                                "Only support literal expr to optimize, but the expr is {:?}",
1822                                &other_expr
1823                            ),
1824                        }
1825                    })
1826                    .collect::<Result<Vec<_>>>()?;
1827
1828                Transformed::yes(Expr::InList(InList {
1829                    expr: std::mem::take(left_expr),
1830                    list: right_exprs,
1831                    negated,
1832                }))
1833            }
1834
1835            // no additional rewrites possible
1836            expr => Transformed::no(expr),
1837        })
1838    }
1839}
1840
1841fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1842    match expr {
1843        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1844        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1845        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1846        _ => None,
1847    }
1848}
1849
1850fn to_string_scalar(data_type: DataType, value: Option<String>) -> Expr {
1851    match data_type {
1852        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1853        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1854        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1855        _ => unreachable!(),
1856    }
1857}
1858
1859fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1860    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1861    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1862}
1863
1864// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1865fn are_inlist_and_eq_and_match_neg(
1866    left: &Expr,
1867    right: &Expr,
1868    is_left_neg: bool,
1869    is_right_neg: bool,
1870) -> bool {
1871    match (left, right) {
1872        (Expr::InList(l), Expr::InList(r)) => {
1873            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1874        }
1875        _ => false,
1876    }
1877}
1878
1879// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1880fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
1881    let left = as_inlist(left);
1882    let right = as_inlist(right);
1883    if let (Some(lhs), Some(rhs)) = (left, right) {
1884        matches!(lhs.expr.as_ref(), Expr::Column(_))
1885            && matches!(rhs.expr.as_ref(), Expr::Column(_))
1886            && lhs.expr == rhs.expr
1887            && !lhs.negated
1888            && !rhs.negated
1889    } else {
1890        false
1891    }
1892}
1893
1894/// Try to convert an expression to an in-list expression
1895fn as_inlist(expr: &'_ Expr) -> Option<Cow<'_, InList>> {
1896    match expr {
1897        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
1898        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
1899            match (left.as_ref(), right.as_ref()) {
1900                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
1901                    expr: left.clone(),
1902                    list: vec![*right.clone()],
1903                    negated: false,
1904                })),
1905                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
1906                    expr: right.clone(),
1907                    list: vec![*left.clone()],
1908                    negated: false,
1909                })),
1910                _ => None,
1911            }
1912        }
1913        _ => None,
1914    }
1915}
1916
1917fn to_inlist(expr: Expr) -> Option<InList> {
1918    match expr {
1919        Expr::InList(inlist) => Some(inlist),
1920        Expr::BinaryExpr(BinaryExpr {
1921            left,
1922            op: Operator::Eq,
1923            right,
1924        }) => match (left.as_ref(), right.as_ref()) {
1925            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
1926                expr: left,
1927                list: vec![*right],
1928                negated: false,
1929            }),
1930            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
1931                expr: right,
1932                list: vec![*left],
1933                negated: false,
1934            }),
1935            _ => None,
1936        },
1937        _ => None,
1938    }
1939}
1940
1941/// Return the union of two inlist expressions
1942/// maintaining the order of the elements in the two lists
1943fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
1944    // extend the list in l1 with the elements in l2 that are not already in l1
1945    let l1_items: HashSet<_> = l1.list.iter().collect();
1946
1947    // keep all l2 items that do not also appear in l1
1948    let keep_l2: Vec<_> = l2
1949        .list
1950        .into_iter()
1951        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
1952        .collect();
1953
1954    l1.list.extend(keep_l2);
1955    l1.negated = negated;
1956    Ok(Expr::InList(l1))
1957}
1958
1959/// Return the intersection of two inlist expressions
1960/// maintaining the order of the elements in the two lists
1961fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
1962    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1963
1964    // remove all items from l1 that are not in l2
1965    l1.list.retain(|e| l2_items.contains(e));
1966
1967    // e in () is always false
1968    // e not in () is always true
1969    if l1.list.is_empty() {
1970        return Ok(lit(negated));
1971    }
1972    Ok(Expr::InList(l1))
1973}
1974
1975/// Return the all items in l1 that are not in l2
1976/// maintaining the order of the elements in the two lists
1977fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
1978    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1979
1980    // keep only items from l1 that are not in l2
1981    l1.list.retain(|e| !l2_items.contains(e));
1982
1983    if l1.list.is_empty() {
1984        return Ok(lit(false));
1985    }
1986    Ok(Expr::InList(l1))
1987}
1988
1989/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
1990fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1991    if !info.nullable(&expr)? {
1992        Ok(expr)
1993    } else {
1994        Ok(Expr::BinaryExpr(BinaryExpr {
1995            left: Box::new(expr),
1996            op: Operator::IsNotDistinctFrom,
1997            right: Box::new(lit(true)),
1998        }))
1999    }
2000}
2001
2002// A * 1 -> A
2003// A / 1 -> A
2004//
2005// Move this function body out of the large match branch avoid stack overflow
2006fn simplify_right_is_one_case<S: SimplifyInfo>(
2007    info: &S,
2008    left: Box<Expr>,
2009    op: &Operator,
2010    right: &Expr,
2011) -> Result<Transformed<Expr>> {
2012    // Check if resulting type would be different due to coercion
2013    let left_type = info.get_data_type(&left)?;
2014    let right_type = info.get_data_type(right)?;
2015    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2016        Ok(result_type) => {
2017            // Only cast if the types differ
2018            if left_type != result_type {
2019                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2020            } else {
2021                Ok(Transformed::yes(*left))
2022            }
2023        }
2024        Err(_) => Ok(Transformed::yes(*left)),
2025    }
2026}
2027
2028#[cfg(test)]
2029mod tests {
2030    use super::*;
2031    use crate::simplify_expressions::SimplifyContext;
2032    use crate::test::test_table_scan_with_name;
2033    use arrow::datatypes::FieldRef;
2034    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
2035    use datafusion_expr::{
2036        expr::WindowFunction,
2037        function::{
2038            AccumulatorArgs, AggregateFunctionSimplification,
2039            WindowFunctionSimplification,
2040        },
2041        interval_arithmetic::Interval,
2042        *,
2043    };
2044    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2045    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2046    use std::hash::Hash;
2047    use std::{
2048        collections::HashMap,
2049        ops::{BitAnd, BitOr, BitXor},
2050        sync::Arc,
2051    };
2052
2053    // ------------------------------
2054    // --- ExprSimplifier tests -----
2055    // ------------------------------
2056    #[test]
2057    fn api_basic() {
2058        let props = ExecutionProps::new();
2059        let simplifier =
2060            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2061
2062        let expr = lit(1) + lit(2);
2063        let expected = lit(3);
2064        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2065    }
2066
2067    #[test]
2068    fn basic_coercion() {
2069        let schema = test_schema();
2070        let props = ExecutionProps::new();
2071        let simplifier = ExprSimplifier::new(
2072            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2073        );
2074
2075        // Note expr type is int32 (not int64)
2076        // (1i64 + 2i32) < i
2077        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2078        // should fully simplify to 3 < i (though i has been coerced to i64)
2079        let expected = lit(3i64).lt(col("i"));
2080
2081        let expr = simplifier.coerce(expr, &schema).unwrap();
2082
2083        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2084    }
2085
2086    fn test_schema() -> DFSchemaRef {
2087        Schema::new(vec![
2088            Field::new("i", DataType::Int64, false),
2089            Field::new("b", DataType::Boolean, true),
2090        ])
2091        .to_dfschema_ref()
2092        .unwrap()
2093    }
2094
2095    #[test]
2096    fn simplify_and_constant_prop() {
2097        let props = ExecutionProps::new();
2098        let simplifier =
2099            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2100
2101        // should be able to simplify to false
2102        // (i * (1 - 2)) > 0
2103        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2104        let expected = lit(false);
2105        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2106    }
2107
2108    #[test]
2109    fn simplify_and_constant_prop_with_case() {
2110        let props = ExecutionProps::new();
2111        let simplifier =
2112            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2113
2114        //   CASE
2115        //     WHEN i>5 AND false THEN i > 5
2116        //     WHEN i<5 AND true THEN i < 5
2117        //     ELSE false
2118        //   END
2119        //
2120        // Can be simplified to `i < 5`
2121        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2122            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2123            .otherwise(lit(false))
2124            .unwrap();
2125        let expected = col("i").lt(lit(5));
2126        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2127    }
2128
2129    // ------------------------------
2130    // --- Simplifier tests -----
2131    // ------------------------------
2132
2133    #[test]
2134    fn test_simplify_canonicalize() {
2135        {
2136            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2137            let expected = col("c2").gt(lit(1));
2138            assert_eq!(simplify(expr), expected);
2139        }
2140        {
2141            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2142            let expected = col("c2").gt(col("c1"));
2143            assert_eq!(simplify(expr), expected);
2144        }
2145        {
2146            let expr = col("c1")
2147                .eq(lit(1))
2148                .and(lit(1).eq(col("c1")))
2149                .and(col("c1").eq(lit(3)));
2150            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2151            assert_eq!(simplify(expr), expected);
2152        }
2153        {
2154            let expr = col("c1")
2155                .eq(col("c2"))
2156                .and(col("c1").gt(lit(5)))
2157                .and(col("c2").eq(col("c1")));
2158            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2159            assert_eq!(simplify(expr), expected);
2160        }
2161        {
2162            let expr = col("c1")
2163                .eq(lit(1))
2164                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2165            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2166            assert_eq!(simplify(expr), expected);
2167        }
2168        {
2169            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2170            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2171            assert_eq!(simplify(expr), expected);
2172        }
2173        {
2174            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2175            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2176            assert_eq!(simplify(expr), expected);
2177        }
2178        {
2179            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2180            let expected = col("c2").lt(col("c1"));
2181            assert_eq!(simplify(expr), expected);
2182        }
2183    }
2184
2185    #[test]
2186    fn test_simplify_eq_not_self() {
2187        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2188        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2189        let expr_a = col("c2").eq(col("c2"));
2190        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2191
2192        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2193        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2194        let expected_b = lit(true);
2195
2196        assert_eq!(simplify(expr_a), expected_a);
2197        assert_eq!(simplify(expr_b), expected_b);
2198    }
2199
2200    #[test]
2201    fn test_simplify_or_true() {
2202        let expr_a = col("c2").or(lit(true));
2203        let expr_b = lit(true).or(col("c2"));
2204        let expected = lit(true);
2205
2206        assert_eq!(simplify(expr_a), expected);
2207        assert_eq!(simplify(expr_b), expected);
2208    }
2209
2210    #[test]
2211    fn test_simplify_or_false() {
2212        let expr_a = lit(false).or(col("c2"));
2213        let expr_b = col("c2").or(lit(false));
2214        let expected = col("c2");
2215
2216        assert_eq!(simplify(expr_a), expected);
2217        assert_eq!(simplify(expr_b), expected);
2218    }
2219
2220    #[test]
2221    fn test_simplify_or_same() {
2222        let expr = col("c2").or(col("c2"));
2223        let expected = col("c2");
2224
2225        assert_eq!(simplify(expr), expected);
2226    }
2227
2228    #[test]
2229    fn test_simplify_or_not_self() {
2230        // A OR !A if A is not nullable --> true
2231        // !A OR A if A is not nullable --> true
2232        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2233        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2234        let expected = lit(true);
2235
2236        assert_eq!(simplify(expr_a), expected);
2237        assert_eq!(simplify(expr_b), expected);
2238    }
2239
2240    #[test]
2241    fn test_simplify_and_false() {
2242        let expr_a = lit(false).and(col("c2"));
2243        let expr_b = col("c2").and(lit(false));
2244        let expected = lit(false);
2245
2246        assert_eq!(simplify(expr_a), expected);
2247        assert_eq!(simplify(expr_b), expected);
2248    }
2249
2250    #[test]
2251    fn test_simplify_and_same() {
2252        let expr = col("c2").and(col("c2"));
2253        let expected = col("c2");
2254
2255        assert_eq!(simplify(expr), expected);
2256    }
2257
2258    #[test]
2259    fn test_simplify_and_true() {
2260        let expr_a = lit(true).and(col("c2"));
2261        let expr_b = col("c2").and(lit(true));
2262        let expected = col("c2");
2263
2264        assert_eq!(simplify(expr_a), expected);
2265        assert_eq!(simplify(expr_b), expected);
2266    }
2267
2268    #[test]
2269    fn test_simplify_and_not_self() {
2270        // A AND !A if A is not nullable --> false
2271        // !A AND A if A is not nullable --> false
2272        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2273        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2274        let expected = lit(false);
2275
2276        assert_eq!(simplify(expr_a), expected);
2277        assert_eq!(simplify(expr_b), expected);
2278    }
2279
2280    #[test]
2281    fn test_simplify_multiply_by_one() {
2282        let expr_a = col("c2") * lit(1);
2283        let expr_b = lit(1) * col("c2");
2284        let expected = col("c2");
2285
2286        assert_eq!(simplify(expr_a), expected);
2287        assert_eq!(simplify(expr_b), expected);
2288
2289        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2290        assert_eq!(simplify(expr), expected);
2291
2292        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2293        assert_eq!(simplify(expr), expected);
2294    }
2295
2296    #[test]
2297    fn test_simplify_multiply_by_null() {
2298        let null = lit(ScalarValue::Int64(None));
2299        // A * null --> null
2300        {
2301            let expr = col("c3") * null.clone();
2302            assert_eq!(simplify(expr), null);
2303        }
2304        // null * A --> null
2305        {
2306            let expr = null.clone() * col("c3");
2307            assert_eq!(simplify(expr), null);
2308        }
2309    }
2310
2311    #[test]
2312    fn test_simplify_multiply_by_zero() {
2313        // cannot optimize A * null (null * A) if A is nullable
2314        {
2315            let expr_a = col("c2") * lit(0);
2316            let expr_b = lit(0) * col("c2");
2317
2318            assert_eq!(simplify(expr_a.clone()), expr_a);
2319            assert_eq!(simplify(expr_b.clone()), expr_b);
2320        }
2321        // 0 * A --> 0 if A is not nullable
2322        {
2323            let expr = lit(0) * col("c2_non_null");
2324            assert_eq!(simplify(expr), lit(0));
2325        }
2326        // A * 0 --> 0 if A is not nullable
2327        {
2328            let expr = col("c2_non_null") * lit(0);
2329            assert_eq!(simplify(expr), lit(0));
2330        }
2331        // A * Decimal128(0) --> 0 if A is not nullable
2332        {
2333            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2334            assert_eq!(
2335                simplify(expr),
2336                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2337            );
2338            let expr = binary_expr(
2339                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2340                Operator::Multiply,
2341                col("c2_non_null"),
2342            );
2343            assert_eq!(
2344                simplify(expr),
2345                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2346            );
2347        }
2348    }
2349
2350    #[test]
2351    fn test_simplify_divide_by_one() {
2352        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2353        let expected = col("c2");
2354        assert_eq!(simplify(expr), expected);
2355        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2356        assert_eq!(simplify(expr), expected);
2357    }
2358
2359    #[test]
2360    fn test_simplify_divide_null() {
2361        // A / null --> null
2362        let null = lit(ScalarValue::Int64(None));
2363        {
2364            let expr = col("c3") / null.clone();
2365            assert_eq!(simplify(expr), null);
2366        }
2367        // null / A --> null
2368        {
2369            let expr = null.clone() / col("c3");
2370            assert_eq!(simplify(expr), null);
2371        }
2372    }
2373
2374    #[test]
2375    fn test_simplify_divide_by_same() {
2376        let expr = col("c2") / col("c2");
2377        // if c2 is null, c2 / c2 = null, so can't simplify
2378        let expected = expr.clone();
2379
2380        assert_eq!(simplify(expr), expected);
2381    }
2382
2383    #[test]
2384    fn test_simplify_modulo_by_null() {
2385        let null = lit(ScalarValue::Int64(None));
2386        // A % null --> null
2387        {
2388            let expr = col("c3") % null.clone();
2389            assert_eq!(simplify(expr), null);
2390        }
2391        // null % A --> null
2392        {
2393            let expr = null.clone() % col("c3");
2394            assert_eq!(simplify(expr), null);
2395        }
2396    }
2397
2398    #[test]
2399    fn test_simplify_modulo_by_one() {
2400        let expr = col("c2") % lit(1);
2401        // if c2 is null, c2 % 1 = null, so can't simplify
2402        let expected = expr.clone();
2403
2404        assert_eq!(simplify(expr), expected);
2405    }
2406
2407    #[test]
2408    fn test_simplify_divide_zero_by_zero() {
2409        // because divide by 0 maybe occur in short-circuit expression
2410        // so we should not simplify this, and throw error in runtime
2411        let expr = lit(0) / lit(0);
2412        let expected = expr.clone();
2413
2414        assert_eq!(simplify(expr), expected);
2415    }
2416
2417    #[test]
2418    fn test_simplify_divide_by_zero() {
2419        // because divide by 0 maybe occur in short-circuit expression
2420        // so we should not simplify this, and throw error in runtime
2421        let expr = col("c2_non_null") / lit(0);
2422        let expected = expr.clone();
2423
2424        assert_eq!(simplify(expr), expected);
2425    }
2426
2427    #[test]
2428    fn test_simplify_modulo_by_one_non_null() {
2429        let expr = col("c3_non_null") % lit(1);
2430        let expected = lit(0_i64);
2431        assert_eq!(simplify(expr), expected);
2432        let expr =
2433            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2434        assert_eq!(simplify(expr), expected);
2435    }
2436
2437    #[test]
2438    fn test_simplify_bitwise_xor_by_null() {
2439        let null = lit(ScalarValue::Int64(None));
2440        // A ^ null --> null
2441        {
2442            let expr = col("c3") ^ null.clone();
2443            assert_eq!(simplify(expr), null);
2444        }
2445        // null ^ A --> null
2446        {
2447            let expr = null.clone() ^ col("c3");
2448            assert_eq!(simplify(expr), null);
2449        }
2450    }
2451
2452    #[test]
2453    fn test_simplify_bitwise_shift_right_by_null() {
2454        let null = lit(ScalarValue::Int64(None));
2455        // A >> null --> null
2456        {
2457            let expr = col("c3") >> null.clone();
2458            assert_eq!(simplify(expr), null);
2459        }
2460        // null >> A --> null
2461        {
2462            let expr = null.clone() >> col("c3");
2463            assert_eq!(simplify(expr), null);
2464        }
2465    }
2466
2467    #[test]
2468    fn test_simplify_bitwise_shift_left_by_null() {
2469        let null = lit(ScalarValue::Int64(None));
2470        // A << null --> null
2471        {
2472            let expr = col("c3") << null.clone();
2473            assert_eq!(simplify(expr), null);
2474        }
2475        // null << A --> null
2476        {
2477            let expr = null.clone() << col("c3");
2478            assert_eq!(simplify(expr), null);
2479        }
2480    }
2481
2482    #[test]
2483    fn test_simplify_bitwise_and_by_zero() {
2484        // A & 0 --> 0
2485        {
2486            let expr = col("c2_non_null") & lit(0);
2487            assert_eq!(simplify(expr), lit(0));
2488        }
2489        // 0 & A --> 0
2490        {
2491            let expr = lit(0) & col("c2_non_null");
2492            assert_eq!(simplify(expr), lit(0));
2493        }
2494    }
2495
2496    #[test]
2497    fn test_simplify_bitwise_or_by_zero() {
2498        // A | 0 --> A
2499        {
2500            let expr = col("c2_non_null") | lit(0);
2501            assert_eq!(simplify(expr), col("c2_non_null"));
2502        }
2503        // 0 | A --> A
2504        {
2505            let expr = lit(0) | col("c2_non_null");
2506            assert_eq!(simplify(expr), col("c2_non_null"));
2507        }
2508    }
2509
2510    #[test]
2511    fn test_simplify_bitwise_xor_by_zero() {
2512        // A ^ 0 --> A
2513        {
2514            let expr = col("c2_non_null") ^ lit(0);
2515            assert_eq!(simplify(expr), col("c2_non_null"));
2516        }
2517        // 0 ^ A --> A
2518        {
2519            let expr = lit(0) ^ col("c2_non_null");
2520            assert_eq!(simplify(expr), col("c2_non_null"));
2521        }
2522    }
2523
2524    #[test]
2525    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2526        // A >> 0 --> A
2527        {
2528            let expr = col("c2_non_null") >> lit(0);
2529            assert_eq!(simplify(expr), col("c2_non_null"));
2530        }
2531    }
2532
2533    #[test]
2534    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2535        // A << 0 --> A
2536        {
2537            let expr = col("c2_non_null") << lit(0);
2538            assert_eq!(simplify(expr), col("c2_non_null"));
2539        }
2540    }
2541
2542    #[test]
2543    fn test_simplify_bitwise_and_by_null() {
2544        let null = Expr::Literal(ScalarValue::Int64(None), None);
2545        // A & null --> null
2546        {
2547            let expr = col("c3") & null.clone();
2548            assert_eq!(simplify(expr), null);
2549        }
2550        // null & A --> null
2551        {
2552            let expr = null.clone() & col("c3");
2553            assert_eq!(simplify(expr), null);
2554        }
2555    }
2556
2557    #[test]
2558    fn test_simplify_composed_bitwise_and() {
2559        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2560
2561        let expr = bitwise_and(
2562            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2563            col("c2").gt(lit(5)),
2564        );
2565        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2566
2567        assert_eq!(simplify(expr), expected);
2568
2569        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2570
2571        let expr = bitwise_and(
2572            col("c2").gt(lit(5)),
2573            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2574        );
2575        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2576        assert_eq!(simplify(expr), expected);
2577    }
2578
2579    #[test]
2580    fn test_simplify_composed_bitwise_or() {
2581        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2582
2583        let expr = bitwise_or(
2584            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2585            col("c2").gt(lit(5)),
2586        );
2587        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2588
2589        assert_eq!(simplify(expr), expected);
2590
2591        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2592
2593        let expr = bitwise_or(
2594            col("c2").gt(lit(5)),
2595            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2596        );
2597        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2598
2599        assert_eq!(simplify(expr), expected);
2600    }
2601
2602    #[test]
2603    fn test_simplify_composed_bitwise_xor() {
2604        // with an even number of the column "c2"
2605        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2606
2607        let expr = bitwise_xor(
2608            col("c2"),
2609            bitwise_xor(
2610                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2611                bitwise_and(col("c1"), col("c2")),
2612            ),
2613        );
2614
2615        let expected = bitwise_xor(
2616            bitwise_or(col("c2"), col("c1")),
2617            bitwise_and(col("c1"), col("c2")),
2618        );
2619
2620        assert_eq!(simplify(expr), expected);
2621
2622        // with an odd number of the column "c2"
2623        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2624
2625        let expr = bitwise_xor(
2626            col("c2"),
2627            bitwise_xor(
2628                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2629                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2630            ),
2631        );
2632
2633        let expected = bitwise_xor(
2634            col("c2"),
2635            bitwise_xor(
2636                bitwise_or(col("c2"), col("c1")),
2637                bitwise_and(col("c1"), col("c2")),
2638            ),
2639        );
2640
2641        assert_eq!(simplify(expr), expected);
2642
2643        // with an even number of the column "c2"
2644        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2645
2646        let expr = bitwise_xor(
2647            bitwise_xor(
2648                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2649                bitwise_and(col("c1"), col("c2")),
2650            ),
2651            col("c2"),
2652        );
2653
2654        let expected = bitwise_xor(
2655            bitwise_or(col("c2"), col("c1")),
2656            bitwise_and(col("c1"), col("c2")),
2657        );
2658
2659        assert_eq!(simplify(expr), expected);
2660
2661        // with an odd number of the column "c2"
2662        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2663
2664        let expr = bitwise_xor(
2665            bitwise_xor(
2666                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2667                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2668            ),
2669            col("c2"),
2670        );
2671
2672        let expected = bitwise_xor(
2673            bitwise_xor(
2674                bitwise_or(col("c2"), col("c1")),
2675                bitwise_and(col("c1"), col("c2")),
2676            ),
2677            col("c2"),
2678        );
2679
2680        assert_eq!(simplify(expr), expected);
2681    }
2682
2683    #[test]
2684    fn test_simplify_negated_bitwise_and() {
2685        // !c4 & c4 --> 0
2686        let expr = (-col("c4_non_null")) & col("c4_non_null");
2687        let expected = lit(0u32);
2688
2689        assert_eq!(simplify(expr), expected);
2690        // c4 & !c4 --> 0
2691        let expr = col("c4_non_null") & (-col("c4_non_null"));
2692        let expected = lit(0u32);
2693
2694        assert_eq!(simplify(expr), expected);
2695
2696        // !c3 & c3 --> 0
2697        let expr = (-col("c3_non_null")) & col("c3_non_null");
2698        let expected = lit(0i64);
2699
2700        assert_eq!(simplify(expr), expected);
2701        // c3 & !c3 --> 0
2702        let expr = col("c3_non_null") & (-col("c3_non_null"));
2703        let expected = lit(0i64);
2704
2705        assert_eq!(simplify(expr), expected);
2706    }
2707
2708    #[test]
2709    fn test_simplify_negated_bitwise_or() {
2710        // !c4 | c4 --> -1
2711        let expr = (-col("c4_non_null")) | col("c4_non_null");
2712        let expected = lit(-1i32);
2713
2714        assert_eq!(simplify(expr), expected);
2715
2716        // c4 | !c4 --> -1
2717        let expr = col("c4_non_null") | (-col("c4_non_null"));
2718        let expected = lit(-1i32);
2719
2720        assert_eq!(simplify(expr), expected);
2721
2722        // !c3 | c3 --> -1
2723        let expr = (-col("c3_non_null")) | col("c3_non_null");
2724        let expected = lit(-1i64);
2725
2726        assert_eq!(simplify(expr), expected);
2727
2728        // c3 | !c3 --> -1
2729        let expr = col("c3_non_null") | (-col("c3_non_null"));
2730        let expected = lit(-1i64);
2731
2732        assert_eq!(simplify(expr), expected);
2733    }
2734
2735    #[test]
2736    fn test_simplify_negated_bitwise_xor() {
2737        // !c4 ^ c4 --> -1
2738        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2739        let expected = lit(-1i32);
2740
2741        assert_eq!(simplify(expr), expected);
2742
2743        // c4 ^ !c4 --> -1
2744        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2745        let expected = lit(-1i32);
2746
2747        assert_eq!(simplify(expr), expected);
2748
2749        // !c3 ^ c3 --> -1
2750        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2751        let expected = lit(-1i64);
2752
2753        assert_eq!(simplify(expr), expected);
2754
2755        // c3 ^ !c3 --> -1
2756        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2757        let expected = lit(-1i64);
2758
2759        assert_eq!(simplify(expr), expected);
2760    }
2761
2762    #[test]
2763    fn test_simplify_bitwise_and_or() {
2764        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2765        let expr = bitwise_and(
2766            col("c2_non_null").lt(lit(3)),
2767            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2768        );
2769        let expected = col("c2_non_null").lt(lit(3));
2770
2771        assert_eq!(simplify(expr), expected);
2772    }
2773
2774    #[test]
2775    fn test_simplify_bitwise_or_and() {
2776        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2777        let expr = bitwise_or(
2778            col("c2_non_null").lt(lit(3)),
2779            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2780        );
2781        let expected = col("c2_non_null").lt(lit(3));
2782
2783        assert_eq!(simplify(expr), expected);
2784    }
2785
2786    #[test]
2787    fn test_simplify_simple_bitwise_and() {
2788        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2789        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2790        let expected = col("c2").gt(lit(5));
2791
2792        assert_eq!(simplify(expr), expected);
2793    }
2794
2795    #[test]
2796    fn test_simplify_simple_bitwise_or() {
2797        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2798        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2799        let expected = col("c2").gt(lit(5));
2800
2801        assert_eq!(simplify(expr), expected);
2802    }
2803
2804    #[test]
2805    fn test_simplify_simple_bitwise_xor() {
2806        // c4 ^ c4 -> 0
2807        let expr = (col("c4")).bitxor(col("c4"));
2808        let expected = lit(0u32);
2809
2810        assert_eq!(simplify(expr), expected);
2811
2812        // c3 ^ c3 -> 0
2813        let expr = col("c3").bitxor(col("c3"));
2814        let expected = lit(0i64);
2815
2816        assert_eq!(simplify(expr), expected);
2817    }
2818
2819    #[test]
2820    fn test_simplify_modulo_by_zero_non_null() {
2821        // because modulo by 0 maybe occur in short-circuit expression
2822        // so we should not simplify this, and throw error in runtime.
2823        let expr = col("c2_non_null") % lit(0);
2824        let expected = expr.clone();
2825
2826        assert_eq!(simplify(expr), expected);
2827    }
2828
2829    #[test]
2830    fn test_simplify_simple_and() {
2831        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2832        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2833        let expected = col("c2").gt(lit(5));
2834
2835        assert_eq!(simplify(expr), expected);
2836    }
2837
2838    #[test]
2839    fn test_simplify_composed_and() {
2840        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2841        let expr = and(
2842            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2843            col("c2").gt(lit(5)),
2844        );
2845        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2846
2847        assert_eq!(simplify(expr), expected);
2848    }
2849
2850    #[test]
2851    fn test_simplify_negated_and() {
2852        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2853        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2854        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2855
2856        assert_eq!(simplify(expr), expected);
2857    }
2858
2859    #[test]
2860    fn test_simplify_or_and() {
2861        let l = col("c2").gt(lit(5));
2862        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2863
2864        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
2865        let expr = or(l.clone(), r.clone());
2866
2867        let expected = l.clone();
2868        assert_eq!(simplify(expr), expected);
2869
2870        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
2871        let expr = or(r, l);
2872        assert_eq!(simplify(expr), expected);
2873    }
2874
2875    #[test]
2876    fn test_simplify_or_and_non_null() {
2877        let l = col("c2_non_null").gt(lit(5));
2878        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2879
2880        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
2881        let expr = or(l.clone(), r.clone());
2882
2883        // This is only true if `c1 < 6` is not nullable / can not be null.
2884        let expected = col("c2_non_null").gt(lit(5));
2885
2886        assert_eq!(simplify(expr), expected);
2887
2888        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
2889        let expr = or(l, r);
2890
2891        assert_eq!(simplify(expr), expected);
2892    }
2893
2894    #[test]
2895    fn test_simplify_and_or() {
2896        let l = col("c2").gt(lit(5));
2897        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2898
2899        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2900        let expr = and(l.clone(), r.clone());
2901
2902        let expected = l.clone();
2903        assert_eq!(simplify(expr), expected);
2904
2905        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2906        let expr = and(r, l);
2907        assert_eq!(simplify(expr), expected);
2908    }
2909
2910    #[test]
2911    fn test_simplify_and_or_non_null() {
2912        let l = col("c2_non_null").gt(lit(5));
2913        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2914
2915        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2916        let expr = and(l.clone(), r.clone());
2917
2918        // This is only true if `c1 < 6` is not nullable / can not be null.
2919        let expected = col("c2_non_null").gt(lit(5));
2920
2921        assert_eq!(simplify(expr), expected);
2922
2923        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2924        let expr = and(l, r);
2925
2926        assert_eq!(simplify(expr), expected);
2927    }
2928
2929    #[test]
2930    fn test_simplify_by_de_morgan_laws() {
2931        // Laws with logical operations
2932        // !(c3 AND c4) --> !c3 OR !c4
2933        let expr = and(col("c3"), col("c4")).not();
2934        let expected = or(col("c3").not(), col("c4").not());
2935        assert_eq!(simplify(expr), expected);
2936        // !(c3 OR c4) --> !c3 AND !c4
2937        let expr = or(col("c3"), col("c4")).not();
2938        let expected = and(col("c3").not(), col("c4").not());
2939        assert_eq!(simplify(expr), expected);
2940        // !(!c3) --> c3
2941        let expr = col("c3").not().not();
2942        let expected = col("c3");
2943        assert_eq!(simplify(expr), expected);
2944
2945        // Laws with bitwise operations
2946        // !(c3 & c4) --> !c3 | !c4
2947        let expr = -bitwise_and(col("c3"), col("c4"));
2948        let expected = bitwise_or(-col("c3"), -col("c4"));
2949        assert_eq!(simplify(expr), expected);
2950        // !(c3 | c4) --> !c3 & !c4
2951        let expr = -bitwise_or(col("c3"), col("c4"));
2952        let expected = bitwise_and(-col("c3"), -col("c4"));
2953        assert_eq!(simplify(expr), expected);
2954        // !(!c3) --> c3
2955        let expr = -(-col("c3"));
2956        let expected = col("c3");
2957        assert_eq!(simplify(expr), expected);
2958    }
2959
2960    #[test]
2961    fn test_simplify_null_and_false() {
2962        let expr = and(lit_bool_null(), lit(false));
2963        let expr_eq = lit(false);
2964
2965        assert_eq!(simplify(expr), expr_eq);
2966    }
2967
2968    #[test]
2969    fn test_simplify_divide_null_by_null() {
2970        let null = lit(ScalarValue::Int32(None));
2971        let expr_plus = null.clone() / null.clone();
2972        let expr_eq = null;
2973
2974        assert_eq!(simplify(expr_plus), expr_eq);
2975    }
2976
2977    #[test]
2978    fn test_simplify_simplify_arithmetic_expr() {
2979        let expr_plus = lit(1) + lit(1);
2980
2981        assert_eq!(simplify(expr_plus), lit(2));
2982    }
2983
2984    #[test]
2985    fn test_simplify_simplify_eq_expr() {
2986        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
2987
2988        assert_eq!(simplify(expr_eq), lit(true));
2989    }
2990
2991    #[test]
2992    fn test_simplify_regex() {
2993        // malformed regex
2994        assert_contains!(
2995            try_simplify(regex_match(col("c1"), lit("foo{")))
2996                .unwrap_err()
2997                .to_string(),
2998            "regex parse error"
2999        );
3000
3001        // unsupported cases
3002        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3003        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3004        assert_no_change(regex_match(col("c1"), lit("%")));
3005        assert_no_change(regex_match(col("c1"), lit("_")));
3006        assert_no_change(regex_match(col("c1"), lit("f%o")));
3007        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3008        assert_no_change(regex_match(col("c1"), lit("f_o")));
3009
3010        // empty cases
3011        assert_change(
3012            regex_match(col("c1"), lit("")),
3013            if_not_null(col("c1"), true),
3014        );
3015        assert_change(
3016            regex_not_match(col("c1"), lit("")),
3017            if_not_null(col("c1"), false),
3018        );
3019        assert_change(
3020            regex_imatch(col("c1"), lit("")),
3021            if_not_null(col("c1"), true),
3022        );
3023        assert_change(
3024            regex_not_imatch(col("c1"), lit("")),
3025            if_not_null(col("c1"), false),
3026        );
3027
3028        // single character
3029        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3030
3031        // single word
3032        assert_change(
3033            regex_match(col("c1"), lit("foo")),
3034            col("c1").like(lit("%foo%")),
3035        );
3036
3037        // regular expressions that match an exact literal
3038        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3039        assert_change(
3040            regex_not_match(col("c1"), lit("^$")),
3041            col("c1").not_eq(lit("")),
3042        );
3043        assert_change(
3044            regex_match(col("c1"), lit("^foo$")),
3045            col("c1").eq(lit("foo")),
3046        );
3047        assert_change(
3048            regex_not_match(col("c1"), lit("^foo$")),
3049            col("c1").not_eq(lit("foo")),
3050        );
3051
3052        // regular expressions that match exact captured literals
3053        assert_change(
3054            regex_match(col("c1"), lit("^(foo|bar)$")),
3055            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3056        );
3057        assert_change(
3058            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3059            col("c1")
3060                .not_eq(lit("foo"))
3061                .and(col("c1").not_eq(lit("bar"))),
3062        );
3063        assert_change(
3064            regex_match(col("c1"), lit("^(foo)$")),
3065            col("c1").eq(lit("foo")),
3066        );
3067        assert_change(
3068            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3069            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3070                .or(col("c1").eq(lit("baz"))),
3071        );
3072        assert_change(
3073            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3074            col("c1")
3075                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3076        );
3077        assert_change(
3078            regex_match(col("c1"), lit("^(fo_o)$")),
3079            col("c1").eq(lit("fo_o")),
3080        );
3081        assert_change(
3082            regex_match(col("c1"), lit("^(fo_o)$")),
3083            col("c1").eq(lit("fo_o")),
3084        );
3085        assert_change(
3086            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3087            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3088        );
3089        assert_change(
3090            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3091            col("c1")
3092                .not_eq(lit("fo_o"))
3093                .and(col("c1").not_eq(lit("ba_r"))),
3094        );
3095        assert_change(
3096            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3097            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3098                .or(col("c1").eq(lit("ba_z"))),
3099        );
3100        assert_change(
3101            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3102            col("c1").in_list(
3103                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3104                false,
3105            ),
3106        );
3107
3108        // regular expressions that mismatch captured literals
3109        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3110        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3111        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3112        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3113        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3114        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3115        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3116        assert_no_change(regex_match(col("c1"), lit("^")));
3117        assert_no_change(regex_match(col("c1"), lit("$")));
3118        assert_no_change(regex_match(col("c1"), lit("$^")));
3119        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3120
3121        // regular expressions that match a partial literal
3122        assert_change(
3123            regex_match(col("c1"), lit("^foo")),
3124            col("c1").like(lit("foo%")),
3125        );
3126        assert_change(
3127            regex_match(col("c1"), lit("foo$")),
3128            col("c1").like(lit("%foo")),
3129        );
3130        assert_change(
3131            regex_match(col("c1"), lit("^foo|bar$")),
3132            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3133        );
3134
3135        // OR-chain
3136        assert_change(
3137            regex_match(col("c1"), lit("foo|bar|baz")),
3138            col("c1")
3139                .like(lit("%foo%"))
3140                .or(col("c1").like(lit("%bar%")))
3141                .or(col("c1").like(lit("%baz%"))),
3142        );
3143        assert_change(
3144            regex_match(col("c1"), lit("foo|x|baz")),
3145            col("c1")
3146                .like(lit("%foo%"))
3147                .or(col("c1").like(lit("%x%")))
3148                .or(col("c1").like(lit("%baz%"))),
3149        );
3150        assert_change(
3151            regex_not_match(col("c1"), lit("foo|bar|baz")),
3152            col("c1")
3153                .not_like(lit("%foo%"))
3154                .and(col("c1").not_like(lit("%bar%")))
3155                .and(col("c1").not_like(lit("%baz%"))),
3156        );
3157        // both anchored expressions (translated to equality) and unanchored
3158        assert_change(
3159            regex_match(col("c1"), lit("foo|^x$|baz")),
3160            col("c1")
3161                .like(lit("%foo%"))
3162                .or(col("c1").eq(lit("x")))
3163                .or(col("c1").like(lit("%baz%"))),
3164        );
3165        assert_change(
3166            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3167            col("c1")
3168                .not_like(lit("%foo%"))
3169                .and(col("c1").not_eq(lit("bar")))
3170                .and(col("c1").not_like(lit("%baz%"))),
3171        );
3172        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3173        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3174    }
3175
3176    #[track_caller]
3177    fn assert_no_change(expr: Expr) {
3178        let optimized = simplify(expr.clone());
3179        assert_eq!(expr, optimized);
3180    }
3181
3182    #[track_caller]
3183    fn assert_change(expr: Expr, expected: Expr) {
3184        let optimized = simplify(expr);
3185        assert_eq!(optimized, expected);
3186    }
3187
3188    fn regex_match(left: Expr, right: Expr) -> Expr {
3189        Expr::BinaryExpr(BinaryExpr {
3190            left: Box::new(left),
3191            op: Operator::RegexMatch,
3192            right: Box::new(right),
3193        })
3194    }
3195
3196    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3197        Expr::BinaryExpr(BinaryExpr {
3198            left: Box::new(left),
3199            op: Operator::RegexNotMatch,
3200            right: Box::new(right),
3201        })
3202    }
3203
3204    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3205        Expr::BinaryExpr(BinaryExpr {
3206            left: Box::new(left),
3207            op: Operator::RegexIMatch,
3208            right: Box::new(right),
3209        })
3210    }
3211
3212    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3213        Expr::BinaryExpr(BinaryExpr {
3214            left: Box::new(left),
3215            op: Operator::RegexNotIMatch,
3216            right: Box::new(right),
3217        })
3218    }
3219
3220    // ------------------------------
3221    // ----- Simplifier tests -------
3222    // ------------------------------
3223
3224    fn try_simplify(expr: Expr) -> Result<Expr> {
3225        let schema = expr_test_schema();
3226        let execution_props = ExecutionProps::new();
3227        let simplifier = ExprSimplifier::new(
3228            SimplifyContext::new(&execution_props).with_schema(schema),
3229        );
3230        simplifier.simplify(expr)
3231    }
3232
3233    fn coerce(expr: Expr) -> Expr {
3234        let schema = expr_test_schema();
3235        let execution_props = ExecutionProps::new();
3236        let simplifier = ExprSimplifier::new(
3237            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3238        );
3239        simplifier.coerce(expr, schema.as_ref()).unwrap()
3240    }
3241
3242    fn simplify(expr: Expr) -> Expr {
3243        try_simplify(expr).unwrap()
3244    }
3245
3246    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3247        let schema = expr_test_schema();
3248        let execution_props = ExecutionProps::new();
3249        let simplifier = ExprSimplifier::new(
3250            SimplifyContext::new(&execution_props).with_schema(schema),
3251        );
3252        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3253        Ok((expr.data, count))
3254    }
3255
3256    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3257        try_simplify_with_cycle_count(expr).unwrap()
3258    }
3259
3260    fn simplify_with_guarantee(
3261        expr: Expr,
3262        guarantees: Vec<(Expr, NullableInterval)>,
3263    ) -> Expr {
3264        let schema = expr_test_schema();
3265        let execution_props = ExecutionProps::new();
3266        let simplifier = ExprSimplifier::new(
3267            SimplifyContext::new(&execution_props).with_schema(schema),
3268        )
3269        .with_guarantees(guarantees);
3270        simplifier.simplify(expr).unwrap()
3271    }
3272
3273    fn expr_test_schema() -> DFSchemaRef {
3274        Arc::new(
3275            DFSchema::from_unqualified_fields(
3276                vec![
3277                    Field::new("c1", DataType::Utf8, true),
3278                    Field::new("c2", DataType::Boolean, true),
3279                    Field::new("c3", DataType::Int64, true),
3280                    Field::new("c4", DataType::UInt32, true),
3281                    Field::new("c1_non_null", DataType::Utf8, false),
3282                    Field::new("c2_non_null", DataType::Boolean, false),
3283                    Field::new("c3_non_null", DataType::Int64, false),
3284                    Field::new("c4_non_null", DataType::UInt32, false),
3285                    Field::new("c5", DataType::FixedSizeBinary(3), true),
3286                ]
3287                .into(),
3288                HashMap::new(),
3289            )
3290            .unwrap(),
3291        )
3292    }
3293
3294    #[test]
3295    fn simplify_expr_null_comparison() {
3296        // x = null is always null
3297        assert_eq!(
3298            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3299            lit(ScalarValue::Boolean(None)),
3300        );
3301
3302        // null != null is always null
3303        assert_eq!(
3304            simplify(
3305                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3306            ),
3307            lit(ScalarValue::Boolean(None)),
3308        );
3309
3310        // x != null is always null
3311        assert_eq!(
3312            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3313            lit(ScalarValue::Boolean(None)),
3314        );
3315
3316        // null = x is always null
3317        assert_eq!(
3318            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3319            lit(ScalarValue::Boolean(None)),
3320        );
3321    }
3322
3323    #[test]
3324    fn simplify_expr_is_not_null() {
3325        assert_eq!(
3326            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3327            Expr::IsNotNull(Box::new(col("c1")))
3328        );
3329
3330        // 'c1_non_null IS NOT NULL' is always true
3331        assert_eq!(
3332            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3333            lit(true)
3334        );
3335    }
3336
3337    #[test]
3338    fn simplify_expr_is_null() {
3339        assert_eq!(
3340            simplify(Expr::IsNull(Box::new(col("c1")))),
3341            Expr::IsNull(Box::new(col("c1")))
3342        );
3343
3344        // 'c1_non_null IS NULL' is always false
3345        assert_eq!(
3346            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3347            lit(false)
3348        );
3349    }
3350
3351    #[test]
3352    fn simplify_expr_is_unknown() {
3353        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3354
3355        // 'c2_non_null is unknown' is always false
3356        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3357    }
3358
3359    #[test]
3360    fn simplify_expr_is_not_known() {
3361        assert_eq!(
3362            simplify(col("c2").is_not_unknown()),
3363            col("c2").is_not_unknown()
3364        );
3365
3366        // 'c2_non_null is not unknown' is always true
3367        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3368    }
3369
3370    #[test]
3371    fn simplify_expr_eq() {
3372        let schema = expr_test_schema();
3373        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3374
3375        // true = true -> true
3376        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3377
3378        // true = false -> false
3379        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3380
3381        // c2 = true -> c2
3382        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3383
3384        // c2 = false => !c2
3385        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3386    }
3387
3388    #[test]
3389    fn simplify_expr_eq_skip_nonboolean_type() {
3390        let schema = expr_test_schema();
3391
3392        // When one of the operand is not of boolean type, folding the
3393        // other boolean constant will change return type of
3394        // expression to non-boolean.
3395        //
3396        // Make sure c1 column to be used in tests is not boolean type
3397        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3398
3399        // don't fold c1 = foo
3400        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3401    }
3402
3403    #[test]
3404    fn simplify_expr_not_eq() {
3405        let schema = expr_test_schema();
3406
3407        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3408
3409        // c2 != true -> !c2
3410        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3411
3412        // c2 != false -> c2
3413        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3414
3415        // test constant
3416        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3417
3418        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3419    }
3420
3421    #[test]
3422    fn simplify_expr_not_eq_skip_nonboolean_type() {
3423        let schema = expr_test_schema();
3424
3425        // when one of the operand is not of boolean type, folding the
3426        // other boolean constant will change return type of
3427        // expression to non-boolean.
3428        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3429
3430        assert_eq!(
3431            simplify(col("c1").not_eq(lit("foo"))),
3432            col("c1").not_eq(lit("foo")),
3433        );
3434    }
3435
3436    #[test]
3437    fn simplify_expr_case_when_then_else() {
3438        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3439        // -->
3440        // CASE WHEN c2 THEN false ELSE c2
3441        // -->
3442        // false
3443        assert_eq!(
3444            simplify(Expr::Case(Case::new(
3445                None,
3446                vec![(
3447                    Box::new(col("c2_non_null").not_eq(lit(false))),
3448                    Box::new(lit("ok").eq(lit("not_ok"))),
3449                )],
3450                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3451            ))),
3452            lit(false) // #1716
3453        );
3454
3455        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3456        // -->
3457        // CASE WHEN c2 THEN true ELSE c2
3458        // -->
3459        // c2
3460        //
3461        // Need to call simplify 2x due to
3462        // https://github.com/apache/datafusion/issues/1160
3463        assert_eq!(
3464            simplify(simplify(Expr::Case(Case::new(
3465                None,
3466                vec![(
3467                    Box::new(col("c2_non_null").not_eq(lit(false))),
3468                    Box::new(lit("ok").eq(lit("ok"))),
3469                )],
3470                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3471            )))),
3472            col("c2_non_null")
3473        );
3474
3475        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3476        // -->
3477        // ISNULL(c2) OR c2
3478        //
3479        // Need to call simplify 2x due to
3480        // https://github.com/apache/datafusion/issues/1160
3481        assert_eq!(
3482            simplify(simplify(Expr::Case(Case::new(
3483                None,
3484                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3485                Some(Box::new(col("c2"))),
3486            )))),
3487            col("c2")
3488                .is_null()
3489                .or(col("c2").is_not_null().and(col("c2")))
3490        );
3491
3492        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3493        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3494        // --> c1 OR (NOT(c1) AND NOT(c2))
3495        // --> c1 OR NOT(c2)
3496        //
3497        // Need to call simplify 2x due to
3498        // https://github.com/apache/datafusion/issues/1160
3499        assert_eq!(
3500            simplify(simplify(Expr::Case(Case::new(
3501                None,
3502                vec![
3503                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3504                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3505                ],
3506                Some(Box::new(lit(true))),
3507            )))),
3508            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3509        );
3510
3511        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3512        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3513        // --> c1 OR (NOT(c1) AND c2)
3514        // --> c1 OR c2
3515        //
3516        // Need to call simplify 2x due to
3517        // https://github.com/apache/datafusion/issues/1160
3518        assert_eq!(
3519            simplify(simplify(Expr::Case(Case::new(
3520                None,
3521                vec![
3522                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3523                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3524                ],
3525                Some(Box::new(lit(true))),
3526            )))),
3527            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3528        );
3529
3530        // CASE WHEN c > 0 THEN true END AS c1
3531        assert_eq!(
3532            simplify(simplify(Expr::Case(Case::new(
3533                None,
3534                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3535                None,
3536            )))),
3537            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3538                col("c3").gt(lit(0_i64)),
3539                lit(true)
3540            )
3541            .and(lit_bool_null()))
3542        );
3543
3544        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3545        assert_eq!(
3546            simplify(simplify(Expr::Case(Case::new(
3547                None,
3548                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3549                Some(Box::new(lit(false))),
3550            )))),
3551            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3552        );
3553    }
3554
3555    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3556        Expr::BinaryExpr(BinaryExpr {
3557            left: Box::new(left.into()),
3558            op: Operator::IsDistinctFrom,
3559            right: Box::new(right.into()),
3560        })
3561    }
3562
3563    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3564        Expr::BinaryExpr(BinaryExpr {
3565            left: Box::new(left.into()),
3566            op: Operator::IsNotDistinctFrom,
3567            right: Box::new(right.into()),
3568        })
3569    }
3570
3571    #[test]
3572    fn simplify_expr_bool_or() {
3573        // col || true is always true
3574        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
3575
3576        // col || false is always col
3577        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
3578
3579        // true || null is always true
3580        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
3581
3582        // null || true is always true
3583        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
3584
3585        // false || null is always null
3586        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
3587
3588        // null || false is always null
3589        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
3590
3591        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
3592        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3593        // and should not be rewritten
3594        let expr = col("c1").between(lit(0), lit(10));
3595        let expr = expr.or(lit_bool_null());
3596        let result = simplify(expr);
3597
3598        let expected_expr = or(
3599            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3600            lit_bool_null(),
3601        );
3602        assert_eq!(expected_expr, result);
3603    }
3604
3605    #[test]
3606    fn simplify_inlist() {
3607        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
3608        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
3609
3610        // null in (...)  --> null
3611        assert_eq!(
3612            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
3613            lit_bool_null()
3614        );
3615
3616        // null not in (...)  --> null
3617        assert_eq!(
3618            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
3619            lit_bool_null()
3620        );
3621
3622        assert_eq!(
3623            simplify(in_list(col("c1"), vec![lit(1)], false)),
3624            col("c1").eq(lit(1))
3625        );
3626        assert_eq!(
3627            simplify(in_list(col("c1"), vec![lit(1)], true)),
3628            col("c1").not_eq(lit(1))
3629        );
3630
3631        // more complex expressions can be simplified if list contains
3632        // one element only
3633        assert_eq!(
3634            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
3635            (col("c1") * lit(10)).eq(lit(2))
3636        );
3637
3638        assert_eq!(
3639            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
3640            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
3641        );
3642        assert_eq!(
3643            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
3644            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
3645        );
3646
3647        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
3648        assert_eq!(
3649            simplify(in_list(
3650                col("c1"),
3651                vec![scalar_subquery(Arc::clone(&subquery))],
3652                false
3653            )),
3654            in_subquery(col("c1"), Arc::clone(&subquery))
3655        );
3656        assert_eq!(
3657            simplify(in_list(
3658                col("c1"),
3659                vec![scalar_subquery(Arc::clone(&subquery))],
3660                true
3661            )),
3662            not_in_subquery(col("c1"), subquery)
3663        );
3664
3665        let subquery1 =
3666            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
3667        let subquery2 =
3668            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
3669
3670        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
3671        assert_eq!(
3672            simplify(in_list(
3673                col("c1"),
3674                vec![subquery1.clone(), subquery2.clone()],
3675                true
3676            )),
3677            col("c1")
3678                .not_eq(subquery1.clone())
3679                .and(col("c1").not_eq(subquery2.clone()))
3680        );
3681
3682        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
3683        assert_eq!(
3684            simplify(in_list(
3685                col("c1"),
3686                vec![subquery1.clone(), subquery2.clone()],
3687                false
3688            )),
3689            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
3690        );
3691
3692        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
3693        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3694            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
3695        );
3696        assert_eq!(simplify(expr), lit(false));
3697
3698        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
3699        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3700            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
3701        );
3702        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
3703
3704        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
3705        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3706            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3707        );
3708        assert_eq!(simplify(expr), lit(true));
3709
3710        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
3711        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3712            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3713        );
3714        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
3715
3716        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
3717        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3718            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3719        );
3720        assert_eq!(
3721            simplify(expr),
3722            in_list(
3723                col("c1"),
3724                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
3725                true
3726            )
3727        );
3728
3729        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
3730        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
3731            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
3732        );
3733        assert_eq!(
3734            simplify(expr),
3735            in_list(
3736                col("c1"),
3737                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3738                false
3739            )
3740        );
3741
3742        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
3743        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
3744            col("c1"),
3745            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3746            true,
3747        ));
3748        assert_eq!(simplify(expr), lit(false));
3749
3750        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
3751        let expr =
3752            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
3753                col("c1"),
3754                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3755                false,
3756            ));
3757        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
3758
3759        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
3760        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3761            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3762        );
3763        assert_eq!(
3764            simplify(expr),
3765            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
3766        );
3767
3768        // inlist with more than two expressions
3769        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
3770        let expr = in_list(
3771            col("c1"),
3772            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
3773            false,
3774        )
3775        .and(in_list(
3776            col("c1"),
3777            vec![lit(1), lit(3), lit(5), lit(6)],
3778            false,
3779        ))
3780        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
3781        assert_eq!(
3782            simplify(expr),
3783            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
3784        );
3785
3786        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
3787        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3788            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
3789                .and(in_list(
3790                    col("c1"),
3791                    vec![lit(3), lit(4), lit(5), lit(6)],
3792                    true,
3793                ))
3794                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
3795        );
3796        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
3797
3798        // Contains non-InList expression
3799        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
3800        let expr =
3801            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
3802                .not_eq(lit(5))
3803                .or(in_list(
3804                    col("c1"),
3805                    vec![lit(6), lit(7), lit(8), lit(9)],
3806                    true,
3807                )));
3808        // TODO: Further simplify this expression
3809        // https://github.com/apache/datafusion/issues/8970
3810        // assert_eq!(simplify(expr.clone()), lit(true));
3811        assert_eq!(simplify(expr.clone()), expr);
3812    }
3813
3814    #[test]
3815    fn simplify_null_in_empty_inlist() {
3816        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
3817        let expr = in_list(lit_bool_null(), vec![], false);
3818        assert_eq!(simplify(expr), lit(false));
3819
3820        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
3821        let expr = in_list(lit_bool_null(), vec![], true);
3822        assert_eq!(simplify(expr), lit(true));
3823
3824        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
3825        let null_null = || Expr::Literal(ScalarValue::Null, None);
3826        let expr = in_list(null_null(), vec![], false);
3827        assert_eq!(simplify(expr), lit(false));
3828
3829        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
3830        let expr = in_list(null_null(), vec![], true);
3831        assert_eq!(simplify(expr), lit(true));
3832    }
3833
3834    #[test]
3835    fn just_simplifier_simplify_null_in_empty_inlist() {
3836        let simplify = |expr: Expr| -> Expr {
3837            let schema = expr_test_schema();
3838            let execution_props = ExecutionProps::new();
3839            let info = SimplifyContext::new(&execution_props).with_schema(schema);
3840            let simplifier = &mut Simplifier::new(&info);
3841            expr.rewrite(simplifier)
3842                .expect("Failed to simplify expression")
3843                .data
3844        };
3845
3846        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
3847        let expr = in_list(lit_bool_null(), vec![], false);
3848        assert_eq!(simplify(expr), lit(false));
3849
3850        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
3851        let expr = in_list(lit_bool_null(), vec![], true);
3852        assert_eq!(simplify(expr), lit(true));
3853
3854        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
3855        let null_null = || Expr::Literal(ScalarValue::Null, None);
3856        let expr = in_list(null_null(), vec![], false);
3857        assert_eq!(simplify(expr), lit(false));
3858
3859        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
3860        let expr = in_list(null_null(), vec![], true);
3861        assert_eq!(simplify(expr), lit(true));
3862    }
3863
3864    #[test]
3865    fn simplify_large_or() {
3866        let expr = (0..5)
3867            .map(|i| col("c1").eq(lit(i)))
3868            .fold(lit(false), |acc, e| acc.or(e));
3869        assert_eq!(
3870            simplify(expr),
3871            in_list(col("c1"), (0..5).map(lit).collect(), false),
3872        );
3873    }
3874
3875    #[test]
3876    fn simplify_expr_bool_and() {
3877        // col & true is always col
3878        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
3879        // col & false is always false
3880        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
3881
3882        // true && null is always null
3883        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
3884
3885        // null && true is always null
3886        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
3887
3888        // false && null is always false
3889        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
3890
3891        // null && false is always false
3892        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
3893
3894        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
3895        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3896        // and the Boolean(NULL) should remain
3897        let expr = col("c1").between(lit(0), lit(10));
3898        let expr = expr.and(lit_bool_null());
3899        let result = simplify(expr);
3900
3901        let expected_expr = and(
3902            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3903            lit_bool_null(),
3904        );
3905        assert_eq!(expected_expr, result);
3906    }
3907
3908    #[test]
3909    fn simplify_expr_between() {
3910        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
3911        let expr = col("c2").between(lit(3), lit(4));
3912        assert_eq!(
3913            simplify(expr),
3914            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
3915        );
3916
3917        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
3918        let expr = col("c2").not_between(lit(3), lit(4));
3919        assert_eq!(
3920            simplify(expr),
3921            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
3922        );
3923    }
3924
3925    #[test]
3926    fn test_like_and_ilike() {
3927        let null = lit(ScalarValue::Utf8(None));
3928
3929        // expr [NOT] [I]LIKE NULL
3930        let expr = col("c1").like(null.clone());
3931        assert_eq!(simplify(expr), lit_bool_null());
3932
3933        let expr = col("c1").not_like(null.clone());
3934        assert_eq!(simplify(expr), lit_bool_null());
3935
3936        let expr = col("c1").ilike(null.clone());
3937        assert_eq!(simplify(expr), lit_bool_null());
3938
3939        let expr = col("c1").not_ilike(null.clone());
3940        assert_eq!(simplify(expr), lit_bool_null());
3941
3942        // expr [NOT] [I]LIKE '%'
3943        let expr = col("c1").like(lit("%"));
3944        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3945
3946        let expr = col("c1").not_like(lit("%"));
3947        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3948
3949        let expr = col("c1").ilike(lit("%"));
3950        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3951
3952        let expr = col("c1").not_ilike(lit("%"));
3953        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3954
3955        // expr [NOT] [I]LIKE '%%'
3956        let expr = col("c1").like(lit("%%"));
3957        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3958
3959        let expr = col("c1").not_like(lit("%%"));
3960        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3961
3962        let expr = col("c1").ilike(lit("%%"));
3963        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3964
3965        let expr = col("c1").not_ilike(lit("%%"));
3966        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3967
3968        // not_null_expr [NOT] [I]LIKE '%'
3969        let expr = col("c1_non_null").like(lit("%"));
3970        assert_eq!(simplify(expr), lit(true));
3971
3972        let expr = col("c1_non_null").not_like(lit("%"));
3973        assert_eq!(simplify(expr), lit(false));
3974
3975        let expr = col("c1_non_null").ilike(lit("%"));
3976        assert_eq!(simplify(expr), lit(true));
3977
3978        let expr = col("c1_non_null").not_ilike(lit("%"));
3979        assert_eq!(simplify(expr), lit(false));
3980
3981        // not_null_expr [NOT] [I]LIKE '%%'
3982        let expr = col("c1_non_null").like(lit("%%"));
3983        assert_eq!(simplify(expr), lit(true));
3984
3985        let expr = col("c1_non_null").not_like(lit("%%"));
3986        assert_eq!(simplify(expr), lit(false));
3987
3988        let expr = col("c1_non_null").ilike(lit("%%"));
3989        assert_eq!(simplify(expr), lit(true));
3990
3991        let expr = col("c1_non_null").not_ilike(lit("%%"));
3992        assert_eq!(simplify(expr), lit(false));
3993
3994        // null_constant [NOT] [I]LIKE '%'
3995        let expr = null.clone().like(lit("%"));
3996        assert_eq!(simplify(expr), lit_bool_null());
3997
3998        let expr = null.clone().not_like(lit("%"));
3999        assert_eq!(simplify(expr), lit_bool_null());
4000
4001        let expr = null.clone().ilike(lit("%"));
4002        assert_eq!(simplify(expr), lit_bool_null());
4003
4004        let expr = null.clone().not_ilike(lit("%"));
4005        assert_eq!(simplify(expr), lit_bool_null());
4006
4007        // null_constant [NOT] [I]LIKE '%%'
4008        let expr = null.clone().like(lit("%%"));
4009        assert_eq!(simplify(expr), lit_bool_null());
4010
4011        let expr = null.clone().not_like(lit("%%"));
4012        assert_eq!(simplify(expr), lit_bool_null());
4013
4014        let expr = null.clone().ilike(lit("%%"));
4015        assert_eq!(simplify(expr), lit_bool_null());
4016
4017        let expr = null.clone().not_ilike(lit("%%"));
4018        assert_eq!(simplify(expr), lit_bool_null());
4019
4020        // null_constant [NOT] [I]LIKE 'a%'
4021        let expr = null.clone().like(lit("a%"));
4022        assert_eq!(simplify(expr), lit_bool_null());
4023
4024        let expr = null.clone().not_like(lit("a%"));
4025        assert_eq!(simplify(expr), lit_bool_null());
4026
4027        let expr = null.clone().ilike(lit("a%"));
4028        assert_eq!(simplify(expr), lit_bool_null());
4029
4030        let expr = null.clone().not_ilike(lit("a%"));
4031        assert_eq!(simplify(expr), lit_bool_null());
4032
4033        // expr [NOT] [I]LIKE with pattern without wildcards
4034        let expr = col("c1").like(lit("a"));
4035        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4036        let expr = col("c1").not_like(lit("a"));
4037        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4038        let expr = col("c1").like(lit("a_"));
4039        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4040        let expr = col("c1").not_like(lit("a_"));
4041        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4042
4043        let expr = col("c1").ilike(lit("a"));
4044        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4045        let expr = col("c1").not_ilike(lit("a"));
4046        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4047    }
4048
4049    #[test]
4050    fn test_simplify_with_guarantee() {
4051        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4052        let expr_x = col("c3").gt(lit(3_i64));
4053        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4054        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4055        let expr = expr_x.clone().and(expr_y.or(expr_z));
4056
4057        // All guaranteed null
4058        let guarantees = vec![
4059            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4060            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4061            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4062        ];
4063
4064        let output = simplify_with_guarantee(expr.clone(), guarantees);
4065        assert_eq!(output, lit_bool_null());
4066
4067        // All guaranteed false
4068        let guarantees = vec![
4069            (
4070                col("c3"),
4071                NullableInterval::NotNull {
4072                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4073                },
4074            ),
4075            (
4076                col("c4"),
4077                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4078            ),
4079            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4080        ];
4081        let output = simplify_with_guarantee(expr.clone(), guarantees);
4082        assert_eq!(output, lit(false));
4083
4084        // Guaranteed false or null -> no change.
4085        let guarantees = vec![
4086            (
4087                col("c3"),
4088                NullableInterval::MaybeNull {
4089                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4090                },
4091            ),
4092            (
4093                col("c4"),
4094                NullableInterval::MaybeNull {
4095                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4096                },
4097            ),
4098            (
4099                col("c1"),
4100                NullableInterval::NotNull {
4101                    values: Interval::try_new(
4102                        ScalarValue::from("d"),
4103                        ScalarValue::from("f"),
4104                    )
4105                    .unwrap(),
4106                },
4107            ),
4108        ];
4109        let output = simplify_with_guarantee(expr.clone(), guarantees);
4110        assert_eq!(&output, &expr_x);
4111
4112        // Sufficient true guarantees
4113        let guarantees = vec![
4114            (
4115                col("c3"),
4116                NullableInterval::from(ScalarValue::Int64(Some(9))),
4117            ),
4118            (
4119                col("c4"),
4120                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4121            ),
4122        ];
4123        let output = simplify_with_guarantee(expr.clone(), guarantees);
4124        assert_eq!(output, lit(true));
4125
4126        // Only partially simplify
4127        let guarantees = vec![(
4128            col("c4"),
4129            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4130        )];
4131        let output = simplify_with_guarantee(expr, guarantees);
4132        assert_eq!(&output, &expr_x);
4133    }
4134
4135    #[test]
4136    fn test_expression_partial_simplify_1() {
4137        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4138        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4139        let expected = (lit(3)) + (lit(4) / lit(0));
4140
4141        assert_eq!(simplify(expr), expected);
4142    }
4143
4144    #[test]
4145    fn test_expression_partial_simplify_2() {
4146        // (1 > 2) and (4 / 0) -> false
4147        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4148        let expected = lit(false);
4149
4150        assert_eq!(simplify(expr), expected);
4151    }
4152
4153    #[test]
4154    fn test_simplify_cycles() {
4155        // TRUE
4156        let expr = lit(true);
4157        let expected = lit(true);
4158        let (expr, num_iter) = simplify_with_cycle_count(expr);
4159        assert_eq!(expr, expected);
4160        assert_eq!(num_iter, 1);
4161
4162        // (true != NULL) OR (5 > 10)
4163        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4164        let expected = lit_bool_null();
4165        let (expr, num_iter) = simplify_with_cycle_count(expr);
4166        assert_eq!(expr, expected);
4167        assert_eq!(num_iter, 2);
4168
4169        // NOTE: this currently does not simplify
4170        // (((c4 - 10) + 10) *100) / 100
4171        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4172        let expected = expr.clone();
4173        let (expr, num_iter) = simplify_with_cycle_count(expr);
4174        assert_eq!(expr, expected);
4175        assert_eq!(num_iter, 1);
4176
4177        // ((c4<1 or c3<2) and c3_non_null<3) and false
4178        let expr = col("c4")
4179            .lt(lit(1))
4180            .or(col("c3").lt(lit(2)))
4181            .and(col("c3_non_null").lt(lit(3)))
4182            .and(lit(false));
4183        let expected = lit(false);
4184        let (expr, num_iter) = simplify_with_cycle_count(expr);
4185        assert_eq!(expr, expected);
4186        assert_eq!(num_iter, 2);
4187    }
4188
4189    fn boolean_test_schema() -> DFSchemaRef {
4190        Schema::new(vec![
4191            Field::new("A", DataType::Boolean, false),
4192            Field::new("B", DataType::Boolean, false),
4193            Field::new("C", DataType::Boolean, false),
4194            Field::new("D", DataType::Boolean, false),
4195        ])
4196        .to_dfschema_ref()
4197        .unwrap()
4198    }
4199
4200    #[test]
4201    fn simplify_common_factor_conjunction_in_disjunction() {
4202        let props = ExecutionProps::new();
4203        let schema = boolean_test_schema();
4204        let simplifier =
4205            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4206
4207        let a = || col("A");
4208        let b = || col("B");
4209        let c = || col("C");
4210        let d = || col("D");
4211
4212        // (A AND B) OR (A AND C) -> A AND (B OR C)
4213        let expr = a().and(b()).or(a().and(c()));
4214        let expected = a().and(b().or(c()));
4215
4216        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4217
4218        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4219        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4220        let expected = a().and(b().or(c()).or(d()));
4221        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4222
4223        // A OR (B AND C AND A) -> A
4224        let expr = a().or(b().and(c().and(a())));
4225        let expected = a();
4226        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4227    }
4228
4229    #[test]
4230    fn test_simplify_udaf() {
4231        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4232        let aggregate_function_expr =
4233            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4234                udaf.into(),
4235                vec![],
4236                false,
4237                None,
4238                vec![],
4239                None,
4240            ));
4241
4242        let expected = col("result_column");
4243        assert_eq!(simplify(aggregate_function_expr), expected);
4244
4245        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4246        let aggregate_function_expr =
4247            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4248                udaf.into(),
4249                vec![],
4250                false,
4251                None,
4252                vec![],
4253                None,
4254            ));
4255
4256        let expected = aggregate_function_expr.clone();
4257        assert_eq!(simplify(aggregate_function_expr), expected);
4258    }
4259
4260    /// A Mock UDAF which defines `simplify` to be used in tests
4261    /// related to UDAF simplification
4262    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4263    struct SimplifyMockUdaf {
4264        simplify: bool,
4265    }
4266
4267    impl SimplifyMockUdaf {
4268        /// make simplify method return new expression
4269        fn new_with_simplify() -> Self {
4270            Self { simplify: true }
4271        }
4272        /// make simplify method return no change
4273        fn new_without_simplify() -> Self {
4274            Self { simplify: false }
4275        }
4276    }
4277
4278    impl AggregateUDFImpl for SimplifyMockUdaf {
4279        fn as_any(&self) -> &dyn std::any::Any {
4280            self
4281        }
4282
4283        fn name(&self) -> &str {
4284            "mock_simplify"
4285        }
4286
4287        fn signature(&self) -> &Signature {
4288            unimplemented!()
4289        }
4290
4291        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4292            unimplemented!("not needed for tests")
4293        }
4294
4295        fn accumulator(
4296            &self,
4297            _acc_args: AccumulatorArgs,
4298        ) -> Result<Box<dyn Accumulator>> {
4299            unimplemented!("not needed for tests")
4300        }
4301
4302        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4303            unimplemented!("not needed for testing")
4304        }
4305
4306        fn create_groups_accumulator(
4307            &self,
4308            _args: AccumulatorArgs,
4309        ) -> Result<Box<dyn GroupsAccumulator>> {
4310            unimplemented!("not needed for testing")
4311        }
4312
4313        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4314            if self.simplify {
4315                Some(Box::new(|_, _| Ok(col("result_column"))))
4316            } else {
4317                None
4318            }
4319        }
4320    }
4321
4322    #[test]
4323    fn test_simplify_udwf() {
4324        let udwf = WindowFunctionDefinition::WindowUDF(
4325            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4326        );
4327        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4328
4329        let expected = col("result_column");
4330        assert_eq!(simplify(window_function_expr), expected);
4331
4332        let udwf = WindowFunctionDefinition::WindowUDF(
4333            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4334        );
4335        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4336
4337        let expected = window_function_expr.clone();
4338        assert_eq!(simplify(window_function_expr), expected);
4339    }
4340
4341    /// A Mock UDWF which defines `simplify` to be used in tests
4342    /// related to UDWF simplification
4343    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4344    struct SimplifyMockUdwf {
4345        simplify: bool,
4346    }
4347
4348    impl SimplifyMockUdwf {
4349        /// make simplify method return new expression
4350        fn new_with_simplify() -> Self {
4351            Self { simplify: true }
4352        }
4353        /// make simplify method return no change
4354        fn new_without_simplify() -> Self {
4355            Self { simplify: false }
4356        }
4357    }
4358
4359    impl WindowUDFImpl for SimplifyMockUdwf {
4360        fn as_any(&self) -> &dyn std::any::Any {
4361            self
4362        }
4363
4364        fn name(&self) -> &str {
4365            "mock_simplify"
4366        }
4367
4368        fn signature(&self) -> &Signature {
4369            unimplemented!()
4370        }
4371
4372        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4373            if self.simplify {
4374                Some(Box::new(|_, _| Ok(col("result_column"))))
4375            } else {
4376                None
4377            }
4378        }
4379
4380        fn partition_evaluator(
4381            &self,
4382            _partition_evaluator_args: PartitionEvaluatorArgs,
4383        ) -> Result<Box<dyn PartitionEvaluator>> {
4384            unimplemented!("not needed for tests")
4385        }
4386
4387        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4388            unimplemented!("not needed for tests")
4389        }
4390    }
4391    #[derive(Debug, PartialEq, Eq, Hash)]
4392    struct VolatileUdf {
4393        signature: Signature,
4394    }
4395
4396    impl VolatileUdf {
4397        pub fn new() -> Self {
4398            Self {
4399                signature: Signature::exact(vec![], Volatility::Volatile),
4400            }
4401        }
4402    }
4403    impl ScalarUDFImpl for VolatileUdf {
4404        fn as_any(&self) -> &dyn std::any::Any {
4405            self
4406        }
4407
4408        fn name(&self) -> &str {
4409            "VolatileUdf"
4410        }
4411
4412        fn signature(&self) -> &Signature {
4413            &self.signature
4414        }
4415
4416        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4417            Ok(DataType::Int16)
4418        }
4419
4420        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4421            panic!("dummy - not implemented")
4422        }
4423    }
4424
4425    #[test]
4426    fn test_optimize_volatile_conditions() {
4427        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4428        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4429        {
4430            let expr = rand
4431                .clone()
4432                .eq(lit(0))
4433                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4434
4435            assert_eq!(simplify(expr.clone()), expr);
4436        }
4437
4438        {
4439            let expr = col("column1")
4440                .eq(lit(2))
4441                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4442
4443            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4444        }
4445
4446        {
4447            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4448                "column1",
4449            )
4450            .eq(lit(2))
4451            .and(rand.clone().eq(lit(0))));
4452
4453            assert_eq!(
4454                simplify(expr),
4455                col("column1")
4456                    .eq(lit(2))
4457                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4458            );
4459        }
4460    }
4461
4462    #[test]
4463    fn simplify_fixed_size_binary_eq_lit() {
4464        let bytes = [1u8, 2, 3].as_slice();
4465
4466        // The expression starts simple.
4467        let expr = col("c5").eq(lit(bytes));
4468
4469        // The type coercer introduces a cast.
4470        let coerced = coerce(expr.clone());
4471        let schema = expr_test_schema();
4472        assert_eq!(
4473            coerced,
4474            col("c5")
4475                .cast_to(&DataType::Binary, schema.as_ref())
4476                .unwrap()
4477                .eq(lit(bytes))
4478        );
4479
4480        // The simplifier removes the cast.
4481        assert_eq!(
4482            simplify(coerced),
4483            col("c5").eq(Expr::Literal(
4484                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4485                None
4486            ))
4487        );
4488    }
4489
4490    fn if_not_null(expr: Expr, then: bool) -> Expr {
4491        Expr::Case(Case {
4492            expr: Some(expr.is_not_null().into()),
4493            when_then_expr: vec![(lit(true).into(), lit(then).into())],
4494            else_expr: None,
4495        })
4496    }
4497}