datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use arrow::{
21    array::{new_null_array, AsArray},
22    datatypes::{DataType, Field, Schema},
23    record_batch::RecordBatch,
24};
25use std::borrow::Cow;
26use std::collections::HashSet;
27use std::ops::Not;
28use std::sync::Arc;
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    metadata::FieldMetadata,
33    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
34};
35use datafusion_common::{
36    exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue,
37};
38use datafusion_expr::{
39    and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
40    Operator, Volatility,
41};
42use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
43use datafusion_expr::{
44    expr::{InList, InSubquery},
45    utils::{iter_conjunction, iter_conjunction_owned},
46};
47use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
48use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
49
50use super::inlist_simplifier::ShortenInListSimplifier;
51use super::utils::*;
52use crate::analyzer::type_coercion::TypeCoercionRewriter;
53use crate::simplify_expressions::guarantees::GuaranteeRewriter;
54use crate::simplify_expressions::regex::simplify_regex_expr;
55use crate::simplify_expressions::unwrap_cast::{
56    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
57    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
58    unwrap_cast_in_comparison_for_binary,
59};
60use crate::simplify_expressions::SimplifyInfo;
61use datafusion_expr_common::casts::try_cast_literal_to_type;
62use indexmap::IndexSet;
63use regex::Regex;
64
65/// This structure handles API for expression simplification
66///
67/// Provides simplification information based on DFSchema and
68/// [`ExecutionProps`]. This is the default implementation used by DataFusion
69///
70/// For example:
71/// ```
72/// use arrow::datatypes::{DataType, Field, Schema};
73/// use datafusion_common::{DataFusionError, ToDFSchema};
74/// use datafusion_expr::execution_props::ExecutionProps;
75/// use datafusion_expr::simplify::SimplifyContext;
76/// use datafusion_expr::{col, lit};
77/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
78///
79/// // Create the schema
80/// let schema = Schema::new(vec![Field::new("i", DataType::Int64, false)])
81///     .to_dfschema_ref()
82///     .unwrap();
83///
84/// // Create the simplifier
85/// let props = ExecutionProps::new();
86/// let context = SimplifyContext::new(&props).with_schema(schema);
87/// let simplifier = ExprSimplifier::new(context);
88///
89/// // Use the simplifier
90///
91/// // b < 2 or (1 > 3)
92/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
93///
94/// // b < 2
95/// let simplified = simplifier.simplify(expr).unwrap();
96/// assert_eq!(simplified, col("b").lt(lit(2)));
97/// ```
98pub struct ExprSimplifier<S> {
99    info: S,
100    /// Guarantees about the values of columns. This is provided by the user
101    /// in [ExprSimplifier::with_guarantees()].
102    guarantees: Vec<(Expr, NullableInterval)>,
103    /// Should expressions be canonicalized before simplification? Defaults to
104    /// true
105    canonicalize: bool,
106    /// Maximum number of simplifier cycles
107    max_simplifier_cycles: u32,
108}
109
110pub const THRESHOLD_INLINE_INLIST: usize = 3;
111pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
112
113impl<S: SimplifyInfo> ExprSimplifier<S> {
114    /// Create a new `ExprSimplifier` with the given `info` such as an
115    /// instance of [`SimplifyContext`]. See
116    /// [`simplify`](Self::simplify) for an example.
117    ///
118    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
119    pub fn new(info: S) -> Self {
120        Self {
121            info,
122            guarantees: vec![],
123            canonicalize: true,
124            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
125        }
126    }
127
128    /// Simplifies this [`Expr`] as much as possible, evaluating
129    /// constants and applying algebraic simplifications.
130    ///
131    /// The types of the expression must match what operators expect,
132    /// or else an error may occur trying to evaluate. See
133    /// [`coerce`](Self::coerce) for a function to help.
134    ///
135    /// # Example:
136    ///
137    /// `b > 2 AND b > 2`
138    ///
139    /// can be written to
140    ///
141    /// `b > 2`
142    ///
143    /// ```
144    /// use arrow::datatypes::DataType;
145    /// use datafusion_common::DFSchema;
146    /// use datafusion_common::Result;
147    /// use datafusion_expr::execution_props::ExecutionProps;
148    /// use datafusion_expr::simplify::SimplifyContext;
149    /// use datafusion_expr::simplify::SimplifyInfo;
150    /// use datafusion_expr::{col, lit, Expr};
151    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
152    /// use std::sync::Arc;
153    ///
154    /// /// Simple implementation that provides `Simplifier` the information it needs
155    /// /// See SimplifyContext for a structure that does this.
156    /// #[derive(Default)]
157    /// struct Info {
158    ///     execution_props: ExecutionProps,
159    /// };
160    ///
161    /// impl SimplifyInfo for Info {
162    ///     fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
163    ///         Ok(false)
164    ///     }
165    ///     fn nullable(&self, expr: &Expr) -> Result<bool> {
166    ///         Ok(true)
167    ///     }
168    ///     fn execution_props(&self) -> &ExecutionProps {
169    ///         &self.execution_props
170    ///     }
171    ///     fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
172    ///         Ok(DataType::Int32)
173    ///     }
174    /// }
175    ///
176    /// // Create the simplifier
177    /// let simplifier = ExprSimplifier::new(Info::default());
178    ///
179    /// // b < 2
180    /// let b_lt_2 = col("b").gt(lit(2));
181    ///
182    /// // (b < 2) OR (b < 2)
183    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
184    ///
185    /// // (b < 2) OR (b < 2) --> (b < 2)
186    /// let expr = simplifier.simplify(expr).unwrap();
187    /// assert_eq!(expr, b_lt_2);
188    /// ```
189    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
190        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
191    }
192
193    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
194    /// constants and applying algebraic simplifications. Additionally returns a `u32`
195    /// representing the number of simplification cycles performed, which can be useful for testing
196    /// optimizations.
197    ///
198    /// See [Self::simplify] for details and usage examples.
199    #[deprecated(
200        since = "48.0.0",
201        note = "Use `simplify_with_cycle_count_transformed` instead"
202    )]
203    #[allow(unused_mut)]
204    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
205        let (transformed, cycle_count) =
206            self.simplify_with_cycle_count_transformed(expr)?;
207        Ok((transformed.data, cycle_count))
208    }
209
210    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
211    /// constants and applying algebraic simplifications. Additionally returns a `u32`
212    /// representing the number of simplification cycles performed, which can be useful for testing
213    /// optimizations.
214    ///
215    /// # Returns
216    ///
217    /// A tuple containing:
218    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
219    /// - The number of simplification cycles that were performed
220    ///
221    /// See [Self::simplify] for details and usage examples.
222    pub fn simplify_with_cycle_count_transformed(
223        &self,
224        mut expr: Expr,
225    ) -> Result<(Transformed<Expr>, u32)> {
226        let mut simplifier = Simplifier::new(&self.info);
227        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
228        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
229        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
230
231        if self.canonicalize {
232            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
233        }
234
235        // Evaluating constants can enable new simplifications and
236        // simplifications can enable new constant evaluation
237        // see `Self::with_max_cycles`
238        let mut num_cycles = 0;
239        let mut has_transformed = false;
240        loop {
241            let Transformed {
242                data, transformed, ..
243            } = expr
244                .rewrite(&mut const_evaluator)?
245                .transform_data(|expr| expr.rewrite(&mut simplifier))?
246                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
247            expr = data;
248            num_cycles += 1;
249            // Track if any transformation occurred
250            has_transformed = has_transformed || transformed;
251            if !transformed || num_cycles >= self.max_simplifier_cycles {
252                break;
253            }
254        }
255        // shorten inlist should be started after other inlist rules are applied
256        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
257        Ok((
258            Transformed::new_transformed(expr, has_transformed),
259            num_cycles,
260        ))
261    }
262
263    /// Apply type coercion to an [`Expr`] so that it can be
264    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
265    ///
266    /// See the [type coercion module](datafusion_expr::type_coercion)
267    /// documentation for more details on type coercion
268    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
269        let mut expr_rewrite = TypeCoercionRewriter { schema };
270        expr.rewrite(&mut expr_rewrite).data()
271    }
272
273    /// Input guarantees about the values of columns.
274    ///
275    /// The guarantees can simplify expressions. For example, if a column `x` is
276    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
277    /// literal `true`.
278    ///
279    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
280    /// where the [Expr] is a column reference and the [NullableInterval]
281    /// is an interval representing the known possible values of that column.
282    ///
283    /// ```rust
284    /// use arrow::datatypes::{DataType, Field, Schema};
285    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
286    /// use datafusion_expr::execution_props::ExecutionProps;
287    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
288    /// use datafusion_expr::simplify::SimplifyContext;
289    /// use datafusion_expr::{col, lit, Expr};
290    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
291    ///
292    /// let schema = Schema::new(vec![
293    ///     Field::new("x", DataType::Int64, false),
294    ///     Field::new("y", DataType::UInt32, false),
295    ///     Field::new("z", DataType::Int64, false),
296    /// ])
297    /// .to_dfschema_ref()
298    /// .unwrap();
299    ///
300    /// // Create the simplifier
301    /// let props = ExecutionProps::new();
302    /// let context = SimplifyContext::new(&props).with_schema(schema);
303    ///
304    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
305    /// let expr_x = col("x").gt_eq(lit(3_i64));
306    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
307    /// let expr_z = col("z").gt(lit(5_i64));
308    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
309    ///
310    /// let guarantees = vec![
311    ///     // x ∈ [3, 5]
312    ///     (
313    ///         col("x"),
314    ///         NullableInterval::NotNull {
315    ///             values: Interval::make(Some(3_i64), Some(5_i64)).unwrap(),
316    ///         },
317    ///     ),
318    ///     // y = 3
319    ///     (
320    ///         col("y"),
321    ///         NullableInterval::from(ScalarValue::UInt32(Some(3))),
322    ///     ),
323    /// ];
324    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
325    /// let output = simplifier.simplify(expr).unwrap();
326    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
327    /// // z > 5.
328    /// assert_eq!(output, expr_z);
329    /// ```
330    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
331        self.guarantees = guarantees;
332        self
333    }
334
335    /// Should `Canonicalizer` be applied before simplification?
336    ///
337    /// If true (the default), the expression will be rewritten to canonical
338    /// form before simplification. This is useful to ensure that the simplifier
339    /// can apply all possible simplifications.
340    ///
341    /// Some expressions, such as those in some Joins, can not be canonicalized
342    /// without changing their meaning. In these cases, canonicalization should
343    /// be disabled.
344    ///
345    /// ```rust
346    /// use arrow::datatypes::{DataType, Field, Schema};
347    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
348    /// use datafusion_expr::execution_props::ExecutionProps;
349    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
350    /// use datafusion_expr::simplify::SimplifyContext;
351    /// use datafusion_expr::{col, lit, Expr};
352    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
353    ///
354    /// let schema = Schema::new(vec![
355    ///     Field::new("a", DataType::Int64, false),
356    ///     Field::new("b", DataType::Int64, false),
357    ///     Field::new("c", DataType::Int64, false),
358    /// ])
359    /// .to_dfschema_ref()
360    /// .unwrap();
361    ///
362    /// // Create the simplifier
363    /// let props = ExecutionProps::new();
364    /// let context = SimplifyContext::new(&props).with_schema(schema);
365    /// let simplifier = ExprSimplifier::new(context);
366    ///
367    /// // Expression: a = c AND 1 = b
368    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
369    ///
370    /// // With canonicalization, the expression is rewritten to canonical form
371    /// // (though it is no simpler in this case):
372    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
373    /// // Expression has been rewritten to: (c = a AND b = 1)
374    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
375    ///
376    /// // If canonicalization is disabled, the expression is not changed
377    /// let non_canonicalized = simplifier
378    ///     .with_canonicalize(false)
379    ///     .simplify(expr.clone())
380    ///     .unwrap();
381    ///
382    /// assert_eq!(non_canonicalized, expr);
383    /// ```
384    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
385        self.canonicalize = canonicalize;
386        self
387    }
388
389    /// Specifies the maximum number of simplification cycles to run.
390    ///
391    /// The simplifier can perform multiple passes of simplification. This is
392    /// because the output of one simplification step can allow more optimizations
393    /// in another simplification step. For example, constant evaluation can allow more
394    /// expression simplifications, and expression simplifications can allow more constant
395    /// evaluations.
396    ///
397    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
398    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
399    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
400    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
401    ///
402    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
403    /// instead.
404    ///
405    /// ```rust
406    /// use arrow::datatypes::{DataType, Field, Schema};
407    /// use datafusion_expr::{col, lit, Expr};
408    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
409    /// use datafusion_expr::execution_props::ExecutionProps;
410    /// use datafusion_expr::simplify::SimplifyContext;
411    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
412    ///
413    /// let schema = Schema::new(vec![
414    ///   Field::new("a", DataType::Int64, false),
415    ///   ])
416    ///   .to_dfschema_ref().unwrap();
417    ///
418    /// // Create the simplifier
419    /// let props = ExecutionProps::new();
420    /// let context = SimplifyContext::new(&props)
421    ///    .with_schema(schema);
422    /// let simplifier = ExprSimplifier::new(context);
423    ///
424    /// // Expression: a IS NOT NULL
425    /// let expr = col("a").is_not_null();
426    ///
427    /// // When using default maximum cycles, 2 cycles will be performed.
428    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
429    /// assert_eq!(simplified_expr.data, lit(true));
430    /// // 2 cycles were executed, but only 1 was needed
431    /// assert_eq!(count, 2);
432    ///
433    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
434    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
435    /// // Expression has been rewritten to: (c = a AND b = 1)
436    /// assert_eq!(simplified_expr.data, lit(true));
437    /// // Only 1 cycle was executed
438    /// assert_eq!(count, 1);
439    /// ```
440    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
441        self.max_simplifier_cycles = max_simplifier_cycles;
442        self
443    }
444}
445
446/// Canonicalize any BinaryExprs that are not in canonical form
447///
448/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
449///
450/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
451/// than `col2` (`a > b` would be canonicalized to `b < a`)
452struct Canonicalizer {}
453
454impl Canonicalizer {
455    fn new() -> Self {
456        Self {}
457    }
458}
459
460impl TreeNodeRewriter for Canonicalizer {
461    type Node = Expr;
462
463    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
464        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
465            return Ok(Transformed::no(expr));
466        };
467        match (left.as_ref(), right.as_ref(), op.swap()) {
468            // <col1> <op> <col2>
469            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
470                if right_col > left_col =>
471            {
472                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
473                    left: right,
474                    op: swapped_op,
475                    right: left,
476                })))
477            }
478            // <literal> <op> <col>
479            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
480                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
481                    left: right,
482                    op: swapped_op,
483                    right: left,
484                })))
485            }
486            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
487                left,
488                op,
489                right,
490            }))),
491        }
492    }
493}
494
495#[allow(rustdoc::private_intra_doc_links)]
496/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
497///
498/// Note it does not handle algebraic rewrites such as `(a or false)`
499/// --> `a`, which is handled by [`Simplifier`]
500struct ConstEvaluator<'a> {
501    /// `can_evaluate` is used during the depth-first-search of the
502    /// `Expr` tree to track if any siblings (or their descendants) were
503    /// non evaluatable (e.g. had a column reference or volatile
504    /// function)
505    ///
506    /// Specifically, `can_evaluate[N]` represents the state of
507    /// traversal when we are N levels deep in the tree, one entry for
508    /// this Expr and each of its parents.
509    ///
510    /// After visiting all siblings if `can_evaluate.top()` is true, that
511    /// means there were no non evaluatable siblings (or their
512    /// descendants) so this `Expr` can be evaluated
513    can_evaluate: Vec<bool>,
514
515    execution_props: &'a ExecutionProps,
516    input_schema: DFSchema,
517    input_batch: RecordBatch,
518}
519
520#[allow(dead_code)]
521/// The simplify result of ConstEvaluator
522enum ConstSimplifyResult {
523    // Expr was simplified and contains the new expression
524    Simplified(ScalarValue, Option<FieldMetadata>),
525    // Expr was not simplified and original value is returned
526    NotSimplified(ScalarValue, Option<FieldMetadata>),
527    // Evaluation encountered an error, contains the original expression
528    SimplifyRuntimeError(DataFusionError, Expr),
529}
530
531impl TreeNodeRewriter for ConstEvaluator<'_> {
532    type Node = Expr;
533
534    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
535        // Default to being able to evaluate this node
536        self.can_evaluate.push(true);
537
538        // if this expr is not ok to evaluate, mark entire parent
539        // stack as not ok (as all parents have at least one child or
540        // descendant that can not be evaluated
541
542        if !Self::can_evaluate(&expr) {
543            // walk back up stack, marking first parent that is not mutable
544            let parent_iter = self.can_evaluate.iter_mut().rev();
545            for p in parent_iter {
546                if !*p {
547                    // optimization: if we find an element on the
548                    // stack already marked, know all elements above are also marked
549                    break;
550                }
551                *p = false;
552            }
553        }
554
555        // NB: do not short circuit recursion even if we find a non
556        // evaluatable node (so we can fold other children, args to
557        // functions, etc.)
558        Ok(Transformed::no(expr))
559    }
560
561    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
562        match self.can_evaluate.pop() {
563            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
564            // and may not evaluate all their sub expressions. Thus, if
565            // any error is countered during simplification, return the original
566            // so that normal evaluation can occur
567            Some(true) => match self.evaluate_to_scalar(expr) {
568                ConstSimplifyResult::Simplified(s, m) => {
569                    Ok(Transformed::yes(Expr::Literal(s, m)))
570                }
571                ConstSimplifyResult::NotSimplified(s, m) => {
572                    Ok(Transformed::no(Expr::Literal(s, m)))
573                }
574                ConstSimplifyResult::SimplifyRuntimeError(err, expr) => {
575                    // For CAST expressions with literal inputs, propagate the error at plan time rather than deferring to execution time.
576                    // This provides clearer error messages and fails fast.
577                    if let Expr::Cast(Cast { ref expr, .. })
578                    | Expr::TryCast(TryCast { ref expr, .. }) = expr
579                    {
580                        if matches!(expr.as_ref(), Expr::Literal(_, _)) {
581                            return Err(err);
582                        }
583                    }
584                    // For other expressions (like CASE, COALESCE), preserve the original
585                    // to allow short-circuit evaluation at execution time
586                    Ok(Transformed::yes(expr))
587                }
588            },
589            Some(false) => Ok(Transformed::no(expr)),
590            _ => internal_err!("Failed to pop can_evaluate"),
591        }
592    }
593}
594
595impl<'a> ConstEvaluator<'a> {
596    /// Create a new `ConstantEvaluator`. Session constants (such as
597    /// the time for `now()` are taken from the passed
598    /// `execution_props`.
599    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
600        // The dummy column name is unused and doesn't matter as only
601        // expressions without column references can be evaluated
602        static DUMMY_COL_NAME: &str = ".";
603        let schema = Arc::new(Schema::new(vec![Field::new(
604            DUMMY_COL_NAME,
605            DataType::Null,
606            true,
607        )]));
608        let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
609        // Need a single "input" row to produce a single output row
610        let col = new_null_array(&DataType::Null, 1);
611        let input_batch = RecordBatch::try_new(schema, vec![col])?;
612
613        Ok(Self {
614            can_evaluate: vec![],
615            execution_props,
616            input_schema,
617            input_batch,
618        })
619    }
620
621    /// Can a function of the specified volatility be evaluated?
622    fn volatility_ok(volatility: Volatility) -> bool {
623        match volatility {
624            Volatility::Immutable => true,
625            // Values for functions such as now() are taken from ExecutionProps
626            Volatility::Stable => true,
627            Volatility::Volatile => false,
628        }
629    }
630
631    /// Can the expression be evaluated at plan time, (assuming all of
632    /// its children can also be evaluated)?
633    fn can_evaluate(expr: &Expr) -> bool {
634        // check for reasons we can't evaluate this node
635        //
636        // NOTE all expr types are listed here so when new ones are
637        // added they can be checked for their ability to be evaluated
638        // at plan time
639        match expr {
640            // TODO: remove the next line after `Expr::Wildcard` is removed
641            #[expect(deprecated)]
642            Expr::AggregateFunction { .. }
643            | Expr::ScalarVariable(_, _)
644            | Expr::Column(_)
645            | Expr::OuterReferenceColumn(_, _)
646            | Expr::Exists { .. }
647            | Expr::InSubquery(_)
648            | Expr::ScalarSubquery(_)
649            | Expr::WindowFunction { .. }
650            | Expr::GroupingSet(_)
651            | Expr::Wildcard { .. }
652            | Expr::Placeholder(_) => false,
653            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
654                Self::volatility_ok(func.signature().volatility)
655            }
656            Expr::Literal(_, _)
657            | Expr::Alias(..)
658            | Expr::Unnest(_)
659            | Expr::BinaryExpr { .. }
660            | Expr::Not(_)
661            | Expr::IsNotNull(_)
662            | Expr::IsNull(_)
663            | Expr::IsTrue(_)
664            | Expr::IsFalse(_)
665            | Expr::IsUnknown(_)
666            | Expr::IsNotTrue(_)
667            | Expr::IsNotFalse(_)
668            | Expr::IsNotUnknown(_)
669            | Expr::Negative(_)
670            | Expr::Between { .. }
671            | Expr::Like { .. }
672            | Expr::SimilarTo { .. }
673            | Expr::Case(_)
674            | Expr::Cast { .. }
675            | Expr::TryCast { .. }
676            | Expr::InList { .. } => true,
677        }
678    }
679
680    /// Internal helper to evaluates an Expr
681    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
682        if let Expr::Literal(s, m) = expr {
683            return ConstSimplifyResult::NotSimplified(s, m);
684        }
685
686        let phys_expr =
687            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
688                Ok(e) => e,
689                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
690            };
691        let metadata = phys_expr
692            .return_field(self.input_batch.schema_ref())
693            .ok()
694            .and_then(|f| {
695                let m = f.metadata();
696                match m.is_empty() {
697                    true => None,
698                    false => Some(FieldMetadata::from(m)),
699                }
700            });
701        let col_val = match phys_expr.evaluate(&self.input_batch) {
702            Ok(v) => v,
703            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
704        };
705        match col_val {
706            ColumnarValue::Array(a) => {
707                if a.len() != 1 {
708                    ConstSimplifyResult::SimplifyRuntimeError(
709                        exec_datafusion_err!("Could not evaluate the expression, found a result of length {}", a.len()),
710                        expr,
711                    )
712                } else if as_list_array(&a).is_ok() {
713                    ConstSimplifyResult::Simplified(
714                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
715                        metadata,
716                    )
717                } else if as_large_list_array(&a).is_ok() {
718                    ConstSimplifyResult::Simplified(
719                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
720                        metadata,
721                    )
722                } else {
723                    // Non-ListArray
724                    match ScalarValue::try_from_array(&a, 0) {
725                        Ok(s) => ConstSimplifyResult::Simplified(s, metadata),
726                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
727                    }
728                }
729            }
730            ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s, metadata),
731        }
732    }
733}
734
735/// Simplifies [`Expr`]s by applying algebraic transformation rules
736///
737/// Example transformations that are applied:
738/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
739/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
740/// * `true = true` and `false = false` to `true`
741/// * `false = true` and `true = false` to `false`
742/// * `!!expr` to `expr`
743/// * `expr = null` and `expr != null` to `null`
744struct Simplifier<'a, S> {
745    info: &'a S,
746}
747
748impl<'a, S> Simplifier<'a, S> {
749    pub fn new(info: &'a S) -> Self {
750        Self { info }
751    }
752}
753
754impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
755    type Node = Expr;
756
757    /// rewrite the expression simplifying any constant expressions
758    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
759        use datafusion_expr::Operator::{
760            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
761            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
762            RegexNotIMatch, RegexNotMatch,
763        };
764
765        let info = self.info;
766        Ok(match expr {
767            // `value op NULL` -> `NULL`
768            // `NULL op value` -> `NULL`
769            // except for few operators that can return non-null value even when one of the operands is NULL
770            ref expr @ Expr::BinaryExpr(BinaryExpr {
771                ref left,
772                ref op,
773                ref right,
774            }) if op.returns_null_on_null()
775                && (is_null(left.as_ref()) || is_null(right.as_ref())) =>
776            {
777                Transformed::yes(Expr::Literal(
778                    ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
779                    None,
780                ))
781            }
782
783            // `NULL {AND, OR} NULL` -> `NULL`
784            Expr::BinaryExpr(BinaryExpr {
785                left,
786                op: And | Or,
787                right,
788            }) if is_null(&left) && is_null(&right) => Transformed::yes(lit_bool_null()),
789
790            //
791            // Rules for Eq
792            //
793
794            // true = A  --> A
795            // false = A --> !A
796            // null = A --> null
797            Expr::BinaryExpr(BinaryExpr {
798                left,
799                op: Eq,
800                right,
801            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
802                Transformed::yes(match as_bool_lit(&left)? {
803                    Some(true) => *right,
804                    Some(false) => Expr::Not(right),
805                    None => lit_bool_null(),
806                })
807            }
808            // A = true  --> A
809            // A = false --> !A
810            // A = null --> null
811            Expr::BinaryExpr(BinaryExpr {
812                left,
813                op: Eq,
814                right,
815            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
816                Transformed::yes(match as_bool_lit(&right)? {
817                    Some(true) => *left,
818                    Some(false) => Expr::Not(left),
819                    None => lit_bool_null(),
820                })
821            }
822            // According to SQL's null semantics, NULL = NULL evaluates to NULL
823            // Both sides are the same expression (A = A) and A is non-volatile expression
824            // A = A --> A IS NOT NULL OR NULL
825            // A = A --> true (if A not nullable)
826            Expr::BinaryExpr(BinaryExpr {
827                left,
828                op: Eq,
829                right,
830            }) if (left == right) & !left.is_volatile() => {
831                Transformed::yes(match !info.nullable(&left)? {
832                    true => lit(true),
833                    false => Expr::BinaryExpr(BinaryExpr {
834                        left: Box::new(Expr::IsNotNull(left)),
835                        op: Or,
836                        right: Box::new(lit_bool_null()),
837                    }),
838                })
839            }
840
841            // Rules for NotEq
842            //
843
844            // true != A  --> !A
845            // false != A --> A
846            // null != A --> null
847            Expr::BinaryExpr(BinaryExpr {
848                left,
849                op: NotEq,
850                right,
851            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
852                Transformed::yes(match as_bool_lit(&left)? {
853                    Some(true) => Expr::Not(right),
854                    Some(false) => *right,
855                    None => lit_bool_null(),
856                })
857            }
858            // A != true  --> !A
859            // A != false --> A
860            // A != null --> null,
861            Expr::BinaryExpr(BinaryExpr {
862                left,
863                op: NotEq,
864                right,
865            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
866                Transformed::yes(match as_bool_lit(&right)? {
867                    Some(true) => Expr::Not(left),
868                    Some(false) => *left,
869                    None => lit_bool_null(),
870                })
871            }
872
873            //
874            // Rules for OR
875            //
876
877            // true OR A --> true (even if A is null)
878            Expr::BinaryExpr(BinaryExpr {
879                left,
880                op: Or,
881                right: _,
882            }) if is_true(&left) => Transformed::yes(*left),
883            // false OR A --> A
884            Expr::BinaryExpr(BinaryExpr {
885                left,
886                op: Or,
887                right,
888            }) if is_false(&left) => Transformed::yes(*right),
889            // A OR true --> true (even if A is null)
890            Expr::BinaryExpr(BinaryExpr {
891                left: _,
892                op: Or,
893                right,
894            }) if is_true(&right) => Transformed::yes(*right),
895            // A OR false --> A
896            Expr::BinaryExpr(BinaryExpr {
897                left,
898                op: Or,
899                right,
900            }) if is_false(&right) => Transformed::yes(*left),
901            // A OR !A ---> true (if A not nullable)
902            Expr::BinaryExpr(BinaryExpr {
903                left,
904                op: Or,
905                right,
906            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
907                Transformed::yes(lit(true))
908            }
909            // !A OR A ---> true (if A not nullable)
910            Expr::BinaryExpr(BinaryExpr {
911                left,
912                op: Or,
913                right,
914            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
915                Transformed::yes(lit(true))
916            }
917            // (..A..) OR A --> (..A..)
918            Expr::BinaryExpr(BinaryExpr {
919                left,
920                op: Or,
921                right,
922            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
923            // A OR (..A..) --> (..A..)
924            Expr::BinaryExpr(BinaryExpr {
925                left,
926                op: Or,
927                right,
928            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
929            // A OR (A AND B) --> A
930            Expr::BinaryExpr(BinaryExpr {
931                left,
932                op: Or,
933                right,
934            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
935            // (A AND B) OR A --> A
936            Expr::BinaryExpr(BinaryExpr {
937                left,
938                op: Or,
939                right,
940            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
941            // Eliminate common factors in conjunctions e.g
942            // (A AND B) OR (A AND C) -> A AND (B OR C)
943            Expr::BinaryExpr(BinaryExpr {
944                left,
945                op: Or,
946                right,
947            }) if has_common_conjunction(&left, &right) => {
948                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
949                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
950                    .partition(|e| lhs.contains(e) && !e.is_volatile());
951
952                let new_rhs = rhs.into_iter().reduce(and);
953                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
954                let common_conjunction = common.into_iter().reduce(and).unwrap();
955
956                let new_expr = match (new_lhs, new_rhs) {
957                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
958                    (_, _) => common_conjunction,
959                };
960                Transformed::yes(new_expr)
961            }
962
963            //
964            // Rules for AND
965            //
966
967            // true AND A --> A
968            Expr::BinaryExpr(BinaryExpr {
969                left,
970                op: And,
971                right,
972            }) if is_true(&left) => Transformed::yes(*right),
973            // false AND A --> false (even if A is null)
974            Expr::BinaryExpr(BinaryExpr {
975                left,
976                op: And,
977                right: _,
978            }) if is_false(&left) => Transformed::yes(*left),
979            // A AND true --> A
980            Expr::BinaryExpr(BinaryExpr {
981                left,
982                op: And,
983                right,
984            }) if is_true(&right) => Transformed::yes(*left),
985            // A AND false --> false (even if A is null)
986            Expr::BinaryExpr(BinaryExpr {
987                left: _,
988                op: And,
989                right,
990            }) if is_false(&right) => Transformed::yes(*right),
991            // A AND !A ---> false (if A not nullable)
992            Expr::BinaryExpr(BinaryExpr {
993                left,
994                op: And,
995                right,
996            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
997                Transformed::yes(lit(false))
998            }
999            // !A AND A ---> false (if A not nullable)
1000            Expr::BinaryExpr(BinaryExpr {
1001                left,
1002                op: And,
1003                right,
1004            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
1005                Transformed::yes(lit(false))
1006            }
1007            // (..A..) AND A --> (..A..)
1008            Expr::BinaryExpr(BinaryExpr {
1009                left,
1010                op: And,
1011                right,
1012            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
1013            // A AND (..A..) --> (..A..)
1014            Expr::BinaryExpr(BinaryExpr {
1015                left,
1016                op: And,
1017                right,
1018            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1019            // A AND (A OR B) --> A
1020            Expr::BinaryExpr(BinaryExpr {
1021                left,
1022                op: And,
1023                right,
1024            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1025            // (A OR B) AND A --> A
1026            Expr::BinaryExpr(BinaryExpr {
1027                left,
1028                op: And,
1029                right,
1030            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1031            // A >= constant AND constant <= A --> A = constant
1032            Expr::BinaryExpr(BinaryExpr {
1033                left,
1034                op: And,
1035                right,
1036            }) if can_reduce_to_equal_statement(&left, &right) => {
1037                if let Expr::BinaryExpr(BinaryExpr {
1038                    left: left_left,
1039                    right: left_right,
1040                    ..
1041                }) = *left
1042                {
1043                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1044                        left: left_left,
1045                        op: Eq,
1046                        right: left_right,
1047                    }))
1048                } else {
1049                    return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr");
1050                }
1051            }
1052
1053            //
1054            // Rules for Multiply
1055            //
1056
1057            // A * 1 --> A (with type coercion if needed)
1058            Expr::BinaryExpr(BinaryExpr {
1059                left,
1060                op: Multiply,
1061                right,
1062            }) if is_one(&right) => {
1063                simplify_right_is_one_case(info, left, &Multiply, &right)?
1064            }
1065            // 1 * A --> A
1066            Expr::BinaryExpr(BinaryExpr {
1067                left,
1068                op: Multiply,
1069                right,
1070            }) if is_one(&left) => {
1071                // 1 * A is equivalent to A * 1
1072                simplify_right_is_one_case(info, right, &Multiply, &left)?
1073            }
1074
1075            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1076            Expr::BinaryExpr(BinaryExpr {
1077                left,
1078                op: Multiply,
1079                right,
1080            }) if !info.nullable(&left)?
1081                && !info.get_data_type(&left)?.is_floating()
1082                && is_zero(&right) =>
1083            {
1084                Transformed::yes(*right)
1085            }
1086            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1087            Expr::BinaryExpr(BinaryExpr {
1088                left,
1089                op: Multiply,
1090                right,
1091            }) if !info.nullable(&right)?
1092                && !info.get_data_type(&right)?.is_floating()
1093                && is_zero(&left) =>
1094            {
1095                Transformed::yes(*left)
1096            }
1097
1098            //
1099            // Rules for Divide
1100            //
1101
1102            // A / 1 --> A
1103            Expr::BinaryExpr(BinaryExpr {
1104                left,
1105                op: Divide,
1106                right,
1107            }) if is_one(&right) => {
1108                simplify_right_is_one_case(info, left, &Divide, &right)?
1109            }
1110
1111            //
1112            // Rules for Modulo
1113            //
1114
1115            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1116            Expr::BinaryExpr(BinaryExpr {
1117                left,
1118                op: Modulo,
1119                right,
1120            }) if !info.nullable(&left)?
1121                && !info.get_data_type(&left)?.is_floating()
1122                && is_one(&right) =>
1123            {
1124                Transformed::yes(Expr::Literal(
1125                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1126                    None,
1127                ))
1128            }
1129
1130            //
1131            // Rules for BitwiseAnd
1132            //
1133
1134            // A & 0 -> 0 (if A not nullable)
1135            Expr::BinaryExpr(BinaryExpr {
1136                left,
1137                op: BitwiseAnd,
1138                right,
1139            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1140
1141            // 0 & A -> 0 (if A not nullable)
1142            Expr::BinaryExpr(BinaryExpr {
1143                left,
1144                op: BitwiseAnd,
1145                right,
1146            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1147
1148            // !A & A -> 0 (if A not nullable)
1149            Expr::BinaryExpr(BinaryExpr {
1150                left,
1151                op: BitwiseAnd,
1152                right,
1153            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1154                Transformed::yes(Expr::Literal(
1155                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1156                    None,
1157                ))
1158            }
1159
1160            // A & !A -> 0 (if A not nullable)
1161            Expr::BinaryExpr(BinaryExpr {
1162                left,
1163                op: BitwiseAnd,
1164                right,
1165            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1166                Transformed::yes(Expr::Literal(
1167                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1168                    None,
1169                ))
1170            }
1171
1172            // (..A..) & A --> (..A..)
1173            Expr::BinaryExpr(BinaryExpr {
1174                left,
1175                op: BitwiseAnd,
1176                right,
1177            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1178
1179            // A & (..A..) --> (..A..)
1180            Expr::BinaryExpr(BinaryExpr {
1181                left,
1182                op: BitwiseAnd,
1183                right,
1184            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1185
1186            // A & (A | B) --> A (if B not null)
1187            Expr::BinaryExpr(BinaryExpr {
1188                left,
1189                op: BitwiseAnd,
1190                right,
1191            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1192                Transformed::yes(*left)
1193            }
1194
1195            // (A | B) & A --> A (if B not null)
1196            Expr::BinaryExpr(BinaryExpr {
1197                left,
1198                op: BitwiseAnd,
1199                right,
1200            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1201                Transformed::yes(*right)
1202            }
1203
1204            //
1205            // Rules for BitwiseOr
1206            //
1207
1208            // A | 0 -> A (even if A is null)
1209            Expr::BinaryExpr(BinaryExpr {
1210                left,
1211                op: BitwiseOr,
1212                right,
1213            }) if is_zero(&right) => Transformed::yes(*left),
1214
1215            // 0 | A -> A (even if A is null)
1216            Expr::BinaryExpr(BinaryExpr {
1217                left,
1218                op: BitwiseOr,
1219                right,
1220            }) if is_zero(&left) => Transformed::yes(*right),
1221
1222            // !A | A -> -1 (if A not nullable)
1223            Expr::BinaryExpr(BinaryExpr {
1224                left,
1225                op: BitwiseOr,
1226                right,
1227            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1228                Transformed::yes(Expr::Literal(
1229                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1230                    None,
1231                ))
1232            }
1233
1234            // A | !A -> -1 (if A not nullable)
1235            Expr::BinaryExpr(BinaryExpr {
1236                left,
1237                op: BitwiseOr,
1238                right,
1239            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1240                Transformed::yes(Expr::Literal(
1241                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1242                    None,
1243                ))
1244            }
1245
1246            // (..A..) | A --> (..A..)
1247            Expr::BinaryExpr(BinaryExpr {
1248                left,
1249                op: BitwiseOr,
1250                right,
1251            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1252
1253            // A | (..A..) --> (..A..)
1254            Expr::BinaryExpr(BinaryExpr {
1255                left,
1256                op: BitwiseOr,
1257                right,
1258            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1259
1260            // A | (A & B) --> A (if B not null)
1261            Expr::BinaryExpr(BinaryExpr {
1262                left,
1263                op: BitwiseOr,
1264                right,
1265            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1266                Transformed::yes(*left)
1267            }
1268
1269            // (A & B) | A --> A (if B not null)
1270            Expr::BinaryExpr(BinaryExpr {
1271                left,
1272                op: BitwiseOr,
1273                right,
1274            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1275                Transformed::yes(*right)
1276            }
1277
1278            //
1279            // Rules for BitwiseXor
1280            //
1281
1282            // A ^ 0 -> A (if A not nullable)
1283            Expr::BinaryExpr(BinaryExpr {
1284                left,
1285                op: BitwiseXor,
1286                right,
1287            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1288
1289            // 0 ^ A -> A (if A not nullable)
1290            Expr::BinaryExpr(BinaryExpr {
1291                left,
1292                op: BitwiseXor,
1293                right,
1294            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1295
1296            // !A ^ A -> -1 (if A not nullable)
1297            Expr::BinaryExpr(BinaryExpr {
1298                left,
1299                op: BitwiseXor,
1300                right,
1301            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1302                Transformed::yes(Expr::Literal(
1303                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1304                    None,
1305                ))
1306            }
1307
1308            // A ^ !A -> -1 (if A not nullable)
1309            Expr::BinaryExpr(BinaryExpr {
1310                left,
1311                op: BitwiseXor,
1312                right,
1313            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1314                Transformed::yes(Expr::Literal(
1315                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1316                    None,
1317                ))
1318            }
1319
1320            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1321            Expr::BinaryExpr(BinaryExpr {
1322                left,
1323                op: BitwiseXor,
1324                right,
1325            }) if expr_contains(&left, &right, BitwiseXor) => {
1326                let expr = delete_xor_in_complex_expr(&left, &right, false);
1327                Transformed::yes(if expr == *right {
1328                    Expr::Literal(
1329                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1330                        None,
1331                    )
1332                } else {
1333                    expr
1334                })
1335            }
1336
1337            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1338            Expr::BinaryExpr(BinaryExpr {
1339                left,
1340                op: BitwiseXor,
1341                right,
1342            }) if expr_contains(&right, &left, BitwiseXor) => {
1343                let expr = delete_xor_in_complex_expr(&right, &left, true);
1344                Transformed::yes(if expr == *left {
1345                    Expr::Literal(
1346                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1347                        None,
1348                    )
1349                } else {
1350                    expr
1351                })
1352            }
1353
1354            //
1355            // Rules for BitwiseShiftRight
1356            //
1357
1358            // A >> 0 -> A (even if A is null)
1359            Expr::BinaryExpr(BinaryExpr {
1360                left,
1361                op: BitwiseShiftRight,
1362                right,
1363            }) if is_zero(&right) => Transformed::yes(*left),
1364
1365            //
1366            // Rules for BitwiseShiftRight
1367            //
1368
1369            // A << 0 -> A (even if A is null)
1370            Expr::BinaryExpr(BinaryExpr {
1371                left,
1372                op: BitwiseShiftLeft,
1373                right,
1374            }) if is_zero(&right) => Transformed::yes(*left),
1375
1376            //
1377            // Rules for Not
1378            //
1379            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1380
1381            //
1382            // Rules for Negative
1383            //
1384            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1385
1386            //
1387            // Rules for Case
1388            //
1389
1390            // Inline a comparison to a literal with the case statement into the `THEN` clauses.
1391            // which can enable further simplifications
1392            // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
1393            Expr::BinaryExpr(BinaryExpr {
1394                left,
1395                op: op @ (Eq | NotEq),
1396                right,
1397            }) if is_case_with_literal_outputs(&left) && is_lit(&right) => {
1398                let case = into_case(*left)?;
1399                Transformed::yes(Expr::Case(Case {
1400                    expr: None,
1401                    when_then_expr: case
1402                        .when_then_expr
1403                        .into_iter()
1404                        .map(|(when, then)| {
1405                            (
1406                                when,
1407                                Box::new(Expr::BinaryExpr(BinaryExpr {
1408                                    left: then,
1409                                    op,
1410                                    right: right.clone(),
1411                                })),
1412                            )
1413                        })
1414                        .collect(),
1415                    else_expr: case.else_expr.map(|els| {
1416                        Box::new(Expr::BinaryExpr(BinaryExpr {
1417                            left: els,
1418                            op,
1419                            right,
1420                        }))
1421                    }),
1422                }))
1423            }
1424
1425            // CASE WHEN true THEN A ... END --> A
1426            // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1427            // CASE WHEN false THEN A END --> NULL
1428            // CASE WHEN false THEN A ELSE B END --> B
1429            // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
1430            Expr::Case(Case {
1431                expr: None,
1432                when_then_expr,
1433                mut else_expr,
1434            }) if when_then_expr
1435                .iter()
1436                .any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) =>
1437            {
1438                let out_type = info.get_data_type(&when_then_expr[0].1)?;
1439                let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len());
1440
1441                for (when, then) in when_then_expr.into_iter() {
1442                    if is_true(when.as_ref()) {
1443                        // Skip adding the rest of the when-then expressions after WHEN true
1444                        // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1445                        else_expr = Some(then);
1446                        break;
1447                    } else if !is_false(when.as_ref()) {
1448                        new_when_then_expr.push((when, then));
1449                    }
1450                    // else: skip WHEN false cases
1451                }
1452
1453                // Exclude CASE statement altogether if there are no when-then expressions left
1454                if new_when_then_expr.is_empty() {
1455                    // CASE WHEN false THEN A ELSE B END --> B
1456                    if let Some(else_expr) = else_expr {
1457                        return Ok(Transformed::yes(*else_expr));
1458                    // CASE WHEN false THEN A END --> NULL
1459                    } else {
1460                        let null =
1461                            Expr::Literal(ScalarValue::try_new_null(&out_type)?, None);
1462                        return Ok(Transformed::yes(null));
1463                    }
1464                }
1465
1466                Transformed::yes(Expr::Case(Case {
1467                    expr: None,
1468                    when_then_expr: new_when_then_expr,
1469                    else_expr,
1470                }))
1471            }
1472
1473            // CASE
1474            //   WHEN X THEN A
1475            //   WHEN Y THEN B
1476            //   ...
1477            //   ELSE Q
1478            // END
1479            //
1480            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1481            //
1482            // Note: the rationale for this rewrite is that the expr can then be further
1483            // simplified using the existing rules for AND/OR
1484            Expr::Case(Case {
1485                expr: None,
1486                when_then_expr,
1487                else_expr,
1488            }) if !when_then_expr.is_empty()
1489                // The rewrite is O(n²) in general so limit to small number of when-thens that can be true
1490                && (when_then_expr.len() < 3 // small number of input whens
1491                    // or all thens are literal bools and a small number of them are true
1492                    || (when_then_expr.iter().all(|(_, then)| is_bool_lit(then))
1493                        && when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3))
1494                && info.is_boolean_type(&when_then_expr[0].1)? =>
1495            {
1496                // String disjunction of all the when predicates encountered so far. Not nullable.
1497                let mut filter_expr = lit(false);
1498                // The disjunction of all the cases
1499                let mut out_expr = lit(false);
1500
1501                for (when, then) in when_then_expr {
1502                    let when = is_exactly_true(*when, info)?;
1503                    let case_expr =
1504                        when.clone().and(filter_expr.clone().not()).and(*then);
1505
1506                    out_expr = out_expr.or(case_expr);
1507                    filter_expr = filter_expr.or(when);
1508                }
1509
1510                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1511                let case_expr = filter_expr.not().and(else_expr);
1512                out_expr = out_expr.or(case_expr);
1513
1514                // Do a first pass at simplification
1515                out_expr.rewrite(self)?
1516            }
1517            // CASE
1518            //   WHEN X THEN true
1519            //   WHEN Y THEN true
1520            //   WHEN Z THEN false
1521            //   ...
1522            //   ELSE true
1523            // END
1524            //
1525            // --->
1526            //
1527            // NOT(CASE
1528            //   WHEN X THEN false
1529            //   WHEN Y THEN false
1530            //   WHEN Z THEN true
1531            //   ...
1532            //   ELSE false
1533            // END)
1534            //
1535            // Note: the rationale for this rewrite is that the case can then be further
1536            // simplified into a small number of ANDs and ORs
1537            Expr::Case(Case {
1538                expr: None,
1539                when_then_expr,
1540                else_expr,
1541            }) if !when_then_expr.is_empty()
1542                && when_then_expr
1543                    .iter()
1544                    .all(|(_, then)| is_bool_lit(then)) // all thens are literal bools
1545                // This simplification is only helpful if we end up with a small number of true thens
1546                && when_then_expr
1547                    .iter()
1548                    .filter(|(_, then)| is_false(then))
1549                    .count()
1550                    < 3
1551                && else_expr.as_deref().is_none_or(is_bool_lit) =>
1552            {
1553                Transformed::yes(
1554                    Expr::Case(Case {
1555                        expr: None,
1556                        when_then_expr: when_then_expr
1557                            .into_iter()
1558                            .map(|(when, then)| (when, Box::new(Expr::Not(then))))
1559                            .collect(),
1560                        else_expr: else_expr
1561                            .map(|else_expr| Box::new(Expr::Not(else_expr))),
1562                    })
1563                    .not(),
1564                )
1565            }
1566            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1567                match udf.simplify(args, info)? {
1568                    ExprSimplifyResult::Original(args) => {
1569                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1570                            func: udf,
1571                            args,
1572                        }))
1573                    }
1574                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1575                }
1576            }
1577
1578            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1579                ref func,
1580                ..
1581            }) => match (func.simplify(), expr) {
1582                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1583                    Transformed::yes(simplify_function(af, info)?)
1584                }
1585                (_, expr) => Transformed::no(expr),
1586            },
1587
1588            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1589                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1590                    Transformed::yes(simplify_function(*wf, info)?)
1591                }
1592                (_, expr) => Transformed::no(expr),
1593            },
1594
1595            //
1596            // Rules for Between
1597            //
1598
1599            // a between 3 and 5  -->  a >= 3 AND a <=5
1600            // a not between 3 and 5  -->  a < 3 OR a > 5
1601            Expr::Between(between) => Transformed::yes(if between.negated {
1602                let l = *between.expr.clone();
1603                let r = *between.expr;
1604                or(l.lt(*between.low), r.gt(*between.high))
1605            } else {
1606                and(
1607                    between.expr.clone().gt_eq(*between.low),
1608                    between.expr.lt_eq(*between.high),
1609                )
1610            }),
1611
1612            //
1613            // Rules for regexes
1614            //
1615            Expr::BinaryExpr(BinaryExpr {
1616                left,
1617                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1618                right,
1619            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1620
1621            // Rules for Like
1622            Expr::Like(like) => {
1623                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1624                let escape_char = like.escape_char.unwrap_or('\\');
1625                match as_string_scalar(&like.pattern) {
1626                    Some((data_type, pattern_str)) => {
1627                        match pattern_str {
1628                            None => return Ok(Transformed::yes(lit_bool_null())),
1629                            Some(pattern_str) if pattern_str == "%" => {
1630                                // exp LIKE '%' is
1631                                //   - when exp is not NULL, it's true
1632                                //   - when exp is NULL, it's NULL
1633                                // exp NOT LIKE '%' is
1634                                //   - when exp is not NULL, it's false
1635                                //   - when exp is NULL, it's NULL
1636                                let result_for_non_null = lit(!like.negated);
1637                                Transformed::yes(if !info.nullable(&like.expr)? {
1638                                    result_for_non_null
1639                                } else {
1640                                    Expr::Case(Case {
1641                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1642                                        when_then_expr: vec![(
1643                                            Box::new(lit(true)),
1644                                            Box::new(result_for_non_null),
1645                                        )],
1646                                        else_expr: None,
1647                                    })
1648                                })
1649                            }
1650                            Some(pattern_str)
1651                                if pattern_str.contains("%%")
1652                                    && !pattern_str.contains(escape_char) =>
1653                            {
1654                                // Repeated occurrences of wildcard are redundant so remove them
1655                                // exp LIKE '%%'  --> exp LIKE '%'
1656                                let simplified_pattern = Regex::new("%%+")
1657                                    .unwrap()
1658                                    .replace_all(pattern_str, "%")
1659                                    .to_string();
1660                                Transformed::yes(Expr::Like(Like {
1661                                    pattern: Box::new(to_string_scalar(
1662                                        &data_type,
1663                                        Some(simplified_pattern),
1664                                    )),
1665                                    ..like
1666                                }))
1667                            }
1668                            Some(pattern_str)
1669                                if !like.case_insensitive
1670                                    && !pattern_str
1671                                        .contains(['%', '_', escape_char].as_ref()) =>
1672                            {
1673                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1674                                // TODO: handle escape characters
1675                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1676                                    left: like.expr.clone(),
1677                                    op: if like.negated { NotEq } else { Eq },
1678                                    right: like.pattern.clone(),
1679                                }))
1680                            }
1681
1682                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1683                        }
1684                    }
1685                    None => Transformed::no(Expr::Like(like)),
1686                }
1687            }
1688
1689            // a is not null/unknown --> true (if a is not nullable)
1690            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1691                if !info.nullable(&expr)? =>
1692            {
1693                Transformed::yes(lit(true))
1694            }
1695
1696            // a is null/unknown --> false (if a is not nullable)
1697            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1698                Transformed::yes(lit(false))
1699            }
1700
1701            // expr IN () --> false
1702            // expr NOT IN () --> true
1703            Expr::InList(InList {
1704                expr: _,
1705                list,
1706                negated,
1707            }) if list.is_empty() => Transformed::yes(lit(negated)),
1708
1709            // null in (x, y, z) --> null
1710            // null not in (x, y, z) --> null
1711            Expr::InList(InList {
1712                expr,
1713                list,
1714                negated: _,
1715            }) if is_null(expr.as_ref()) && !list.is_empty() => {
1716                Transformed::yes(lit_bool_null())
1717            }
1718
1719            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1720            Expr::InList(InList {
1721                expr,
1722                mut list,
1723                negated,
1724            }) if list.len() == 1
1725                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1726            {
1727                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1728                    unreachable!()
1729                };
1730
1731                Transformed::yes(Expr::InSubquery(InSubquery::new(
1732                    expr, subquery, negated,
1733                )))
1734            }
1735
1736            // Combine multiple OR expressions into a single IN list expression if possible
1737            //
1738            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1739            Expr::BinaryExpr(BinaryExpr {
1740                left,
1741                op: Or,
1742                right,
1743            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1744                let lhs = to_inlist(*left).unwrap();
1745                let rhs = to_inlist(*right).unwrap();
1746                let mut seen: HashSet<Expr> = HashSet::new();
1747                let list = lhs
1748                    .list
1749                    .into_iter()
1750                    .chain(rhs.list)
1751                    .filter(|e| seen.insert(e.to_owned()))
1752                    .collect::<Vec<_>>();
1753
1754                let merged_inlist = InList {
1755                    expr: lhs.expr,
1756                    list,
1757                    negated: false,
1758                };
1759
1760                Transformed::yes(Expr::InList(merged_inlist))
1761            }
1762
1763            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1764            //
1765            // Rules:
1766            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1767            //   Intersection:
1768            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1769            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1770            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1771            //   Union:
1772            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1773            //     # This rule is handled by `or_in_list_simplifier.rs`
1774            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1775            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1776            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1777            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1778            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1779            Expr::BinaryExpr(BinaryExpr {
1780                left,
1781                op: And,
1782                right,
1783            }) if are_inlist_and_eq_and_match_neg(
1784                left.as_ref(),
1785                right.as_ref(),
1786                false,
1787                false,
1788            ) =>
1789            {
1790                match (*left, *right) {
1791                    (Expr::InList(l1), Expr::InList(l2)) => {
1792                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1793                    }
1794                    // Matched previously once
1795                    _ => unreachable!(),
1796                }
1797            }
1798
1799            Expr::BinaryExpr(BinaryExpr {
1800                left,
1801                op: And,
1802                right,
1803            }) if are_inlist_and_eq_and_match_neg(
1804                left.as_ref(),
1805                right.as_ref(),
1806                true,
1807                true,
1808            ) =>
1809            {
1810                match (*left, *right) {
1811                    (Expr::InList(l1), Expr::InList(l2)) => {
1812                        return inlist_union(l1, l2, true).map(Transformed::yes);
1813                    }
1814                    // Matched previously once
1815                    _ => unreachable!(),
1816                }
1817            }
1818
1819            Expr::BinaryExpr(BinaryExpr {
1820                left,
1821                op: And,
1822                right,
1823            }) if are_inlist_and_eq_and_match_neg(
1824                left.as_ref(),
1825                right.as_ref(),
1826                false,
1827                true,
1828            ) =>
1829            {
1830                match (*left, *right) {
1831                    (Expr::InList(l1), Expr::InList(l2)) => {
1832                        return inlist_except(l1, &l2).map(Transformed::yes);
1833                    }
1834                    // Matched previously once
1835                    _ => unreachable!(),
1836                }
1837            }
1838
1839            Expr::BinaryExpr(BinaryExpr {
1840                left,
1841                op: And,
1842                right,
1843            }) if are_inlist_and_eq_and_match_neg(
1844                left.as_ref(),
1845                right.as_ref(),
1846                true,
1847                false,
1848            ) =>
1849            {
1850                match (*left, *right) {
1851                    (Expr::InList(l1), Expr::InList(l2)) => {
1852                        return inlist_except(l2, &l1).map(Transformed::yes);
1853                    }
1854                    // Matched previously once
1855                    _ => unreachable!(),
1856                }
1857            }
1858
1859            Expr::BinaryExpr(BinaryExpr {
1860                left,
1861                op: Or,
1862                right,
1863            }) if are_inlist_and_eq_and_match_neg(
1864                left.as_ref(),
1865                right.as_ref(),
1866                true,
1867                true,
1868            ) =>
1869            {
1870                match (*left, *right) {
1871                    (Expr::InList(l1), Expr::InList(l2)) => {
1872                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1873                    }
1874                    // Matched previously once
1875                    _ => unreachable!(),
1876                }
1877            }
1878
1879            // =======================================
1880            // unwrap_cast_in_comparison
1881            // =======================================
1882            //
1883            // For case:
1884            // try_cast/cast(expr as data_type) op literal
1885            Expr::BinaryExpr(BinaryExpr { left, op, right })
1886                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1887                    info, &left, op, &right,
1888                ) && op.supports_propagation() =>
1889            {
1890                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1891            }
1892            // literal op try_cast/cast(expr as data_type)
1893            // -->
1894            // try_cast/cast(expr as data_type) op_swap literal
1895            Expr::BinaryExpr(BinaryExpr { left, op, right })
1896                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1897                    info, &right, op, &left,
1898                ) && op.supports_propagation()
1899                    && op.swap().is_some() =>
1900            {
1901                unwrap_cast_in_comparison_for_binary(
1902                    info,
1903                    *right,
1904                    *left,
1905                    op.swap().unwrap(),
1906                )?
1907            }
1908            // For case:
1909            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1910            Expr::InList(InList {
1911                expr: mut left,
1912                list,
1913                negated,
1914            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1915                info, &left, &list,
1916            ) =>
1917            {
1918                let (Expr::TryCast(TryCast {
1919                    expr: left_expr, ..
1920                })
1921                | Expr::Cast(Cast {
1922                    expr: left_expr, ..
1923                })) = left.as_mut()
1924                else {
1925                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1926                };
1927
1928                let expr_type = info.get_data_type(left_expr)?;
1929                let right_exprs = list
1930                    .into_iter()
1931                    .map(|right| {
1932                        match right {
1933                            Expr::Literal(right_lit_value, _) => {
1934                                // if the right_lit_value can be casted to the type of internal_left_expr
1935                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1936                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1937                                    internal_err!(
1938                                        "Can't cast the list expr {:?} to type {}",
1939                                        right_lit_value, &expr_type
1940                                    )?
1941                                };
1942                                Ok(lit(value))
1943                            }
1944                            other_expr => internal_err!(
1945                                "Only support literal expr to optimize, but the expr is {:?}",
1946                                &other_expr
1947                            ),
1948                        }
1949                    })
1950                    .collect::<Result<Vec<_>>>()?;
1951
1952                Transformed::yes(Expr::InList(InList {
1953                    expr: std::mem::take(left_expr),
1954                    list: right_exprs,
1955                    negated,
1956                }))
1957            }
1958
1959            // no additional rewrites possible
1960            expr => Transformed::no(expr),
1961        })
1962    }
1963}
1964
1965fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1966    match expr {
1967        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1968        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1969        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1970        _ => None,
1971    }
1972}
1973
1974fn to_string_scalar(data_type: &DataType, value: Option<String>) -> Expr {
1975    match data_type {
1976        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1977        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1978        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1979        _ => unreachable!(),
1980    }
1981}
1982
1983fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1984    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1985    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1986}
1987
1988// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1989fn are_inlist_and_eq_and_match_neg(
1990    left: &Expr,
1991    right: &Expr,
1992    is_left_neg: bool,
1993    is_right_neg: bool,
1994) -> bool {
1995    match (left, right) {
1996        (Expr::InList(l), Expr::InList(r)) => {
1997            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1998        }
1999        _ => false,
2000    }
2001}
2002
2003// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
2004fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
2005    let left = as_inlist(left);
2006    let right = as_inlist(right);
2007    if let (Some(lhs), Some(rhs)) = (left, right) {
2008        matches!(lhs.expr.as_ref(), Expr::Column(_))
2009            && matches!(rhs.expr.as_ref(), Expr::Column(_))
2010            && lhs.expr == rhs.expr
2011            && !lhs.negated
2012            && !rhs.negated
2013    } else {
2014        false
2015    }
2016}
2017
2018/// Try to convert an expression to an in-list expression
2019fn as_inlist(expr: &'_ Expr) -> Option<Cow<'_, InList>> {
2020    match expr {
2021        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
2022        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
2023            match (left.as_ref(), right.as_ref()) {
2024                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
2025                    expr: left.clone(),
2026                    list: vec![*right.clone()],
2027                    negated: false,
2028                })),
2029                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
2030                    expr: right.clone(),
2031                    list: vec![*left.clone()],
2032                    negated: false,
2033                })),
2034                _ => None,
2035            }
2036        }
2037        _ => None,
2038    }
2039}
2040
2041fn to_inlist(expr: Expr) -> Option<InList> {
2042    match expr {
2043        Expr::InList(inlist) => Some(inlist),
2044        Expr::BinaryExpr(BinaryExpr {
2045            left,
2046            op: Operator::Eq,
2047            right,
2048        }) => match (left.as_ref(), right.as_ref()) {
2049            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
2050                expr: left,
2051                list: vec![*right],
2052                negated: false,
2053            }),
2054            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
2055                expr: right,
2056                list: vec![*left],
2057                negated: false,
2058            }),
2059            _ => None,
2060        },
2061        _ => None,
2062    }
2063}
2064
2065/// Return the union of two inlist expressions
2066/// maintaining the order of the elements in the two lists
2067fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
2068    // extend the list in l1 with the elements in l2 that are not already in l1
2069    let l1_items: HashSet<_> = l1.list.iter().collect();
2070
2071    // keep all l2 items that do not also appear in l1
2072    let keep_l2: Vec<_> = l2
2073        .list
2074        .into_iter()
2075        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
2076        .collect();
2077
2078    l1.list.extend(keep_l2);
2079    l1.negated = negated;
2080    Ok(Expr::InList(l1))
2081}
2082
2083/// Return the intersection of two inlist expressions
2084/// maintaining the order of the elements in the two lists
2085fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
2086    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2087
2088    // remove all items from l1 that are not in l2
2089    l1.list.retain(|e| l2_items.contains(e));
2090
2091    // e in () is always false
2092    // e not in () is always true
2093    if l1.list.is_empty() {
2094        return Ok(lit(negated));
2095    }
2096    Ok(Expr::InList(l1))
2097}
2098
2099/// Return the all items in l1 that are not in l2
2100/// maintaining the order of the elements in the two lists
2101fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
2102    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2103
2104    // keep only items from l1 that are not in l2
2105    l1.list.retain(|e| !l2_items.contains(e));
2106
2107    if l1.list.is_empty() {
2108        return Ok(lit(false));
2109    }
2110    Ok(Expr::InList(l1))
2111}
2112
2113/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
2114fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
2115    if !info.nullable(&expr)? {
2116        Ok(expr)
2117    } else {
2118        Ok(Expr::BinaryExpr(BinaryExpr {
2119            left: Box::new(expr),
2120            op: Operator::IsNotDistinctFrom,
2121            right: Box::new(lit(true)),
2122        }))
2123    }
2124}
2125
2126// A * 1 -> A
2127// A / 1 -> A
2128//
2129// Move this function body out of the large match branch avoid stack overflow
2130fn simplify_right_is_one_case<S: SimplifyInfo>(
2131    info: &S,
2132    left: Box<Expr>,
2133    op: &Operator,
2134    right: &Expr,
2135) -> Result<Transformed<Expr>> {
2136    // Check if resulting type would be different due to coercion
2137    let left_type = info.get_data_type(&left)?;
2138    let right_type = info.get_data_type(right)?;
2139    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2140        Ok(result_type) => {
2141            // Only cast if the types differ
2142            if left_type != result_type {
2143                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2144            } else {
2145                Ok(Transformed::yes(*left))
2146            }
2147        }
2148        Err(_) => Ok(Transformed::yes(*left)),
2149    }
2150}
2151
2152#[cfg(test)]
2153mod tests {
2154    use super::*;
2155    use crate::simplify_expressions::SimplifyContext;
2156    use crate::test::test_table_scan_with_name;
2157    use arrow::datatypes::FieldRef;
2158    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
2159    use datafusion_expr::{
2160        expr::WindowFunction,
2161        function::{
2162            AccumulatorArgs, AggregateFunctionSimplification,
2163            WindowFunctionSimplification,
2164        },
2165        interval_arithmetic::Interval,
2166        *,
2167    };
2168    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2169    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2170    use datafusion_physical_expr::PhysicalExpr;
2171    use std::hash::Hash;
2172    use std::sync::LazyLock;
2173    use std::{
2174        collections::HashMap,
2175        ops::{BitAnd, BitOr, BitXor},
2176        sync::Arc,
2177    };
2178
2179    // ------------------------------
2180    // --- ExprSimplifier tests -----
2181    // ------------------------------
2182    #[test]
2183    fn api_basic() {
2184        let props = ExecutionProps::new();
2185        let simplifier =
2186            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2187
2188        let expr = lit(1) + lit(2);
2189        let expected = lit(3);
2190        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2191    }
2192
2193    #[test]
2194    fn basic_coercion() {
2195        let schema = test_schema();
2196        let props = ExecutionProps::new();
2197        let simplifier = ExprSimplifier::new(
2198            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2199        );
2200
2201        // Note expr type is int32 (not int64)
2202        // (1i64 + 2i32) < i
2203        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2204        // should fully simplify to 3 < i (though i has been coerced to i64)
2205        let expected = lit(3i64).lt(col("i"));
2206
2207        let expr = simplifier.coerce(expr, &schema).unwrap();
2208
2209        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2210    }
2211
2212    fn test_schema() -> DFSchemaRef {
2213        static TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
2214            Schema::new(vec![
2215                Field::new("i", DataType::Int64, false),
2216                Field::new("b", DataType::Boolean, true),
2217            ])
2218            .to_dfschema_ref()
2219            .unwrap()
2220        });
2221        Arc::clone(&TEST_SCHEMA)
2222    }
2223
2224    #[test]
2225    fn simplify_and_constant_prop() {
2226        let props = ExecutionProps::new();
2227        let simplifier =
2228            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2229
2230        // should be able to simplify to false
2231        // (i * (1 - 2)) > 0
2232        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2233        let expected = lit(false);
2234        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2235    }
2236
2237    #[test]
2238    fn simplify_and_constant_prop_with_case() {
2239        let props = ExecutionProps::new();
2240        let simplifier =
2241            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2242
2243        //   CASE
2244        //     WHEN i>5 AND false THEN i > 5
2245        //     WHEN i<5 AND true THEN i < 5
2246        //     ELSE false
2247        //   END
2248        //
2249        // Can be simplified to `i < 5`
2250        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2251            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2252            .otherwise(lit(false))
2253            .unwrap();
2254        let expected = col("i").lt(lit(5));
2255        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2256    }
2257
2258    // ------------------------------
2259    // --- Simplifier tests -----
2260    // ------------------------------
2261
2262    #[test]
2263    fn test_simplify_canonicalize() {
2264        {
2265            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2266            let expected = col("c2").gt(lit(1));
2267            assert_eq!(simplify(expr), expected);
2268        }
2269        {
2270            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2271            let expected = col("c2").gt(col("c1"));
2272            assert_eq!(simplify(expr), expected);
2273        }
2274        {
2275            let expr = col("c1")
2276                .eq(lit(1))
2277                .and(lit(1).eq(col("c1")))
2278                .and(col("c1").eq(lit(3)));
2279            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2280            assert_eq!(simplify(expr), expected);
2281        }
2282        {
2283            let expr = col("c1")
2284                .eq(col("c2"))
2285                .and(col("c1").gt(lit(5)))
2286                .and(col("c2").eq(col("c1")));
2287            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2288            assert_eq!(simplify(expr), expected);
2289        }
2290        {
2291            let expr = col("c1")
2292                .eq(lit(1))
2293                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2294            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2295            assert_eq!(simplify(expr), expected);
2296        }
2297        {
2298            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2299            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2300            assert_eq!(simplify(expr), expected);
2301        }
2302        {
2303            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2304            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2305            assert_eq!(simplify(expr), expected);
2306        }
2307        {
2308            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2309            let expected = col("c2").lt(col("c1"));
2310            assert_eq!(simplify(expr), expected);
2311        }
2312    }
2313
2314    #[test]
2315    fn test_simplify_eq_not_self() {
2316        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2317        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2318        let expr_a = col("c2").eq(col("c2"));
2319        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2320
2321        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2322        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2323        let expected_b = lit(true);
2324
2325        assert_eq!(simplify(expr_a), expected_a);
2326        assert_eq!(simplify(expr_b), expected_b);
2327    }
2328
2329    #[test]
2330    fn test_simplify_or_true() {
2331        let expr_a = col("c2").or(lit(true));
2332        let expr_b = lit(true).or(col("c2"));
2333        let expected = lit(true);
2334
2335        assert_eq!(simplify(expr_a), expected);
2336        assert_eq!(simplify(expr_b), expected);
2337    }
2338
2339    #[test]
2340    fn test_simplify_or_false() {
2341        let expr_a = lit(false).or(col("c2"));
2342        let expr_b = col("c2").or(lit(false));
2343        let expected = col("c2");
2344
2345        assert_eq!(simplify(expr_a), expected);
2346        assert_eq!(simplify(expr_b), expected);
2347    }
2348
2349    #[test]
2350    fn test_simplify_or_same() {
2351        let expr = col("c2").or(col("c2"));
2352        let expected = col("c2");
2353
2354        assert_eq!(simplify(expr), expected);
2355    }
2356
2357    #[test]
2358    fn test_simplify_or_not_self() {
2359        // A OR !A if A is not nullable --> true
2360        // !A OR A if A is not nullable --> true
2361        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2362        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2363        let expected = lit(true);
2364
2365        assert_eq!(simplify(expr_a), expected);
2366        assert_eq!(simplify(expr_b), expected);
2367    }
2368
2369    #[test]
2370    fn test_simplify_and_false() {
2371        let expr_a = lit(false).and(col("c2"));
2372        let expr_b = col("c2").and(lit(false));
2373        let expected = lit(false);
2374
2375        assert_eq!(simplify(expr_a), expected);
2376        assert_eq!(simplify(expr_b), expected);
2377    }
2378
2379    #[test]
2380    fn test_simplify_and_same() {
2381        let expr = col("c2").and(col("c2"));
2382        let expected = col("c2");
2383
2384        assert_eq!(simplify(expr), expected);
2385    }
2386
2387    #[test]
2388    fn test_simplify_and_true() {
2389        let expr_a = lit(true).and(col("c2"));
2390        let expr_b = col("c2").and(lit(true));
2391        let expected = col("c2");
2392
2393        assert_eq!(simplify(expr_a), expected);
2394        assert_eq!(simplify(expr_b), expected);
2395    }
2396
2397    #[test]
2398    fn test_simplify_and_not_self() {
2399        // A AND !A if A is not nullable --> false
2400        // !A AND A if A is not nullable --> false
2401        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2402        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2403        let expected = lit(false);
2404
2405        assert_eq!(simplify(expr_a), expected);
2406        assert_eq!(simplify(expr_b), expected);
2407    }
2408
2409    #[test]
2410    fn test_simplify_multiply_by_one() {
2411        let expr_a = col("c2") * lit(1);
2412        let expr_b = lit(1) * col("c2");
2413        let expected = col("c2");
2414
2415        assert_eq!(simplify(expr_a), expected);
2416        assert_eq!(simplify(expr_b), expected);
2417
2418        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2419        assert_eq!(simplify(expr), expected);
2420
2421        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2422        assert_eq!(simplify(expr), expected);
2423    }
2424
2425    #[test]
2426    fn test_simplify_multiply_by_null() {
2427        let null = lit(ScalarValue::Int64(None));
2428        // A * null --> null
2429        {
2430            let expr = col("c3") * null.clone();
2431            assert_eq!(simplify(expr), null);
2432        }
2433        // null * A --> null
2434        {
2435            let expr = null.clone() * col("c3");
2436            assert_eq!(simplify(expr), null);
2437        }
2438    }
2439
2440    #[test]
2441    fn test_simplify_multiply_by_zero() {
2442        // cannot optimize A * null (null * A) if A is nullable
2443        {
2444            let expr_a = col("c2") * lit(0);
2445            let expr_b = lit(0) * col("c2");
2446
2447            assert_eq!(simplify(expr_a.clone()), expr_a);
2448            assert_eq!(simplify(expr_b.clone()), expr_b);
2449        }
2450        // 0 * A --> 0 if A is not nullable
2451        {
2452            let expr = lit(0) * col("c2_non_null");
2453            assert_eq!(simplify(expr), lit(0));
2454        }
2455        // A * 0 --> 0 if A is not nullable
2456        {
2457            let expr = col("c2_non_null") * lit(0);
2458            assert_eq!(simplify(expr), lit(0));
2459        }
2460        // A * Decimal128(0) --> 0 if A is not nullable
2461        {
2462            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2463            assert_eq!(
2464                simplify(expr),
2465                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2466            );
2467            let expr = binary_expr(
2468                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2469                Operator::Multiply,
2470                col("c2_non_null"),
2471            );
2472            assert_eq!(
2473                simplify(expr),
2474                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2475            );
2476        }
2477    }
2478
2479    #[test]
2480    fn test_simplify_divide_by_one() {
2481        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2482        let expected = col("c2");
2483        assert_eq!(simplify(expr), expected);
2484        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2485        assert_eq!(simplify(expr), expected);
2486    }
2487
2488    #[test]
2489    fn test_simplify_divide_null() {
2490        // A / null --> null
2491        let null = lit(ScalarValue::Int64(None));
2492        {
2493            let expr = col("c3") / null.clone();
2494            assert_eq!(simplify(expr), null);
2495        }
2496        // null / A --> null
2497        {
2498            let expr = null.clone() / col("c3");
2499            assert_eq!(simplify(expr), null);
2500        }
2501    }
2502
2503    #[test]
2504    fn test_simplify_divide_by_same() {
2505        let expr = col("c2") / col("c2");
2506        // if c2 is null, c2 / c2 = null, so can't simplify
2507        let expected = expr.clone();
2508
2509        assert_eq!(simplify(expr), expected);
2510    }
2511
2512    #[test]
2513    fn test_simplify_modulo_by_null() {
2514        let null = lit(ScalarValue::Int64(None));
2515        // A % null --> null
2516        {
2517            let expr = col("c3") % null.clone();
2518            assert_eq!(simplify(expr), null);
2519        }
2520        // null % A --> null
2521        {
2522            let expr = null.clone() % col("c3");
2523            assert_eq!(simplify(expr), null);
2524        }
2525    }
2526
2527    #[test]
2528    fn test_simplify_modulo_by_one() {
2529        let expr = col("c2") % lit(1);
2530        // if c2 is null, c2 % 1 = null, so can't simplify
2531        let expected = expr.clone();
2532
2533        assert_eq!(simplify(expr), expected);
2534    }
2535
2536    #[test]
2537    fn test_simplify_divide_zero_by_zero() {
2538        // because divide by 0 maybe occur in short-circuit expression
2539        // so we should not simplify this, and throw error in runtime
2540        let expr = lit(0) / lit(0);
2541        let expected = expr.clone();
2542
2543        assert_eq!(simplify(expr), expected);
2544    }
2545
2546    #[test]
2547    fn test_simplify_divide_by_zero() {
2548        // because divide by 0 maybe occur in short-circuit expression
2549        // so we should not simplify this, and throw error in runtime
2550        let expr = col("c2_non_null") / lit(0);
2551        let expected = expr.clone();
2552
2553        assert_eq!(simplify(expr), expected);
2554    }
2555
2556    #[test]
2557    fn test_simplify_modulo_by_one_non_null() {
2558        let expr = col("c3_non_null") % lit(1);
2559        let expected = lit(0_i64);
2560        assert_eq!(simplify(expr), expected);
2561        let expr =
2562            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2563        assert_eq!(simplify(expr), expected);
2564    }
2565
2566    #[test]
2567    fn test_simplify_bitwise_xor_by_null() {
2568        let null = lit(ScalarValue::Int64(None));
2569        // A ^ null --> null
2570        {
2571            let expr = col("c3") ^ null.clone();
2572            assert_eq!(simplify(expr), null);
2573        }
2574        // null ^ A --> null
2575        {
2576            let expr = null.clone() ^ col("c3");
2577            assert_eq!(simplify(expr), null);
2578        }
2579    }
2580
2581    #[test]
2582    fn test_simplify_bitwise_shift_right_by_null() {
2583        let null = lit(ScalarValue::Int64(None));
2584        // A >> null --> null
2585        {
2586            let expr = col("c3") >> null.clone();
2587            assert_eq!(simplify(expr), null);
2588        }
2589        // null >> A --> null
2590        {
2591            let expr = null.clone() >> col("c3");
2592            assert_eq!(simplify(expr), null);
2593        }
2594    }
2595
2596    #[test]
2597    fn test_simplify_bitwise_shift_left_by_null() {
2598        let null = lit(ScalarValue::Int64(None));
2599        // A << null --> null
2600        {
2601            let expr = col("c3") << null.clone();
2602            assert_eq!(simplify(expr), null);
2603        }
2604        // null << A --> null
2605        {
2606            let expr = null.clone() << col("c3");
2607            assert_eq!(simplify(expr), null);
2608        }
2609    }
2610
2611    #[test]
2612    fn test_simplify_bitwise_and_by_zero() {
2613        // A & 0 --> 0
2614        {
2615            let expr = col("c2_non_null") & lit(0);
2616            assert_eq!(simplify(expr), lit(0));
2617        }
2618        // 0 & A --> 0
2619        {
2620            let expr = lit(0) & col("c2_non_null");
2621            assert_eq!(simplify(expr), lit(0));
2622        }
2623    }
2624
2625    #[test]
2626    fn test_simplify_bitwise_or_by_zero() {
2627        // A | 0 --> A
2628        {
2629            let expr = col("c2_non_null") | lit(0);
2630            assert_eq!(simplify(expr), col("c2_non_null"));
2631        }
2632        // 0 | A --> A
2633        {
2634            let expr = lit(0) | col("c2_non_null");
2635            assert_eq!(simplify(expr), col("c2_non_null"));
2636        }
2637    }
2638
2639    #[test]
2640    fn test_simplify_bitwise_xor_by_zero() {
2641        // A ^ 0 --> A
2642        {
2643            let expr = col("c2_non_null") ^ lit(0);
2644            assert_eq!(simplify(expr), col("c2_non_null"));
2645        }
2646        // 0 ^ A --> A
2647        {
2648            let expr = lit(0) ^ col("c2_non_null");
2649            assert_eq!(simplify(expr), col("c2_non_null"));
2650        }
2651    }
2652
2653    #[test]
2654    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2655        // A >> 0 --> A
2656        {
2657            let expr = col("c2_non_null") >> lit(0);
2658            assert_eq!(simplify(expr), col("c2_non_null"));
2659        }
2660    }
2661
2662    #[test]
2663    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2664        // A << 0 --> A
2665        {
2666            let expr = col("c2_non_null") << lit(0);
2667            assert_eq!(simplify(expr), col("c2_non_null"));
2668        }
2669    }
2670
2671    #[test]
2672    fn test_simplify_bitwise_and_by_null() {
2673        let null = Expr::Literal(ScalarValue::Int64(None), None);
2674        // A & null --> null
2675        {
2676            let expr = col("c3") & null.clone();
2677            assert_eq!(simplify(expr), null);
2678        }
2679        // null & A --> null
2680        {
2681            let expr = null.clone() & col("c3");
2682            assert_eq!(simplify(expr), null);
2683        }
2684    }
2685
2686    #[test]
2687    fn test_simplify_composed_bitwise_and() {
2688        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2689
2690        let expr = bitwise_and(
2691            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2692            col("c2").gt(lit(5)),
2693        );
2694        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2695
2696        assert_eq!(simplify(expr), expected);
2697
2698        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2699
2700        let expr = bitwise_and(
2701            col("c2").gt(lit(5)),
2702            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2703        );
2704        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2705        assert_eq!(simplify(expr), expected);
2706    }
2707
2708    #[test]
2709    fn test_simplify_composed_bitwise_or() {
2710        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2711
2712        let expr = bitwise_or(
2713            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2714            col("c2").gt(lit(5)),
2715        );
2716        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2717
2718        assert_eq!(simplify(expr), expected);
2719
2720        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2721
2722        let expr = bitwise_or(
2723            col("c2").gt(lit(5)),
2724            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2725        );
2726        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2727
2728        assert_eq!(simplify(expr), expected);
2729    }
2730
2731    #[test]
2732    fn test_simplify_composed_bitwise_xor() {
2733        // with an even number of the column "c2"
2734        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2735
2736        let expr = bitwise_xor(
2737            col("c2"),
2738            bitwise_xor(
2739                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2740                bitwise_and(col("c1"), col("c2")),
2741            ),
2742        );
2743
2744        let expected = bitwise_xor(
2745            bitwise_or(col("c2"), col("c1")),
2746            bitwise_and(col("c1"), col("c2")),
2747        );
2748
2749        assert_eq!(simplify(expr), expected);
2750
2751        // with an odd number of the column "c2"
2752        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2753
2754        let expr = bitwise_xor(
2755            col("c2"),
2756            bitwise_xor(
2757                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2758                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2759            ),
2760        );
2761
2762        let expected = bitwise_xor(
2763            col("c2"),
2764            bitwise_xor(
2765                bitwise_or(col("c2"), col("c1")),
2766                bitwise_and(col("c1"), col("c2")),
2767            ),
2768        );
2769
2770        assert_eq!(simplify(expr), expected);
2771
2772        // with an even number of the column "c2"
2773        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2774
2775        let expr = bitwise_xor(
2776            bitwise_xor(
2777                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2778                bitwise_and(col("c1"), col("c2")),
2779            ),
2780            col("c2"),
2781        );
2782
2783        let expected = bitwise_xor(
2784            bitwise_or(col("c2"), col("c1")),
2785            bitwise_and(col("c1"), col("c2")),
2786        );
2787
2788        assert_eq!(simplify(expr), expected);
2789
2790        // with an odd number of the column "c2"
2791        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2792
2793        let expr = bitwise_xor(
2794            bitwise_xor(
2795                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2796                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2797            ),
2798            col("c2"),
2799        );
2800
2801        let expected = bitwise_xor(
2802            bitwise_xor(
2803                bitwise_or(col("c2"), col("c1")),
2804                bitwise_and(col("c1"), col("c2")),
2805            ),
2806            col("c2"),
2807        );
2808
2809        assert_eq!(simplify(expr), expected);
2810    }
2811
2812    #[test]
2813    fn test_simplify_negated_bitwise_and() {
2814        // !c4 & c4 --> 0
2815        let expr = (-col("c4_non_null")) & col("c4_non_null");
2816        let expected = lit(0u32);
2817
2818        assert_eq!(simplify(expr), expected);
2819        // c4 & !c4 --> 0
2820        let expr = col("c4_non_null") & (-col("c4_non_null"));
2821        let expected = lit(0u32);
2822
2823        assert_eq!(simplify(expr), expected);
2824
2825        // !c3 & c3 --> 0
2826        let expr = (-col("c3_non_null")) & col("c3_non_null");
2827        let expected = lit(0i64);
2828
2829        assert_eq!(simplify(expr), expected);
2830        // c3 & !c3 --> 0
2831        let expr = col("c3_non_null") & (-col("c3_non_null"));
2832        let expected = lit(0i64);
2833
2834        assert_eq!(simplify(expr), expected);
2835    }
2836
2837    #[test]
2838    fn test_simplify_negated_bitwise_or() {
2839        // !c4 | c4 --> -1
2840        let expr = (-col("c4_non_null")) | col("c4_non_null");
2841        let expected = lit(-1i32);
2842
2843        assert_eq!(simplify(expr), expected);
2844
2845        // c4 | !c4 --> -1
2846        let expr = col("c4_non_null") | (-col("c4_non_null"));
2847        let expected = lit(-1i32);
2848
2849        assert_eq!(simplify(expr), expected);
2850
2851        // !c3 | c3 --> -1
2852        let expr = (-col("c3_non_null")) | col("c3_non_null");
2853        let expected = lit(-1i64);
2854
2855        assert_eq!(simplify(expr), expected);
2856
2857        // c3 | !c3 --> -1
2858        let expr = col("c3_non_null") | (-col("c3_non_null"));
2859        let expected = lit(-1i64);
2860
2861        assert_eq!(simplify(expr), expected);
2862    }
2863
2864    #[test]
2865    fn test_simplify_negated_bitwise_xor() {
2866        // !c4 ^ c4 --> -1
2867        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2868        let expected = lit(-1i32);
2869
2870        assert_eq!(simplify(expr), expected);
2871
2872        // c4 ^ !c4 --> -1
2873        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2874        let expected = lit(-1i32);
2875
2876        assert_eq!(simplify(expr), expected);
2877
2878        // !c3 ^ c3 --> -1
2879        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2880        let expected = lit(-1i64);
2881
2882        assert_eq!(simplify(expr), expected);
2883
2884        // c3 ^ !c3 --> -1
2885        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2886        let expected = lit(-1i64);
2887
2888        assert_eq!(simplify(expr), expected);
2889    }
2890
2891    #[test]
2892    fn test_simplify_bitwise_and_or() {
2893        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2894        let expr = bitwise_and(
2895            col("c2_non_null").lt(lit(3)),
2896            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2897        );
2898        let expected = col("c2_non_null").lt(lit(3));
2899
2900        assert_eq!(simplify(expr), expected);
2901    }
2902
2903    #[test]
2904    fn test_simplify_bitwise_or_and() {
2905        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2906        let expr = bitwise_or(
2907            col("c2_non_null").lt(lit(3)),
2908            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2909        );
2910        let expected = col("c2_non_null").lt(lit(3));
2911
2912        assert_eq!(simplify(expr), expected);
2913    }
2914
2915    #[test]
2916    fn test_simplify_simple_bitwise_and() {
2917        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2918        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2919        let expected = col("c2").gt(lit(5));
2920
2921        assert_eq!(simplify(expr), expected);
2922    }
2923
2924    #[test]
2925    fn test_simplify_simple_bitwise_or() {
2926        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2927        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2928        let expected = col("c2").gt(lit(5));
2929
2930        assert_eq!(simplify(expr), expected);
2931    }
2932
2933    #[test]
2934    fn test_simplify_simple_bitwise_xor() {
2935        // c4 ^ c4 -> 0
2936        let expr = (col("c4")).bitxor(col("c4"));
2937        let expected = lit(0u32);
2938
2939        assert_eq!(simplify(expr), expected);
2940
2941        // c3 ^ c3 -> 0
2942        let expr = col("c3").bitxor(col("c3"));
2943        let expected = lit(0i64);
2944
2945        assert_eq!(simplify(expr), expected);
2946    }
2947
2948    #[test]
2949    fn test_simplify_modulo_by_zero_non_null() {
2950        // because modulo by 0 maybe occur in short-circuit expression
2951        // so we should not simplify this, and throw error in runtime.
2952        let expr = col("c2_non_null") % lit(0);
2953        let expected = expr.clone();
2954
2955        assert_eq!(simplify(expr), expected);
2956    }
2957
2958    #[test]
2959    fn test_simplify_simple_and() {
2960        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2961        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2962        let expected = col("c2").gt(lit(5));
2963
2964        assert_eq!(simplify(expr), expected);
2965    }
2966
2967    #[test]
2968    fn test_simplify_composed_and() {
2969        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2970        let expr = and(
2971            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2972            col("c2").gt(lit(5)),
2973        );
2974        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2975
2976        assert_eq!(simplify(expr), expected);
2977    }
2978
2979    #[test]
2980    fn test_simplify_negated_and() {
2981        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2982        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2983        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2984
2985        assert_eq!(simplify(expr), expected);
2986    }
2987
2988    #[test]
2989    fn test_simplify_or_and() {
2990        let l = col("c2").gt(lit(5));
2991        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2992
2993        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
2994        let expr = or(l.clone(), r.clone());
2995
2996        let expected = l.clone();
2997        assert_eq!(simplify(expr), expected);
2998
2999        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
3000        let expr = or(r, l);
3001        assert_eq!(simplify(expr), expected);
3002    }
3003
3004    #[test]
3005    fn test_simplify_or_and_non_null() {
3006        let l = col("c2_non_null").gt(lit(5));
3007        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3008
3009        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
3010        let expr = or(l.clone(), r.clone());
3011
3012        // This is only true if `c1 < 6` is not nullable / can not be null.
3013        let expected = col("c2_non_null").gt(lit(5));
3014
3015        assert_eq!(simplify(expr), expected);
3016
3017        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
3018        let expr = or(l, r);
3019
3020        assert_eq!(simplify(expr), expected);
3021    }
3022
3023    #[test]
3024    fn test_simplify_and_or() {
3025        let l = col("c2").gt(lit(5));
3026        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3027
3028        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3029        let expr = and(l.clone(), r.clone());
3030
3031        let expected = l.clone();
3032        assert_eq!(simplify(expr), expected);
3033
3034        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3035        let expr = and(r, l);
3036        assert_eq!(simplify(expr), expected);
3037    }
3038
3039    #[test]
3040    fn test_simplify_and_or_non_null() {
3041        let l = col("c2_non_null").gt(lit(5));
3042        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3043
3044        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3045        let expr = and(l.clone(), r.clone());
3046
3047        // This is only true if `c1 < 6` is not nullable / can not be null.
3048        let expected = col("c2_non_null").gt(lit(5));
3049
3050        assert_eq!(simplify(expr), expected);
3051
3052        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3053        let expr = and(l, r);
3054
3055        assert_eq!(simplify(expr), expected);
3056    }
3057
3058    #[test]
3059    fn test_simplify_by_de_morgan_laws() {
3060        // Laws with logical operations
3061        // !(c3 AND c4) --> !c3 OR !c4
3062        let expr = and(col("c3"), col("c4")).not();
3063        let expected = or(col("c3").not(), col("c4").not());
3064        assert_eq!(simplify(expr), expected);
3065        // !(c3 OR c4) --> !c3 AND !c4
3066        let expr = or(col("c3"), col("c4")).not();
3067        let expected = and(col("c3").not(), col("c4").not());
3068        assert_eq!(simplify(expr), expected);
3069        // !(!c3) --> c3
3070        let expr = col("c3").not().not();
3071        let expected = col("c3");
3072        assert_eq!(simplify(expr), expected);
3073
3074        // Laws with bitwise operations
3075        // !(c3 & c4) --> !c3 | !c4
3076        let expr = -bitwise_and(col("c3"), col("c4"));
3077        let expected = bitwise_or(-col("c3"), -col("c4"));
3078        assert_eq!(simplify(expr), expected);
3079        // !(c3 | c4) --> !c3 & !c4
3080        let expr = -bitwise_or(col("c3"), col("c4"));
3081        let expected = bitwise_and(-col("c3"), -col("c4"));
3082        assert_eq!(simplify(expr), expected);
3083        // !(!c3) --> c3
3084        let expr = -(-col("c3"));
3085        let expected = col("c3");
3086        assert_eq!(simplify(expr), expected);
3087    }
3088
3089    #[test]
3090    fn test_simplify_null_and_false() {
3091        let expr = and(lit_bool_null(), lit(false));
3092        let expr_eq = lit(false);
3093
3094        assert_eq!(simplify(expr), expr_eq);
3095    }
3096
3097    #[test]
3098    fn test_simplify_divide_null_by_null() {
3099        let null = lit(ScalarValue::Int32(None));
3100        let expr_plus = null.clone() / null.clone();
3101        let expr_eq = null;
3102
3103        assert_eq!(simplify(expr_plus), expr_eq);
3104    }
3105
3106    #[test]
3107    fn test_simplify_simplify_arithmetic_expr() {
3108        let expr_plus = lit(1) + lit(1);
3109
3110        assert_eq!(simplify(expr_plus), lit(2));
3111    }
3112
3113    #[test]
3114    fn test_simplify_simplify_eq_expr() {
3115        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
3116
3117        assert_eq!(simplify(expr_eq), lit(true));
3118    }
3119
3120    #[test]
3121    fn test_simplify_regex() {
3122        // malformed regex
3123        assert_contains!(
3124            try_simplify(regex_match(col("c1"), lit("foo{")))
3125                .unwrap_err()
3126                .to_string(),
3127            "regex parse error"
3128        );
3129
3130        // unsupported cases
3131        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3132        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3133        assert_no_change(regex_match(col("c1"), lit("%")));
3134        assert_no_change(regex_match(col("c1"), lit("_")));
3135        assert_no_change(regex_match(col("c1"), lit("f%o")));
3136        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3137        assert_no_change(regex_match(col("c1"), lit("f_o")));
3138
3139        // empty cases
3140        assert_change(
3141            regex_match(col("c1"), lit("")),
3142            if_not_null(col("c1"), true),
3143        );
3144        assert_change(
3145            regex_not_match(col("c1"), lit("")),
3146            if_not_null(col("c1"), false),
3147        );
3148        assert_change(
3149            regex_imatch(col("c1"), lit("")),
3150            if_not_null(col("c1"), true),
3151        );
3152        assert_change(
3153            regex_not_imatch(col("c1"), lit("")),
3154            if_not_null(col("c1"), false),
3155        );
3156
3157        // single character
3158        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3159
3160        // single word
3161        assert_change(
3162            regex_match(col("c1"), lit("foo")),
3163            col("c1").like(lit("%foo%")),
3164        );
3165
3166        // regular expressions that match an exact literal
3167        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3168        assert_change(
3169            regex_not_match(col("c1"), lit("^$")),
3170            col("c1").not_eq(lit("")),
3171        );
3172        assert_change(
3173            regex_match(col("c1"), lit("^foo$")),
3174            col("c1").eq(lit("foo")),
3175        );
3176        assert_change(
3177            regex_not_match(col("c1"), lit("^foo$")),
3178            col("c1").not_eq(lit("foo")),
3179        );
3180
3181        // regular expressions that match exact captured literals
3182        assert_change(
3183            regex_match(col("c1"), lit("^(foo|bar)$")),
3184            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3185        );
3186        assert_change(
3187            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3188            col("c1")
3189                .not_eq(lit("foo"))
3190                .and(col("c1").not_eq(lit("bar"))),
3191        );
3192        assert_change(
3193            regex_match(col("c1"), lit("^(foo)$")),
3194            col("c1").eq(lit("foo")),
3195        );
3196        assert_change(
3197            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3198            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3199                .or(col("c1").eq(lit("baz"))),
3200        );
3201        assert_change(
3202            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3203            col("c1")
3204                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3205        );
3206        assert_change(
3207            regex_match(col("c1"), lit("^(fo_o)$")),
3208            col("c1").eq(lit("fo_o")),
3209        );
3210        assert_change(
3211            regex_match(col("c1"), lit("^(fo_o)$")),
3212            col("c1").eq(lit("fo_o")),
3213        );
3214        assert_change(
3215            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3216            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3217        );
3218        assert_change(
3219            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3220            col("c1")
3221                .not_eq(lit("fo_o"))
3222                .and(col("c1").not_eq(lit("ba_r"))),
3223        );
3224        assert_change(
3225            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3226            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3227                .or(col("c1").eq(lit("ba_z"))),
3228        );
3229        assert_change(
3230            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3231            col("c1").in_list(
3232                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3233                false,
3234            ),
3235        );
3236
3237        // regular expressions that mismatch captured literals
3238        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3239        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3240        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3241        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3242        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3243        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3244        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3245        assert_no_change(regex_match(col("c1"), lit("^")));
3246        assert_no_change(regex_match(col("c1"), lit("$")));
3247        assert_no_change(regex_match(col("c1"), lit("$^")));
3248        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3249
3250        // regular expressions that match a partial literal
3251        assert_change(
3252            regex_match(col("c1"), lit("^foo")),
3253            col("c1").like(lit("foo%")),
3254        );
3255        assert_change(
3256            regex_match(col("c1"), lit("foo$")),
3257            col("c1").like(lit("%foo")),
3258        );
3259        assert_change(
3260            regex_match(col("c1"), lit("^foo|bar$")),
3261            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3262        );
3263
3264        // OR-chain
3265        assert_change(
3266            regex_match(col("c1"), lit("foo|bar|baz")),
3267            col("c1")
3268                .like(lit("%foo%"))
3269                .or(col("c1").like(lit("%bar%")))
3270                .or(col("c1").like(lit("%baz%"))),
3271        );
3272        assert_change(
3273            regex_match(col("c1"), lit("foo|x|baz")),
3274            col("c1")
3275                .like(lit("%foo%"))
3276                .or(col("c1").like(lit("%x%")))
3277                .or(col("c1").like(lit("%baz%"))),
3278        );
3279        assert_change(
3280            regex_not_match(col("c1"), lit("foo|bar|baz")),
3281            col("c1")
3282                .not_like(lit("%foo%"))
3283                .and(col("c1").not_like(lit("%bar%")))
3284                .and(col("c1").not_like(lit("%baz%"))),
3285        );
3286        // both anchored expressions (translated to equality) and unanchored
3287        assert_change(
3288            regex_match(col("c1"), lit("foo|^x$|baz")),
3289            col("c1")
3290                .like(lit("%foo%"))
3291                .or(col("c1").eq(lit("x")))
3292                .or(col("c1").like(lit("%baz%"))),
3293        );
3294        assert_change(
3295            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3296            col("c1")
3297                .not_like(lit("%foo%"))
3298                .and(col("c1").not_eq(lit("bar")))
3299                .and(col("c1").not_like(lit("%baz%"))),
3300        );
3301        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3302        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3303    }
3304
3305    #[track_caller]
3306    fn assert_no_change(expr: Expr) {
3307        let optimized = simplify(expr.clone());
3308        assert_eq!(expr, optimized);
3309    }
3310
3311    #[track_caller]
3312    fn assert_change(expr: Expr, expected: Expr) {
3313        let optimized = simplify(expr);
3314        assert_eq!(optimized, expected);
3315    }
3316
3317    fn regex_match(left: Expr, right: Expr) -> Expr {
3318        Expr::BinaryExpr(BinaryExpr {
3319            left: Box::new(left),
3320            op: Operator::RegexMatch,
3321            right: Box::new(right),
3322        })
3323    }
3324
3325    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3326        Expr::BinaryExpr(BinaryExpr {
3327            left: Box::new(left),
3328            op: Operator::RegexNotMatch,
3329            right: Box::new(right),
3330        })
3331    }
3332
3333    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3334        Expr::BinaryExpr(BinaryExpr {
3335            left: Box::new(left),
3336            op: Operator::RegexIMatch,
3337            right: Box::new(right),
3338        })
3339    }
3340
3341    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3342        Expr::BinaryExpr(BinaryExpr {
3343            left: Box::new(left),
3344            op: Operator::RegexNotIMatch,
3345            right: Box::new(right),
3346        })
3347    }
3348
3349    // ------------------------------
3350    // ----- Simplifier tests -------
3351    // ------------------------------
3352
3353    fn try_simplify(expr: Expr) -> Result<Expr> {
3354        let schema = expr_test_schema();
3355        let execution_props = ExecutionProps::new();
3356        let simplifier = ExprSimplifier::new(
3357            SimplifyContext::new(&execution_props).with_schema(schema),
3358        );
3359        simplifier.simplify(expr)
3360    }
3361
3362    fn coerce(expr: Expr) -> Expr {
3363        let schema = expr_test_schema();
3364        let execution_props = ExecutionProps::new();
3365        let simplifier = ExprSimplifier::new(
3366            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3367        );
3368        simplifier.coerce(expr, schema.as_ref()).unwrap()
3369    }
3370
3371    fn simplify(expr: Expr) -> Expr {
3372        try_simplify(expr).unwrap()
3373    }
3374
3375    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3376        let schema = expr_test_schema();
3377        let execution_props = ExecutionProps::new();
3378        let simplifier = ExprSimplifier::new(
3379            SimplifyContext::new(&execution_props).with_schema(schema),
3380        );
3381        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3382        Ok((expr.data, count))
3383    }
3384
3385    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3386        try_simplify_with_cycle_count(expr).unwrap()
3387    }
3388
3389    fn simplify_with_guarantee(
3390        expr: Expr,
3391        guarantees: Vec<(Expr, NullableInterval)>,
3392    ) -> Expr {
3393        let schema = expr_test_schema();
3394        let execution_props = ExecutionProps::new();
3395        let simplifier = ExprSimplifier::new(
3396            SimplifyContext::new(&execution_props).with_schema(schema),
3397        )
3398        .with_guarantees(guarantees);
3399        simplifier.simplify(expr).unwrap()
3400    }
3401
3402    fn expr_test_schema() -> DFSchemaRef {
3403        static EXPR_TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
3404            Arc::new(
3405                DFSchema::from_unqualified_fields(
3406                    vec![
3407                        Field::new("c1", DataType::Utf8, true),
3408                        Field::new("c2", DataType::Boolean, true),
3409                        Field::new("c3", DataType::Int64, true),
3410                        Field::new("c4", DataType::UInt32, true),
3411                        Field::new("c1_non_null", DataType::Utf8, false),
3412                        Field::new("c2_non_null", DataType::Boolean, false),
3413                        Field::new("c3_non_null", DataType::Int64, false),
3414                        Field::new("c4_non_null", DataType::UInt32, false),
3415                        Field::new("c5", DataType::FixedSizeBinary(3), true),
3416                    ]
3417                    .into(),
3418                    HashMap::new(),
3419                )
3420                .unwrap(),
3421            )
3422        });
3423        Arc::clone(&EXPR_TEST_SCHEMA)
3424    }
3425
3426    #[test]
3427    fn simplify_expr_null_comparison() {
3428        // x = null is always null
3429        assert_eq!(
3430            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3431            lit(ScalarValue::Boolean(None)),
3432        );
3433
3434        // null != null is always null
3435        assert_eq!(
3436            simplify(
3437                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3438            ),
3439            lit(ScalarValue::Boolean(None)),
3440        );
3441
3442        // x != null is always null
3443        assert_eq!(
3444            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3445            lit(ScalarValue::Boolean(None)),
3446        );
3447
3448        // null = x is always null
3449        assert_eq!(
3450            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3451            lit(ScalarValue::Boolean(None)),
3452        );
3453    }
3454
3455    #[test]
3456    fn simplify_expr_is_not_null() {
3457        assert_eq!(
3458            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3459            Expr::IsNotNull(Box::new(col("c1")))
3460        );
3461
3462        // 'c1_non_null IS NOT NULL' is always true
3463        assert_eq!(
3464            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3465            lit(true)
3466        );
3467    }
3468
3469    #[test]
3470    fn simplify_expr_is_null() {
3471        assert_eq!(
3472            simplify(Expr::IsNull(Box::new(col("c1")))),
3473            Expr::IsNull(Box::new(col("c1")))
3474        );
3475
3476        // 'c1_non_null IS NULL' is always false
3477        assert_eq!(
3478            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3479            lit(false)
3480        );
3481    }
3482
3483    #[test]
3484    fn simplify_expr_is_unknown() {
3485        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3486
3487        // 'c2_non_null is unknown' is always false
3488        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3489    }
3490
3491    #[test]
3492    fn simplify_expr_is_not_known() {
3493        assert_eq!(
3494            simplify(col("c2").is_not_unknown()),
3495            col("c2").is_not_unknown()
3496        );
3497
3498        // 'c2_non_null is not unknown' is always true
3499        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3500    }
3501
3502    #[test]
3503    fn simplify_expr_eq() {
3504        let schema = expr_test_schema();
3505        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3506
3507        // true = true -> true
3508        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3509
3510        // true = false -> false
3511        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3512
3513        // c2 = true -> c2
3514        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3515
3516        // c2 = false => !c2
3517        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3518    }
3519
3520    #[test]
3521    fn simplify_expr_eq_skip_nonboolean_type() {
3522        let schema = expr_test_schema();
3523
3524        // When one of the operand is not of boolean type, folding the
3525        // other boolean constant will change return type of
3526        // expression to non-boolean.
3527        //
3528        // Make sure c1 column to be used in tests is not boolean type
3529        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3530
3531        // don't fold c1 = foo
3532        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3533    }
3534
3535    #[test]
3536    fn simplify_expr_not_eq() {
3537        let schema = expr_test_schema();
3538
3539        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3540
3541        // c2 != true -> !c2
3542        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3543
3544        // c2 != false -> c2
3545        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3546
3547        // test constant
3548        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3549
3550        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3551    }
3552
3553    #[test]
3554    fn simplify_expr_not_eq_skip_nonboolean_type() {
3555        let schema = expr_test_schema();
3556
3557        // when one of the operand is not of boolean type, folding the
3558        // other boolean constant will change return type of
3559        // expression to non-boolean.
3560        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3561
3562        assert_eq!(
3563            simplify(col("c1").not_eq(lit("foo"))),
3564            col("c1").not_eq(lit("foo")),
3565        );
3566    }
3567
3568    #[test]
3569    fn simplify_literal_case_equality() {
3570        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
3571        let simple_case = Expr::Case(Case::new(
3572            None,
3573            vec![(
3574                Box::new(col("c2_non_null").not_eq(lit(false))),
3575                Box::new(lit("ok")),
3576            )],
3577            Some(Box::new(lit("not_ok"))),
3578        ));
3579
3580        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
3581        // -->
3582        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
3583        // -->
3584        // CASE WHEN c2 != false THEN true ELSE false
3585        // -->
3586        // c2
3587        assert_eq!(
3588            simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)),
3589            col("c2_non_null"),
3590        );
3591
3592        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
3593        // -->
3594        // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
3595        // -->
3596        // NOT(CASE WHEN c2 != false THEN true ELSE false)
3597        // -->
3598        // NOT(c2)
3599        assert_eq!(
3600            simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)),
3601            not(col("c2_non_null")),
3602        );
3603
3604        let complex_case = Expr::Case(Case::new(
3605            None,
3606            vec![
3607                (
3608                    Box::new(col("c1").eq(lit("inboxed"))),
3609                    Box::new(lit("pending")),
3610                ),
3611                (
3612                    Box::new(col("c1").eq(lit("scheduled"))),
3613                    Box::new(lit("pending")),
3614                ),
3615                (
3616                    Box::new(col("c1").eq(lit("completed"))),
3617                    Box::new(lit("completed")),
3618                ),
3619                (
3620                    Box::new(col("c1").eq(lit("paused"))),
3621                    Box::new(lit("paused")),
3622                ),
3623                (Box::new(col("c2")), Box::new(lit("running"))),
3624                (
3625                    Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))),
3626                    Box::new(lit("backing-off")),
3627                ),
3628            ],
3629            Some(Box::new(lit("ready"))),
3630        ));
3631
3632        assert_eq!(
3633            simplify(binary_expr(
3634                complex_case.clone(),
3635                Operator::Eq,
3636                lit("completed"),
3637            )),
3638            not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and(
3639                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3640                    .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3641            )
3642        );
3643
3644        assert_eq!(
3645            simplify(binary_expr(
3646                complex_case.clone(),
3647                Operator::NotEq,
3648                lit("completed"),
3649            )),
3650            distinct_from(col("c1").eq(lit("completed")), lit(true))
3651                .or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3652                    .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))))
3653        );
3654
3655        assert_eq!(
3656            simplify(binary_expr(
3657                complex_case.clone(),
3658                Operator::Eq,
3659                lit("running"),
3660            )),
3661            not_distinct_from(col("c2"), lit(true)).and(
3662                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3663                    .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3664                    .and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3665                    .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3666            )
3667        );
3668
3669        assert_eq!(
3670            simplify(binary_expr(
3671                complex_case.clone(),
3672                Operator::Eq,
3673                lit("ready"),
3674            )),
3675            distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3676                .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3677                .and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3678                .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3679                .and(distinct_from(col("c2"), lit(true)))
3680                .and(distinct_from(
3681                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3682                    lit(true)
3683                ))
3684        );
3685
3686        assert_eq!(
3687            simplify(binary_expr(
3688                complex_case.clone(),
3689                Operator::NotEq,
3690                lit("ready"),
3691            )),
3692            not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3693                .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3694                .or(not_distinct_from(col("c1").eq(lit("completed")), lit(true)))
3695                .or(not_distinct_from(col("c1").eq(lit("paused")), lit(true)))
3696                .or(not_distinct_from(col("c2"), lit(true)))
3697                .or(not_distinct_from(
3698                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3699                    lit(true)
3700                ))
3701        );
3702    }
3703
3704    #[test]
3705    fn simplify_expr_case_when_then_else() {
3706        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3707        // -->
3708        // CASE WHEN c2 THEN false ELSE c2
3709        // -->
3710        // false
3711        assert_eq!(
3712            simplify(Expr::Case(Case::new(
3713                None,
3714                vec![(
3715                    Box::new(col("c2_non_null").not_eq(lit(false))),
3716                    Box::new(lit("ok").eq(lit("not_ok"))),
3717                )],
3718                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3719            ))),
3720            lit(false) // #1716
3721        );
3722
3723        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3724        // -->
3725        // CASE WHEN c2 THEN true ELSE c2
3726        // -->
3727        // c2
3728        //
3729        // Need to call simplify 2x due to
3730        // https://github.com/apache/datafusion/issues/1160
3731        assert_eq!(
3732            simplify(simplify(Expr::Case(Case::new(
3733                None,
3734                vec![(
3735                    Box::new(col("c2_non_null").not_eq(lit(false))),
3736                    Box::new(lit("ok").eq(lit("ok"))),
3737                )],
3738                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3739            )))),
3740            col("c2_non_null")
3741        );
3742
3743        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3744        // -->
3745        // ISNULL(c2) OR c2
3746        //
3747        // Need to call simplify 2x due to
3748        // https://github.com/apache/datafusion/issues/1160
3749        assert_eq!(
3750            simplify(simplify(Expr::Case(Case::new(
3751                None,
3752                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3753                Some(Box::new(col("c2"))),
3754            )))),
3755            col("c2")
3756                .is_null()
3757                .or(col("c2").is_not_null().and(col("c2")))
3758        );
3759
3760        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3761        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3762        // --> c1 OR (NOT(c1) AND NOT(c2))
3763        // --> c1 OR NOT(c2)
3764        //
3765        // Need to call simplify 2x due to
3766        // https://github.com/apache/datafusion/issues/1160
3767        assert_eq!(
3768            simplify(simplify(Expr::Case(Case::new(
3769                None,
3770                vec![
3771                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3772                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3773                ],
3774                Some(Box::new(lit(true))),
3775            )))),
3776            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3777        );
3778
3779        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3780        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3781        // --> c1 OR (NOT(c1) AND c2)
3782        // --> c1 OR c2
3783        //
3784        // Need to call simplify 2x due to
3785        // https://github.com/apache/datafusion/issues/1160
3786        assert_eq!(
3787            simplify(simplify(Expr::Case(Case::new(
3788                None,
3789                vec![
3790                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3791                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3792                ],
3793                Some(Box::new(lit(true))),
3794            )))),
3795            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3796        );
3797
3798        // CASE WHEN c > 0 THEN true END AS c1
3799        assert_eq!(
3800            simplify(simplify(Expr::Case(Case::new(
3801                None,
3802                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3803                None,
3804            )))),
3805            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3806                col("c3").gt(lit(0_i64)),
3807                lit(true)
3808            )
3809            .and(lit_bool_null()))
3810        );
3811
3812        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3813        assert_eq!(
3814            simplify(simplify(Expr::Case(Case::new(
3815                None,
3816                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3817                Some(Box::new(lit(false))),
3818            )))),
3819            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3820        );
3821    }
3822
3823    #[test]
3824    fn simplify_expr_case_when_first_true() {
3825        // CASE WHEN true THEN 1 ELSE c1 END --> 1
3826        assert_eq!(
3827            simplify(Expr::Case(Case::new(
3828                None,
3829                vec![(Box::new(lit(true)), Box::new(lit(1)),)],
3830                Some(Box::new(col("c1"))),
3831            ))),
3832            lit(1)
3833        );
3834
3835        // CASE WHEN true THEN col('a') ELSE col('b') END --> col('a')
3836        assert_eq!(
3837            simplify(Expr::Case(Case::new(
3838                None,
3839                vec![(Box::new(lit(true)), Box::new(lit("a")),)],
3840                Some(Box::new(lit("b"))),
3841            ))),
3842            lit("a")
3843        );
3844
3845        // CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a')
3846        assert_eq!(
3847            simplify(Expr::Case(Case::new(
3848                None,
3849                vec![
3850                    (Box::new(lit(true)), Box::new(lit("a"))),
3851                    (Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))),
3852                ],
3853                Some(Box::new(lit("c"))),
3854            ))),
3855            lit("a")
3856        );
3857
3858        // CASE WHEN true THEN col('a') END --> col('a') (no else clause)
3859        assert_eq!(
3860            simplify(Expr::Case(Case::new(
3861                None,
3862                vec![(Box::new(lit(true)), Box::new(lit("a")),)],
3863                None,
3864            ))),
3865            lit("a")
3866        );
3867
3868        // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
3869        let expr = Expr::Case(Case::new(
3870            None,
3871            vec![(Box::new(col("c2")), Box::new(lit(1)))],
3872            Some(Box::new(lit(2))),
3873        ));
3874        assert_eq!(simplify(expr.clone()), expr);
3875
3876        // Negative test: CASE WHEN false THEN 1 ELSE 2 END should not use this rule
3877        let expr = Expr::Case(Case::new(
3878            None,
3879            vec![(Box::new(lit(false)), Box::new(lit(1)))],
3880            Some(Box::new(lit(2))),
3881        ));
3882        assert_ne!(simplify(expr), lit(1));
3883
3884        // Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified
3885        let expr = Expr::Case(Case::new(
3886            None,
3887            vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))],
3888            Some(Box::new(lit(2))),
3889        ));
3890        assert_eq!(simplify(expr.clone()), expr);
3891    }
3892
3893    #[test]
3894    fn simplify_expr_case_when_any_true() {
3895        // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
3896        assert_eq!(
3897            simplify(Expr::Case(Case::new(
3898                None,
3899                vec![
3900                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3901                    (Box::new(lit(true)), Box::new(lit("b"))),
3902                ],
3903                Some(Box::new(lit("c"))),
3904            ))),
3905            Expr::Case(Case::new(
3906                None,
3907                vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))],
3908                Some(Box::new(lit("b"))),
3909            ))
3910        );
3911
3912        // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
3913        // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
3914        assert_eq!(
3915            simplify(Expr::Case(Case::new(
3916                None,
3917                vec![
3918                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3919                    (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
3920                    (Box::new(lit(true)), Box::new(lit("c"))),
3921                    (Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))),
3922                ],
3923                Some(Box::new(lit("e"))),
3924            ))),
3925            Expr::Case(Case::new(
3926                None,
3927                vec![
3928                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3929                    (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
3930                ],
3931                Some(Box::new(lit("c"))),
3932            ))
3933        );
3934
3935        // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
3936        // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
3937        assert_eq!(
3938            simplify(Expr::Case(Case::new(
3939                None,
3940                vec![
3941                    (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3942                    (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3943                    (Box::new(lit(true)), Box::new(lit(3))),
3944                ],
3945                None,
3946            ))),
3947            Expr::Case(Case::new(
3948                None,
3949                vec![
3950                    (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3951                    (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3952                ],
3953                Some(Box::new(lit(3))),
3954            ))
3955        );
3956
3957        // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
3958        let expr = Expr::Case(Case::new(
3959            None,
3960            vec![
3961                (Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))),
3962                (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3963            ],
3964            Some(Box::new(lit(3))),
3965        ));
3966        assert_eq!(simplify(expr.clone()), expr);
3967    }
3968
3969    #[test]
3970    fn simplify_expr_case_when_any_false() {
3971        // CASE WHEN false THEN 'a' END --> NULL
3972        assert_eq!(
3973            simplify(Expr::Case(Case::new(
3974                None,
3975                vec![(Box::new(lit(false)), Box::new(lit("a")))],
3976                None,
3977            ))),
3978            Expr::Literal(ScalarValue::Utf8(None), None)
3979        );
3980
3981        // CASE WHEN false THEN 2 ELSE 1 END --> 1
3982        assert_eq!(
3983            simplify(Expr::Case(Case::new(
3984                None,
3985                vec![(Box::new(lit(false)), Box::new(lit(2)))],
3986                Some(Box::new(lit(1))),
3987            ))),
3988            lit(1),
3989        );
3990
3991        // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
3992        assert_eq!(
3993            simplify(Expr::Case(Case::new(
3994                None,
3995                vec![
3996                    (Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))),
3997                    (Box::new(lit(false)), Box::new(col("c3"))),
3998                ],
3999                Some(Box::new(col("c4"))),
4000            ))),
4001            Expr::Case(Case::new(
4002                None,
4003                vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))],
4004                Some(Box::new(col("c4"))),
4005            ))
4006        );
4007
4008        // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
4009        let expr = Expr::Case(Case::new(
4010            None,
4011            vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))],
4012            Some(Box::new(lit(2))),
4013        ));
4014        assert_eq!(simplify(expr.clone()), expr);
4015    }
4016
4017    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
4018        Expr::BinaryExpr(BinaryExpr {
4019            left: Box::new(left.into()),
4020            op: Operator::IsDistinctFrom,
4021            right: Box::new(right.into()),
4022        })
4023    }
4024
4025    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
4026        Expr::BinaryExpr(BinaryExpr {
4027            left: Box::new(left.into()),
4028            op: Operator::IsNotDistinctFrom,
4029            right: Box::new(right.into()),
4030        })
4031    }
4032
4033    #[test]
4034    fn simplify_expr_bool_or() {
4035        // col || true is always true
4036        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
4037
4038        // col || false is always col
4039        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
4040
4041        // true || null is always true
4042        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
4043
4044        // null || true is always true
4045        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
4046
4047        // false || null is always null
4048        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
4049
4050        // null || false is always null
4051        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
4052
4053        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
4054        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
4055        // and should not be rewritten
4056        let expr = col("c1").between(lit(0), lit(10));
4057        let expr = expr.or(lit_bool_null());
4058        let result = simplify(expr);
4059
4060        let expected_expr = or(
4061            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
4062            lit_bool_null(),
4063        );
4064        assert_eq!(expected_expr, result);
4065    }
4066
4067    #[test]
4068    fn simplify_inlist() {
4069        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
4070        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
4071
4072        // null in (...)  --> null
4073        assert_eq!(
4074            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
4075            lit_bool_null()
4076        );
4077
4078        // null not in (...)  --> null
4079        assert_eq!(
4080            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
4081            lit_bool_null()
4082        );
4083
4084        assert_eq!(
4085            simplify(in_list(col("c1"), vec![lit(1)], false)),
4086            col("c1").eq(lit(1))
4087        );
4088        assert_eq!(
4089            simplify(in_list(col("c1"), vec![lit(1)], true)),
4090            col("c1").not_eq(lit(1))
4091        );
4092
4093        // more complex expressions can be simplified if list contains
4094        // one element only
4095        assert_eq!(
4096            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
4097            (col("c1") * lit(10)).eq(lit(2))
4098        );
4099
4100        assert_eq!(
4101            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
4102            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
4103        );
4104        assert_eq!(
4105            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
4106            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
4107        );
4108
4109        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
4110        assert_eq!(
4111            simplify(in_list(
4112                col("c1"),
4113                vec![scalar_subquery(Arc::clone(&subquery))],
4114                false
4115            )),
4116            in_subquery(col("c1"), Arc::clone(&subquery))
4117        );
4118        assert_eq!(
4119            simplify(in_list(
4120                col("c1"),
4121                vec![scalar_subquery(Arc::clone(&subquery))],
4122                true
4123            )),
4124            not_in_subquery(col("c1"), subquery)
4125        );
4126
4127        let subquery1 =
4128            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
4129        let subquery2 =
4130            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
4131
4132        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
4133        assert_eq!(
4134            simplify(in_list(
4135                col("c1"),
4136                vec![subquery1.clone(), subquery2.clone()],
4137                true
4138            )),
4139            col("c1")
4140                .not_eq(subquery1.clone())
4141                .and(col("c1").not_eq(subquery2.clone()))
4142        );
4143
4144        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
4145        assert_eq!(
4146            simplify(in_list(
4147                col("c1"),
4148                vec![subquery1.clone(), subquery2.clone()],
4149                false
4150            )),
4151            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
4152        );
4153
4154        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
4155        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4156            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
4157        );
4158        assert_eq!(simplify(expr), lit(false));
4159
4160        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
4161        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4162            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
4163        );
4164        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
4165
4166        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
4167        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
4168            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
4169        );
4170        assert_eq!(simplify(expr), lit(true));
4171
4172        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
4173        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
4174            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
4175        );
4176        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
4177
4178        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
4179        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
4180            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
4181        );
4182        assert_eq!(
4183            simplify(expr),
4184            in_list(
4185                col("c1"),
4186                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
4187                true
4188            )
4189        );
4190
4191        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
4192        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
4193            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
4194        );
4195        assert_eq!(
4196            simplify(expr),
4197            in_list(
4198                col("c1"),
4199                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4200                false
4201            )
4202        );
4203
4204        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
4205        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
4206            col("c1"),
4207            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4208            true,
4209        ));
4210        assert_eq!(simplify(expr), lit(false));
4211
4212        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
4213        let expr =
4214            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
4215                col("c1"),
4216                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4217                false,
4218            ));
4219        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
4220
4221        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
4222        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4223            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
4224        );
4225        assert_eq!(
4226            simplify(expr),
4227            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
4228        );
4229
4230        // inlist with more than two expressions
4231        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
4232        let expr = in_list(
4233            col("c1"),
4234            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
4235            false,
4236        )
4237        .and(in_list(
4238            col("c1"),
4239            vec![lit(1), lit(3), lit(5), lit(6)],
4240            false,
4241        ))
4242        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
4243        assert_eq!(
4244            simplify(expr),
4245            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
4246        );
4247
4248        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
4249        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
4250            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
4251                .and(in_list(
4252                    col("c1"),
4253                    vec![lit(3), lit(4), lit(5), lit(6)],
4254                    true,
4255                ))
4256                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
4257        );
4258        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
4259
4260        // Contains non-InList expression
4261        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
4262        let expr =
4263            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
4264                .not_eq(lit(5))
4265                .or(in_list(
4266                    col("c1"),
4267                    vec![lit(6), lit(7), lit(8), lit(9)],
4268                    true,
4269                )));
4270        // TODO: Further simplify this expression
4271        // https://github.com/apache/datafusion/issues/8970
4272        // assert_eq!(simplify(expr.clone()), lit(true));
4273        assert_eq!(simplify(expr.clone()), expr);
4274    }
4275
4276    #[test]
4277    fn simplify_null_in_empty_inlist() {
4278        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
4279        let expr = in_list(lit_bool_null(), vec![], false);
4280        assert_eq!(simplify(expr), lit(false));
4281
4282        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
4283        let expr = in_list(lit_bool_null(), vec![], true);
4284        assert_eq!(simplify(expr), lit(true));
4285
4286        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
4287        let null_null = || Expr::Literal(ScalarValue::Null, None);
4288        let expr = in_list(null_null(), vec![], false);
4289        assert_eq!(simplify(expr), lit(false));
4290
4291        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
4292        let expr = in_list(null_null(), vec![], true);
4293        assert_eq!(simplify(expr), lit(true));
4294    }
4295
4296    #[test]
4297    fn just_simplifier_simplify_null_in_empty_inlist() {
4298        let simplify = |expr: Expr| -> Expr {
4299            let schema = expr_test_schema();
4300            let execution_props = ExecutionProps::new();
4301            let info = SimplifyContext::new(&execution_props).with_schema(schema);
4302            let simplifier = &mut Simplifier::new(&info);
4303            expr.rewrite(simplifier)
4304                .expect("Failed to simplify expression")
4305                .data
4306        };
4307
4308        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
4309        let expr = in_list(lit_bool_null(), vec![], false);
4310        assert_eq!(simplify(expr), lit(false));
4311
4312        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
4313        let expr = in_list(lit_bool_null(), vec![], true);
4314        assert_eq!(simplify(expr), lit(true));
4315
4316        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
4317        let null_null = || Expr::Literal(ScalarValue::Null, None);
4318        let expr = in_list(null_null(), vec![], false);
4319        assert_eq!(simplify(expr), lit(false));
4320
4321        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
4322        let expr = in_list(null_null(), vec![], true);
4323        assert_eq!(simplify(expr), lit(true));
4324    }
4325
4326    #[test]
4327    fn simplify_large_or() {
4328        let expr = (0..5)
4329            .map(|i| col("c1").eq(lit(i)))
4330            .fold(lit(false), |acc, e| acc.or(e));
4331        assert_eq!(
4332            simplify(expr),
4333            in_list(col("c1"), (0..5).map(lit).collect(), false),
4334        );
4335    }
4336
4337    #[test]
4338    fn simplify_expr_bool_and() {
4339        // col & true is always col
4340        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
4341        // col & false is always false
4342        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
4343
4344        // true && null is always null
4345        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
4346
4347        // null && true is always null
4348        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
4349
4350        // false && null is always false
4351        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
4352
4353        // null && false is always false
4354        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
4355
4356        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
4357        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
4358        // and the Boolean(NULL) should remain
4359        let expr = col("c1").between(lit(0), lit(10));
4360        let expr = expr.and(lit_bool_null());
4361        let result = simplify(expr);
4362
4363        let expected_expr = and(
4364            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
4365            lit_bool_null(),
4366        );
4367        assert_eq!(expected_expr, result);
4368    }
4369
4370    #[test]
4371    fn simplify_expr_between() {
4372        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
4373        let expr = col("c2").between(lit(3), lit(4));
4374        assert_eq!(
4375            simplify(expr),
4376            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
4377        );
4378
4379        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
4380        let expr = col("c2").not_between(lit(3), lit(4));
4381        assert_eq!(
4382            simplify(expr),
4383            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
4384        );
4385    }
4386
4387    #[test]
4388    fn test_like_and_ilike() {
4389        let null = lit(ScalarValue::Utf8(None));
4390
4391        // expr [NOT] [I]LIKE NULL
4392        let expr = col("c1").like(null.clone());
4393        assert_eq!(simplify(expr), lit_bool_null());
4394
4395        let expr = col("c1").not_like(null.clone());
4396        assert_eq!(simplify(expr), lit_bool_null());
4397
4398        let expr = col("c1").ilike(null.clone());
4399        assert_eq!(simplify(expr), lit_bool_null());
4400
4401        let expr = col("c1").not_ilike(null.clone());
4402        assert_eq!(simplify(expr), lit_bool_null());
4403
4404        // expr [NOT] [I]LIKE '%'
4405        let expr = col("c1").like(lit("%"));
4406        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4407
4408        let expr = col("c1").not_like(lit("%"));
4409        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4410
4411        let expr = col("c1").ilike(lit("%"));
4412        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4413
4414        let expr = col("c1").not_ilike(lit("%"));
4415        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4416
4417        // expr [NOT] [I]LIKE '%%'
4418        let expr = col("c1").like(lit("%%"));
4419        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4420
4421        let expr = col("c1").not_like(lit("%%"));
4422        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4423
4424        let expr = col("c1").ilike(lit("%%"));
4425        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4426
4427        let expr = col("c1").not_ilike(lit("%%"));
4428        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4429
4430        // not_null_expr [NOT] [I]LIKE '%'
4431        let expr = col("c1_non_null").like(lit("%"));
4432        assert_eq!(simplify(expr), lit(true));
4433
4434        let expr = col("c1_non_null").not_like(lit("%"));
4435        assert_eq!(simplify(expr), lit(false));
4436
4437        let expr = col("c1_non_null").ilike(lit("%"));
4438        assert_eq!(simplify(expr), lit(true));
4439
4440        let expr = col("c1_non_null").not_ilike(lit("%"));
4441        assert_eq!(simplify(expr), lit(false));
4442
4443        // not_null_expr [NOT] [I]LIKE '%%'
4444        let expr = col("c1_non_null").like(lit("%%"));
4445        assert_eq!(simplify(expr), lit(true));
4446
4447        let expr = col("c1_non_null").not_like(lit("%%"));
4448        assert_eq!(simplify(expr), lit(false));
4449
4450        let expr = col("c1_non_null").ilike(lit("%%"));
4451        assert_eq!(simplify(expr), lit(true));
4452
4453        let expr = col("c1_non_null").not_ilike(lit("%%"));
4454        assert_eq!(simplify(expr), lit(false));
4455
4456        // null_constant [NOT] [I]LIKE '%'
4457        let expr = null.clone().like(lit("%"));
4458        assert_eq!(simplify(expr), lit_bool_null());
4459
4460        let expr = null.clone().not_like(lit("%"));
4461        assert_eq!(simplify(expr), lit_bool_null());
4462
4463        let expr = null.clone().ilike(lit("%"));
4464        assert_eq!(simplify(expr), lit_bool_null());
4465
4466        let expr = null.clone().not_ilike(lit("%"));
4467        assert_eq!(simplify(expr), lit_bool_null());
4468
4469        // null_constant [NOT] [I]LIKE '%%'
4470        let expr = null.clone().like(lit("%%"));
4471        assert_eq!(simplify(expr), lit_bool_null());
4472
4473        let expr = null.clone().not_like(lit("%%"));
4474        assert_eq!(simplify(expr), lit_bool_null());
4475
4476        let expr = null.clone().ilike(lit("%%"));
4477        assert_eq!(simplify(expr), lit_bool_null());
4478
4479        let expr = null.clone().not_ilike(lit("%%"));
4480        assert_eq!(simplify(expr), lit_bool_null());
4481
4482        // null_constant [NOT] [I]LIKE 'a%'
4483        let expr = null.clone().like(lit("a%"));
4484        assert_eq!(simplify(expr), lit_bool_null());
4485
4486        let expr = null.clone().not_like(lit("a%"));
4487        assert_eq!(simplify(expr), lit_bool_null());
4488
4489        let expr = null.clone().ilike(lit("a%"));
4490        assert_eq!(simplify(expr), lit_bool_null());
4491
4492        let expr = null.clone().not_ilike(lit("a%"));
4493        assert_eq!(simplify(expr), lit_bool_null());
4494
4495        // expr [NOT] [I]LIKE with pattern without wildcards
4496        let expr = col("c1").like(lit("a"));
4497        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4498        let expr = col("c1").not_like(lit("a"));
4499        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4500        let expr = col("c1").like(lit("a_"));
4501        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4502        let expr = col("c1").not_like(lit("a_"));
4503        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4504
4505        let expr = col("c1").ilike(lit("a"));
4506        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4507        let expr = col("c1").not_ilike(lit("a"));
4508        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4509    }
4510
4511    #[test]
4512    fn test_simplify_with_guarantee() {
4513        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4514        let expr_x = col("c3").gt(lit(3_i64));
4515        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4516        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4517        let expr = expr_x.clone().and(expr_y.or(expr_z));
4518
4519        // All guaranteed null
4520        let guarantees = vec![
4521            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4522            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4523            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4524        ];
4525
4526        let output = simplify_with_guarantee(expr.clone(), guarantees);
4527        assert_eq!(output, lit_bool_null());
4528
4529        // All guaranteed false
4530        let guarantees = vec![
4531            (
4532                col("c3"),
4533                NullableInterval::NotNull {
4534                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4535                },
4536            ),
4537            (
4538                col("c4"),
4539                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4540            ),
4541            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4542        ];
4543        let output = simplify_with_guarantee(expr.clone(), guarantees);
4544        assert_eq!(output, lit(false));
4545
4546        // Guaranteed false or null -> no change.
4547        let guarantees = vec![
4548            (
4549                col("c3"),
4550                NullableInterval::MaybeNull {
4551                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4552                },
4553            ),
4554            (
4555                col("c4"),
4556                NullableInterval::MaybeNull {
4557                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4558                },
4559            ),
4560            (
4561                col("c1"),
4562                NullableInterval::NotNull {
4563                    values: Interval::try_new(
4564                        ScalarValue::from("d"),
4565                        ScalarValue::from("f"),
4566                    )
4567                    .unwrap(),
4568                },
4569            ),
4570        ];
4571        let output = simplify_with_guarantee(expr.clone(), guarantees);
4572        assert_eq!(&output, &expr_x);
4573
4574        // Sufficient true guarantees
4575        let guarantees = vec![
4576            (
4577                col("c3"),
4578                NullableInterval::from(ScalarValue::Int64(Some(9))),
4579            ),
4580            (
4581                col("c4"),
4582                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4583            ),
4584        ];
4585        let output = simplify_with_guarantee(expr.clone(), guarantees);
4586        assert_eq!(output, lit(true));
4587
4588        // Only partially simplify
4589        let guarantees = vec![(
4590            col("c4"),
4591            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4592        )];
4593        let output = simplify_with_guarantee(expr, guarantees);
4594        assert_eq!(&output, &expr_x);
4595    }
4596
4597    #[test]
4598    fn test_expression_partial_simplify_1() {
4599        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4600        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4601        let expected = (lit(3)) + (lit(4) / lit(0));
4602
4603        assert_eq!(simplify(expr), expected);
4604    }
4605
4606    #[test]
4607    fn test_expression_partial_simplify_2() {
4608        // (1 > 2) and (4 / 0) -> false
4609        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4610        let expected = lit(false);
4611
4612        assert_eq!(simplify(expr), expected);
4613    }
4614
4615    #[test]
4616    fn test_simplify_cycles() {
4617        // TRUE
4618        let expr = lit(true);
4619        let expected = lit(true);
4620        let (expr, num_iter) = simplify_with_cycle_count(expr);
4621        assert_eq!(expr, expected);
4622        assert_eq!(num_iter, 1);
4623
4624        // (true != NULL) OR (5 > 10)
4625        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4626        let expected = lit_bool_null();
4627        let (expr, num_iter) = simplify_with_cycle_count(expr);
4628        assert_eq!(expr, expected);
4629        assert_eq!(num_iter, 2);
4630
4631        // NOTE: this currently does not simplify
4632        // (((c4 - 10) + 10) *100) / 100
4633        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4634        let expected = expr.clone();
4635        let (expr, num_iter) = simplify_with_cycle_count(expr);
4636        assert_eq!(expr, expected);
4637        assert_eq!(num_iter, 1);
4638
4639        // ((c4<1 or c3<2) and c3_non_null<3) and false
4640        let expr = col("c4")
4641            .lt(lit(1))
4642            .or(col("c3").lt(lit(2)))
4643            .and(col("c3_non_null").lt(lit(3)))
4644            .and(lit(false));
4645        let expected = lit(false);
4646        let (expr, num_iter) = simplify_with_cycle_count(expr);
4647        assert_eq!(expr, expected);
4648        assert_eq!(num_iter, 2);
4649    }
4650
4651    fn boolean_test_schema() -> DFSchemaRef {
4652        static BOOLEAN_TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
4653            Schema::new(vec![
4654                Field::new("A", DataType::Boolean, false),
4655                Field::new("B", DataType::Boolean, false),
4656                Field::new("C", DataType::Boolean, false),
4657                Field::new("D", DataType::Boolean, false),
4658            ])
4659            .to_dfschema_ref()
4660            .unwrap()
4661        });
4662        Arc::clone(&BOOLEAN_TEST_SCHEMA)
4663    }
4664
4665    #[test]
4666    fn simplify_common_factor_conjunction_in_disjunction() {
4667        let props = ExecutionProps::new();
4668        let schema = boolean_test_schema();
4669        let simplifier =
4670            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4671
4672        let a = || col("A");
4673        let b = || col("B");
4674        let c = || col("C");
4675        let d = || col("D");
4676
4677        // (A AND B) OR (A AND C) -> A AND (B OR C)
4678        let expr = a().and(b()).or(a().and(c()));
4679        let expected = a().and(b().or(c()));
4680
4681        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4682
4683        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4684        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4685        let expected = a().and(b().or(c()).or(d()));
4686        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4687
4688        // A OR (B AND C AND A) -> A
4689        let expr = a().or(b().and(c().and(a())));
4690        let expected = a();
4691        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4692    }
4693
4694    #[test]
4695    fn test_simplify_udaf() {
4696        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4697        let aggregate_function_expr =
4698            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4699                udaf.into(),
4700                vec![],
4701                false,
4702                None,
4703                vec![],
4704                None,
4705            ));
4706
4707        let expected = col("result_column");
4708        assert_eq!(simplify(aggregate_function_expr), expected);
4709
4710        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4711        let aggregate_function_expr =
4712            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4713                udaf.into(),
4714                vec![],
4715                false,
4716                None,
4717                vec![],
4718                None,
4719            ));
4720
4721        let expected = aggregate_function_expr.clone();
4722        assert_eq!(simplify(aggregate_function_expr), expected);
4723    }
4724
4725    /// A Mock UDAF which defines `simplify` to be used in tests
4726    /// related to UDAF simplification
4727    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4728    struct SimplifyMockUdaf {
4729        simplify: bool,
4730    }
4731
4732    impl SimplifyMockUdaf {
4733        /// make simplify method return new expression
4734        fn new_with_simplify() -> Self {
4735            Self { simplify: true }
4736        }
4737        /// make simplify method return no change
4738        fn new_without_simplify() -> Self {
4739            Self { simplify: false }
4740        }
4741    }
4742
4743    impl AggregateUDFImpl for SimplifyMockUdaf {
4744        fn as_any(&self) -> &dyn std::any::Any {
4745            self
4746        }
4747
4748        fn name(&self) -> &str {
4749            "mock_simplify"
4750        }
4751
4752        fn signature(&self) -> &Signature {
4753            unimplemented!()
4754        }
4755
4756        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4757            unimplemented!("not needed for tests")
4758        }
4759
4760        fn accumulator(
4761            &self,
4762            _acc_args: AccumulatorArgs,
4763        ) -> Result<Box<dyn Accumulator>> {
4764            unimplemented!("not needed for tests")
4765        }
4766
4767        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4768            unimplemented!("not needed for testing")
4769        }
4770
4771        fn create_groups_accumulator(
4772            &self,
4773            _args: AccumulatorArgs,
4774        ) -> Result<Box<dyn GroupsAccumulator>> {
4775            unimplemented!("not needed for testing")
4776        }
4777
4778        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4779            if self.simplify {
4780                Some(Box::new(|_, _| Ok(col("result_column"))))
4781            } else {
4782                None
4783            }
4784        }
4785    }
4786
4787    #[test]
4788    fn test_simplify_udwf() {
4789        let udwf = WindowFunctionDefinition::WindowUDF(
4790            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4791        );
4792        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4793
4794        let expected = col("result_column");
4795        assert_eq!(simplify(window_function_expr), expected);
4796
4797        let udwf = WindowFunctionDefinition::WindowUDF(
4798            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4799        );
4800        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4801
4802        let expected = window_function_expr.clone();
4803        assert_eq!(simplify(window_function_expr), expected);
4804    }
4805
4806    /// A Mock UDWF which defines `simplify` to be used in tests
4807    /// related to UDWF simplification
4808    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4809    struct SimplifyMockUdwf {
4810        simplify: bool,
4811    }
4812
4813    impl SimplifyMockUdwf {
4814        /// make simplify method return new expression
4815        fn new_with_simplify() -> Self {
4816            Self { simplify: true }
4817        }
4818        /// make simplify method return no change
4819        fn new_without_simplify() -> Self {
4820            Self { simplify: false }
4821        }
4822    }
4823
4824    impl WindowUDFImpl for SimplifyMockUdwf {
4825        fn as_any(&self) -> &dyn std::any::Any {
4826            self
4827        }
4828
4829        fn name(&self) -> &str {
4830            "mock_simplify"
4831        }
4832
4833        fn signature(&self) -> &Signature {
4834            unimplemented!()
4835        }
4836
4837        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4838            if self.simplify {
4839                Some(Box::new(|_, _| Ok(col("result_column"))))
4840            } else {
4841                None
4842            }
4843        }
4844
4845        fn partition_evaluator(
4846            &self,
4847            _partition_evaluator_args: PartitionEvaluatorArgs,
4848        ) -> Result<Box<dyn PartitionEvaluator>> {
4849            unimplemented!("not needed for tests")
4850        }
4851
4852        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4853            unimplemented!("not needed for tests")
4854        }
4855
4856        fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
4857            LimitEffect::Unknown
4858        }
4859    }
4860    #[derive(Debug, PartialEq, Eq, Hash)]
4861    struct VolatileUdf {
4862        signature: Signature,
4863    }
4864
4865    impl VolatileUdf {
4866        pub fn new() -> Self {
4867            Self {
4868                signature: Signature::exact(vec![], Volatility::Volatile),
4869            }
4870        }
4871    }
4872    impl ScalarUDFImpl for VolatileUdf {
4873        fn as_any(&self) -> &dyn std::any::Any {
4874            self
4875        }
4876
4877        fn name(&self) -> &str {
4878            "VolatileUdf"
4879        }
4880
4881        fn signature(&self) -> &Signature {
4882            &self.signature
4883        }
4884
4885        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4886            Ok(DataType::Int16)
4887        }
4888
4889        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4890            panic!("dummy - not implemented")
4891        }
4892    }
4893
4894    #[test]
4895    fn test_optimize_volatile_conditions() {
4896        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4897        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4898        {
4899            let expr = rand
4900                .clone()
4901                .eq(lit(0))
4902                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4903
4904            assert_eq!(simplify(expr.clone()), expr);
4905        }
4906
4907        {
4908            let expr = col("column1")
4909                .eq(lit(2))
4910                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4911
4912            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4913        }
4914
4915        {
4916            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4917                "column1",
4918            )
4919            .eq(lit(2))
4920            .and(rand.clone().eq(lit(0))));
4921
4922            assert_eq!(
4923                simplify(expr),
4924                col("column1")
4925                    .eq(lit(2))
4926                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4927            );
4928        }
4929    }
4930
4931    #[test]
4932    fn simplify_fixed_size_binary_eq_lit() {
4933        let bytes = [1u8, 2, 3].as_slice();
4934
4935        // The expression starts simple.
4936        let expr = col("c5").eq(lit(bytes));
4937
4938        // The type coercer introduces a cast.
4939        let coerced = coerce(expr.clone());
4940        let schema = expr_test_schema();
4941        assert_eq!(
4942            coerced,
4943            col("c5")
4944                .cast_to(&DataType::Binary, schema.as_ref())
4945                .unwrap()
4946                .eq(lit(bytes))
4947        );
4948
4949        // The simplifier removes the cast.
4950        assert_eq!(
4951            simplify(coerced),
4952            col("c5").eq(Expr::Literal(
4953                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4954                None
4955            ))
4956        );
4957    }
4958
4959    #[test]
4960    fn simplify_cast_literal() {
4961        // Test that CAST(literal) expressions are evaluated at plan time
4962
4963        // CAST(123 AS Int64) should become 123i64
4964        let expr = Expr::Cast(Cast::new(Box::new(lit(123i32)), DataType::Int64));
4965        let expected = lit(123i64);
4966        assert_eq!(simplify(expr), expected);
4967
4968        // CAST(1761630189642 AS Timestamp(Nanosecond, Some("+00:00")))
4969        // Integer to timestamp cast
4970        let expr = Expr::Cast(Cast::new(
4971            Box::new(lit(1761630189642i64)),
4972            DataType::Timestamp(
4973                arrow::datatypes::TimeUnit::Nanosecond,
4974                Some("+00:00".into()),
4975            ),
4976        ));
4977        // Should evaluate to a timestamp literal
4978        let result = simplify(expr);
4979        match result {
4980            Expr::Literal(ScalarValue::TimestampNanosecond(Some(val), tz), _) => {
4981                assert_eq!(val, 1761630189642i64);
4982                assert_eq!(tz.as_deref(), Some("+00:00"));
4983            }
4984            other => panic!("Expected TimestampNanosecond literal, got: {other:?}"),
4985        }
4986
4987        // Test CAST of invalid string to timestamp - should return an error at plan time
4988        // This represents the case from the issue: CAST(Utf8("1761630189642") AS Timestamp)
4989        // "1761630189642" is NOT a valid timestamp string format
4990        let expr = Expr::Cast(Cast::new(
4991            Box::new(lit("1761630189642")),
4992            DataType::Timestamp(
4993                arrow::datatypes::TimeUnit::Nanosecond,
4994                Some("+00:00".into()),
4995            ),
4996        ));
4997
4998        // The simplification should now fail with an error at plan time
4999        let schema = test_schema();
5000        let props = ExecutionProps::new();
5001        let simplifier =
5002            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
5003        let result = simplifier.simplify(expr);
5004        assert!(result.is_err(), "Expected error for invalid cast");
5005        let err_msg = result.unwrap_err().to_string();
5006        assert_contains!(err_msg, "Error parsing timestamp");
5007    }
5008
5009    fn if_not_null(expr: Expr, then: bool) -> Expr {
5010        Expr::Case(Case {
5011            expr: Some(expr.is_not_null().into()),
5012            when_then_expr: vec![(lit(true).into(), lit(then).into())],
5013            else_expr: None,
5014        })
5015    }
5016}