datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use arrow::{
21    array::{new_null_array, AsArray},
22    datatypes::{DataType, Field, Schema},
23    record_batch::RecordBatch,
24};
25use std::borrow::Cow;
26use std::collections::HashSet;
27use std::ops::Not;
28use std::sync::Arc;
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
33};
34use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35use datafusion_expr::{
36    and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
37    Operator, Volatility,
38};
39use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
40use datafusion_expr::{
41    expr::{InList, InSubquery},
42    utils::{iter_conjunction, iter_conjunction_owned},
43};
44use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
45use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
46
47use super::inlist_simplifier::ShortenInListSimplifier;
48use super::utils::*;
49use crate::analyzer::type_coercion::TypeCoercionRewriter;
50use crate::simplify_expressions::guarantees::GuaranteeRewriter;
51use crate::simplify_expressions::regex::simplify_regex_expr;
52use crate::simplify_expressions::unwrap_cast::{
53    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
54    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
55    unwrap_cast_in_comparison_for_binary,
56};
57use crate::simplify_expressions::SimplifyInfo;
58use datafusion_expr::expr::FieldMetadata;
59use datafusion_expr_common::casts::try_cast_literal_to_type;
60use indexmap::IndexSet;
61use regex::Regex;
62
63/// This structure handles API for expression simplification
64///
65/// Provides simplification information based on DFSchema and
66/// [`ExecutionProps`]. This is the default implementation used by DataFusion
67///
68/// For example:
69/// ```
70/// use arrow::datatypes::{Schema, Field, DataType};
71/// use datafusion_expr::{col, lit};
72/// use datafusion_common::{DataFusionError, ToDFSchema};
73/// use datafusion_expr::execution_props::ExecutionProps;
74/// use datafusion_expr::simplify::SimplifyContext;
75/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
76///
77/// // Create the schema
78/// let schema = Schema::new(vec![
79///     Field::new("i", DataType::Int64, false),
80///   ])
81///   .to_dfschema_ref().unwrap();
82///
83/// // Create the simplifier
84/// let props = ExecutionProps::new();
85/// let context = SimplifyContext::new(&props)
86///    .with_schema(schema);
87/// let simplifier = ExprSimplifier::new(context);
88///
89/// // Use the simplifier
90///
91/// // b < 2 or (1 > 3)
92/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
93///
94/// // b < 2
95/// let simplified = simplifier.simplify(expr).unwrap();
96/// assert_eq!(simplified, col("b").lt(lit(2)));
97/// ```
98pub struct ExprSimplifier<S> {
99    info: S,
100    /// Guarantees about the values of columns. This is provided by the user
101    /// in [ExprSimplifier::with_guarantees()].
102    guarantees: Vec<(Expr, NullableInterval)>,
103    /// Should expressions be canonicalized before simplification? Defaults to
104    /// true
105    canonicalize: bool,
106    /// Maximum number of simplifier cycles
107    max_simplifier_cycles: u32,
108}
109
110pub const THRESHOLD_INLINE_INLIST: usize = 3;
111pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
112
113impl<S: SimplifyInfo> ExprSimplifier<S> {
114    /// Create a new `ExprSimplifier` with the given `info` such as an
115    /// instance of [`SimplifyContext`]. See
116    /// [`simplify`](Self::simplify) for an example.
117    ///
118    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
119    pub fn new(info: S) -> Self {
120        Self {
121            info,
122            guarantees: vec![],
123            canonicalize: true,
124            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
125        }
126    }
127
128    /// Simplifies this [`Expr`] as much as possible, evaluating
129    /// constants and applying algebraic simplifications.
130    ///
131    /// The types of the expression must match what operators expect,
132    /// or else an error may occur trying to evaluate. See
133    /// [`coerce`](Self::coerce) for a function to help.
134    ///
135    /// # Example:
136    ///
137    /// `b > 2 AND b > 2`
138    ///
139    /// can be written to
140    ///
141    /// `b > 2`
142    ///
143    /// ```
144    /// use arrow::datatypes::DataType;
145    /// use datafusion_expr::{col, lit, Expr};
146    /// use datafusion_common::Result;
147    /// use datafusion_expr::execution_props::ExecutionProps;
148    /// use datafusion_expr::simplify::SimplifyContext;
149    /// use datafusion_expr::simplify::SimplifyInfo;
150    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
151    /// use datafusion_common::DFSchema;
152    /// use std::sync::Arc;
153    ///
154    /// /// Simple implementation that provides `Simplifier` the information it needs
155    /// /// See SimplifyContext for a structure that does this.
156    /// #[derive(Default)]
157    /// struct Info {
158    ///   execution_props: ExecutionProps,
159    /// };
160    ///
161    /// impl SimplifyInfo for Info {
162    ///   fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
163    ///     Ok(false)
164    ///   }
165    ///   fn nullable(&self, expr: &Expr) -> Result<bool> {
166    ///     Ok(true)
167    ///   }
168    ///   fn execution_props(&self) -> &ExecutionProps {
169    ///     &self.execution_props
170    ///   }
171    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
172    ///     Ok(DataType::Int32)
173    ///   }
174    /// }
175    ///
176    /// // Create the simplifier
177    /// let simplifier = ExprSimplifier::new(Info::default());
178    ///
179    /// // b < 2
180    /// let b_lt_2 = col("b").gt(lit(2));
181    ///
182    /// // (b < 2) OR (b < 2)
183    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
184    ///
185    /// // (b < 2) OR (b < 2) --> (b < 2)
186    /// let expr = simplifier.simplify(expr).unwrap();
187    /// assert_eq!(expr, b_lt_2);
188    /// ```
189    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
190        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
191    }
192
193    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
194    /// constants and applying algebraic simplifications. Additionally returns a `u32`
195    /// representing the number of simplification cycles performed, which can be useful for testing
196    /// optimizations.
197    ///
198    /// See [Self::simplify] for details and usage examples.
199    ///
200    #[deprecated(
201        since = "48.0.0",
202        note = "Use `simplify_with_cycle_count_transformed` instead"
203    )]
204    #[allow(unused_mut)]
205    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
206        let (transformed, cycle_count) =
207            self.simplify_with_cycle_count_transformed(expr)?;
208        Ok((transformed.data, cycle_count))
209    }
210
211    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
212    /// constants and applying algebraic simplifications. Additionally returns a `u32`
213    /// representing the number of simplification cycles performed, which can be useful for testing
214    /// optimizations.
215    ///
216    /// # Returns
217    ///
218    /// A tuple containing:
219    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
220    /// - The number of simplification cycles that were performed
221    ///
222    /// See [Self::simplify] for details and usage examples.
223    ///
224    pub fn simplify_with_cycle_count_transformed(
225        &self,
226        mut expr: Expr,
227    ) -> Result<(Transformed<Expr>, u32)> {
228        let mut simplifier = Simplifier::new(&self.info);
229        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
230        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
231        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
232
233        if self.canonicalize {
234            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
235        }
236
237        // Evaluating constants can enable new simplifications and
238        // simplifications can enable new constant evaluation
239        // see `Self::with_max_cycles`
240        let mut num_cycles = 0;
241        let mut has_transformed = false;
242        loop {
243            let Transformed {
244                data, transformed, ..
245            } = expr
246                .rewrite(&mut const_evaluator)?
247                .transform_data(|expr| expr.rewrite(&mut simplifier))?
248                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
249            expr = data;
250            num_cycles += 1;
251            // Track if any transformation occurred
252            has_transformed = has_transformed || transformed;
253            if !transformed || num_cycles >= self.max_simplifier_cycles {
254                break;
255            }
256        }
257        // shorten inlist should be started after other inlist rules are applied
258        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
259        Ok((
260            Transformed::new_transformed(expr, has_transformed),
261            num_cycles,
262        ))
263    }
264
265    /// Apply type coercion to an [`Expr`] so that it can be
266    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
267    ///
268    /// See the [type coercion module](datafusion_expr::type_coercion)
269    /// documentation for more details on type coercion
270    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
271        let mut expr_rewrite = TypeCoercionRewriter { schema };
272        expr.rewrite(&mut expr_rewrite).data()
273    }
274
275    /// Input guarantees about the values of columns.
276    ///
277    /// The guarantees can simplify expressions. For example, if a column `x` is
278    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
279    /// literal `true`.
280    ///
281    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
282    /// where the [Expr] is a column reference and the [NullableInterval]
283    /// is an interval representing the known possible values of that column.
284    ///
285    /// ```rust
286    /// use arrow::datatypes::{DataType, Field, Schema};
287    /// use datafusion_expr::{col, lit, Expr};
288    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
289    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
290    /// use datafusion_expr::execution_props::ExecutionProps;
291    /// use datafusion_expr::simplify::SimplifyContext;
292    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
293    ///
294    /// let schema = Schema::new(vec![
295    ///   Field::new("x", DataType::Int64, false),
296    ///   Field::new("y", DataType::UInt32, false),
297    ///   Field::new("z", DataType::Int64, false),
298    ///   ])
299    ///   .to_dfschema_ref().unwrap();
300    ///
301    /// // Create the simplifier
302    /// let props = ExecutionProps::new();
303    /// let context = SimplifyContext::new(&props)
304    ///    .with_schema(schema);
305    ///
306    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
307    /// let expr_x = col("x").gt_eq(lit(3_i64));
308    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
309    /// let expr_z = col("z").gt(lit(5_i64));
310    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
311    ///
312    /// let guarantees = vec![
313    ///    // x ∈ [3, 5]
314    ///    (
315    ///        col("x"),
316    ///        NullableInterval::NotNull {
317    ///            values: Interval::make(Some(3_i64), Some(5_i64)).unwrap()
318    ///        }
319    ///    ),
320    ///    // y = 3
321    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
322    /// ];
323    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
324    /// let output = simplifier.simplify(expr).unwrap();
325    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
326    /// // z > 5.
327    /// assert_eq!(output, expr_z);
328    /// ```
329    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
330        self.guarantees = guarantees;
331        self
332    }
333
334    /// Should `Canonicalizer` be applied before simplification?
335    ///
336    /// If true (the default), the expression will be rewritten to canonical
337    /// form before simplification. This is useful to ensure that the simplifier
338    /// can apply all possible simplifications.
339    ///
340    /// Some expressions, such as those in some Joins, can not be canonicalized
341    /// without changing their meaning. In these cases, canonicalization should
342    /// be disabled.
343    ///
344    /// ```rust
345    /// use arrow::datatypes::{DataType, Field, Schema};
346    /// use datafusion_expr::{col, lit, Expr};
347    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
348    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
349    /// use datafusion_expr::execution_props::ExecutionProps;
350    /// use datafusion_expr::simplify::SimplifyContext;
351    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
352    ///
353    /// let schema = Schema::new(vec![
354    ///   Field::new("a", DataType::Int64, false),
355    ///   Field::new("b", DataType::Int64, false),
356    ///   Field::new("c", DataType::Int64, false),
357    ///   ])
358    ///   .to_dfschema_ref().unwrap();
359    ///
360    /// // Create the simplifier
361    /// let props = ExecutionProps::new();
362    /// let context = SimplifyContext::new(&props)
363    ///    .with_schema(schema);
364    /// let simplifier = ExprSimplifier::new(context);
365    ///
366    /// // Expression: a = c AND 1 = b
367    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
368    ///
369    /// // With canonicalization, the expression is rewritten to canonical form
370    /// // (though it is no simpler in this case):
371    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
372    /// // Expression has been rewritten to: (c = a AND b = 1)
373    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
374    ///
375    /// // If canonicalization is disabled, the expression is not changed
376    /// let non_canonicalized = simplifier
377    ///   .with_canonicalize(false)
378    ///   .simplify(expr.clone())
379    ///   .unwrap();
380    ///
381    /// assert_eq!(non_canonicalized, expr);
382    /// ```
383    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
384        self.canonicalize = canonicalize;
385        self
386    }
387
388    /// Specifies the maximum number of simplification cycles to run.
389    ///
390    /// The simplifier can perform multiple passes of simplification. This is
391    /// because the output of one simplification step can allow more optimizations
392    /// in another simplification step. For example, constant evaluation can allow more
393    /// expression simplifications, and expression simplifications can allow more constant
394    /// evaluations.
395    ///
396    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
397    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
398    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
399    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
400    ///
401    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
402    /// instead.
403    ///
404    /// ```rust
405    /// use arrow::datatypes::{DataType, Field, Schema};
406    /// use datafusion_expr::{col, lit, Expr};
407    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
408    /// use datafusion_expr::execution_props::ExecutionProps;
409    /// use datafusion_expr::simplify::SimplifyContext;
410    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
411    ///
412    /// let schema = Schema::new(vec![
413    ///   Field::new("a", DataType::Int64, false),
414    ///   ])
415    ///   .to_dfschema_ref().unwrap();
416    ///
417    /// // Create the simplifier
418    /// let props = ExecutionProps::new();
419    /// let context = SimplifyContext::new(&props)
420    ///    .with_schema(schema);
421    /// let simplifier = ExprSimplifier::new(context);
422    ///
423    /// // Expression: a IS NOT NULL
424    /// let expr = col("a").is_not_null();
425    ///
426    /// // When using default maximum cycles, 2 cycles will be performed.
427    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
428    /// assert_eq!(simplified_expr.data, lit(true));
429    /// // 2 cycles were executed, but only 1 was needed
430    /// assert_eq!(count, 2);
431    ///
432    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
433    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
434    /// // Expression has been rewritten to: (c = a AND b = 1)
435    /// assert_eq!(simplified_expr.data, lit(true));
436    /// // Only 1 cycle was executed
437    /// assert_eq!(count, 1);
438    ///
439    /// ```
440    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
441        self.max_simplifier_cycles = max_simplifier_cycles;
442        self
443    }
444}
445
446/// Canonicalize any BinaryExprs that are not in canonical form
447///
448/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
449///
450/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
451/// than `col2` (`a > b` would be canonicalized to `b < a`)
452struct Canonicalizer {}
453
454impl Canonicalizer {
455    fn new() -> Self {
456        Self {}
457    }
458}
459
460impl TreeNodeRewriter for Canonicalizer {
461    type Node = Expr;
462
463    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
464        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
465            return Ok(Transformed::no(expr));
466        };
467        match (left.as_ref(), right.as_ref(), op.swap()) {
468            // <col1> <op> <col2>
469            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
470                if right_col > left_col =>
471            {
472                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
473                    left: right,
474                    op: swapped_op,
475                    right: left,
476                })))
477            }
478            // <literal> <op> <col>
479            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
480                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
481                    left: right,
482                    op: swapped_op,
483                    right: left,
484                })))
485            }
486            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
487                left,
488                op,
489                right,
490            }))),
491        }
492    }
493}
494
495#[allow(rustdoc::private_intra_doc_links)]
496/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
497///
498/// Note it does not handle algebraic rewrites such as `(a or false)`
499/// --> `a`, which is handled by [`Simplifier`]
500struct ConstEvaluator<'a> {
501    /// `can_evaluate` is used during the depth-first-search of the
502    /// `Expr` tree to track if any siblings (or their descendants) were
503    /// non evaluatable (e.g. had a column reference or volatile
504    /// function)
505    ///
506    /// Specifically, `can_evaluate[N]` represents the state of
507    /// traversal when we are N levels deep in the tree, one entry for
508    /// this Expr and each of its parents.
509    ///
510    /// After visiting all siblings if `can_evaluate.top()` is true, that
511    /// means there were no non evaluatable siblings (or their
512    /// descendants) so this `Expr` can be evaluated
513    can_evaluate: Vec<bool>,
514
515    execution_props: &'a ExecutionProps,
516    input_schema: DFSchema,
517    input_batch: RecordBatch,
518}
519
520#[allow(dead_code)]
521/// The simplify result of ConstEvaluator
522enum ConstSimplifyResult {
523    // Expr was simplified and contains the new expression
524    Simplified(ScalarValue, Option<FieldMetadata>),
525    // Expr was not simplified and original value is returned
526    NotSimplified(ScalarValue, Option<FieldMetadata>),
527    // Evaluation encountered an error, contains the original expression
528    SimplifyRuntimeError(DataFusionError, Expr),
529}
530
531impl TreeNodeRewriter for ConstEvaluator<'_> {
532    type Node = Expr;
533
534    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
535        // Default to being able to evaluate this node
536        self.can_evaluate.push(true);
537
538        // if this expr is not ok to evaluate, mark entire parent
539        // stack as not ok (as all parents have at least one child or
540        // descendant that can not be evaluated
541
542        if !Self::can_evaluate(&expr) {
543            // walk back up stack, marking first parent that is not mutable
544            let parent_iter = self.can_evaluate.iter_mut().rev();
545            for p in parent_iter {
546                if !*p {
547                    // optimization: if we find an element on the
548                    // stack already marked, know all elements above are also marked
549                    break;
550                }
551                *p = false;
552            }
553        }
554
555        // NB: do not short circuit recursion even if we find a non
556        // evaluatable node (so we can fold other children, args to
557        // functions, etc.)
558        Ok(Transformed::no(expr))
559    }
560
561    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
562        match self.can_evaluate.pop() {
563            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
564            // and may not evaluate all their sub expressions. Thus, if
565            // any error is countered during simplification, return the original
566            // so that normal evaluation can occur
567            Some(true) => match self.evaluate_to_scalar(expr) {
568                ConstSimplifyResult::Simplified(s, m) => {
569                    Ok(Transformed::yes(Expr::Literal(s, m)))
570                }
571                ConstSimplifyResult::NotSimplified(s, m) => {
572                    Ok(Transformed::no(Expr::Literal(s, m)))
573                }
574                ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
575                    Ok(Transformed::yes(expr))
576                }
577            },
578            Some(false) => Ok(Transformed::no(expr)),
579            _ => internal_err!("Failed to pop can_evaluate"),
580        }
581    }
582}
583
584impl<'a> ConstEvaluator<'a> {
585    /// Create a new `ConstantEvaluator`. Session constants (such as
586    /// the time for `now()` are taken from the passed
587    /// `execution_props`.
588    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
589        // The dummy column name is unused and doesn't matter as only
590        // expressions without column references can be evaluated
591        static DUMMY_COL_NAME: &str = ".";
592        let schema = Arc::new(Schema::new(vec![Field::new(
593            DUMMY_COL_NAME,
594            DataType::Null,
595            true,
596        )]));
597        let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
598        // Need a single "input" row to produce a single output row
599        let col = new_null_array(&DataType::Null, 1);
600        let input_batch = RecordBatch::try_new(schema, vec![col])?;
601
602        Ok(Self {
603            can_evaluate: vec![],
604            execution_props,
605            input_schema,
606            input_batch,
607        })
608    }
609
610    /// Can a function of the specified volatility be evaluated?
611    fn volatility_ok(volatility: Volatility) -> bool {
612        match volatility {
613            Volatility::Immutable => true,
614            // Values for functions such as now() are taken from ExecutionProps
615            Volatility::Stable => true,
616            Volatility::Volatile => false,
617        }
618    }
619
620    /// Can the expression be evaluated at plan time, (assuming all of
621    /// its children can also be evaluated)?
622    fn can_evaluate(expr: &Expr) -> bool {
623        // check for reasons we can't evaluate this node
624        //
625        // NOTE all expr types are listed here so when new ones are
626        // added they can be checked for their ability to be evaluated
627        // at plan time
628        match expr {
629            // TODO: remove the next line after `Expr::Wildcard` is removed
630            #[expect(deprecated)]
631            Expr::AggregateFunction { .. }
632            | Expr::ScalarVariable(_, _)
633            | Expr::Column(_)
634            | Expr::OuterReferenceColumn(_, _)
635            | Expr::Exists { .. }
636            | Expr::InSubquery(_)
637            | Expr::ScalarSubquery(_)
638            | Expr::WindowFunction { .. }
639            | Expr::GroupingSet(_)
640            | Expr::Wildcard { .. }
641            | Expr::Placeholder(_) => false,
642            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
643                Self::volatility_ok(func.signature().volatility)
644            }
645            Expr::Literal(_, _)
646            | Expr::Alias(..)
647            | Expr::Unnest(_)
648            | Expr::BinaryExpr { .. }
649            | Expr::Not(_)
650            | Expr::IsNotNull(_)
651            | Expr::IsNull(_)
652            | Expr::IsTrue(_)
653            | Expr::IsFalse(_)
654            | Expr::IsUnknown(_)
655            | Expr::IsNotTrue(_)
656            | Expr::IsNotFalse(_)
657            | Expr::IsNotUnknown(_)
658            | Expr::Negative(_)
659            | Expr::Between { .. }
660            | Expr::Like { .. }
661            | Expr::SimilarTo { .. }
662            | Expr::Case(_)
663            | Expr::Cast { .. }
664            | Expr::TryCast { .. }
665            | Expr::InList { .. } => true,
666        }
667    }
668
669    /// Internal helper to evaluates an Expr
670    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
671        if let Expr::Literal(s, m) = expr {
672            return ConstSimplifyResult::NotSimplified(s, m);
673        }
674
675        let phys_expr =
676            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
677                Ok(e) => e,
678                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
679            };
680        let metadata = phys_expr
681            .return_field(self.input_batch.schema_ref())
682            .ok()
683            .and_then(|f| {
684                let m = f.metadata();
685                match m.is_empty() {
686                    true => None,
687                    false => Some(FieldMetadata::from(m)),
688                }
689            });
690        let col_val = match phys_expr.evaluate(&self.input_batch) {
691            Ok(v) => v,
692            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
693        };
694        match col_val {
695            ColumnarValue::Array(a) => {
696                if a.len() != 1 {
697                    ConstSimplifyResult::SimplifyRuntimeError(
698                        DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
699                        expr,
700                    )
701                } else if as_list_array(&a).is_ok() {
702                    ConstSimplifyResult::Simplified(
703                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
704                        metadata,
705                    )
706                } else if as_large_list_array(&a).is_ok() {
707                    ConstSimplifyResult::Simplified(
708                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
709                        metadata,
710                    )
711                } else {
712                    // Non-ListArray
713                    match ScalarValue::try_from_array(&a, 0) {
714                        Ok(s) => {
715                            // TODO: support the optimization for `Map` type after support impl hash for it
716                            if matches!(&s, ScalarValue::Map(_)) {
717                                ConstSimplifyResult::SimplifyRuntimeError(
718                                    DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()),
719                                    expr,
720                                )
721                            } else {
722                                ConstSimplifyResult::Simplified(s, metadata)
723                            }
724                        }
725                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
726                    }
727                }
728            }
729            ColumnarValue::Scalar(s) => {
730                // TODO: support the optimization for `Map` type after support impl hash for it
731                if matches!(&s, ScalarValue::Map(_)) {
732                    ConstSimplifyResult::SimplifyRuntimeError(
733                        DataFusionError::NotImplemented(
734                            "Const evaluate for Map type is still not supported"
735                                .to_string(),
736                        ),
737                        expr,
738                    )
739                } else {
740                    ConstSimplifyResult::Simplified(s, metadata)
741                }
742            }
743        }
744    }
745}
746
747/// Simplifies [`Expr`]s by applying algebraic transformation rules
748///
749/// Example transformations that are applied:
750/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
751/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
752/// * `true = true` and `false = false` to `true`
753/// * `false = true` and `true = false` to `false`
754/// * `!!expr` to `expr`
755/// * `expr = null` and `expr != null` to `null`
756struct Simplifier<'a, S> {
757    info: &'a S,
758}
759
760impl<'a, S> Simplifier<'a, S> {
761    pub fn new(info: &'a S) -> Self {
762        Self { info }
763    }
764}
765
766impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
767    type Node = Expr;
768
769    /// rewrite the expression simplifying any constant expressions
770    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
771        use datafusion_expr::Operator::{
772            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
773            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
774            RegexNotIMatch, RegexNotMatch,
775        };
776
777        let info = self.info;
778        Ok(match expr {
779            // `value op NULL` -> `NULL`
780            // `NULL op value` -> `NULL`
781            // except for few operators that can return non-null value even when one of the operands is NULL
782            ref expr @ Expr::BinaryExpr(BinaryExpr {
783                ref left,
784                ref op,
785                ref right,
786            }) if op.returns_null_on_null()
787                && (is_null(left.as_ref()) || is_null(right.as_ref())) =>
788            {
789                Transformed::yes(Expr::Literal(
790                    ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
791                    None,
792                ))
793            }
794
795            // `NULL {AND, OR} NULL` -> `NULL`
796            Expr::BinaryExpr(BinaryExpr {
797                left,
798                op: And | Or,
799                right,
800            }) if is_null(&left) && is_null(&right) => Transformed::yes(lit_bool_null()),
801
802            //
803            // Rules for Eq
804            //
805
806            // true = A  --> A
807            // false = A --> !A
808            // null = A --> null
809            Expr::BinaryExpr(BinaryExpr {
810                left,
811                op: Eq,
812                right,
813            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
814                Transformed::yes(match as_bool_lit(&left)? {
815                    Some(true) => *right,
816                    Some(false) => Expr::Not(right),
817                    None => lit_bool_null(),
818                })
819            }
820            // A = true  --> A
821            // A = false --> !A
822            // A = null --> null
823            Expr::BinaryExpr(BinaryExpr {
824                left,
825                op: Eq,
826                right,
827            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
828                Transformed::yes(match as_bool_lit(&right)? {
829                    Some(true) => *left,
830                    Some(false) => Expr::Not(left),
831                    None => lit_bool_null(),
832                })
833            }
834            // According to SQL's null semantics, NULL = NULL evaluates to NULL
835            // Both sides are the same expression (A = A) and A is non-volatile expression
836            // A = A --> A IS NOT NULL OR NULL
837            // A = A --> true (if A not nullable)
838            Expr::BinaryExpr(BinaryExpr {
839                left,
840                op: Eq,
841                right,
842            }) if (left == right) & !left.is_volatile() => {
843                Transformed::yes(match !info.nullable(&left)? {
844                    true => lit(true),
845                    false => Expr::BinaryExpr(BinaryExpr {
846                        left: Box::new(Expr::IsNotNull(left)),
847                        op: Or,
848                        right: Box::new(lit_bool_null()),
849                    }),
850                })
851            }
852
853            // Rules for NotEq
854            //
855
856            // true != A  --> !A
857            // false != A --> A
858            // null != A --> null
859            Expr::BinaryExpr(BinaryExpr {
860                left,
861                op: NotEq,
862                right,
863            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
864                Transformed::yes(match as_bool_lit(&left)? {
865                    Some(true) => Expr::Not(right),
866                    Some(false) => *right,
867                    None => lit_bool_null(),
868                })
869            }
870            // A != true  --> !A
871            // A != false --> A
872            // A != null --> null,
873            Expr::BinaryExpr(BinaryExpr {
874                left,
875                op: NotEq,
876                right,
877            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
878                Transformed::yes(match as_bool_lit(&right)? {
879                    Some(true) => Expr::Not(left),
880                    Some(false) => *left,
881                    None => lit_bool_null(),
882                })
883            }
884
885            //
886            // Rules for OR
887            //
888
889            // true OR A --> true (even if A is null)
890            Expr::BinaryExpr(BinaryExpr {
891                left,
892                op: Or,
893                right: _,
894            }) if is_true(&left) => Transformed::yes(*left),
895            // false OR A --> A
896            Expr::BinaryExpr(BinaryExpr {
897                left,
898                op: Or,
899                right,
900            }) if is_false(&left) => Transformed::yes(*right),
901            // A OR true --> true (even if A is null)
902            Expr::BinaryExpr(BinaryExpr {
903                left: _,
904                op: Or,
905                right,
906            }) if is_true(&right) => Transformed::yes(*right),
907            // A OR false --> A
908            Expr::BinaryExpr(BinaryExpr {
909                left,
910                op: Or,
911                right,
912            }) if is_false(&right) => Transformed::yes(*left),
913            // A OR !A ---> true (if A not nullable)
914            Expr::BinaryExpr(BinaryExpr {
915                left,
916                op: Or,
917                right,
918            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
919                Transformed::yes(lit(true))
920            }
921            // !A OR A ---> true (if A not nullable)
922            Expr::BinaryExpr(BinaryExpr {
923                left,
924                op: Or,
925                right,
926            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
927                Transformed::yes(lit(true))
928            }
929            // (..A..) OR A --> (..A..)
930            Expr::BinaryExpr(BinaryExpr {
931                left,
932                op: Or,
933                right,
934            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
935            // A OR (..A..) --> (..A..)
936            Expr::BinaryExpr(BinaryExpr {
937                left,
938                op: Or,
939                right,
940            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
941            // A OR (A AND B) --> A
942            Expr::BinaryExpr(BinaryExpr {
943                left,
944                op: Or,
945                right,
946            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
947            // (A AND B) OR A --> A
948            Expr::BinaryExpr(BinaryExpr {
949                left,
950                op: Or,
951                right,
952            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
953            // Eliminate common factors in conjunctions e.g
954            // (A AND B) OR (A AND C) -> A AND (B OR C)
955            Expr::BinaryExpr(BinaryExpr {
956                left,
957                op: Or,
958                right,
959            }) if has_common_conjunction(&left, &right) => {
960                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
961                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
962                    .partition(|e| lhs.contains(e) && !e.is_volatile());
963
964                let new_rhs = rhs.into_iter().reduce(and);
965                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
966                let common_conjunction = common.into_iter().reduce(and).unwrap();
967
968                let new_expr = match (new_lhs, new_rhs) {
969                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
970                    (_, _) => common_conjunction,
971                };
972                Transformed::yes(new_expr)
973            }
974
975            //
976            // Rules for AND
977            //
978
979            // true AND A --> A
980            Expr::BinaryExpr(BinaryExpr {
981                left,
982                op: And,
983                right,
984            }) if is_true(&left) => Transformed::yes(*right),
985            // false AND A --> false (even if A is null)
986            Expr::BinaryExpr(BinaryExpr {
987                left,
988                op: And,
989                right: _,
990            }) if is_false(&left) => Transformed::yes(*left),
991            // A AND true --> A
992            Expr::BinaryExpr(BinaryExpr {
993                left,
994                op: And,
995                right,
996            }) if is_true(&right) => Transformed::yes(*left),
997            // A AND false --> false (even if A is null)
998            Expr::BinaryExpr(BinaryExpr {
999                left: _,
1000                op: And,
1001                right,
1002            }) if is_false(&right) => Transformed::yes(*right),
1003            // A AND !A ---> false (if A not nullable)
1004            Expr::BinaryExpr(BinaryExpr {
1005                left,
1006                op: And,
1007                right,
1008            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
1009                Transformed::yes(lit(false))
1010            }
1011            // !A AND A ---> false (if A not nullable)
1012            Expr::BinaryExpr(BinaryExpr {
1013                left,
1014                op: And,
1015                right,
1016            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
1017                Transformed::yes(lit(false))
1018            }
1019            // (..A..) AND A --> (..A..)
1020            Expr::BinaryExpr(BinaryExpr {
1021                left,
1022                op: And,
1023                right,
1024            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
1025            // A AND (..A..) --> (..A..)
1026            Expr::BinaryExpr(BinaryExpr {
1027                left,
1028                op: And,
1029                right,
1030            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1031            // A AND (A OR B) --> A
1032            Expr::BinaryExpr(BinaryExpr {
1033                left,
1034                op: And,
1035                right,
1036            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1037            // (A OR B) AND A --> A
1038            Expr::BinaryExpr(BinaryExpr {
1039                left,
1040                op: And,
1041                right,
1042            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1043            // A >= constant AND constant <= A --> A = constant
1044            Expr::BinaryExpr(BinaryExpr {
1045                left,
1046                op: And,
1047                right,
1048            }) if can_reduce_to_equal_statement(&left, &right) => {
1049                if let Expr::BinaryExpr(BinaryExpr {
1050                    left: left_left,
1051                    right: left_right,
1052                    ..
1053                }) = *left
1054                {
1055                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1056                        left: left_left,
1057                        op: Eq,
1058                        right: left_right,
1059                    }))
1060                } else {
1061                    return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr");
1062                }
1063            }
1064
1065            //
1066            // Rules for Multiply
1067            //
1068
1069            // A * 1 --> A (with type coercion if needed)
1070            Expr::BinaryExpr(BinaryExpr {
1071                left,
1072                op: Multiply,
1073                right,
1074            }) if is_one(&right) => {
1075                simplify_right_is_one_case(info, left, &Multiply, &right)?
1076            }
1077            // 1 * A --> A
1078            Expr::BinaryExpr(BinaryExpr {
1079                left,
1080                op: Multiply,
1081                right,
1082            }) if is_one(&left) => {
1083                // 1 * A is equivalent to A * 1
1084                simplify_right_is_one_case(info, right, &Multiply, &left)?
1085            }
1086
1087            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1088            Expr::BinaryExpr(BinaryExpr {
1089                left,
1090                op: Multiply,
1091                right,
1092            }) if !info.nullable(&left)?
1093                && !info.get_data_type(&left)?.is_floating()
1094                && is_zero(&right) =>
1095            {
1096                Transformed::yes(*right)
1097            }
1098            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1099            Expr::BinaryExpr(BinaryExpr {
1100                left,
1101                op: Multiply,
1102                right,
1103            }) if !info.nullable(&right)?
1104                && !info.get_data_type(&right)?.is_floating()
1105                && is_zero(&left) =>
1106            {
1107                Transformed::yes(*left)
1108            }
1109
1110            //
1111            // Rules for Divide
1112            //
1113
1114            // A / 1 --> A
1115            Expr::BinaryExpr(BinaryExpr {
1116                left,
1117                op: Divide,
1118                right,
1119            }) if is_one(&right) => {
1120                simplify_right_is_one_case(info, left, &Divide, &right)?
1121            }
1122
1123            //
1124            // Rules for Modulo
1125            //
1126
1127            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1128            Expr::BinaryExpr(BinaryExpr {
1129                left,
1130                op: Modulo,
1131                right,
1132            }) if !info.nullable(&left)?
1133                && !info.get_data_type(&left)?.is_floating()
1134                && is_one(&right) =>
1135            {
1136                Transformed::yes(Expr::Literal(
1137                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1138                    None,
1139                ))
1140            }
1141
1142            //
1143            // Rules for BitwiseAnd
1144            //
1145
1146            // A & 0 -> 0 (if A not nullable)
1147            Expr::BinaryExpr(BinaryExpr {
1148                left,
1149                op: BitwiseAnd,
1150                right,
1151            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1152
1153            // 0 & A -> 0 (if A not nullable)
1154            Expr::BinaryExpr(BinaryExpr {
1155                left,
1156                op: BitwiseAnd,
1157                right,
1158            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1159
1160            // !A & A -> 0 (if A not nullable)
1161            Expr::BinaryExpr(BinaryExpr {
1162                left,
1163                op: BitwiseAnd,
1164                right,
1165            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1166                Transformed::yes(Expr::Literal(
1167                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1168                    None,
1169                ))
1170            }
1171
1172            // A & !A -> 0 (if A not nullable)
1173            Expr::BinaryExpr(BinaryExpr {
1174                left,
1175                op: BitwiseAnd,
1176                right,
1177            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1178                Transformed::yes(Expr::Literal(
1179                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1180                    None,
1181                ))
1182            }
1183
1184            // (..A..) & A --> (..A..)
1185            Expr::BinaryExpr(BinaryExpr {
1186                left,
1187                op: BitwiseAnd,
1188                right,
1189            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1190
1191            // A & (..A..) --> (..A..)
1192            Expr::BinaryExpr(BinaryExpr {
1193                left,
1194                op: BitwiseAnd,
1195                right,
1196            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1197
1198            // A & (A | B) --> A (if B not null)
1199            Expr::BinaryExpr(BinaryExpr {
1200                left,
1201                op: BitwiseAnd,
1202                right,
1203            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1204                Transformed::yes(*left)
1205            }
1206
1207            // (A | B) & A --> A (if B not null)
1208            Expr::BinaryExpr(BinaryExpr {
1209                left,
1210                op: BitwiseAnd,
1211                right,
1212            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1213                Transformed::yes(*right)
1214            }
1215
1216            //
1217            // Rules for BitwiseOr
1218            //
1219
1220            // A | 0 -> A (even if A is null)
1221            Expr::BinaryExpr(BinaryExpr {
1222                left,
1223                op: BitwiseOr,
1224                right,
1225            }) if is_zero(&right) => Transformed::yes(*left),
1226
1227            // 0 | A -> A (even if A is null)
1228            Expr::BinaryExpr(BinaryExpr {
1229                left,
1230                op: BitwiseOr,
1231                right,
1232            }) if is_zero(&left) => Transformed::yes(*right),
1233
1234            // !A | A -> -1 (if A not nullable)
1235            Expr::BinaryExpr(BinaryExpr {
1236                left,
1237                op: BitwiseOr,
1238                right,
1239            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1240                Transformed::yes(Expr::Literal(
1241                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1242                    None,
1243                ))
1244            }
1245
1246            // A | !A -> -1 (if A not nullable)
1247            Expr::BinaryExpr(BinaryExpr {
1248                left,
1249                op: BitwiseOr,
1250                right,
1251            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1252                Transformed::yes(Expr::Literal(
1253                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1254                    None,
1255                ))
1256            }
1257
1258            // (..A..) | A --> (..A..)
1259            Expr::BinaryExpr(BinaryExpr {
1260                left,
1261                op: BitwiseOr,
1262                right,
1263            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1264
1265            // A | (..A..) --> (..A..)
1266            Expr::BinaryExpr(BinaryExpr {
1267                left,
1268                op: BitwiseOr,
1269                right,
1270            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1271
1272            // A | (A & B) --> A (if B not null)
1273            Expr::BinaryExpr(BinaryExpr {
1274                left,
1275                op: BitwiseOr,
1276                right,
1277            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1278                Transformed::yes(*left)
1279            }
1280
1281            // (A & B) | A --> A (if B not null)
1282            Expr::BinaryExpr(BinaryExpr {
1283                left,
1284                op: BitwiseOr,
1285                right,
1286            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1287                Transformed::yes(*right)
1288            }
1289
1290            //
1291            // Rules for BitwiseXor
1292            //
1293
1294            // A ^ 0 -> A (if A not nullable)
1295            Expr::BinaryExpr(BinaryExpr {
1296                left,
1297                op: BitwiseXor,
1298                right,
1299            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1300
1301            // 0 ^ A -> A (if A not nullable)
1302            Expr::BinaryExpr(BinaryExpr {
1303                left,
1304                op: BitwiseXor,
1305                right,
1306            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1307
1308            // !A ^ A -> -1 (if A not nullable)
1309            Expr::BinaryExpr(BinaryExpr {
1310                left,
1311                op: BitwiseXor,
1312                right,
1313            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1314                Transformed::yes(Expr::Literal(
1315                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1316                    None,
1317                ))
1318            }
1319
1320            // A ^ !A -> -1 (if A not nullable)
1321            Expr::BinaryExpr(BinaryExpr {
1322                left,
1323                op: BitwiseXor,
1324                right,
1325            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1326                Transformed::yes(Expr::Literal(
1327                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1328                    None,
1329                ))
1330            }
1331
1332            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1333            Expr::BinaryExpr(BinaryExpr {
1334                left,
1335                op: BitwiseXor,
1336                right,
1337            }) if expr_contains(&left, &right, BitwiseXor) => {
1338                let expr = delete_xor_in_complex_expr(&left, &right, false);
1339                Transformed::yes(if expr == *right {
1340                    Expr::Literal(
1341                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1342                        None,
1343                    )
1344                } else {
1345                    expr
1346                })
1347            }
1348
1349            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1350            Expr::BinaryExpr(BinaryExpr {
1351                left,
1352                op: BitwiseXor,
1353                right,
1354            }) if expr_contains(&right, &left, BitwiseXor) => {
1355                let expr = delete_xor_in_complex_expr(&right, &left, true);
1356                Transformed::yes(if expr == *left {
1357                    Expr::Literal(
1358                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1359                        None,
1360                    )
1361                } else {
1362                    expr
1363                })
1364            }
1365
1366            //
1367            // Rules for BitwiseShiftRight
1368            //
1369
1370            // A >> 0 -> A (even if A is null)
1371            Expr::BinaryExpr(BinaryExpr {
1372                left,
1373                op: BitwiseShiftRight,
1374                right,
1375            }) if is_zero(&right) => Transformed::yes(*left),
1376
1377            //
1378            // Rules for BitwiseShiftRight
1379            //
1380
1381            // A << 0 -> A (even if A is null)
1382            Expr::BinaryExpr(BinaryExpr {
1383                left,
1384                op: BitwiseShiftLeft,
1385                right,
1386            }) if is_zero(&right) => Transformed::yes(*left),
1387
1388            //
1389            // Rules for Not
1390            //
1391            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1392
1393            //
1394            // Rules for Negative
1395            //
1396            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1397
1398            //
1399            // Rules for Case
1400            //
1401
1402            // CASE
1403            //   WHEN X THEN A
1404            //   WHEN Y THEN B
1405            //   ...
1406            //   ELSE Q
1407            // END
1408            //
1409            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1410            //
1411            // Note: the rationale for this rewrite is that the expr can then be further
1412            // simplified using the existing rules for AND/OR
1413            Expr::Case(Case {
1414                expr: None,
1415                when_then_expr,
1416                else_expr,
1417            }) if !when_then_expr.is_empty()
1418                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1419                && info.is_boolean_type(&when_then_expr[0].1)? =>
1420            {
1421                // String disjunction of all the when predicates encountered so far. Not nullable.
1422                let mut filter_expr = lit(false);
1423                // The disjunction of all the cases
1424                let mut out_expr = lit(false);
1425
1426                for (when, then) in when_then_expr {
1427                    let when = is_exactly_true(*when, info)?;
1428                    let case_expr =
1429                        when.clone().and(filter_expr.clone().not()).and(*then);
1430
1431                    out_expr = out_expr.or(case_expr);
1432                    filter_expr = filter_expr.or(when);
1433                }
1434
1435                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1436                let case_expr = filter_expr.not().and(else_expr);
1437                out_expr = out_expr.or(case_expr);
1438
1439                // Do a first pass at simplification
1440                out_expr.rewrite(self)?
1441            }
1442            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1443                match udf.simplify(args, info)? {
1444                    ExprSimplifyResult::Original(args) => {
1445                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1446                            func: udf,
1447                            args,
1448                        }))
1449                    }
1450                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1451                }
1452            }
1453
1454            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1455                ref func,
1456                ..
1457            }) => match (func.simplify(), expr) {
1458                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1459                    Transformed::yes(simplify_function(af, info)?)
1460                }
1461                (_, expr) => Transformed::no(expr),
1462            },
1463
1464            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1465                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1466                    Transformed::yes(simplify_function(*wf, info)?)
1467                }
1468                (_, expr) => Transformed::no(expr),
1469            },
1470
1471            //
1472            // Rules for Between
1473            //
1474
1475            // a between 3 and 5  -->  a >= 3 AND a <=5
1476            // a not between 3 and 5  -->  a < 3 OR a > 5
1477            Expr::Between(between) => Transformed::yes(if between.negated {
1478                let l = *between.expr.clone();
1479                let r = *between.expr;
1480                or(l.lt(*between.low), r.gt(*between.high))
1481            } else {
1482                and(
1483                    between.expr.clone().gt_eq(*between.low),
1484                    between.expr.lt_eq(*between.high),
1485                )
1486            }),
1487
1488            //
1489            // Rules for regexes
1490            //
1491            Expr::BinaryExpr(BinaryExpr {
1492                left,
1493                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1494                right,
1495            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1496
1497            // Rules for Like
1498            Expr::Like(like) => {
1499                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1500                let escape_char = like.escape_char.unwrap_or('\\');
1501                match as_string_scalar(&like.pattern) {
1502                    Some((data_type, pattern_str)) => {
1503                        match pattern_str {
1504                            None => return Ok(Transformed::yes(lit_bool_null())),
1505                            Some(pattern_str) if pattern_str == "%" => {
1506                                // exp LIKE '%' is
1507                                //   - when exp is not NULL, it's true
1508                                //   - when exp is NULL, it's NULL
1509                                // exp NOT LIKE '%' is
1510                                //   - when exp is not NULL, it's false
1511                                //   - when exp is NULL, it's NULL
1512                                let result_for_non_null = lit(!like.negated);
1513                                Transformed::yes(if !info.nullable(&like.expr)? {
1514                                    result_for_non_null
1515                                } else {
1516                                    Expr::Case(Case {
1517                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1518                                        when_then_expr: vec![(
1519                                            Box::new(lit(true)),
1520                                            Box::new(result_for_non_null),
1521                                        )],
1522                                        else_expr: None,
1523                                    })
1524                                })
1525                            }
1526                            Some(pattern_str)
1527                                if pattern_str.contains("%%")
1528                                    && !pattern_str.contains(escape_char) =>
1529                            {
1530                                // Repeated occurrences of wildcard are redundant so remove them
1531                                // exp LIKE '%%'  --> exp LIKE '%'
1532                                let simplified_pattern = Regex::new("%%+")
1533                                    .unwrap()
1534                                    .replace_all(pattern_str, "%")
1535                                    .to_string();
1536                                Transformed::yes(Expr::Like(Like {
1537                                    pattern: Box::new(to_string_scalar(
1538                                        data_type,
1539                                        Some(simplified_pattern),
1540                                    )),
1541                                    ..like
1542                                }))
1543                            }
1544                            Some(pattern_str)
1545                                if !like.case_insensitive
1546                                    && !pattern_str
1547                                        .contains(['%', '_', escape_char].as_ref()) =>
1548                            {
1549                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1550                                // TODO: handle escape characters
1551                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1552                                    left: like.expr.clone(),
1553                                    op: if like.negated { NotEq } else { Eq },
1554                                    right: like.pattern.clone(),
1555                                }))
1556                            }
1557
1558                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1559                        }
1560                    }
1561                    None => Transformed::no(Expr::Like(like)),
1562                }
1563            }
1564
1565            // a is not null/unknown --> true (if a is not nullable)
1566            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1567                if !info.nullable(&expr)? =>
1568            {
1569                Transformed::yes(lit(true))
1570            }
1571
1572            // a is null/unknown --> false (if a is not nullable)
1573            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1574                Transformed::yes(lit(false))
1575            }
1576
1577            // expr IN () --> false
1578            // expr NOT IN () --> true
1579            Expr::InList(InList {
1580                expr: _,
1581                list,
1582                negated,
1583            }) if list.is_empty() => Transformed::yes(lit(negated)),
1584
1585            // null in (x, y, z) --> null
1586            // null not in (x, y, z) --> null
1587            Expr::InList(InList {
1588                expr,
1589                list,
1590                negated: _,
1591            }) if is_null(expr.as_ref()) && !list.is_empty() => {
1592                Transformed::yes(lit_bool_null())
1593            }
1594
1595            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1596            Expr::InList(InList {
1597                expr,
1598                mut list,
1599                negated,
1600            }) if list.len() == 1
1601                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1602            {
1603                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1604                    unreachable!()
1605                };
1606
1607                Transformed::yes(Expr::InSubquery(InSubquery::new(
1608                    expr, subquery, negated,
1609                )))
1610            }
1611
1612            // Combine multiple OR expressions into a single IN list expression if possible
1613            //
1614            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1615            Expr::BinaryExpr(BinaryExpr {
1616                left,
1617                op: Or,
1618                right,
1619            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1620                let lhs = to_inlist(*left).unwrap();
1621                let rhs = to_inlist(*right).unwrap();
1622                let mut seen: HashSet<Expr> = HashSet::new();
1623                let list = lhs
1624                    .list
1625                    .into_iter()
1626                    .chain(rhs.list)
1627                    .filter(|e| seen.insert(e.to_owned()))
1628                    .collect::<Vec<_>>();
1629
1630                let merged_inlist = InList {
1631                    expr: lhs.expr,
1632                    list,
1633                    negated: false,
1634                };
1635
1636                Transformed::yes(Expr::InList(merged_inlist))
1637            }
1638
1639            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1640            //
1641            // Rules:
1642            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1643            //   Intersection:
1644            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1645            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1646            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1647            //   Union:
1648            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1649            //     # This rule is handled by `or_in_list_simplifier.rs`
1650            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1651            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1652            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1653            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1654            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1655            Expr::BinaryExpr(BinaryExpr {
1656                left,
1657                op: And,
1658                right,
1659            }) if are_inlist_and_eq_and_match_neg(
1660                left.as_ref(),
1661                right.as_ref(),
1662                false,
1663                false,
1664            ) =>
1665            {
1666                match (*left, *right) {
1667                    (Expr::InList(l1), Expr::InList(l2)) => {
1668                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1669                    }
1670                    // Matched previously once
1671                    _ => unreachable!(),
1672                }
1673            }
1674
1675            Expr::BinaryExpr(BinaryExpr {
1676                left,
1677                op: And,
1678                right,
1679            }) if are_inlist_and_eq_and_match_neg(
1680                left.as_ref(),
1681                right.as_ref(),
1682                true,
1683                true,
1684            ) =>
1685            {
1686                match (*left, *right) {
1687                    (Expr::InList(l1), Expr::InList(l2)) => {
1688                        return inlist_union(l1, l2, true).map(Transformed::yes);
1689                    }
1690                    // Matched previously once
1691                    _ => unreachable!(),
1692                }
1693            }
1694
1695            Expr::BinaryExpr(BinaryExpr {
1696                left,
1697                op: And,
1698                right,
1699            }) if are_inlist_and_eq_and_match_neg(
1700                left.as_ref(),
1701                right.as_ref(),
1702                false,
1703                true,
1704            ) =>
1705            {
1706                match (*left, *right) {
1707                    (Expr::InList(l1), Expr::InList(l2)) => {
1708                        return inlist_except(l1, &l2).map(Transformed::yes);
1709                    }
1710                    // Matched previously once
1711                    _ => unreachable!(),
1712                }
1713            }
1714
1715            Expr::BinaryExpr(BinaryExpr {
1716                left,
1717                op: And,
1718                right,
1719            }) if are_inlist_and_eq_and_match_neg(
1720                left.as_ref(),
1721                right.as_ref(),
1722                true,
1723                false,
1724            ) =>
1725            {
1726                match (*left, *right) {
1727                    (Expr::InList(l1), Expr::InList(l2)) => {
1728                        return inlist_except(l2, &l1).map(Transformed::yes);
1729                    }
1730                    // Matched previously once
1731                    _ => unreachable!(),
1732                }
1733            }
1734
1735            Expr::BinaryExpr(BinaryExpr {
1736                left,
1737                op: Or,
1738                right,
1739            }) if are_inlist_and_eq_and_match_neg(
1740                left.as_ref(),
1741                right.as_ref(),
1742                true,
1743                true,
1744            ) =>
1745            {
1746                match (*left, *right) {
1747                    (Expr::InList(l1), Expr::InList(l2)) => {
1748                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1749                    }
1750                    // Matched previously once
1751                    _ => unreachable!(),
1752                }
1753            }
1754
1755            // =======================================
1756            // unwrap_cast_in_comparison
1757            // =======================================
1758            //
1759            // For case:
1760            // try_cast/cast(expr as data_type) op literal
1761            Expr::BinaryExpr(BinaryExpr { left, op, right })
1762                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1763                    info, &left, op, &right,
1764                ) && op.supports_propagation() =>
1765            {
1766                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1767            }
1768            // literal op try_cast/cast(expr as data_type)
1769            // -->
1770            // try_cast/cast(expr as data_type) op_swap literal
1771            Expr::BinaryExpr(BinaryExpr { left, op, right })
1772                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1773                    info, &right, op, &left,
1774                ) && op.supports_propagation()
1775                    && op.swap().is_some() =>
1776            {
1777                unwrap_cast_in_comparison_for_binary(
1778                    info,
1779                    *right,
1780                    *left,
1781                    op.swap().unwrap(),
1782                )?
1783            }
1784            // For case:
1785            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1786            Expr::InList(InList {
1787                expr: mut left,
1788                list,
1789                negated,
1790            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1791                info, &left, &list,
1792            ) =>
1793            {
1794                let (Expr::TryCast(TryCast {
1795                    expr: left_expr, ..
1796                })
1797                | Expr::Cast(Cast {
1798                    expr: left_expr, ..
1799                })) = left.as_mut()
1800                else {
1801                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1802                };
1803
1804                let expr_type = info.get_data_type(left_expr)?;
1805                let right_exprs = list
1806                    .into_iter()
1807                    .map(|right| {
1808                        match right {
1809                            Expr::Literal(right_lit_value, _) => {
1810                                // if the right_lit_value can be casted to the type of internal_left_expr
1811                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1812                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1813                                    internal_err!(
1814                                        "Can't cast the list expr {:?} to type {:?}",
1815                                        right_lit_value, &expr_type
1816                                    )?
1817                                };
1818                                Ok(lit(value))
1819                            }
1820                            other_expr => internal_err!(
1821                                "Only support literal expr to optimize, but the expr is {:?}",
1822                                &other_expr
1823                            ),
1824                        }
1825                    })
1826                    .collect::<Result<Vec<_>>>()?;
1827
1828                Transformed::yes(Expr::InList(InList {
1829                    expr: std::mem::take(left_expr),
1830                    list: right_exprs,
1831                    negated,
1832                }))
1833            }
1834
1835            // no additional rewrites possible
1836            expr => Transformed::no(expr),
1837        })
1838    }
1839}
1840
1841fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1842    match expr {
1843        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1844        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1845        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1846        _ => None,
1847    }
1848}
1849
1850fn to_string_scalar(data_type: DataType, value: Option<String>) -> Expr {
1851    match data_type {
1852        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1853        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1854        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1855        _ => unreachable!(),
1856    }
1857}
1858
1859fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1860    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1861    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1862}
1863
1864// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1865fn are_inlist_and_eq_and_match_neg(
1866    left: &Expr,
1867    right: &Expr,
1868    is_left_neg: bool,
1869    is_right_neg: bool,
1870) -> bool {
1871    match (left, right) {
1872        (Expr::InList(l), Expr::InList(r)) => {
1873            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1874        }
1875        _ => false,
1876    }
1877}
1878
1879// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1880fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
1881    let left = as_inlist(left);
1882    let right = as_inlist(right);
1883    if let (Some(lhs), Some(rhs)) = (left, right) {
1884        matches!(lhs.expr.as_ref(), Expr::Column(_))
1885            && matches!(rhs.expr.as_ref(), Expr::Column(_))
1886            && lhs.expr == rhs.expr
1887            && !lhs.negated
1888            && !rhs.negated
1889    } else {
1890        false
1891    }
1892}
1893
1894/// Try to convert an expression to an in-list expression
1895fn as_inlist(expr: &'_ Expr) -> Option<Cow<'_, InList>> {
1896    match expr {
1897        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
1898        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
1899            match (left.as_ref(), right.as_ref()) {
1900                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
1901                    expr: left.clone(),
1902                    list: vec![*right.clone()],
1903                    negated: false,
1904                })),
1905                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
1906                    expr: right.clone(),
1907                    list: vec![*left.clone()],
1908                    negated: false,
1909                })),
1910                _ => None,
1911            }
1912        }
1913        _ => None,
1914    }
1915}
1916
1917fn to_inlist(expr: Expr) -> Option<InList> {
1918    match expr {
1919        Expr::InList(inlist) => Some(inlist),
1920        Expr::BinaryExpr(BinaryExpr {
1921            left,
1922            op: Operator::Eq,
1923            right,
1924        }) => match (left.as_ref(), right.as_ref()) {
1925            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
1926                expr: left,
1927                list: vec![*right],
1928                negated: false,
1929            }),
1930            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
1931                expr: right,
1932                list: vec![*left],
1933                negated: false,
1934            }),
1935            _ => None,
1936        },
1937        _ => None,
1938    }
1939}
1940
1941/// Return the union of two inlist expressions
1942/// maintaining the order of the elements in the two lists
1943fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
1944    // extend the list in l1 with the elements in l2 that are not already in l1
1945    let l1_items: HashSet<_> = l1.list.iter().collect();
1946
1947    // keep all l2 items that do not also appear in l1
1948    let keep_l2: Vec<_> = l2
1949        .list
1950        .into_iter()
1951        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
1952        .collect();
1953
1954    l1.list.extend(keep_l2);
1955    l1.negated = negated;
1956    Ok(Expr::InList(l1))
1957}
1958
1959/// Return the intersection of two inlist expressions
1960/// maintaining the order of the elements in the two lists
1961fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
1962    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1963
1964    // remove all items from l1 that are not in l2
1965    l1.list.retain(|e| l2_items.contains(e));
1966
1967    // e in () is always false
1968    // e not in () is always true
1969    if l1.list.is_empty() {
1970        return Ok(lit(negated));
1971    }
1972    Ok(Expr::InList(l1))
1973}
1974
1975/// Return the all items in l1 that are not in l2
1976/// maintaining the order of the elements in the two lists
1977fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
1978    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1979
1980    // keep only items from l1 that are not in l2
1981    l1.list.retain(|e| !l2_items.contains(e));
1982
1983    if l1.list.is_empty() {
1984        return Ok(lit(false));
1985    }
1986    Ok(Expr::InList(l1))
1987}
1988
1989/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
1990fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1991    if !info.nullable(&expr)? {
1992        Ok(expr)
1993    } else {
1994        Ok(Expr::BinaryExpr(BinaryExpr {
1995            left: Box::new(expr),
1996            op: Operator::IsNotDistinctFrom,
1997            right: Box::new(lit(true)),
1998        }))
1999    }
2000}
2001
2002// A * 1 -> A
2003// A / 1 -> A
2004//
2005// Move this function body out of the large match branch avoid stack overflow
2006fn simplify_right_is_one_case<S: SimplifyInfo>(
2007    info: &S,
2008    left: Box<Expr>,
2009    op: &Operator,
2010    right: &Expr,
2011) -> Result<Transformed<Expr>> {
2012    // Check if resulting type would be different due to coercion
2013    let left_type = info.get_data_type(&left)?;
2014    let right_type = info.get_data_type(right)?;
2015    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2016        Ok(result_type) => {
2017            // Only cast if the types differ
2018            if left_type != result_type {
2019                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2020            } else {
2021                Ok(Transformed::yes(*left))
2022            }
2023        }
2024        Err(_) => Ok(Transformed::yes(*left)),
2025    }
2026}
2027
2028#[cfg(test)]
2029mod tests {
2030    use super::*;
2031    use crate::simplify_expressions::SimplifyContext;
2032    use crate::test::test_table_scan_with_name;
2033    use arrow::datatypes::FieldRef;
2034    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
2035    use datafusion_expr::{
2036        expr::WindowFunction,
2037        function::{
2038            AccumulatorArgs, AggregateFunctionSimplification,
2039            WindowFunctionSimplification,
2040        },
2041        interval_arithmetic::Interval,
2042        *,
2043    };
2044    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2045    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2046    use datafusion_physical_expr::PhysicalExpr;
2047    use std::hash::Hash;
2048    use std::{
2049        collections::HashMap,
2050        ops::{BitAnd, BitOr, BitXor},
2051        sync::Arc,
2052    };
2053
2054    // ------------------------------
2055    // --- ExprSimplifier tests -----
2056    // ------------------------------
2057    #[test]
2058    fn api_basic() {
2059        let props = ExecutionProps::new();
2060        let simplifier =
2061            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2062
2063        let expr = lit(1) + lit(2);
2064        let expected = lit(3);
2065        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2066    }
2067
2068    #[test]
2069    fn basic_coercion() {
2070        let schema = test_schema();
2071        let props = ExecutionProps::new();
2072        let simplifier = ExprSimplifier::new(
2073            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2074        );
2075
2076        // Note expr type is int32 (not int64)
2077        // (1i64 + 2i32) < i
2078        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2079        // should fully simplify to 3 < i (though i has been coerced to i64)
2080        let expected = lit(3i64).lt(col("i"));
2081
2082        let expr = simplifier.coerce(expr, &schema).unwrap();
2083
2084        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2085    }
2086
2087    fn test_schema() -> DFSchemaRef {
2088        Schema::new(vec![
2089            Field::new("i", DataType::Int64, false),
2090            Field::new("b", DataType::Boolean, true),
2091        ])
2092        .to_dfschema_ref()
2093        .unwrap()
2094    }
2095
2096    #[test]
2097    fn simplify_and_constant_prop() {
2098        let props = ExecutionProps::new();
2099        let simplifier =
2100            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2101
2102        // should be able to simplify to false
2103        // (i * (1 - 2)) > 0
2104        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2105        let expected = lit(false);
2106        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2107    }
2108
2109    #[test]
2110    fn simplify_and_constant_prop_with_case() {
2111        let props = ExecutionProps::new();
2112        let simplifier =
2113            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2114
2115        //   CASE
2116        //     WHEN i>5 AND false THEN i > 5
2117        //     WHEN i<5 AND true THEN i < 5
2118        //     ELSE false
2119        //   END
2120        //
2121        // Can be simplified to `i < 5`
2122        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2123            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2124            .otherwise(lit(false))
2125            .unwrap();
2126        let expected = col("i").lt(lit(5));
2127        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2128    }
2129
2130    // ------------------------------
2131    // --- Simplifier tests -----
2132    // ------------------------------
2133
2134    #[test]
2135    fn test_simplify_canonicalize() {
2136        {
2137            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2138            let expected = col("c2").gt(lit(1));
2139            assert_eq!(simplify(expr), expected);
2140        }
2141        {
2142            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2143            let expected = col("c2").gt(col("c1"));
2144            assert_eq!(simplify(expr), expected);
2145        }
2146        {
2147            let expr = col("c1")
2148                .eq(lit(1))
2149                .and(lit(1).eq(col("c1")))
2150                .and(col("c1").eq(lit(3)));
2151            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2152            assert_eq!(simplify(expr), expected);
2153        }
2154        {
2155            let expr = col("c1")
2156                .eq(col("c2"))
2157                .and(col("c1").gt(lit(5)))
2158                .and(col("c2").eq(col("c1")));
2159            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2160            assert_eq!(simplify(expr), expected);
2161        }
2162        {
2163            let expr = col("c1")
2164                .eq(lit(1))
2165                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2166            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2167            assert_eq!(simplify(expr), expected);
2168        }
2169        {
2170            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2171            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2172            assert_eq!(simplify(expr), expected);
2173        }
2174        {
2175            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2176            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2177            assert_eq!(simplify(expr), expected);
2178        }
2179        {
2180            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2181            let expected = col("c2").lt(col("c1"));
2182            assert_eq!(simplify(expr), expected);
2183        }
2184    }
2185
2186    #[test]
2187    fn test_simplify_eq_not_self() {
2188        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2189        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2190        let expr_a = col("c2").eq(col("c2"));
2191        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2192
2193        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2194        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2195        let expected_b = lit(true);
2196
2197        assert_eq!(simplify(expr_a), expected_a);
2198        assert_eq!(simplify(expr_b), expected_b);
2199    }
2200
2201    #[test]
2202    fn test_simplify_or_true() {
2203        let expr_a = col("c2").or(lit(true));
2204        let expr_b = lit(true).or(col("c2"));
2205        let expected = lit(true);
2206
2207        assert_eq!(simplify(expr_a), expected);
2208        assert_eq!(simplify(expr_b), expected);
2209    }
2210
2211    #[test]
2212    fn test_simplify_or_false() {
2213        let expr_a = lit(false).or(col("c2"));
2214        let expr_b = col("c2").or(lit(false));
2215        let expected = col("c2");
2216
2217        assert_eq!(simplify(expr_a), expected);
2218        assert_eq!(simplify(expr_b), expected);
2219    }
2220
2221    #[test]
2222    fn test_simplify_or_same() {
2223        let expr = col("c2").or(col("c2"));
2224        let expected = col("c2");
2225
2226        assert_eq!(simplify(expr), expected);
2227    }
2228
2229    #[test]
2230    fn test_simplify_or_not_self() {
2231        // A OR !A if A is not nullable --> true
2232        // !A OR A if A is not nullable --> true
2233        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2234        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2235        let expected = lit(true);
2236
2237        assert_eq!(simplify(expr_a), expected);
2238        assert_eq!(simplify(expr_b), expected);
2239    }
2240
2241    #[test]
2242    fn test_simplify_and_false() {
2243        let expr_a = lit(false).and(col("c2"));
2244        let expr_b = col("c2").and(lit(false));
2245        let expected = lit(false);
2246
2247        assert_eq!(simplify(expr_a), expected);
2248        assert_eq!(simplify(expr_b), expected);
2249    }
2250
2251    #[test]
2252    fn test_simplify_and_same() {
2253        let expr = col("c2").and(col("c2"));
2254        let expected = col("c2");
2255
2256        assert_eq!(simplify(expr), expected);
2257    }
2258
2259    #[test]
2260    fn test_simplify_and_true() {
2261        let expr_a = lit(true).and(col("c2"));
2262        let expr_b = col("c2").and(lit(true));
2263        let expected = col("c2");
2264
2265        assert_eq!(simplify(expr_a), expected);
2266        assert_eq!(simplify(expr_b), expected);
2267    }
2268
2269    #[test]
2270    fn test_simplify_and_not_self() {
2271        // A AND !A if A is not nullable --> false
2272        // !A AND A if A is not nullable --> false
2273        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2274        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2275        let expected = lit(false);
2276
2277        assert_eq!(simplify(expr_a), expected);
2278        assert_eq!(simplify(expr_b), expected);
2279    }
2280
2281    #[test]
2282    fn test_simplify_multiply_by_one() {
2283        let expr_a = col("c2") * lit(1);
2284        let expr_b = lit(1) * col("c2");
2285        let expected = col("c2");
2286
2287        assert_eq!(simplify(expr_a), expected);
2288        assert_eq!(simplify(expr_b), expected);
2289
2290        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2291        assert_eq!(simplify(expr), expected);
2292
2293        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2294        assert_eq!(simplify(expr), expected);
2295    }
2296
2297    #[test]
2298    fn test_simplify_multiply_by_null() {
2299        let null = lit(ScalarValue::Int64(None));
2300        // A * null --> null
2301        {
2302            let expr = col("c3") * null.clone();
2303            assert_eq!(simplify(expr), null);
2304        }
2305        // null * A --> null
2306        {
2307            let expr = null.clone() * col("c3");
2308            assert_eq!(simplify(expr), null);
2309        }
2310    }
2311
2312    #[test]
2313    fn test_simplify_multiply_by_zero() {
2314        // cannot optimize A * null (null * A) if A is nullable
2315        {
2316            let expr_a = col("c2") * lit(0);
2317            let expr_b = lit(0) * col("c2");
2318
2319            assert_eq!(simplify(expr_a.clone()), expr_a);
2320            assert_eq!(simplify(expr_b.clone()), expr_b);
2321        }
2322        // 0 * A --> 0 if A is not nullable
2323        {
2324            let expr = lit(0) * col("c2_non_null");
2325            assert_eq!(simplify(expr), lit(0));
2326        }
2327        // A * 0 --> 0 if A is not nullable
2328        {
2329            let expr = col("c2_non_null") * lit(0);
2330            assert_eq!(simplify(expr), lit(0));
2331        }
2332        // A * Decimal128(0) --> 0 if A is not nullable
2333        {
2334            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2335            assert_eq!(
2336                simplify(expr),
2337                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2338            );
2339            let expr = binary_expr(
2340                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2341                Operator::Multiply,
2342                col("c2_non_null"),
2343            );
2344            assert_eq!(
2345                simplify(expr),
2346                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2347            );
2348        }
2349    }
2350
2351    #[test]
2352    fn test_simplify_divide_by_one() {
2353        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2354        let expected = col("c2");
2355        assert_eq!(simplify(expr), expected);
2356        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2357        assert_eq!(simplify(expr), expected);
2358    }
2359
2360    #[test]
2361    fn test_simplify_divide_null() {
2362        // A / null --> null
2363        let null = lit(ScalarValue::Int64(None));
2364        {
2365            let expr = col("c3") / null.clone();
2366            assert_eq!(simplify(expr), null);
2367        }
2368        // null / A --> null
2369        {
2370            let expr = null.clone() / col("c3");
2371            assert_eq!(simplify(expr), null);
2372        }
2373    }
2374
2375    #[test]
2376    fn test_simplify_divide_by_same() {
2377        let expr = col("c2") / col("c2");
2378        // if c2 is null, c2 / c2 = null, so can't simplify
2379        let expected = expr.clone();
2380
2381        assert_eq!(simplify(expr), expected);
2382    }
2383
2384    #[test]
2385    fn test_simplify_modulo_by_null() {
2386        let null = lit(ScalarValue::Int64(None));
2387        // A % null --> null
2388        {
2389            let expr = col("c3") % null.clone();
2390            assert_eq!(simplify(expr), null);
2391        }
2392        // null % A --> null
2393        {
2394            let expr = null.clone() % col("c3");
2395            assert_eq!(simplify(expr), null);
2396        }
2397    }
2398
2399    #[test]
2400    fn test_simplify_modulo_by_one() {
2401        let expr = col("c2") % lit(1);
2402        // if c2 is null, c2 % 1 = null, so can't simplify
2403        let expected = expr.clone();
2404
2405        assert_eq!(simplify(expr), expected);
2406    }
2407
2408    #[test]
2409    fn test_simplify_divide_zero_by_zero() {
2410        // because divide by 0 maybe occur in short-circuit expression
2411        // so we should not simplify this, and throw error in runtime
2412        let expr = lit(0) / lit(0);
2413        let expected = expr.clone();
2414
2415        assert_eq!(simplify(expr), expected);
2416    }
2417
2418    #[test]
2419    fn test_simplify_divide_by_zero() {
2420        // because divide by 0 maybe occur in short-circuit expression
2421        // so we should not simplify this, and throw error in runtime
2422        let expr = col("c2_non_null") / lit(0);
2423        let expected = expr.clone();
2424
2425        assert_eq!(simplify(expr), expected);
2426    }
2427
2428    #[test]
2429    fn test_simplify_modulo_by_one_non_null() {
2430        let expr = col("c3_non_null") % lit(1);
2431        let expected = lit(0_i64);
2432        assert_eq!(simplify(expr), expected);
2433        let expr =
2434            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2435        assert_eq!(simplify(expr), expected);
2436    }
2437
2438    #[test]
2439    fn test_simplify_bitwise_xor_by_null() {
2440        let null = lit(ScalarValue::Int64(None));
2441        // A ^ null --> null
2442        {
2443            let expr = col("c3") ^ null.clone();
2444            assert_eq!(simplify(expr), null);
2445        }
2446        // null ^ A --> null
2447        {
2448            let expr = null.clone() ^ col("c3");
2449            assert_eq!(simplify(expr), null);
2450        }
2451    }
2452
2453    #[test]
2454    fn test_simplify_bitwise_shift_right_by_null() {
2455        let null = lit(ScalarValue::Int64(None));
2456        // A >> null --> null
2457        {
2458            let expr = col("c3") >> null.clone();
2459            assert_eq!(simplify(expr), null);
2460        }
2461        // null >> A --> null
2462        {
2463            let expr = null.clone() >> col("c3");
2464            assert_eq!(simplify(expr), null);
2465        }
2466    }
2467
2468    #[test]
2469    fn test_simplify_bitwise_shift_left_by_null() {
2470        let null = lit(ScalarValue::Int64(None));
2471        // A << null --> null
2472        {
2473            let expr = col("c3") << null.clone();
2474            assert_eq!(simplify(expr), null);
2475        }
2476        // null << A --> null
2477        {
2478            let expr = null.clone() << col("c3");
2479            assert_eq!(simplify(expr), null);
2480        }
2481    }
2482
2483    #[test]
2484    fn test_simplify_bitwise_and_by_zero() {
2485        // A & 0 --> 0
2486        {
2487            let expr = col("c2_non_null") & lit(0);
2488            assert_eq!(simplify(expr), lit(0));
2489        }
2490        // 0 & A --> 0
2491        {
2492            let expr = lit(0) & col("c2_non_null");
2493            assert_eq!(simplify(expr), lit(0));
2494        }
2495    }
2496
2497    #[test]
2498    fn test_simplify_bitwise_or_by_zero() {
2499        // A | 0 --> A
2500        {
2501            let expr = col("c2_non_null") | lit(0);
2502            assert_eq!(simplify(expr), col("c2_non_null"));
2503        }
2504        // 0 | A --> A
2505        {
2506            let expr = lit(0) | col("c2_non_null");
2507            assert_eq!(simplify(expr), col("c2_non_null"));
2508        }
2509    }
2510
2511    #[test]
2512    fn test_simplify_bitwise_xor_by_zero() {
2513        // A ^ 0 --> A
2514        {
2515            let expr = col("c2_non_null") ^ lit(0);
2516            assert_eq!(simplify(expr), col("c2_non_null"));
2517        }
2518        // 0 ^ A --> A
2519        {
2520            let expr = lit(0) ^ col("c2_non_null");
2521            assert_eq!(simplify(expr), col("c2_non_null"));
2522        }
2523    }
2524
2525    #[test]
2526    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2527        // A >> 0 --> A
2528        {
2529            let expr = col("c2_non_null") >> lit(0);
2530            assert_eq!(simplify(expr), col("c2_non_null"));
2531        }
2532    }
2533
2534    #[test]
2535    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2536        // A << 0 --> A
2537        {
2538            let expr = col("c2_non_null") << lit(0);
2539            assert_eq!(simplify(expr), col("c2_non_null"));
2540        }
2541    }
2542
2543    #[test]
2544    fn test_simplify_bitwise_and_by_null() {
2545        let null = Expr::Literal(ScalarValue::Int64(None), None);
2546        // A & null --> null
2547        {
2548            let expr = col("c3") & null.clone();
2549            assert_eq!(simplify(expr), null);
2550        }
2551        // null & A --> null
2552        {
2553            let expr = null.clone() & col("c3");
2554            assert_eq!(simplify(expr), null);
2555        }
2556    }
2557
2558    #[test]
2559    fn test_simplify_composed_bitwise_and() {
2560        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2561
2562        let expr = bitwise_and(
2563            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2564            col("c2").gt(lit(5)),
2565        );
2566        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2567
2568        assert_eq!(simplify(expr), expected);
2569
2570        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2571
2572        let expr = bitwise_and(
2573            col("c2").gt(lit(5)),
2574            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2575        );
2576        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2577        assert_eq!(simplify(expr), expected);
2578    }
2579
2580    #[test]
2581    fn test_simplify_composed_bitwise_or() {
2582        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2583
2584        let expr = bitwise_or(
2585            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2586            col("c2").gt(lit(5)),
2587        );
2588        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2589
2590        assert_eq!(simplify(expr), expected);
2591
2592        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2593
2594        let expr = bitwise_or(
2595            col("c2").gt(lit(5)),
2596            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2597        );
2598        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2599
2600        assert_eq!(simplify(expr), expected);
2601    }
2602
2603    #[test]
2604    fn test_simplify_composed_bitwise_xor() {
2605        // with an even number of the column "c2"
2606        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2607
2608        let expr = bitwise_xor(
2609            col("c2"),
2610            bitwise_xor(
2611                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2612                bitwise_and(col("c1"), col("c2")),
2613            ),
2614        );
2615
2616        let expected = bitwise_xor(
2617            bitwise_or(col("c2"), col("c1")),
2618            bitwise_and(col("c1"), col("c2")),
2619        );
2620
2621        assert_eq!(simplify(expr), expected);
2622
2623        // with an odd number of the column "c2"
2624        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2625
2626        let expr = bitwise_xor(
2627            col("c2"),
2628            bitwise_xor(
2629                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2630                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2631            ),
2632        );
2633
2634        let expected = bitwise_xor(
2635            col("c2"),
2636            bitwise_xor(
2637                bitwise_or(col("c2"), col("c1")),
2638                bitwise_and(col("c1"), col("c2")),
2639            ),
2640        );
2641
2642        assert_eq!(simplify(expr), expected);
2643
2644        // with an even number of the column "c2"
2645        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2646
2647        let expr = bitwise_xor(
2648            bitwise_xor(
2649                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2650                bitwise_and(col("c1"), col("c2")),
2651            ),
2652            col("c2"),
2653        );
2654
2655        let expected = bitwise_xor(
2656            bitwise_or(col("c2"), col("c1")),
2657            bitwise_and(col("c1"), col("c2")),
2658        );
2659
2660        assert_eq!(simplify(expr), expected);
2661
2662        // with an odd number of the column "c2"
2663        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2664
2665        let expr = bitwise_xor(
2666            bitwise_xor(
2667                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2668                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2669            ),
2670            col("c2"),
2671        );
2672
2673        let expected = bitwise_xor(
2674            bitwise_xor(
2675                bitwise_or(col("c2"), col("c1")),
2676                bitwise_and(col("c1"), col("c2")),
2677            ),
2678            col("c2"),
2679        );
2680
2681        assert_eq!(simplify(expr), expected);
2682    }
2683
2684    #[test]
2685    fn test_simplify_negated_bitwise_and() {
2686        // !c4 & c4 --> 0
2687        let expr = (-col("c4_non_null")) & col("c4_non_null");
2688        let expected = lit(0u32);
2689
2690        assert_eq!(simplify(expr), expected);
2691        // c4 & !c4 --> 0
2692        let expr = col("c4_non_null") & (-col("c4_non_null"));
2693        let expected = lit(0u32);
2694
2695        assert_eq!(simplify(expr), expected);
2696
2697        // !c3 & c3 --> 0
2698        let expr = (-col("c3_non_null")) & col("c3_non_null");
2699        let expected = lit(0i64);
2700
2701        assert_eq!(simplify(expr), expected);
2702        // c3 & !c3 --> 0
2703        let expr = col("c3_non_null") & (-col("c3_non_null"));
2704        let expected = lit(0i64);
2705
2706        assert_eq!(simplify(expr), expected);
2707    }
2708
2709    #[test]
2710    fn test_simplify_negated_bitwise_or() {
2711        // !c4 | c4 --> -1
2712        let expr = (-col("c4_non_null")) | col("c4_non_null");
2713        let expected = lit(-1i32);
2714
2715        assert_eq!(simplify(expr), expected);
2716
2717        // c4 | !c4 --> -1
2718        let expr = col("c4_non_null") | (-col("c4_non_null"));
2719        let expected = lit(-1i32);
2720
2721        assert_eq!(simplify(expr), expected);
2722
2723        // !c3 | c3 --> -1
2724        let expr = (-col("c3_non_null")) | col("c3_non_null");
2725        let expected = lit(-1i64);
2726
2727        assert_eq!(simplify(expr), expected);
2728
2729        // c3 | !c3 --> -1
2730        let expr = col("c3_non_null") | (-col("c3_non_null"));
2731        let expected = lit(-1i64);
2732
2733        assert_eq!(simplify(expr), expected);
2734    }
2735
2736    #[test]
2737    fn test_simplify_negated_bitwise_xor() {
2738        // !c4 ^ c4 --> -1
2739        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2740        let expected = lit(-1i32);
2741
2742        assert_eq!(simplify(expr), expected);
2743
2744        // c4 ^ !c4 --> -1
2745        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2746        let expected = lit(-1i32);
2747
2748        assert_eq!(simplify(expr), expected);
2749
2750        // !c3 ^ c3 --> -1
2751        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2752        let expected = lit(-1i64);
2753
2754        assert_eq!(simplify(expr), expected);
2755
2756        // c3 ^ !c3 --> -1
2757        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2758        let expected = lit(-1i64);
2759
2760        assert_eq!(simplify(expr), expected);
2761    }
2762
2763    #[test]
2764    fn test_simplify_bitwise_and_or() {
2765        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2766        let expr = bitwise_and(
2767            col("c2_non_null").lt(lit(3)),
2768            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2769        );
2770        let expected = col("c2_non_null").lt(lit(3));
2771
2772        assert_eq!(simplify(expr), expected);
2773    }
2774
2775    #[test]
2776    fn test_simplify_bitwise_or_and() {
2777        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2778        let expr = bitwise_or(
2779            col("c2_non_null").lt(lit(3)),
2780            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2781        );
2782        let expected = col("c2_non_null").lt(lit(3));
2783
2784        assert_eq!(simplify(expr), expected);
2785    }
2786
2787    #[test]
2788    fn test_simplify_simple_bitwise_and() {
2789        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2790        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2791        let expected = col("c2").gt(lit(5));
2792
2793        assert_eq!(simplify(expr), expected);
2794    }
2795
2796    #[test]
2797    fn test_simplify_simple_bitwise_or() {
2798        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2799        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2800        let expected = col("c2").gt(lit(5));
2801
2802        assert_eq!(simplify(expr), expected);
2803    }
2804
2805    #[test]
2806    fn test_simplify_simple_bitwise_xor() {
2807        // c4 ^ c4 -> 0
2808        let expr = (col("c4")).bitxor(col("c4"));
2809        let expected = lit(0u32);
2810
2811        assert_eq!(simplify(expr), expected);
2812
2813        // c3 ^ c3 -> 0
2814        let expr = col("c3").bitxor(col("c3"));
2815        let expected = lit(0i64);
2816
2817        assert_eq!(simplify(expr), expected);
2818    }
2819
2820    #[test]
2821    fn test_simplify_modulo_by_zero_non_null() {
2822        // because modulo by 0 maybe occur in short-circuit expression
2823        // so we should not simplify this, and throw error in runtime.
2824        let expr = col("c2_non_null") % lit(0);
2825        let expected = expr.clone();
2826
2827        assert_eq!(simplify(expr), expected);
2828    }
2829
2830    #[test]
2831    fn test_simplify_simple_and() {
2832        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2833        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2834        let expected = col("c2").gt(lit(5));
2835
2836        assert_eq!(simplify(expr), expected);
2837    }
2838
2839    #[test]
2840    fn test_simplify_composed_and() {
2841        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2842        let expr = and(
2843            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2844            col("c2").gt(lit(5)),
2845        );
2846        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2847
2848        assert_eq!(simplify(expr), expected);
2849    }
2850
2851    #[test]
2852    fn test_simplify_negated_and() {
2853        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2854        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2855        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2856
2857        assert_eq!(simplify(expr), expected);
2858    }
2859
2860    #[test]
2861    fn test_simplify_or_and() {
2862        let l = col("c2").gt(lit(5));
2863        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2864
2865        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
2866        let expr = or(l.clone(), r.clone());
2867
2868        let expected = l.clone();
2869        assert_eq!(simplify(expr), expected);
2870
2871        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
2872        let expr = or(r, l);
2873        assert_eq!(simplify(expr), expected);
2874    }
2875
2876    #[test]
2877    fn test_simplify_or_and_non_null() {
2878        let l = col("c2_non_null").gt(lit(5));
2879        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2880
2881        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
2882        let expr = or(l.clone(), r.clone());
2883
2884        // This is only true if `c1 < 6` is not nullable / can not be null.
2885        let expected = col("c2_non_null").gt(lit(5));
2886
2887        assert_eq!(simplify(expr), expected);
2888
2889        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
2890        let expr = or(l, r);
2891
2892        assert_eq!(simplify(expr), expected);
2893    }
2894
2895    #[test]
2896    fn test_simplify_and_or() {
2897        let l = col("c2").gt(lit(5));
2898        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2899
2900        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2901        let expr = and(l.clone(), r.clone());
2902
2903        let expected = l.clone();
2904        assert_eq!(simplify(expr), expected);
2905
2906        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2907        let expr = and(r, l);
2908        assert_eq!(simplify(expr), expected);
2909    }
2910
2911    #[test]
2912    fn test_simplify_and_or_non_null() {
2913        let l = col("c2_non_null").gt(lit(5));
2914        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2915
2916        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2917        let expr = and(l.clone(), r.clone());
2918
2919        // This is only true if `c1 < 6` is not nullable / can not be null.
2920        let expected = col("c2_non_null").gt(lit(5));
2921
2922        assert_eq!(simplify(expr), expected);
2923
2924        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2925        let expr = and(l, r);
2926
2927        assert_eq!(simplify(expr), expected);
2928    }
2929
2930    #[test]
2931    fn test_simplify_by_de_morgan_laws() {
2932        // Laws with logical operations
2933        // !(c3 AND c4) --> !c3 OR !c4
2934        let expr = and(col("c3"), col("c4")).not();
2935        let expected = or(col("c3").not(), col("c4").not());
2936        assert_eq!(simplify(expr), expected);
2937        // !(c3 OR c4) --> !c3 AND !c4
2938        let expr = or(col("c3"), col("c4")).not();
2939        let expected = and(col("c3").not(), col("c4").not());
2940        assert_eq!(simplify(expr), expected);
2941        // !(!c3) --> c3
2942        let expr = col("c3").not().not();
2943        let expected = col("c3");
2944        assert_eq!(simplify(expr), expected);
2945
2946        // Laws with bitwise operations
2947        // !(c3 & c4) --> !c3 | !c4
2948        let expr = -bitwise_and(col("c3"), col("c4"));
2949        let expected = bitwise_or(-col("c3"), -col("c4"));
2950        assert_eq!(simplify(expr), expected);
2951        // !(c3 | c4) --> !c3 & !c4
2952        let expr = -bitwise_or(col("c3"), col("c4"));
2953        let expected = bitwise_and(-col("c3"), -col("c4"));
2954        assert_eq!(simplify(expr), expected);
2955        // !(!c3) --> c3
2956        let expr = -(-col("c3"));
2957        let expected = col("c3");
2958        assert_eq!(simplify(expr), expected);
2959    }
2960
2961    #[test]
2962    fn test_simplify_null_and_false() {
2963        let expr = and(lit_bool_null(), lit(false));
2964        let expr_eq = lit(false);
2965
2966        assert_eq!(simplify(expr), expr_eq);
2967    }
2968
2969    #[test]
2970    fn test_simplify_divide_null_by_null() {
2971        let null = lit(ScalarValue::Int32(None));
2972        let expr_plus = null.clone() / null.clone();
2973        let expr_eq = null;
2974
2975        assert_eq!(simplify(expr_plus), expr_eq);
2976    }
2977
2978    #[test]
2979    fn test_simplify_simplify_arithmetic_expr() {
2980        let expr_plus = lit(1) + lit(1);
2981
2982        assert_eq!(simplify(expr_plus), lit(2));
2983    }
2984
2985    #[test]
2986    fn test_simplify_simplify_eq_expr() {
2987        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
2988
2989        assert_eq!(simplify(expr_eq), lit(true));
2990    }
2991
2992    #[test]
2993    fn test_simplify_regex() {
2994        // malformed regex
2995        assert_contains!(
2996            try_simplify(regex_match(col("c1"), lit("foo{")))
2997                .unwrap_err()
2998                .to_string(),
2999            "regex parse error"
3000        );
3001
3002        // unsupported cases
3003        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3004        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3005        assert_no_change(regex_match(col("c1"), lit("%")));
3006        assert_no_change(regex_match(col("c1"), lit("_")));
3007        assert_no_change(regex_match(col("c1"), lit("f%o")));
3008        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3009        assert_no_change(regex_match(col("c1"), lit("f_o")));
3010
3011        // empty cases
3012        assert_change(
3013            regex_match(col("c1"), lit("")),
3014            if_not_null(col("c1"), true),
3015        );
3016        assert_change(
3017            regex_not_match(col("c1"), lit("")),
3018            if_not_null(col("c1"), false),
3019        );
3020        assert_change(
3021            regex_imatch(col("c1"), lit("")),
3022            if_not_null(col("c1"), true),
3023        );
3024        assert_change(
3025            regex_not_imatch(col("c1"), lit("")),
3026            if_not_null(col("c1"), false),
3027        );
3028
3029        // single character
3030        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3031
3032        // single word
3033        assert_change(
3034            regex_match(col("c1"), lit("foo")),
3035            col("c1").like(lit("%foo%")),
3036        );
3037
3038        // regular expressions that match an exact literal
3039        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3040        assert_change(
3041            regex_not_match(col("c1"), lit("^$")),
3042            col("c1").not_eq(lit("")),
3043        );
3044        assert_change(
3045            regex_match(col("c1"), lit("^foo$")),
3046            col("c1").eq(lit("foo")),
3047        );
3048        assert_change(
3049            regex_not_match(col("c1"), lit("^foo$")),
3050            col("c1").not_eq(lit("foo")),
3051        );
3052
3053        // regular expressions that match exact captured literals
3054        assert_change(
3055            regex_match(col("c1"), lit("^(foo|bar)$")),
3056            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3057        );
3058        assert_change(
3059            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3060            col("c1")
3061                .not_eq(lit("foo"))
3062                .and(col("c1").not_eq(lit("bar"))),
3063        );
3064        assert_change(
3065            regex_match(col("c1"), lit("^(foo)$")),
3066            col("c1").eq(lit("foo")),
3067        );
3068        assert_change(
3069            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3070            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3071                .or(col("c1").eq(lit("baz"))),
3072        );
3073        assert_change(
3074            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3075            col("c1")
3076                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3077        );
3078        assert_change(
3079            regex_match(col("c1"), lit("^(fo_o)$")),
3080            col("c1").eq(lit("fo_o")),
3081        );
3082        assert_change(
3083            regex_match(col("c1"), lit("^(fo_o)$")),
3084            col("c1").eq(lit("fo_o")),
3085        );
3086        assert_change(
3087            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3088            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3089        );
3090        assert_change(
3091            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3092            col("c1")
3093                .not_eq(lit("fo_o"))
3094                .and(col("c1").not_eq(lit("ba_r"))),
3095        );
3096        assert_change(
3097            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3098            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3099                .or(col("c1").eq(lit("ba_z"))),
3100        );
3101        assert_change(
3102            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3103            col("c1").in_list(
3104                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3105                false,
3106            ),
3107        );
3108
3109        // regular expressions that mismatch captured literals
3110        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3111        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3112        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3113        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3114        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3115        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3116        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3117        assert_no_change(regex_match(col("c1"), lit("^")));
3118        assert_no_change(regex_match(col("c1"), lit("$")));
3119        assert_no_change(regex_match(col("c1"), lit("$^")));
3120        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3121
3122        // regular expressions that match a partial literal
3123        assert_change(
3124            regex_match(col("c1"), lit("^foo")),
3125            col("c1").like(lit("foo%")),
3126        );
3127        assert_change(
3128            regex_match(col("c1"), lit("foo$")),
3129            col("c1").like(lit("%foo")),
3130        );
3131        assert_change(
3132            regex_match(col("c1"), lit("^foo|bar$")),
3133            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3134        );
3135
3136        // OR-chain
3137        assert_change(
3138            regex_match(col("c1"), lit("foo|bar|baz")),
3139            col("c1")
3140                .like(lit("%foo%"))
3141                .or(col("c1").like(lit("%bar%")))
3142                .or(col("c1").like(lit("%baz%"))),
3143        );
3144        assert_change(
3145            regex_match(col("c1"), lit("foo|x|baz")),
3146            col("c1")
3147                .like(lit("%foo%"))
3148                .or(col("c1").like(lit("%x%")))
3149                .or(col("c1").like(lit("%baz%"))),
3150        );
3151        assert_change(
3152            regex_not_match(col("c1"), lit("foo|bar|baz")),
3153            col("c1")
3154                .not_like(lit("%foo%"))
3155                .and(col("c1").not_like(lit("%bar%")))
3156                .and(col("c1").not_like(lit("%baz%"))),
3157        );
3158        // both anchored expressions (translated to equality) and unanchored
3159        assert_change(
3160            regex_match(col("c1"), lit("foo|^x$|baz")),
3161            col("c1")
3162                .like(lit("%foo%"))
3163                .or(col("c1").eq(lit("x")))
3164                .or(col("c1").like(lit("%baz%"))),
3165        );
3166        assert_change(
3167            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3168            col("c1")
3169                .not_like(lit("%foo%"))
3170                .and(col("c1").not_eq(lit("bar")))
3171                .and(col("c1").not_like(lit("%baz%"))),
3172        );
3173        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3174        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3175    }
3176
3177    #[track_caller]
3178    fn assert_no_change(expr: Expr) {
3179        let optimized = simplify(expr.clone());
3180        assert_eq!(expr, optimized);
3181    }
3182
3183    #[track_caller]
3184    fn assert_change(expr: Expr, expected: Expr) {
3185        let optimized = simplify(expr);
3186        assert_eq!(optimized, expected);
3187    }
3188
3189    fn regex_match(left: Expr, right: Expr) -> Expr {
3190        Expr::BinaryExpr(BinaryExpr {
3191            left: Box::new(left),
3192            op: Operator::RegexMatch,
3193            right: Box::new(right),
3194        })
3195    }
3196
3197    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3198        Expr::BinaryExpr(BinaryExpr {
3199            left: Box::new(left),
3200            op: Operator::RegexNotMatch,
3201            right: Box::new(right),
3202        })
3203    }
3204
3205    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3206        Expr::BinaryExpr(BinaryExpr {
3207            left: Box::new(left),
3208            op: Operator::RegexIMatch,
3209            right: Box::new(right),
3210        })
3211    }
3212
3213    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3214        Expr::BinaryExpr(BinaryExpr {
3215            left: Box::new(left),
3216            op: Operator::RegexNotIMatch,
3217            right: Box::new(right),
3218        })
3219    }
3220
3221    // ------------------------------
3222    // ----- Simplifier tests -------
3223    // ------------------------------
3224
3225    fn try_simplify(expr: Expr) -> Result<Expr> {
3226        let schema = expr_test_schema();
3227        let execution_props = ExecutionProps::new();
3228        let simplifier = ExprSimplifier::new(
3229            SimplifyContext::new(&execution_props).with_schema(schema),
3230        );
3231        simplifier.simplify(expr)
3232    }
3233
3234    fn coerce(expr: Expr) -> Expr {
3235        let schema = expr_test_schema();
3236        let execution_props = ExecutionProps::new();
3237        let simplifier = ExprSimplifier::new(
3238            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3239        );
3240        simplifier.coerce(expr, schema.as_ref()).unwrap()
3241    }
3242
3243    fn simplify(expr: Expr) -> Expr {
3244        try_simplify(expr).unwrap()
3245    }
3246
3247    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3248        let schema = expr_test_schema();
3249        let execution_props = ExecutionProps::new();
3250        let simplifier = ExprSimplifier::new(
3251            SimplifyContext::new(&execution_props).with_schema(schema),
3252        );
3253        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3254        Ok((expr.data, count))
3255    }
3256
3257    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3258        try_simplify_with_cycle_count(expr).unwrap()
3259    }
3260
3261    fn simplify_with_guarantee(
3262        expr: Expr,
3263        guarantees: Vec<(Expr, NullableInterval)>,
3264    ) -> Expr {
3265        let schema = expr_test_schema();
3266        let execution_props = ExecutionProps::new();
3267        let simplifier = ExprSimplifier::new(
3268            SimplifyContext::new(&execution_props).with_schema(schema),
3269        )
3270        .with_guarantees(guarantees);
3271        simplifier.simplify(expr).unwrap()
3272    }
3273
3274    fn expr_test_schema() -> DFSchemaRef {
3275        Arc::new(
3276            DFSchema::from_unqualified_fields(
3277                vec![
3278                    Field::new("c1", DataType::Utf8, true),
3279                    Field::new("c2", DataType::Boolean, true),
3280                    Field::new("c3", DataType::Int64, true),
3281                    Field::new("c4", DataType::UInt32, true),
3282                    Field::new("c1_non_null", DataType::Utf8, false),
3283                    Field::new("c2_non_null", DataType::Boolean, false),
3284                    Field::new("c3_non_null", DataType::Int64, false),
3285                    Field::new("c4_non_null", DataType::UInt32, false),
3286                    Field::new("c5", DataType::FixedSizeBinary(3), true),
3287                ]
3288                .into(),
3289                HashMap::new(),
3290            )
3291            .unwrap(),
3292        )
3293    }
3294
3295    #[test]
3296    fn simplify_expr_null_comparison() {
3297        // x = null is always null
3298        assert_eq!(
3299            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3300            lit(ScalarValue::Boolean(None)),
3301        );
3302
3303        // null != null is always null
3304        assert_eq!(
3305            simplify(
3306                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3307            ),
3308            lit(ScalarValue::Boolean(None)),
3309        );
3310
3311        // x != null is always null
3312        assert_eq!(
3313            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3314            lit(ScalarValue::Boolean(None)),
3315        );
3316
3317        // null = x is always null
3318        assert_eq!(
3319            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3320            lit(ScalarValue::Boolean(None)),
3321        );
3322    }
3323
3324    #[test]
3325    fn simplify_expr_is_not_null() {
3326        assert_eq!(
3327            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3328            Expr::IsNotNull(Box::new(col("c1")))
3329        );
3330
3331        // 'c1_non_null IS NOT NULL' is always true
3332        assert_eq!(
3333            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3334            lit(true)
3335        );
3336    }
3337
3338    #[test]
3339    fn simplify_expr_is_null() {
3340        assert_eq!(
3341            simplify(Expr::IsNull(Box::new(col("c1")))),
3342            Expr::IsNull(Box::new(col("c1")))
3343        );
3344
3345        // 'c1_non_null IS NULL' is always false
3346        assert_eq!(
3347            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3348            lit(false)
3349        );
3350    }
3351
3352    #[test]
3353    fn simplify_expr_is_unknown() {
3354        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3355
3356        // 'c2_non_null is unknown' is always false
3357        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3358    }
3359
3360    #[test]
3361    fn simplify_expr_is_not_known() {
3362        assert_eq!(
3363            simplify(col("c2").is_not_unknown()),
3364            col("c2").is_not_unknown()
3365        );
3366
3367        // 'c2_non_null is not unknown' is always true
3368        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3369    }
3370
3371    #[test]
3372    fn simplify_expr_eq() {
3373        let schema = expr_test_schema();
3374        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3375
3376        // true = true -> true
3377        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3378
3379        // true = false -> false
3380        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3381
3382        // c2 = true -> c2
3383        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3384
3385        // c2 = false => !c2
3386        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3387    }
3388
3389    #[test]
3390    fn simplify_expr_eq_skip_nonboolean_type() {
3391        let schema = expr_test_schema();
3392
3393        // When one of the operand is not of boolean type, folding the
3394        // other boolean constant will change return type of
3395        // expression to non-boolean.
3396        //
3397        // Make sure c1 column to be used in tests is not boolean type
3398        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3399
3400        // don't fold c1 = foo
3401        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3402    }
3403
3404    #[test]
3405    fn simplify_expr_not_eq() {
3406        let schema = expr_test_schema();
3407
3408        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3409
3410        // c2 != true -> !c2
3411        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3412
3413        // c2 != false -> c2
3414        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3415
3416        // test constant
3417        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3418
3419        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3420    }
3421
3422    #[test]
3423    fn simplify_expr_not_eq_skip_nonboolean_type() {
3424        let schema = expr_test_schema();
3425
3426        // when one of the operand is not of boolean type, folding the
3427        // other boolean constant will change return type of
3428        // expression to non-boolean.
3429        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3430
3431        assert_eq!(
3432            simplify(col("c1").not_eq(lit("foo"))),
3433            col("c1").not_eq(lit("foo")),
3434        );
3435    }
3436
3437    #[test]
3438    fn simplify_expr_case_when_then_else() {
3439        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3440        // -->
3441        // CASE WHEN c2 THEN false ELSE c2
3442        // -->
3443        // false
3444        assert_eq!(
3445            simplify(Expr::Case(Case::new(
3446                None,
3447                vec![(
3448                    Box::new(col("c2_non_null").not_eq(lit(false))),
3449                    Box::new(lit("ok").eq(lit("not_ok"))),
3450                )],
3451                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3452            ))),
3453            lit(false) // #1716
3454        );
3455
3456        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3457        // -->
3458        // CASE WHEN c2 THEN true ELSE c2
3459        // -->
3460        // c2
3461        //
3462        // Need to call simplify 2x due to
3463        // https://github.com/apache/datafusion/issues/1160
3464        assert_eq!(
3465            simplify(simplify(Expr::Case(Case::new(
3466                None,
3467                vec![(
3468                    Box::new(col("c2_non_null").not_eq(lit(false))),
3469                    Box::new(lit("ok").eq(lit("ok"))),
3470                )],
3471                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3472            )))),
3473            col("c2_non_null")
3474        );
3475
3476        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3477        // -->
3478        // ISNULL(c2) OR c2
3479        //
3480        // Need to call simplify 2x due to
3481        // https://github.com/apache/datafusion/issues/1160
3482        assert_eq!(
3483            simplify(simplify(Expr::Case(Case::new(
3484                None,
3485                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3486                Some(Box::new(col("c2"))),
3487            )))),
3488            col("c2")
3489                .is_null()
3490                .or(col("c2").is_not_null().and(col("c2")))
3491        );
3492
3493        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3494        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3495        // --> c1 OR (NOT(c1) AND NOT(c2))
3496        // --> c1 OR NOT(c2)
3497        //
3498        // Need to call simplify 2x due to
3499        // https://github.com/apache/datafusion/issues/1160
3500        assert_eq!(
3501            simplify(simplify(Expr::Case(Case::new(
3502                None,
3503                vec![
3504                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3505                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3506                ],
3507                Some(Box::new(lit(true))),
3508            )))),
3509            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3510        );
3511
3512        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3513        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3514        // --> c1 OR (NOT(c1) AND c2)
3515        // --> c1 OR c2
3516        //
3517        // Need to call simplify 2x due to
3518        // https://github.com/apache/datafusion/issues/1160
3519        assert_eq!(
3520            simplify(simplify(Expr::Case(Case::new(
3521                None,
3522                vec![
3523                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3524                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3525                ],
3526                Some(Box::new(lit(true))),
3527            )))),
3528            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3529        );
3530
3531        // CASE WHEN c > 0 THEN true END AS c1
3532        assert_eq!(
3533            simplify(simplify(Expr::Case(Case::new(
3534                None,
3535                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3536                None,
3537            )))),
3538            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3539                col("c3").gt(lit(0_i64)),
3540                lit(true)
3541            )
3542            .and(lit_bool_null()))
3543        );
3544
3545        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3546        assert_eq!(
3547            simplify(simplify(Expr::Case(Case::new(
3548                None,
3549                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3550                Some(Box::new(lit(false))),
3551            )))),
3552            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3553        );
3554    }
3555
3556    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3557        Expr::BinaryExpr(BinaryExpr {
3558            left: Box::new(left.into()),
3559            op: Operator::IsDistinctFrom,
3560            right: Box::new(right.into()),
3561        })
3562    }
3563
3564    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3565        Expr::BinaryExpr(BinaryExpr {
3566            left: Box::new(left.into()),
3567            op: Operator::IsNotDistinctFrom,
3568            right: Box::new(right.into()),
3569        })
3570    }
3571
3572    #[test]
3573    fn simplify_expr_bool_or() {
3574        // col || true is always true
3575        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
3576
3577        // col || false is always col
3578        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
3579
3580        // true || null is always true
3581        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
3582
3583        // null || true is always true
3584        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
3585
3586        // false || null is always null
3587        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
3588
3589        // null || false is always null
3590        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
3591
3592        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
3593        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3594        // and should not be rewritten
3595        let expr = col("c1").between(lit(0), lit(10));
3596        let expr = expr.or(lit_bool_null());
3597        let result = simplify(expr);
3598
3599        let expected_expr = or(
3600            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3601            lit_bool_null(),
3602        );
3603        assert_eq!(expected_expr, result);
3604    }
3605
3606    #[test]
3607    fn simplify_inlist() {
3608        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
3609        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
3610
3611        // null in (...)  --> null
3612        assert_eq!(
3613            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
3614            lit_bool_null()
3615        );
3616
3617        // null not in (...)  --> null
3618        assert_eq!(
3619            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
3620            lit_bool_null()
3621        );
3622
3623        assert_eq!(
3624            simplify(in_list(col("c1"), vec![lit(1)], false)),
3625            col("c1").eq(lit(1))
3626        );
3627        assert_eq!(
3628            simplify(in_list(col("c1"), vec![lit(1)], true)),
3629            col("c1").not_eq(lit(1))
3630        );
3631
3632        // more complex expressions can be simplified if list contains
3633        // one element only
3634        assert_eq!(
3635            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
3636            (col("c1") * lit(10)).eq(lit(2))
3637        );
3638
3639        assert_eq!(
3640            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
3641            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
3642        );
3643        assert_eq!(
3644            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
3645            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
3646        );
3647
3648        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
3649        assert_eq!(
3650            simplify(in_list(
3651                col("c1"),
3652                vec![scalar_subquery(Arc::clone(&subquery))],
3653                false
3654            )),
3655            in_subquery(col("c1"), Arc::clone(&subquery))
3656        );
3657        assert_eq!(
3658            simplify(in_list(
3659                col("c1"),
3660                vec![scalar_subquery(Arc::clone(&subquery))],
3661                true
3662            )),
3663            not_in_subquery(col("c1"), subquery)
3664        );
3665
3666        let subquery1 =
3667            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
3668        let subquery2 =
3669            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
3670
3671        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
3672        assert_eq!(
3673            simplify(in_list(
3674                col("c1"),
3675                vec![subquery1.clone(), subquery2.clone()],
3676                true
3677            )),
3678            col("c1")
3679                .not_eq(subquery1.clone())
3680                .and(col("c1").not_eq(subquery2.clone()))
3681        );
3682
3683        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
3684        assert_eq!(
3685            simplify(in_list(
3686                col("c1"),
3687                vec![subquery1.clone(), subquery2.clone()],
3688                false
3689            )),
3690            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
3691        );
3692
3693        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
3694        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3695            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
3696        );
3697        assert_eq!(simplify(expr), lit(false));
3698
3699        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
3700        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3701            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
3702        );
3703        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
3704
3705        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
3706        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3707            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3708        );
3709        assert_eq!(simplify(expr), lit(true));
3710
3711        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
3712        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3713            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3714        );
3715        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
3716
3717        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
3718        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3719            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3720        );
3721        assert_eq!(
3722            simplify(expr),
3723            in_list(
3724                col("c1"),
3725                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
3726                true
3727            )
3728        );
3729
3730        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
3731        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
3732            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
3733        );
3734        assert_eq!(
3735            simplify(expr),
3736            in_list(
3737                col("c1"),
3738                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3739                false
3740            )
3741        );
3742
3743        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
3744        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
3745            col("c1"),
3746            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3747            true,
3748        ));
3749        assert_eq!(simplify(expr), lit(false));
3750
3751        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
3752        let expr =
3753            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
3754                col("c1"),
3755                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3756                false,
3757            ));
3758        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
3759
3760        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
3761        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3762            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3763        );
3764        assert_eq!(
3765            simplify(expr),
3766            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
3767        );
3768
3769        // inlist with more than two expressions
3770        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
3771        let expr = in_list(
3772            col("c1"),
3773            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
3774            false,
3775        )
3776        .and(in_list(
3777            col("c1"),
3778            vec![lit(1), lit(3), lit(5), lit(6)],
3779            false,
3780        ))
3781        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
3782        assert_eq!(
3783            simplify(expr),
3784            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
3785        );
3786
3787        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
3788        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3789            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
3790                .and(in_list(
3791                    col("c1"),
3792                    vec![lit(3), lit(4), lit(5), lit(6)],
3793                    true,
3794                ))
3795                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
3796        );
3797        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
3798
3799        // Contains non-InList expression
3800        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
3801        let expr =
3802            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
3803                .not_eq(lit(5))
3804                .or(in_list(
3805                    col("c1"),
3806                    vec![lit(6), lit(7), lit(8), lit(9)],
3807                    true,
3808                )));
3809        // TODO: Further simplify this expression
3810        // https://github.com/apache/datafusion/issues/8970
3811        // assert_eq!(simplify(expr.clone()), lit(true));
3812        assert_eq!(simplify(expr.clone()), expr);
3813    }
3814
3815    #[test]
3816    fn simplify_null_in_empty_inlist() {
3817        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
3818        let expr = in_list(lit_bool_null(), vec![], false);
3819        assert_eq!(simplify(expr), lit(false));
3820
3821        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
3822        let expr = in_list(lit_bool_null(), vec![], true);
3823        assert_eq!(simplify(expr), lit(true));
3824
3825        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
3826        let null_null = || Expr::Literal(ScalarValue::Null, None);
3827        let expr = in_list(null_null(), vec![], false);
3828        assert_eq!(simplify(expr), lit(false));
3829
3830        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
3831        let expr = in_list(null_null(), vec![], true);
3832        assert_eq!(simplify(expr), lit(true));
3833    }
3834
3835    #[test]
3836    fn just_simplifier_simplify_null_in_empty_inlist() {
3837        let simplify = |expr: Expr| -> Expr {
3838            let schema = expr_test_schema();
3839            let execution_props = ExecutionProps::new();
3840            let info = SimplifyContext::new(&execution_props).with_schema(schema);
3841            let simplifier = &mut Simplifier::new(&info);
3842            expr.rewrite(simplifier)
3843                .expect("Failed to simplify expression")
3844                .data
3845        };
3846
3847        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
3848        let expr = in_list(lit_bool_null(), vec![], false);
3849        assert_eq!(simplify(expr), lit(false));
3850
3851        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
3852        let expr = in_list(lit_bool_null(), vec![], true);
3853        assert_eq!(simplify(expr), lit(true));
3854
3855        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
3856        let null_null = || Expr::Literal(ScalarValue::Null, None);
3857        let expr = in_list(null_null(), vec![], false);
3858        assert_eq!(simplify(expr), lit(false));
3859
3860        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
3861        let expr = in_list(null_null(), vec![], true);
3862        assert_eq!(simplify(expr), lit(true));
3863    }
3864
3865    #[test]
3866    fn simplify_large_or() {
3867        let expr = (0..5)
3868            .map(|i| col("c1").eq(lit(i)))
3869            .fold(lit(false), |acc, e| acc.or(e));
3870        assert_eq!(
3871            simplify(expr),
3872            in_list(col("c1"), (0..5).map(lit).collect(), false),
3873        );
3874    }
3875
3876    #[test]
3877    fn simplify_expr_bool_and() {
3878        // col & true is always col
3879        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
3880        // col & false is always false
3881        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
3882
3883        // true && null is always null
3884        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
3885
3886        // null && true is always null
3887        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
3888
3889        // false && null is always false
3890        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
3891
3892        // null && false is always false
3893        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
3894
3895        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
3896        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3897        // and the Boolean(NULL) should remain
3898        let expr = col("c1").between(lit(0), lit(10));
3899        let expr = expr.and(lit_bool_null());
3900        let result = simplify(expr);
3901
3902        let expected_expr = and(
3903            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3904            lit_bool_null(),
3905        );
3906        assert_eq!(expected_expr, result);
3907    }
3908
3909    #[test]
3910    fn simplify_expr_between() {
3911        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
3912        let expr = col("c2").between(lit(3), lit(4));
3913        assert_eq!(
3914            simplify(expr),
3915            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
3916        );
3917
3918        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
3919        let expr = col("c2").not_between(lit(3), lit(4));
3920        assert_eq!(
3921            simplify(expr),
3922            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
3923        );
3924    }
3925
3926    #[test]
3927    fn test_like_and_ilike() {
3928        let null = lit(ScalarValue::Utf8(None));
3929
3930        // expr [NOT] [I]LIKE NULL
3931        let expr = col("c1").like(null.clone());
3932        assert_eq!(simplify(expr), lit_bool_null());
3933
3934        let expr = col("c1").not_like(null.clone());
3935        assert_eq!(simplify(expr), lit_bool_null());
3936
3937        let expr = col("c1").ilike(null.clone());
3938        assert_eq!(simplify(expr), lit_bool_null());
3939
3940        let expr = col("c1").not_ilike(null.clone());
3941        assert_eq!(simplify(expr), lit_bool_null());
3942
3943        // expr [NOT] [I]LIKE '%'
3944        let expr = col("c1").like(lit("%"));
3945        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3946
3947        let expr = col("c1").not_like(lit("%"));
3948        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3949
3950        let expr = col("c1").ilike(lit("%"));
3951        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3952
3953        let expr = col("c1").not_ilike(lit("%"));
3954        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3955
3956        // expr [NOT] [I]LIKE '%%'
3957        let expr = col("c1").like(lit("%%"));
3958        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3959
3960        let expr = col("c1").not_like(lit("%%"));
3961        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3962
3963        let expr = col("c1").ilike(lit("%%"));
3964        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3965
3966        let expr = col("c1").not_ilike(lit("%%"));
3967        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3968
3969        // not_null_expr [NOT] [I]LIKE '%'
3970        let expr = col("c1_non_null").like(lit("%"));
3971        assert_eq!(simplify(expr), lit(true));
3972
3973        let expr = col("c1_non_null").not_like(lit("%"));
3974        assert_eq!(simplify(expr), lit(false));
3975
3976        let expr = col("c1_non_null").ilike(lit("%"));
3977        assert_eq!(simplify(expr), lit(true));
3978
3979        let expr = col("c1_non_null").not_ilike(lit("%"));
3980        assert_eq!(simplify(expr), lit(false));
3981
3982        // not_null_expr [NOT] [I]LIKE '%%'
3983        let expr = col("c1_non_null").like(lit("%%"));
3984        assert_eq!(simplify(expr), lit(true));
3985
3986        let expr = col("c1_non_null").not_like(lit("%%"));
3987        assert_eq!(simplify(expr), lit(false));
3988
3989        let expr = col("c1_non_null").ilike(lit("%%"));
3990        assert_eq!(simplify(expr), lit(true));
3991
3992        let expr = col("c1_non_null").not_ilike(lit("%%"));
3993        assert_eq!(simplify(expr), lit(false));
3994
3995        // null_constant [NOT] [I]LIKE '%'
3996        let expr = null.clone().like(lit("%"));
3997        assert_eq!(simplify(expr), lit_bool_null());
3998
3999        let expr = null.clone().not_like(lit("%"));
4000        assert_eq!(simplify(expr), lit_bool_null());
4001
4002        let expr = null.clone().ilike(lit("%"));
4003        assert_eq!(simplify(expr), lit_bool_null());
4004
4005        let expr = null.clone().not_ilike(lit("%"));
4006        assert_eq!(simplify(expr), lit_bool_null());
4007
4008        // null_constant [NOT] [I]LIKE '%%'
4009        let expr = null.clone().like(lit("%%"));
4010        assert_eq!(simplify(expr), lit_bool_null());
4011
4012        let expr = null.clone().not_like(lit("%%"));
4013        assert_eq!(simplify(expr), lit_bool_null());
4014
4015        let expr = null.clone().ilike(lit("%%"));
4016        assert_eq!(simplify(expr), lit_bool_null());
4017
4018        let expr = null.clone().not_ilike(lit("%%"));
4019        assert_eq!(simplify(expr), lit_bool_null());
4020
4021        // null_constant [NOT] [I]LIKE 'a%'
4022        let expr = null.clone().like(lit("a%"));
4023        assert_eq!(simplify(expr), lit_bool_null());
4024
4025        let expr = null.clone().not_like(lit("a%"));
4026        assert_eq!(simplify(expr), lit_bool_null());
4027
4028        let expr = null.clone().ilike(lit("a%"));
4029        assert_eq!(simplify(expr), lit_bool_null());
4030
4031        let expr = null.clone().not_ilike(lit("a%"));
4032        assert_eq!(simplify(expr), lit_bool_null());
4033
4034        // expr [NOT] [I]LIKE with pattern without wildcards
4035        let expr = col("c1").like(lit("a"));
4036        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4037        let expr = col("c1").not_like(lit("a"));
4038        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4039        let expr = col("c1").like(lit("a_"));
4040        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4041        let expr = col("c1").not_like(lit("a_"));
4042        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4043
4044        let expr = col("c1").ilike(lit("a"));
4045        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4046        let expr = col("c1").not_ilike(lit("a"));
4047        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4048    }
4049
4050    #[test]
4051    fn test_simplify_with_guarantee() {
4052        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4053        let expr_x = col("c3").gt(lit(3_i64));
4054        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4055        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4056        let expr = expr_x.clone().and(expr_y.or(expr_z));
4057
4058        // All guaranteed null
4059        let guarantees = vec![
4060            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4061            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4062            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4063        ];
4064
4065        let output = simplify_with_guarantee(expr.clone(), guarantees);
4066        assert_eq!(output, lit_bool_null());
4067
4068        // All guaranteed false
4069        let guarantees = vec![
4070            (
4071                col("c3"),
4072                NullableInterval::NotNull {
4073                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4074                },
4075            ),
4076            (
4077                col("c4"),
4078                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4079            ),
4080            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4081        ];
4082        let output = simplify_with_guarantee(expr.clone(), guarantees);
4083        assert_eq!(output, lit(false));
4084
4085        // Guaranteed false or null -> no change.
4086        let guarantees = vec![
4087            (
4088                col("c3"),
4089                NullableInterval::MaybeNull {
4090                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4091                },
4092            ),
4093            (
4094                col("c4"),
4095                NullableInterval::MaybeNull {
4096                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4097                },
4098            ),
4099            (
4100                col("c1"),
4101                NullableInterval::NotNull {
4102                    values: Interval::try_new(
4103                        ScalarValue::from("d"),
4104                        ScalarValue::from("f"),
4105                    )
4106                    .unwrap(),
4107                },
4108            ),
4109        ];
4110        let output = simplify_with_guarantee(expr.clone(), guarantees);
4111        assert_eq!(&output, &expr_x);
4112
4113        // Sufficient true guarantees
4114        let guarantees = vec![
4115            (
4116                col("c3"),
4117                NullableInterval::from(ScalarValue::Int64(Some(9))),
4118            ),
4119            (
4120                col("c4"),
4121                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4122            ),
4123        ];
4124        let output = simplify_with_guarantee(expr.clone(), guarantees);
4125        assert_eq!(output, lit(true));
4126
4127        // Only partially simplify
4128        let guarantees = vec![(
4129            col("c4"),
4130            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4131        )];
4132        let output = simplify_with_guarantee(expr, guarantees);
4133        assert_eq!(&output, &expr_x);
4134    }
4135
4136    #[test]
4137    fn test_expression_partial_simplify_1() {
4138        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4139        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4140        let expected = (lit(3)) + (lit(4) / lit(0));
4141
4142        assert_eq!(simplify(expr), expected);
4143    }
4144
4145    #[test]
4146    fn test_expression_partial_simplify_2() {
4147        // (1 > 2) and (4 / 0) -> false
4148        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4149        let expected = lit(false);
4150
4151        assert_eq!(simplify(expr), expected);
4152    }
4153
4154    #[test]
4155    fn test_simplify_cycles() {
4156        // TRUE
4157        let expr = lit(true);
4158        let expected = lit(true);
4159        let (expr, num_iter) = simplify_with_cycle_count(expr);
4160        assert_eq!(expr, expected);
4161        assert_eq!(num_iter, 1);
4162
4163        // (true != NULL) OR (5 > 10)
4164        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4165        let expected = lit_bool_null();
4166        let (expr, num_iter) = simplify_with_cycle_count(expr);
4167        assert_eq!(expr, expected);
4168        assert_eq!(num_iter, 2);
4169
4170        // NOTE: this currently does not simplify
4171        // (((c4 - 10) + 10) *100) / 100
4172        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4173        let expected = expr.clone();
4174        let (expr, num_iter) = simplify_with_cycle_count(expr);
4175        assert_eq!(expr, expected);
4176        assert_eq!(num_iter, 1);
4177
4178        // ((c4<1 or c3<2) and c3_non_null<3) and false
4179        let expr = col("c4")
4180            .lt(lit(1))
4181            .or(col("c3").lt(lit(2)))
4182            .and(col("c3_non_null").lt(lit(3)))
4183            .and(lit(false));
4184        let expected = lit(false);
4185        let (expr, num_iter) = simplify_with_cycle_count(expr);
4186        assert_eq!(expr, expected);
4187        assert_eq!(num_iter, 2);
4188    }
4189
4190    fn boolean_test_schema() -> DFSchemaRef {
4191        Schema::new(vec![
4192            Field::new("A", DataType::Boolean, false),
4193            Field::new("B", DataType::Boolean, false),
4194            Field::new("C", DataType::Boolean, false),
4195            Field::new("D", DataType::Boolean, false),
4196        ])
4197        .to_dfschema_ref()
4198        .unwrap()
4199    }
4200
4201    #[test]
4202    fn simplify_common_factor_conjunction_in_disjunction() {
4203        let props = ExecutionProps::new();
4204        let schema = boolean_test_schema();
4205        let simplifier =
4206            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4207
4208        let a = || col("A");
4209        let b = || col("B");
4210        let c = || col("C");
4211        let d = || col("D");
4212
4213        // (A AND B) OR (A AND C) -> A AND (B OR C)
4214        let expr = a().and(b()).or(a().and(c()));
4215        let expected = a().and(b().or(c()));
4216
4217        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4218
4219        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4220        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4221        let expected = a().and(b().or(c()).or(d()));
4222        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4223
4224        // A OR (B AND C AND A) -> A
4225        let expr = a().or(b().and(c().and(a())));
4226        let expected = a();
4227        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4228    }
4229
4230    #[test]
4231    fn test_simplify_udaf() {
4232        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4233        let aggregate_function_expr =
4234            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4235                udaf.into(),
4236                vec![],
4237                false,
4238                None,
4239                vec![],
4240                None,
4241            ));
4242
4243        let expected = col("result_column");
4244        assert_eq!(simplify(aggregate_function_expr), expected);
4245
4246        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4247        let aggregate_function_expr =
4248            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4249                udaf.into(),
4250                vec![],
4251                false,
4252                None,
4253                vec![],
4254                None,
4255            ));
4256
4257        let expected = aggregate_function_expr.clone();
4258        assert_eq!(simplify(aggregate_function_expr), expected);
4259    }
4260
4261    /// A Mock UDAF which defines `simplify` to be used in tests
4262    /// related to UDAF simplification
4263    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4264    struct SimplifyMockUdaf {
4265        simplify: bool,
4266    }
4267
4268    impl SimplifyMockUdaf {
4269        /// make simplify method return new expression
4270        fn new_with_simplify() -> Self {
4271            Self { simplify: true }
4272        }
4273        /// make simplify method return no change
4274        fn new_without_simplify() -> Self {
4275            Self { simplify: false }
4276        }
4277    }
4278
4279    impl AggregateUDFImpl for SimplifyMockUdaf {
4280        fn as_any(&self) -> &dyn std::any::Any {
4281            self
4282        }
4283
4284        fn name(&self) -> &str {
4285            "mock_simplify"
4286        }
4287
4288        fn signature(&self) -> &Signature {
4289            unimplemented!()
4290        }
4291
4292        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4293            unimplemented!("not needed for tests")
4294        }
4295
4296        fn accumulator(
4297            &self,
4298            _acc_args: AccumulatorArgs,
4299        ) -> Result<Box<dyn Accumulator>> {
4300            unimplemented!("not needed for tests")
4301        }
4302
4303        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4304            unimplemented!("not needed for testing")
4305        }
4306
4307        fn create_groups_accumulator(
4308            &self,
4309            _args: AccumulatorArgs,
4310        ) -> Result<Box<dyn GroupsAccumulator>> {
4311            unimplemented!("not needed for testing")
4312        }
4313
4314        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4315            if self.simplify {
4316                Some(Box::new(|_, _| Ok(col("result_column"))))
4317            } else {
4318                None
4319            }
4320        }
4321    }
4322
4323    #[test]
4324    fn test_simplify_udwf() {
4325        let udwf = WindowFunctionDefinition::WindowUDF(
4326            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4327        );
4328        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4329
4330        let expected = col("result_column");
4331        assert_eq!(simplify(window_function_expr), expected);
4332
4333        let udwf = WindowFunctionDefinition::WindowUDF(
4334            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4335        );
4336        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4337
4338        let expected = window_function_expr.clone();
4339        assert_eq!(simplify(window_function_expr), expected);
4340    }
4341
4342    /// A Mock UDWF which defines `simplify` to be used in tests
4343    /// related to UDWF simplification
4344    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4345    struct SimplifyMockUdwf {
4346        simplify: bool,
4347    }
4348
4349    impl SimplifyMockUdwf {
4350        /// make simplify method return new expression
4351        fn new_with_simplify() -> Self {
4352            Self { simplify: true }
4353        }
4354        /// make simplify method return no change
4355        fn new_without_simplify() -> Self {
4356            Self { simplify: false }
4357        }
4358    }
4359
4360    impl WindowUDFImpl for SimplifyMockUdwf {
4361        fn as_any(&self) -> &dyn std::any::Any {
4362            self
4363        }
4364
4365        fn name(&self) -> &str {
4366            "mock_simplify"
4367        }
4368
4369        fn signature(&self) -> &Signature {
4370            unimplemented!()
4371        }
4372
4373        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4374            if self.simplify {
4375                Some(Box::new(|_, _| Ok(col("result_column"))))
4376            } else {
4377                None
4378            }
4379        }
4380
4381        fn partition_evaluator(
4382            &self,
4383            _partition_evaluator_args: PartitionEvaluatorArgs,
4384        ) -> Result<Box<dyn PartitionEvaluator>> {
4385            unimplemented!("not needed for tests")
4386        }
4387
4388        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4389            unimplemented!("not needed for tests")
4390        }
4391
4392        fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
4393            LimitEffect::Unknown
4394        }
4395    }
4396    #[derive(Debug, PartialEq, Eq, Hash)]
4397    struct VolatileUdf {
4398        signature: Signature,
4399    }
4400
4401    impl VolatileUdf {
4402        pub fn new() -> Self {
4403            Self {
4404                signature: Signature::exact(vec![], Volatility::Volatile),
4405            }
4406        }
4407    }
4408    impl ScalarUDFImpl for VolatileUdf {
4409        fn as_any(&self) -> &dyn std::any::Any {
4410            self
4411        }
4412
4413        fn name(&self) -> &str {
4414            "VolatileUdf"
4415        }
4416
4417        fn signature(&self) -> &Signature {
4418            &self.signature
4419        }
4420
4421        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4422            Ok(DataType::Int16)
4423        }
4424
4425        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4426            panic!("dummy - not implemented")
4427        }
4428    }
4429
4430    #[test]
4431    fn test_optimize_volatile_conditions() {
4432        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4433        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4434        {
4435            let expr = rand
4436                .clone()
4437                .eq(lit(0))
4438                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4439
4440            assert_eq!(simplify(expr.clone()), expr);
4441        }
4442
4443        {
4444            let expr = col("column1")
4445                .eq(lit(2))
4446                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4447
4448            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4449        }
4450
4451        {
4452            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4453                "column1",
4454            )
4455            .eq(lit(2))
4456            .and(rand.clone().eq(lit(0))));
4457
4458            assert_eq!(
4459                simplify(expr),
4460                col("column1")
4461                    .eq(lit(2))
4462                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4463            );
4464        }
4465    }
4466
4467    #[test]
4468    fn simplify_fixed_size_binary_eq_lit() {
4469        let bytes = [1u8, 2, 3].as_slice();
4470
4471        // The expression starts simple.
4472        let expr = col("c5").eq(lit(bytes));
4473
4474        // The type coercer introduces a cast.
4475        let coerced = coerce(expr.clone());
4476        let schema = expr_test_schema();
4477        assert_eq!(
4478            coerced,
4479            col("c5")
4480                .cast_to(&DataType::Binary, schema.as_ref())
4481                .unwrap()
4482                .eq(lit(bytes))
4483        );
4484
4485        // The simplifier removes the cast.
4486        assert_eq!(
4487            simplify(coerced),
4488            col("c5").eq(Expr::Literal(
4489                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4490                None
4491            ))
4492        );
4493    }
4494
4495    fn if_not_null(expr: Expr, then: bool) -> Expr {
4496        Expr::Case(Case {
4497            expr: Some(expr.is_not_null().into()),
4498            when_then_expr: vec![(lit(true).into(), lit(then).into())],
4499            else_expr: None,
4500        })
4501    }
4502}