datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use arrow::{
21    array::{AsArray, new_null_array},
22    datatypes::{DataType, Field, Schema},
23    record_batch::RecordBatch,
24};
25use std::borrow::Cow;
26use std::collections::HashSet;
27use std::ops::Not;
28use std::sync::Arc;
29
30use datafusion_common::{
31    DFSchema, DataFusionError, Result, ScalarValue, exec_datafusion_err, internal_err,
32};
33use datafusion_common::{
34    HashMap,
35    cast::{as_large_list_array, as_list_array},
36    metadata::FieldMetadata,
37    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
38};
39use datafusion_expr::{
40    BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and,
41    binary::BinaryTypeCoercer, lit, or,
42};
43use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult};
44use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
45use datafusion_expr::{
46    expr::{InList, InSubquery},
47    utils::{iter_conjunction, iter_conjunction_owned},
48};
49use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
50
51use super::inlist_simplifier::ShortenInListSimplifier;
52use super::utils::*;
53use crate::analyzer::type_coercion::TypeCoercionRewriter;
54use crate::simplify_expressions::SimplifyInfo;
55use crate::simplify_expressions::regex::simplify_regex_expr;
56use crate::simplify_expressions::unwrap_cast::{
57    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
58    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
59    unwrap_cast_in_comparison_for_binary,
60};
61use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map;
62use datafusion_expr_common::casts::try_cast_literal_to_type;
63use indexmap::IndexSet;
64use regex::Regex;
65
66/// This structure handles API for expression simplification
67///
68/// Provides simplification information based on DFSchema and
69/// [`ExecutionProps`]. This is the default implementation used by DataFusion
70///
71/// For example:
72/// ```
73/// use arrow::datatypes::{DataType, Field, Schema};
74/// use datafusion_common::{DataFusionError, ToDFSchema};
75/// use datafusion_expr::execution_props::ExecutionProps;
76/// use datafusion_expr::simplify::SimplifyContext;
77/// use datafusion_expr::{col, lit};
78/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
79///
80/// // Create the schema
81/// let schema = Schema::new(vec![Field::new("i", DataType::Int64, false)])
82///     .to_dfschema_ref()
83///     .unwrap();
84///
85/// // Create the simplifier
86/// let props = ExecutionProps::new();
87/// let context = SimplifyContext::new(&props).with_schema(schema);
88/// let simplifier = ExprSimplifier::new(context);
89///
90/// // Use the simplifier
91///
92/// // b < 2 or (1 > 3)
93/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
94///
95/// // b < 2
96/// let simplified = simplifier.simplify(expr).unwrap();
97/// assert_eq!(simplified, col("b").lt(lit(2)));
98/// ```
99pub struct ExprSimplifier<S> {
100    info: S,
101    /// Guarantees about the values of columns. This is provided by the user
102    /// in [ExprSimplifier::with_guarantees()].
103    guarantees: Vec<(Expr, NullableInterval)>,
104    /// Should expressions be canonicalized before simplification? Defaults to
105    /// true
106    canonicalize: bool,
107    /// Maximum number of simplifier cycles
108    max_simplifier_cycles: u32,
109}
110
111pub const THRESHOLD_INLINE_INLIST: usize = 3;
112pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
113
114impl<S: SimplifyInfo> ExprSimplifier<S> {
115    /// Create a new `ExprSimplifier` with the given `info` such as an
116    /// instance of [`SimplifyContext`]. See
117    /// [`simplify`](Self::simplify) for an example.
118    ///
119    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
120    pub fn new(info: S) -> Self {
121        Self {
122            info,
123            guarantees: vec![],
124            canonicalize: true,
125            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
126        }
127    }
128
129    /// Simplifies this [`Expr`] as much as possible, evaluating
130    /// constants and applying algebraic simplifications.
131    ///
132    /// The types of the expression must match what operators expect,
133    /// or else an error may occur trying to evaluate. See
134    /// [`coerce`](Self::coerce) for a function to help.
135    ///
136    /// # Example:
137    ///
138    /// `b > 2 AND b > 2`
139    ///
140    /// can be written to
141    ///
142    /// `b > 2`
143    ///
144    /// ```
145    /// use arrow::datatypes::DataType;
146    /// use datafusion_common::DFSchema;
147    /// use datafusion_common::Result;
148    /// use datafusion_expr::execution_props::ExecutionProps;
149    /// use datafusion_expr::simplify::SimplifyContext;
150    /// use datafusion_expr::simplify::SimplifyInfo;
151    /// use datafusion_expr::{col, lit, Expr};
152    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
153    /// use std::sync::Arc;
154    ///
155    /// /// Simple implementation that provides `Simplifier` the information it needs
156    /// /// See SimplifyContext for a structure that does this.
157    /// #[derive(Default)]
158    /// struct Info {
159    ///     execution_props: ExecutionProps,
160    /// };
161    ///
162    /// impl SimplifyInfo for Info {
163    ///     fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
164    ///         Ok(false)
165    ///     }
166    ///     fn nullable(&self, expr: &Expr) -> Result<bool> {
167    ///         Ok(true)
168    ///     }
169    ///     fn execution_props(&self) -> &ExecutionProps {
170    ///         &self.execution_props
171    ///     }
172    ///     fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
173    ///         Ok(DataType::Int32)
174    ///     }
175    /// }
176    ///
177    /// // Create the simplifier
178    /// let simplifier = ExprSimplifier::new(Info::default());
179    ///
180    /// // b < 2
181    /// let b_lt_2 = col("b").gt(lit(2));
182    ///
183    /// // (b < 2) OR (b < 2)
184    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
185    ///
186    /// // (b < 2) OR (b < 2) --> (b < 2)
187    /// let expr = simplifier.simplify(expr).unwrap();
188    /// assert_eq!(expr, b_lt_2);
189    /// ```
190    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
191        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
192    }
193
194    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
195    /// constants and applying algebraic simplifications. Additionally returns a `u32`
196    /// representing the number of simplification cycles performed, which can be useful for testing
197    /// optimizations.
198    ///
199    /// See [Self::simplify] for details and usage examples.
200    #[deprecated(
201        since = "48.0.0",
202        note = "Use `simplify_with_cycle_count_transformed` instead"
203    )]
204    #[expect(unused_mut)]
205    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
206        let (transformed, cycle_count) =
207            self.simplify_with_cycle_count_transformed(expr)?;
208        Ok((transformed.data, cycle_count))
209    }
210
211    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
212    /// constants and applying algebraic simplifications. Additionally returns a `u32`
213    /// representing the number of simplification cycles performed, which can be useful for testing
214    /// optimizations.
215    ///
216    /// # Returns
217    ///
218    /// A tuple containing:
219    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
220    /// - The number of simplification cycles that were performed
221    ///
222    /// See [Self::simplify] for details and usage examples.
223    pub fn simplify_with_cycle_count_transformed(
224        &self,
225        mut expr: Expr,
226    ) -> Result<(Transformed<Expr>, u32)> {
227        let mut simplifier = Simplifier::new(&self.info);
228        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
229        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
230        let guarantees_map: HashMap<&Expr, &NullableInterval> =
231            self.guarantees.iter().map(|(k, v)| (k, v)).collect();
232
233        if self.canonicalize {
234            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
235        }
236
237        // Evaluating constants can enable new simplifications and
238        // simplifications can enable new constant evaluation
239        // see `Self::with_max_cycles`
240        let mut num_cycles = 0;
241        let mut has_transformed = false;
242        loop {
243            let Transformed {
244                data, transformed, ..
245            } = expr
246                .rewrite(&mut const_evaluator)?
247                .transform_data(|expr| expr.rewrite(&mut simplifier))?
248                .transform_data(|expr| {
249                    rewrite_with_guarantees_map(expr, &guarantees_map)
250                })?;
251            expr = data;
252            num_cycles += 1;
253            // Track if any transformation occurred
254            has_transformed = has_transformed || transformed;
255            if !transformed || num_cycles >= self.max_simplifier_cycles {
256                break;
257            }
258        }
259        // shorten inlist should be started after other inlist rules are applied
260        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
261        Ok((
262            Transformed::new_transformed(expr, has_transformed),
263            num_cycles,
264        ))
265    }
266
267    /// Apply type coercion to an [`Expr`] so that it can be
268    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
269    ///
270    /// See the [type coercion module](datafusion_expr::type_coercion)
271    /// documentation for more details on type coercion
272    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
273        let mut expr_rewrite = TypeCoercionRewriter { schema };
274        expr.rewrite(&mut expr_rewrite).data()
275    }
276
277    /// Input guarantees about the values of columns.
278    ///
279    /// The guarantees can simplify expressions. For example, if a column `x` is
280    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
281    /// literal `true`.
282    ///
283    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
284    /// where the [Expr] is a column reference and the [NullableInterval]
285    /// is an interval representing the known possible values of that column.
286    ///
287    /// ```rust
288    /// use arrow::datatypes::{DataType, Field, Schema};
289    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
290    /// use datafusion_expr::execution_props::ExecutionProps;
291    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
292    /// use datafusion_expr::simplify::SimplifyContext;
293    /// use datafusion_expr::{col, lit, Expr};
294    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
295    ///
296    /// let schema = Schema::new(vec![
297    ///     Field::new("x", DataType::Int64, false),
298    ///     Field::new("y", DataType::UInt32, false),
299    ///     Field::new("z", DataType::Int64, false),
300    /// ])
301    /// .to_dfschema_ref()
302    /// .unwrap();
303    ///
304    /// // Create the simplifier
305    /// let props = ExecutionProps::new();
306    /// let context = SimplifyContext::new(&props).with_schema(schema);
307    ///
308    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
309    /// let expr_x = col("x").gt_eq(lit(3_i64));
310    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
311    /// let expr_z = col("z").gt(lit(5_i64));
312    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
313    ///
314    /// let guarantees = vec![
315    ///     // x ∈ [3, 5]
316    ///     (
317    ///         col("x"),
318    ///         NullableInterval::NotNull {
319    ///             values: Interval::make(Some(3_i64), Some(5_i64)).unwrap(),
320    ///         },
321    ///     ),
322    ///     // y = 3
323    ///     (
324    ///         col("y"),
325    ///         NullableInterval::from(ScalarValue::UInt32(Some(3))),
326    ///     ),
327    /// ];
328    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
329    /// let output = simplifier.simplify(expr).unwrap();
330    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
331    /// // z > 5.
332    /// assert_eq!(output, expr_z);
333    /// ```
334    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
335        self.guarantees = guarantees;
336        self
337    }
338
339    /// Should `Canonicalizer` be applied before simplification?
340    ///
341    /// If true (the default), the expression will be rewritten to canonical
342    /// form before simplification. This is useful to ensure that the simplifier
343    /// can apply all possible simplifications.
344    ///
345    /// Some expressions, such as those in some Joins, can not be canonicalized
346    /// without changing their meaning. In these cases, canonicalization should
347    /// be disabled.
348    ///
349    /// ```rust
350    /// use arrow::datatypes::{DataType, Field, Schema};
351    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
352    /// use datafusion_expr::execution_props::ExecutionProps;
353    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
354    /// use datafusion_expr::simplify::SimplifyContext;
355    /// use datafusion_expr::{col, lit, Expr};
356    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
357    ///
358    /// let schema = Schema::new(vec![
359    ///     Field::new("a", DataType::Int64, false),
360    ///     Field::new("b", DataType::Int64, false),
361    ///     Field::new("c", DataType::Int64, false),
362    /// ])
363    /// .to_dfschema_ref()
364    /// .unwrap();
365    ///
366    /// // Create the simplifier
367    /// let props = ExecutionProps::new();
368    /// let context = SimplifyContext::new(&props).with_schema(schema);
369    /// let simplifier = ExprSimplifier::new(context);
370    ///
371    /// // Expression: a = c AND 1 = b
372    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
373    ///
374    /// // With canonicalization, the expression is rewritten to canonical form
375    /// // (though it is no simpler in this case):
376    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
377    /// // Expression has been rewritten to: (c = a AND b = 1)
378    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
379    ///
380    /// // If canonicalization is disabled, the expression is not changed
381    /// let non_canonicalized = simplifier
382    ///     .with_canonicalize(false)
383    ///     .simplify(expr.clone())
384    ///     .unwrap();
385    ///
386    /// assert_eq!(non_canonicalized, expr);
387    /// ```
388    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
389        self.canonicalize = canonicalize;
390        self
391    }
392
393    /// Specifies the maximum number of simplification cycles to run.
394    ///
395    /// The simplifier can perform multiple passes of simplification. This is
396    /// because the output of one simplification step can allow more optimizations
397    /// in another simplification step. For example, constant evaluation can allow more
398    /// expression simplifications, and expression simplifications can allow more constant
399    /// evaluations.
400    ///
401    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
402    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
403    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
404    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
405    ///
406    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
407    /// instead.
408    ///
409    /// ```rust
410    /// use arrow::datatypes::{DataType, Field, Schema};
411    /// use datafusion_expr::{col, lit, Expr};
412    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
413    /// use datafusion_expr::execution_props::ExecutionProps;
414    /// use datafusion_expr::simplify::SimplifyContext;
415    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
416    ///
417    /// let schema = Schema::new(vec![
418    ///   Field::new("a", DataType::Int64, false),
419    ///   ])
420    ///   .to_dfschema_ref().unwrap();
421    ///
422    /// // Create the simplifier
423    /// let props = ExecutionProps::new();
424    /// let context = SimplifyContext::new(&props)
425    ///    .with_schema(schema);
426    /// let simplifier = ExprSimplifier::new(context);
427    ///
428    /// // Expression: a IS NOT NULL
429    /// let expr = col("a").is_not_null();
430    ///
431    /// // When using default maximum cycles, 2 cycles will be performed.
432    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
433    /// assert_eq!(simplified_expr.data, lit(true));
434    /// // 2 cycles were executed, but only 1 was needed
435    /// assert_eq!(count, 2);
436    ///
437    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
438    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
439    /// // Expression has been rewritten to: (c = a AND b = 1)
440    /// assert_eq!(simplified_expr.data, lit(true));
441    /// // Only 1 cycle was executed
442    /// assert_eq!(count, 1);
443    /// ```
444    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
445        self.max_simplifier_cycles = max_simplifier_cycles;
446        self
447    }
448}
449
450/// Canonicalize any BinaryExprs that are not in canonical form
451///
452/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
453///
454/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
455/// than `col2` (`a > b` would be canonicalized to `b < a`)
456struct Canonicalizer {}
457
458impl Canonicalizer {
459    fn new() -> Self {
460        Self {}
461    }
462}
463
464impl TreeNodeRewriter for Canonicalizer {
465    type Node = Expr;
466
467    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
468        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
469            return Ok(Transformed::no(expr));
470        };
471        match (left.as_ref(), right.as_ref(), op.swap()) {
472            // <col1> <op> <col2>
473            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
474                if right_col > left_col =>
475            {
476                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
477                    left: right,
478                    op: swapped_op,
479                    right: left,
480                })))
481            }
482            // <literal> <op> <col>
483            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
484                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
485                    left: right,
486                    op: swapped_op,
487                    right: left,
488                })))
489            }
490            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
491                left,
492                op,
493                right,
494            }))),
495        }
496    }
497}
498
499/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
500///
501/// Note it does not handle algebraic rewrites such as `(a or false)`
502/// --> `a`, which is handled by [`Simplifier`]
503struct ConstEvaluator<'a> {
504    /// `can_evaluate` is used during the depth-first-search of the
505    /// `Expr` tree to track if any siblings (or their descendants) were
506    /// non evaluatable (e.g. had a column reference or volatile
507    /// function)
508    ///
509    /// Specifically, `can_evaluate[N]` represents the state of
510    /// traversal when we are N levels deep in the tree, one entry for
511    /// this Expr and each of its parents.
512    ///
513    /// After visiting all siblings if `can_evaluate.top()` is true, that
514    /// means there were no non evaluatable siblings (or their
515    /// descendants) so this `Expr` can be evaluated
516    can_evaluate: Vec<bool>,
517
518    execution_props: &'a ExecutionProps,
519    input_schema: DFSchema,
520    input_batch: RecordBatch,
521}
522
523/// The simplify result of ConstEvaluator
524enum ConstSimplifyResult {
525    // Expr was simplified and contains the new expression
526    Simplified(ScalarValue, Option<FieldMetadata>),
527    // Expr was not simplified and original value is returned
528    NotSimplified(ScalarValue, Option<FieldMetadata>),
529    // Evaluation encountered an error, contains the original expression
530    SimplifyRuntimeError(DataFusionError, Expr),
531}
532
533impl TreeNodeRewriter for ConstEvaluator<'_> {
534    type Node = Expr;
535
536    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
537        // Default to being able to evaluate this node
538        self.can_evaluate.push(true);
539
540        // if this expr is not ok to evaluate, mark entire parent
541        // stack as not ok (as all parents have at least one child or
542        // descendant that can not be evaluated
543
544        if !Self::can_evaluate(&expr) {
545            // walk back up stack, marking first parent that is not mutable
546            let parent_iter = self.can_evaluate.iter_mut().rev();
547            for p in parent_iter {
548                if !*p {
549                    // optimization: if we find an element on the
550                    // stack already marked, know all elements above are also marked
551                    break;
552                }
553                *p = false;
554            }
555        }
556
557        // NB: do not short circuit recursion even if we find a non
558        // evaluatable node (so we can fold other children, args to
559        // functions, etc.)
560        Ok(Transformed::no(expr))
561    }
562
563    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
564        match self.can_evaluate.pop() {
565            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
566            // and may not evaluate all their sub expressions. Thus, if
567            // any error is countered during simplification, return the original
568            // so that normal evaluation can occur
569            Some(true) => match self.evaluate_to_scalar(expr) {
570                ConstSimplifyResult::Simplified(s, m) => {
571                    Ok(Transformed::yes(Expr::Literal(s, m)))
572                }
573                ConstSimplifyResult::NotSimplified(s, m) => {
574                    Ok(Transformed::no(Expr::Literal(s, m)))
575                }
576                ConstSimplifyResult::SimplifyRuntimeError(err, expr) => {
577                    // For CAST expressions with literal inputs, propagate the error at plan time rather than deferring to execution time.
578                    // This provides clearer error messages and fails fast.
579                    if let Expr::Cast(Cast { ref expr, .. })
580                    | Expr::TryCast(TryCast { ref expr, .. }) = expr
581                        && matches!(expr.as_ref(), Expr::Literal(_, _))
582                    {
583                        return Err(err);
584                    }
585                    // For other expressions (like CASE, COALESCE), preserve the original
586                    // to allow short-circuit evaluation at execution time
587                    Ok(Transformed::yes(expr))
588                }
589            },
590            Some(false) => Ok(Transformed::no(expr)),
591            _ => internal_err!("Failed to pop can_evaluate"),
592        }
593    }
594}
595
596impl<'a> ConstEvaluator<'a> {
597    /// Create a new `ConstantEvaluator`. Session constants (such as
598    /// the time for `now()` are taken from the passed
599    /// `execution_props`.
600    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
601        // The dummy column name is unused and doesn't matter as only
602        // expressions without column references can be evaluated
603        static DUMMY_COL_NAME: &str = ".";
604        let schema = Arc::new(Schema::new(vec![Field::new(
605            DUMMY_COL_NAME,
606            DataType::Null,
607            true,
608        )]));
609        let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
610        // Need a single "input" row to produce a single output row
611        let col = new_null_array(&DataType::Null, 1);
612        let input_batch = RecordBatch::try_new(schema, vec![col])?;
613
614        Ok(Self {
615            can_evaluate: vec![],
616            execution_props,
617            input_schema,
618            input_batch,
619        })
620    }
621
622    /// Can a function of the specified volatility be evaluated?
623    fn volatility_ok(volatility: Volatility) -> bool {
624        match volatility {
625            Volatility::Immutable => true,
626            // Values for functions such as now() are taken from ExecutionProps
627            Volatility::Stable => true,
628            Volatility::Volatile => false,
629        }
630    }
631
632    /// Can the expression be evaluated at plan time, (assuming all of
633    /// its children can also be evaluated)?
634    fn can_evaluate(expr: &Expr) -> bool {
635        // check for reasons we can't evaluate this node
636        //
637        // NOTE all expr types are listed here so when new ones are
638        // added they can be checked for their ability to be evaluated
639        // at plan time
640        match expr {
641            // TODO: remove the next line after `Expr::Wildcard` is removed
642            #[expect(deprecated)]
643            Expr::AggregateFunction { .. }
644            | Expr::ScalarVariable(_, _)
645            | Expr::Column(_)
646            | Expr::OuterReferenceColumn(_, _)
647            | Expr::Exists { .. }
648            | Expr::InSubquery(_)
649            | Expr::ScalarSubquery(_)
650            | Expr::WindowFunction { .. }
651            | Expr::GroupingSet(_)
652            | Expr::Wildcard { .. }
653            | Expr::Placeholder(_) => false,
654            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
655                Self::volatility_ok(func.signature().volatility)
656            }
657            Expr::Literal(_, _)
658            | Expr::Alias(..)
659            | Expr::Unnest(_)
660            | Expr::BinaryExpr { .. }
661            | Expr::Not(_)
662            | Expr::IsNotNull(_)
663            | Expr::IsNull(_)
664            | Expr::IsTrue(_)
665            | Expr::IsFalse(_)
666            | Expr::IsUnknown(_)
667            | Expr::IsNotTrue(_)
668            | Expr::IsNotFalse(_)
669            | Expr::IsNotUnknown(_)
670            | Expr::Negative(_)
671            | Expr::Between { .. }
672            | Expr::Like { .. }
673            | Expr::SimilarTo { .. }
674            | Expr::Case(_)
675            | Expr::Cast { .. }
676            | Expr::TryCast { .. }
677            | Expr::InList { .. } => true,
678        }
679    }
680
681    /// Internal helper to evaluates an Expr
682    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
683        if let Expr::Literal(s, m) = expr {
684            return ConstSimplifyResult::NotSimplified(s, m);
685        }
686
687        let phys_expr =
688            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
689                Ok(e) => e,
690                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
691            };
692        let metadata = phys_expr
693            .return_field(self.input_batch.schema_ref())
694            .ok()
695            .and_then(|f| {
696                let m = f.metadata();
697                match m.is_empty() {
698                    true => None,
699                    false => Some(FieldMetadata::from(m)),
700                }
701            });
702        let col_val = match phys_expr.evaluate(&self.input_batch) {
703            Ok(v) => v,
704            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
705        };
706        match col_val {
707            ColumnarValue::Array(a) => {
708                if a.len() != 1 {
709                    ConstSimplifyResult::SimplifyRuntimeError(
710                        exec_datafusion_err!(
711                            "Could not evaluate the expression, found a result of length {}",
712                            a.len()
713                        ),
714                        expr,
715                    )
716                } else if as_list_array(&a).is_ok() {
717                    ConstSimplifyResult::Simplified(
718                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
719                        metadata,
720                    )
721                } else if as_large_list_array(&a).is_ok() {
722                    ConstSimplifyResult::Simplified(
723                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
724                        metadata,
725                    )
726                } else {
727                    // Non-ListArray
728                    match ScalarValue::try_from_array(&a, 0) {
729                        Ok(s) => ConstSimplifyResult::Simplified(s, metadata),
730                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
731                    }
732                }
733            }
734            ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s, metadata),
735        }
736    }
737}
738
739/// Simplifies [`Expr`]s by applying algebraic transformation rules
740///
741/// Example transformations that are applied:
742/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
743/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
744/// * `true = true` and `false = false` to `true`
745/// * `false = true` and `true = false` to `false`
746/// * `!!expr` to `expr`
747/// * `expr = null` and `expr != null` to `null`
748struct Simplifier<'a, S> {
749    info: &'a S,
750}
751
752impl<'a, S> Simplifier<'a, S> {
753    pub fn new(info: &'a S) -> Self {
754        Self { info }
755    }
756}
757
758impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
759    type Node = Expr;
760
761    /// rewrite the expression simplifying any constant expressions
762    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
763        use datafusion_expr::Operator::{
764            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
765            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
766            RegexNotIMatch, RegexNotMatch,
767        };
768
769        let info = self.info;
770        Ok(match expr {
771            // `value op NULL` -> `NULL`
772            // `NULL op value` -> `NULL`
773            // except for few operators that can return non-null value even when one of the operands is NULL
774            ref expr @ Expr::BinaryExpr(BinaryExpr {
775                ref left,
776                ref op,
777                ref right,
778            }) if op.returns_null_on_null()
779                && (is_null(left.as_ref()) || is_null(right.as_ref())) =>
780            {
781                Transformed::yes(Expr::Literal(
782                    ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
783                    None,
784                ))
785            }
786
787            // `NULL {AND, OR} NULL` -> `NULL`
788            Expr::BinaryExpr(BinaryExpr {
789                left,
790                op: And | Or,
791                right,
792            }) if is_null(&left) && is_null(&right) => Transformed::yes(lit_bool_null()),
793
794            //
795            // Rules for Eq
796            //
797
798            // true = A  --> A
799            // false = A --> !A
800            // null = A --> null
801            Expr::BinaryExpr(BinaryExpr {
802                left,
803                op: Eq,
804                right,
805            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
806                Transformed::yes(match as_bool_lit(&left)? {
807                    Some(true) => *right,
808                    Some(false) => Expr::Not(right),
809                    None => lit_bool_null(),
810                })
811            }
812            // A = true  --> A
813            // A = false --> !A
814            // A = null --> null
815            Expr::BinaryExpr(BinaryExpr {
816                left,
817                op: Eq,
818                right,
819            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
820                Transformed::yes(match as_bool_lit(&right)? {
821                    Some(true) => *left,
822                    Some(false) => Expr::Not(left),
823                    None => lit_bool_null(),
824                })
825            }
826            // According to SQL's null semantics, NULL = NULL evaluates to NULL
827            // Both sides are the same expression (A = A) and A is non-volatile expression
828            // A = A --> A IS NOT NULL OR NULL
829            // A = A --> true (if A not nullable)
830            Expr::BinaryExpr(BinaryExpr {
831                left,
832                op: Eq,
833                right,
834            }) if (left == right) & !left.is_volatile() => {
835                Transformed::yes(match !info.nullable(&left)? {
836                    true => lit(true),
837                    false => Expr::BinaryExpr(BinaryExpr {
838                        left: Box::new(Expr::IsNotNull(left)),
839                        op: Or,
840                        right: Box::new(lit_bool_null()),
841                    }),
842                })
843            }
844
845            // Rules for NotEq
846            //
847
848            // true != A  --> !A
849            // false != A --> A
850            // null != A --> null
851            Expr::BinaryExpr(BinaryExpr {
852                left,
853                op: NotEq,
854                right,
855            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
856                Transformed::yes(match as_bool_lit(&left)? {
857                    Some(true) => Expr::Not(right),
858                    Some(false) => *right,
859                    None => lit_bool_null(),
860                })
861            }
862            // A != true  --> !A
863            // A != false --> A
864            // A != null --> null,
865            Expr::BinaryExpr(BinaryExpr {
866                left,
867                op: NotEq,
868                right,
869            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
870                Transformed::yes(match as_bool_lit(&right)? {
871                    Some(true) => Expr::Not(left),
872                    Some(false) => *left,
873                    None => lit_bool_null(),
874                })
875            }
876
877            //
878            // Rules for OR
879            //
880
881            // true OR A --> true (even if A is null)
882            Expr::BinaryExpr(BinaryExpr {
883                left,
884                op: Or,
885                right: _,
886            }) if is_true(&left) => Transformed::yes(*left),
887            // false OR A --> A
888            Expr::BinaryExpr(BinaryExpr {
889                left,
890                op: Or,
891                right,
892            }) if is_false(&left) => Transformed::yes(*right),
893            // A OR true --> true (even if A is null)
894            Expr::BinaryExpr(BinaryExpr {
895                left: _,
896                op: Or,
897                right,
898            }) if is_true(&right) => Transformed::yes(*right),
899            // A OR false --> A
900            Expr::BinaryExpr(BinaryExpr {
901                left,
902                op: Or,
903                right,
904            }) if is_false(&right) => Transformed::yes(*left),
905            // A OR !A ---> true (if A not nullable)
906            Expr::BinaryExpr(BinaryExpr {
907                left,
908                op: Or,
909                right,
910            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
911                Transformed::yes(lit(true))
912            }
913            // !A OR A ---> true (if A not nullable)
914            Expr::BinaryExpr(BinaryExpr {
915                left,
916                op: Or,
917                right,
918            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
919                Transformed::yes(lit(true))
920            }
921            // (..A..) OR A --> (..A..)
922            Expr::BinaryExpr(BinaryExpr {
923                left,
924                op: Or,
925                right,
926            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
927            // A OR (..A..) --> (..A..)
928            Expr::BinaryExpr(BinaryExpr {
929                left,
930                op: Or,
931                right,
932            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
933            // A OR (A AND B) --> A
934            Expr::BinaryExpr(BinaryExpr {
935                left,
936                op: Or,
937                right,
938            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
939            // (A AND B) OR A --> A
940            Expr::BinaryExpr(BinaryExpr {
941                left,
942                op: Or,
943                right,
944            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
945            // Eliminate common factors in conjunctions e.g
946            // (A AND B) OR (A AND C) -> A AND (B OR C)
947            Expr::BinaryExpr(BinaryExpr {
948                left,
949                op: Or,
950                right,
951            }) if has_common_conjunction(&left, &right) => {
952                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
953                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
954                    .partition(|e| lhs.contains(e) && !e.is_volatile());
955
956                let new_rhs = rhs.into_iter().reduce(and);
957                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
958                let common_conjunction = common.into_iter().reduce(and).unwrap();
959
960                let new_expr = match (new_lhs, new_rhs) {
961                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
962                    (_, _) => common_conjunction,
963                };
964                Transformed::yes(new_expr)
965            }
966
967            //
968            // Rules for AND
969            //
970
971            // true AND A --> A
972            Expr::BinaryExpr(BinaryExpr {
973                left,
974                op: And,
975                right,
976            }) if is_true(&left) => Transformed::yes(*right),
977            // false AND A --> false (even if A is null)
978            Expr::BinaryExpr(BinaryExpr {
979                left,
980                op: And,
981                right: _,
982            }) if is_false(&left) => Transformed::yes(*left),
983            // A AND true --> A
984            Expr::BinaryExpr(BinaryExpr {
985                left,
986                op: And,
987                right,
988            }) if is_true(&right) => Transformed::yes(*left),
989            // A AND false --> false (even if A is null)
990            Expr::BinaryExpr(BinaryExpr {
991                left: _,
992                op: And,
993                right,
994            }) if is_false(&right) => Transformed::yes(*right),
995            // A AND !A ---> false (if A not nullable)
996            Expr::BinaryExpr(BinaryExpr {
997                left,
998                op: And,
999                right,
1000            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
1001                Transformed::yes(lit(false))
1002            }
1003            // !A AND A ---> false (if A not nullable)
1004            Expr::BinaryExpr(BinaryExpr {
1005                left,
1006                op: And,
1007                right,
1008            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
1009                Transformed::yes(lit(false))
1010            }
1011            // (..A..) AND A --> (..A..)
1012            Expr::BinaryExpr(BinaryExpr {
1013                left,
1014                op: And,
1015                right,
1016            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
1017            // A AND (..A..) --> (..A..)
1018            Expr::BinaryExpr(BinaryExpr {
1019                left,
1020                op: And,
1021                right,
1022            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1023            // A AND (A OR B) --> A
1024            Expr::BinaryExpr(BinaryExpr {
1025                left,
1026                op: And,
1027                right,
1028            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1029            // (A OR B) AND A --> A
1030            Expr::BinaryExpr(BinaryExpr {
1031                left,
1032                op: And,
1033                right,
1034            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1035            // A >= constant AND constant <= A --> A = constant
1036            Expr::BinaryExpr(BinaryExpr {
1037                left,
1038                op: And,
1039                right,
1040            }) if can_reduce_to_equal_statement(&left, &right) => {
1041                if let Expr::BinaryExpr(BinaryExpr {
1042                    left: left_left,
1043                    right: left_right,
1044                    ..
1045                }) = *left
1046                {
1047                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1048                        left: left_left,
1049                        op: Eq,
1050                        right: left_right,
1051                    }))
1052                } else {
1053                    return internal_err!(
1054                        "can_reduce_to_equal_statement should only be called with a BinaryExpr"
1055                    );
1056                }
1057            }
1058
1059            //
1060            // Rules for Multiply
1061            //
1062
1063            // A * 1 --> A (with type coercion if needed)
1064            Expr::BinaryExpr(BinaryExpr {
1065                left,
1066                op: Multiply,
1067                right,
1068            }) if is_one(&right) => {
1069                simplify_right_is_one_case(info, left, &Multiply, &right)?
1070            }
1071            // 1 * A --> A
1072            Expr::BinaryExpr(BinaryExpr {
1073                left,
1074                op: Multiply,
1075                right,
1076            }) if is_one(&left) => {
1077                // 1 * A is equivalent to A * 1
1078                simplify_right_is_one_case(info, right, &Multiply, &left)?
1079            }
1080
1081            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1082            Expr::BinaryExpr(BinaryExpr {
1083                left,
1084                op: Multiply,
1085                right,
1086            }) if !info.nullable(&left)?
1087                && !info.get_data_type(&left)?.is_floating()
1088                && is_zero(&right) =>
1089            {
1090                Transformed::yes(*right)
1091            }
1092            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1093            Expr::BinaryExpr(BinaryExpr {
1094                left,
1095                op: Multiply,
1096                right,
1097            }) if !info.nullable(&right)?
1098                && !info.get_data_type(&right)?.is_floating()
1099                && is_zero(&left) =>
1100            {
1101                Transformed::yes(*left)
1102            }
1103
1104            //
1105            // Rules for Divide
1106            //
1107
1108            // A / 1 --> A
1109            Expr::BinaryExpr(BinaryExpr {
1110                left,
1111                op: Divide,
1112                right,
1113            }) if is_one(&right) => {
1114                simplify_right_is_one_case(info, left, &Divide, &right)?
1115            }
1116
1117            //
1118            // Rules for Modulo
1119            //
1120
1121            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1122            Expr::BinaryExpr(BinaryExpr {
1123                left,
1124                op: Modulo,
1125                right,
1126            }) if !info.nullable(&left)?
1127                && !info.get_data_type(&left)?.is_floating()
1128                && is_one(&right) =>
1129            {
1130                Transformed::yes(Expr::Literal(
1131                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1132                    None,
1133                ))
1134            }
1135
1136            //
1137            // Rules for BitwiseAnd
1138            //
1139
1140            // A & 0 -> 0 (if A not nullable)
1141            Expr::BinaryExpr(BinaryExpr {
1142                left,
1143                op: BitwiseAnd,
1144                right,
1145            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1146
1147            // 0 & A -> 0 (if A not nullable)
1148            Expr::BinaryExpr(BinaryExpr {
1149                left,
1150                op: BitwiseAnd,
1151                right,
1152            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1153
1154            // !A & A -> 0 (if A not nullable)
1155            Expr::BinaryExpr(BinaryExpr {
1156                left,
1157                op: BitwiseAnd,
1158                right,
1159            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1160                Transformed::yes(Expr::Literal(
1161                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1162                    None,
1163                ))
1164            }
1165
1166            // A & !A -> 0 (if A not nullable)
1167            Expr::BinaryExpr(BinaryExpr {
1168                left,
1169                op: BitwiseAnd,
1170                right,
1171            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1172                Transformed::yes(Expr::Literal(
1173                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1174                    None,
1175                ))
1176            }
1177
1178            // (..A..) & A --> (..A..)
1179            Expr::BinaryExpr(BinaryExpr {
1180                left,
1181                op: BitwiseAnd,
1182                right,
1183            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1184
1185            // A & (..A..) --> (..A..)
1186            Expr::BinaryExpr(BinaryExpr {
1187                left,
1188                op: BitwiseAnd,
1189                right,
1190            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1191
1192            // A & (A | B) --> A (if B not null)
1193            Expr::BinaryExpr(BinaryExpr {
1194                left,
1195                op: BitwiseAnd,
1196                right,
1197            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1198                Transformed::yes(*left)
1199            }
1200
1201            // (A | B) & A --> A (if B not null)
1202            Expr::BinaryExpr(BinaryExpr {
1203                left,
1204                op: BitwiseAnd,
1205                right,
1206            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1207                Transformed::yes(*right)
1208            }
1209
1210            //
1211            // Rules for BitwiseOr
1212            //
1213
1214            // A | 0 -> A (even if A is null)
1215            Expr::BinaryExpr(BinaryExpr {
1216                left,
1217                op: BitwiseOr,
1218                right,
1219            }) if is_zero(&right) => Transformed::yes(*left),
1220
1221            // 0 | A -> A (even if A is null)
1222            Expr::BinaryExpr(BinaryExpr {
1223                left,
1224                op: BitwiseOr,
1225                right,
1226            }) if is_zero(&left) => Transformed::yes(*right),
1227
1228            // !A | A -> -1 (if A not nullable)
1229            Expr::BinaryExpr(BinaryExpr {
1230                left,
1231                op: BitwiseOr,
1232                right,
1233            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1234                Transformed::yes(Expr::Literal(
1235                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1236                    None,
1237                ))
1238            }
1239
1240            // A | !A -> -1 (if A not nullable)
1241            Expr::BinaryExpr(BinaryExpr {
1242                left,
1243                op: BitwiseOr,
1244                right,
1245            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1246                Transformed::yes(Expr::Literal(
1247                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1248                    None,
1249                ))
1250            }
1251
1252            // (..A..) | A --> (..A..)
1253            Expr::BinaryExpr(BinaryExpr {
1254                left,
1255                op: BitwiseOr,
1256                right,
1257            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1258
1259            // A | (..A..) --> (..A..)
1260            Expr::BinaryExpr(BinaryExpr {
1261                left,
1262                op: BitwiseOr,
1263                right,
1264            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1265
1266            // A | (A & B) --> A (if B not null)
1267            Expr::BinaryExpr(BinaryExpr {
1268                left,
1269                op: BitwiseOr,
1270                right,
1271            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1272                Transformed::yes(*left)
1273            }
1274
1275            // (A & B) | A --> A (if B not null)
1276            Expr::BinaryExpr(BinaryExpr {
1277                left,
1278                op: BitwiseOr,
1279                right,
1280            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1281                Transformed::yes(*right)
1282            }
1283
1284            //
1285            // Rules for BitwiseXor
1286            //
1287
1288            // A ^ 0 -> A (if A not nullable)
1289            Expr::BinaryExpr(BinaryExpr {
1290                left,
1291                op: BitwiseXor,
1292                right,
1293            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1294
1295            // 0 ^ A -> A (if A not nullable)
1296            Expr::BinaryExpr(BinaryExpr {
1297                left,
1298                op: BitwiseXor,
1299                right,
1300            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1301
1302            // !A ^ A -> -1 (if A not nullable)
1303            Expr::BinaryExpr(BinaryExpr {
1304                left,
1305                op: BitwiseXor,
1306                right,
1307            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1308                Transformed::yes(Expr::Literal(
1309                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1310                    None,
1311                ))
1312            }
1313
1314            // A ^ !A -> -1 (if A not nullable)
1315            Expr::BinaryExpr(BinaryExpr {
1316                left,
1317                op: BitwiseXor,
1318                right,
1319            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1320                Transformed::yes(Expr::Literal(
1321                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1322                    None,
1323                ))
1324            }
1325
1326            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1327            Expr::BinaryExpr(BinaryExpr {
1328                left,
1329                op: BitwiseXor,
1330                right,
1331            }) if expr_contains(&left, &right, BitwiseXor) => {
1332                let expr = delete_xor_in_complex_expr(&left, &right, false);
1333                Transformed::yes(if expr == *right {
1334                    Expr::Literal(
1335                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1336                        None,
1337                    )
1338                } else {
1339                    expr
1340                })
1341            }
1342
1343            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1344            Expr::BinaryExpr(BinaryExpr {
1345                left,
1346                op: BitwiseXor,
1347                right,
1348            }) if expr_contains(&right, &left, BitwiseXor) => {
1349                let expr = delete_xor_in_complex_expr(&right, &left, true);
1350                Transformed::yes(if expr == *left {
1351                    Expr::Literal(
1352                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1353                        None,
1354                    )
1355                } else {
1356                    expr
1357                })
1358            }
1359
1360            //
1361            // Rules for BitwiseShiftRight
1362            //
1363
1364            // A >> 0 -> A (even if A is null)
1365            Expr::BinaryExpr(BinaryExpr {
1366                left,
1367                op: BitwiseShiftRight,
1368                right,
1369            }) if is_zero(&right) => Transformed::yes(*left),
1370
1371            //
1372            // Rules for BitwiseShiftRight
1373            //
1374
1375            // A << 0 -> A (even if A is null)
1376            Expr::BinaryExpr(BinaryExpr {
1377                left,
1378                op: BitwiseShiftLeft,
1379                right,
1380            }) if is_zero(&right) => Transformed::yes(*left),
1381
1382            //
1383            // Rules for Not
1384            //
1385            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1386
1387            //
1388            // Rules for Negative
1389            //
1390            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1391
1392            //
1393            // Rules for Case
1394            //
1395
1396            // Inline a comparison to a literal with the case statement into the `THEN` clauses.
1397            // which can enable further simplifications
1398            // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
1399            Expr::BinaryExpr(BinaryExpr {
1400                left,
1401                op: op @ (Eq | NotEq),
1402                right,
1403            }) if is_case_with_literal_outputs(&left) && is_lit(&right) => {
1404                let case = into_case(*left)?;
1405                Transformed::yes(Expr::Case(Case {
1406                    expr: None,
1407                    when_then_expr: case
1408                        .when_then_expr
1409                        .into_iter()
1410                        .map(|(when, then)| {
1411                            (
1412                                when,
1413                                Box::new(Expr::BinaryExpr(BinaryExpr {
1414                                    left: then,
1415                                    op,
1416                                    right: right.clone(),
1417                                })),
1418                            )
1419                        })
1420                        .collect(),
1421                    else_expr: case.else_expr.map(|els| {
1422                        Box::new(Expr::BinaryExpr(BinaryExpr {
1423                            left: els,
1424                            op,
1425                            right,
1426                        }))
1427                    }),
1428                }))
1429            }
1430
1431            // CASE WHEN true THEN A ... END --> A
1432            // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1433            // CASE WHEN false THEN A END --> NULL
1434            // CASE WHEN false THEN A ELSE B END --> B
1435            // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
1436            Expr::Case(Case {
1437                expr: None,
1438                when_then_expr,
1439                mut else_expr,
1440            }) if when_then_expr
1441                .iter()
1442                .any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) =>
1443            {
1444                let out_type = info.get_data_type(&when_then_expr[0].1)?;
1445                let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len());
1446
1447                for (when, then) in when_then_expr.into_iter() {
1448                    if is_true(when.as_ref()) {
1449                        // Skip adding the rest of the when-then expressions after WHEN true
1450                        // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
1451                        else_expr = Some(then);
1452                        break;
1453                    } else if !is_false(when.as_ref()) {
1454                        new_when_then_expr.push((when, then));
1455                    }
1456                    // else: skip WHEN false cases
1457                }
1458
1459                // Exclude CASE statement altogether if there are no when-then expressions left
1460                if new_when_then_expr.is_empty() {
1461                    // CASE WHEN false THEN A ELSE B END --> B
1462                    if let Some(else_expr) = else_expr {
1463                        return Ok(Transformed::yes(*else_expr));
1464                    // CASE WHEN false THEN A END --> NULL
1465                    } else {
1466                        let null =
1467                            Expr::Literal(ScalarValue::try_new_null(&out_type)?, None);
1468                        return Ok(Transformed::yes(null));
1469                    }
1470                }
1471
1472                Transformed::yes(Expr::Case(Case {
1473                    expr: None,
1474                    when_then_expr: new_when_then_expr,
1475                    else_expr,
1476                }))
1477            }
1478
1479            // CASE
1480            //   WHEN X THEN A
1481            //   WHEN Y THEN B
1482            //   ...
1483            //   ELSE Q
1484            // END
1485            //
1486            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1487            //
1488            // Note: the rationale for this rewrite is that the expr can then be further
1489            // simplified using the existing rules for AND/OR
1490            Expr::Case(Case {
1491                expr: None,
1492                when_then_expr,
1493                else_expr,
1494            }) if !when_then_expr.is_empty()
1495                // The rewrite is O(n²) in general so limit to small number of when-thens that can be true
1496                && (when_then_expr.len() < 3 // small number of input whens
1497                    // or all thens are literal bools and a small number of them are true
1498                    || (when_then_expr.iter().all(|(_, then)| is_bool_lit(then))
1499                        && when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3))
1500                && info.is_boolean_type(&when_then_expr[0].1)? =>
1501            {
1502                // String disjunction of all the when predicates encountered so far. Not nullable.
1503                let mut filter_expr = lit(false);
1504                // The disjunction of all the cases
1505                let mut out_expr = lit(false);
1506
1507                for (when, then) in when_then_expr {
1508                    let when = is_exactly_true(*when, info)?;
1509                    let case_expr =
1510                        when.clone().and(filter_expr.clone().not()).and(*then);
1511
1512                    out_expr = out_expr.or(case_expr);
1513                    filter_expr = filter_expr.or(when);
1514                }
1515
1516                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1517                let case_expr = filter_expr.not().and(else_expr);
1518                out_expr = out_expr.or(case_expr);
1519
1520                // Do a first pass at simplification
1521                out_expr.rewrite(self)?
1522            }
1523            // CASE
1524            //   WHEN X THEN true
1525            //   WHEN Y THEN true
1526            //   WHEN Z THEN false
1527            //   ...
1528            //   ELSE true
1529            // END
1530            //
1531            // --->
1532            //
1533            // NOT(CASE
1534            //   WHEN X THEN false
1535            //   WHEN Y THEN false
1536            //   WHEN Z THEN true
1537            //   ...
1538            //   ELSE false
1539            // END)
1540            //
1541            // Note: the rationale for this rewrite is that the case can then be further
1542            // simplified into a small number of ANDs and ORs
1543            Expr::Case(Case {
1544                expr: None,
1545                when_then_expr,
1546                else_expr,
1547            }) if !when_then_expr.is_empty()
1548                && when_then_expr
1549                    .iter()
1550                    .all(|(_, then)| is_bool_lit(then)) // all thens are literal bools
1551                // This simplification is only helpful if we end up with a small number of true thens
1552                && when_then_expr
1553                    .iter()
1554                    .filter(|(_, then)| is_false(then))
1555                    .count()
1556                    < 3
1557                && else_expr.as_deref().is_none_or(is_bool_lit) =>
1558            {
1559                Transformed::yes(
1560                    Expr::Case(Case {
1561                        expr: None,
1562                        when_then_expr: when_then_expr
1563                            .into_iter()
1564                            .map(|(when, then)| (when, Box::new(Expr::Not(then))))
1565                            .collect(),
1566                        else_expr: else_expr
1567                            .map(|else_expr| Box::new(Expr::Not(else_expr))),
1568                    })
1569                    .not(),
1570                )
1571            }
1572            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1573                match udf.simplify(args, info)? {
1574                    ExprSimplifyResult::Original(args) => {
1575                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1576                            func: udf,
1577                            args,
1578                        }))
1579                    }
1580                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1581                }
1582            }
1583
1584            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1585                ref func,
1586                ..
1587            }) => match (func.simplify(), expr) {
1588                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1589                    Transformed::yes(simplify_function(af, info)?)
1590                }
1591                (_, expr) => Transformed::no(expr),
1592            },
1593
1594            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1595                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1596                    Transformed::yes(simplify_function(*wf, info)?)
1597                }
1598                (_, expr) => Transformed::no(expr),
1599            },
1600
1601            //
1602            // Rules for Between
1603            //
1604
1605            // a between 3 and 5  -->  a >= 3 AND a <=5
1606            // a not between 3 and 5  -->  a < 3 OR a > 5
1607            Expr::Between(between) => Transformed::yes(if between.negated {
1608                let l = *between.expr.clone();
1609                let r = *between.expr;
1610                or(l.lt(*between.low), r.gt(*between.high))
1611            } else {
1612                and(
1613                    between.expr.clone().gt_eq(*between.low),
1614                    between.expr.lt_eq(*between.high),
1615                )
1616            }),
1617
1618            //
1619            // Rules for regexes
1620            //
1621            Expr::BinaryExpr(BinaryExpr {
1622                left,
1623                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1624                right,
1625            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1626
1627            // Rules for Like
1628            Expr::Like(like) => {
1629                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1630                let escape_char = like.escape_char.unwrap_or('\\');
1631                match as_string_scalar(&like.pattern) {
1632                    Some((data_type, pattern_str)) => {
1633                        match pattern_str {
1634                            None => return Ok(Transformed::yes(lit_bool_null())),
1635                            Some(pattern_str) if pattern_str == "%" => {
1636                                // exp LIKE '%' is
1637                                //   - when exp is not NULL, it's true
1638                                //   - when exp is NULL, it's NULL
1639                                // exp NOT LIKE '%' is
1640                                //   - when exp is not NULL, it's false
1641                                //   - when exp is NULL, it's NULL
1642                                let result_for_non_null = lit(!like.negated);
1643                                Transformed::yes(if !info.nullable(&like.expr)? {
1644                                    result_for_non_null
1645                                } else {
1646                                    Expr::Case(Case {
1647                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1648                                        when_then_expr: vec![(
1649                                            Box::new(lit(true)),
1650                                            Box::new(result_for_non_null),
1651                                        )],
1652                                        else_expr: None,
1653                                    })
1654                                })
1655                            }
1656                            Some(pattern_str)
1657                                if pattern_str.contains("%%")
1658                                    && !pattern_str.contains(escape_char) =>
1659                            {
1660                                // Repeated occurrences of wildcard are redundant so remove them
1661                                // exp LIKE '%%'  --> exp LIKE '%'
1662                                let simplified_pattern = Regex::new("%%+")
1663                                    .unwrap()
1664                                    .replace_all(pattern_str, "%")
1665                                    .to_string();
1666                                Transformed::yes(Expr::Like(Like {
1667                                    pattern: Box::new(to_string_scalar(
1668                                        &data_type,
1669                                        Some(simplified_pattern),
1670                                    )),
1671                                    ..like
1672                                }))
1673                            }
1674                            Some(pattern_str)
1675                                if !like.case_insensitive
1676                                    && !pattern_str
1677                                        .contains(['%', '_', escape_char].as_ref()) =>
1678                            {
1679                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1680                                // TODO: handle escape characters
1681                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1682                                    left: like.expr.clone(),
1683                                    op: if like.negated { NotEq } else { Eq },
1684                                    right: like.pattern.clone(),
1685                                }))
1686                            }
1687
1688                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1689                        }
1690                    }
1691                    None => Transformed::no(Expr::Like(like)),
1692                }
1693            }
1694
1695            // a is not null/unknown --> true (if a is not nullable)
1696            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1697                if !info.nullable(&expr)? =>
1698            {
1699                Transformed::yes(lit(true))
1700            }
1701
1702            // a is null/unknown --> false (if a is not nullable)
1703            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1704                Transformed::yes(lit(false))
1705            }
1706
1707            // expr IN () --> false
1708            // expr NOT IN () --> true
1709            Expr::InList(InList {
1710                expr: _,
1711                list,
1712                negated,
1713            }) if list.is_empty() => Transformed::yes(lit(negated)),
1714
1715            // null in (x, y, z) --> null
1716            // null not in (x, y, z) --> null
1717            Expr::InList(InList {
1718                expr,
1719                list,
1720                negated: _,
1721            }) if is_null(expr.as_ref()) && !list.is_empty() => {
1722                Transformed::yes(lit_bool_null())
1723            }
1724
1725            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1726            Expr::InList(InList {
1727                expr,
1728                mut list,
1729                negated,
1730            }) if list.len() == 1
1731                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1732            {
1733                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1734                    unreachable!()
1735                };
1736
1737                Transformed::yes(Expr::InSubquery(InSubquery::new(
1738                    expr, subquery, negated,
1739                )))
1740            }
1741
1742            // Combine multiple OR expressions into a single IN list expression if possible
1743            //
1744            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1745            Expr::BinaryExpr(BinaryExpr {
1746                left,
1747                op: Or,
1748                right,
1749            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1750                let lhs = to_inlist(*left).unwrap();
1751                let rhs = to_inlist(*right).unwrap();
1752                let mut seen: HashSet<Expr> = HashSet::new();
1753                let list = lhs
1754                    .list
1755                    .into_iter()
1756                    .chain(rhs.list)
1757                    .filter(|e| seen.insert(e.to_owned()))
1758                    .collect::<Vec<_>>();
1759
1760                let merged_inlist = InList {
1761                    expr: lhs.expr,
1762                    list,
1763                    negated: false,
1764                };
1765
1766                Transformed::yes(Expr::InList(merged_inlist))
1767            }
1768
1769            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1770            //
1771            // Rules:
1772            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1773            //   Intersection:
1774            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1775            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1776            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1777            //   Union:
1778            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1779            //     # This rule is handled by `or_in_list_simplifier.rs`
1780            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1781            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1782            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1783            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1784            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1785            Expr::BinaryExpr(BinaryExpr {
1786                left,
1787                op: And,
1788                right,
1789            }) if are_inlist_and_eq_and_match_neg(
1790                left.as_ref(),
1791                right.as_ref(),
1792                false,
1793                false,
1794            ) =>
1795            {
1796                match (*left, *right) {
1797                    (Expr::InList(l1), Expr::InList(l2)) => {
1798                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1799                    }
1800                    // Matched previously once
1801                    _ => unreachable!(),
1802                }
1803            }
1804
1805            Expr::BinaryExpr(BinaryExpr {
1806                left,
1807                op: And,
1808                right,
1809            }) if are_inlist_and_eq_and_match_neg(
1810                left.as_ref(),
1811                right.as_ref(),
1812                true,
1813                true,
1814            ) =>
1815            {
1816                match (*left, *right) {
1817                    (Expr::InList(l1), Expr::InList(l2)) => {
1818                        return inlist_union(l1, l2, true).map(Transformed::yes);
1819                    }
1820                    // Matched previously once
1821                    _ => unreachable!(),
1822                }
1823            }
1824
1825            Expr::BinaryExpr(BinaryExpr {
1826                left,
1827                op: And,
1828                right,
1829            }) if are_inlist_and_eq_and_match_neg(
1830                left.as_ref(),
1831                right.as_ref(),
1832                false,
1833                true,
1834            ) =>
1835            {
1836                match (*left, *right) {
1837                    (Expr::InList(l1), Expr::InList(l2)) => {
1838                        return inlist_except(l1, &l2).map(Transformed::yes);
1839                    }
1840                    // Matched previously once
1841                    _ => unreachable!(),
1842                }
1843            }
1844
1845            Expr::BinaryExpr(BinaryExpr {
1846                left,
1847                op: And,
1848                right,
1849            }) if are_inlist_and_eq_and_match_neg(
1850                left.as_ref(),
1851                right.as_ref(),
1852                true,
1853                false,
1854            ) =>
1855            {
1856                match (*left, *right) {
1857                    (Expr::InList(l1), Expr::InList(l2)) => {
1858                        return inlist_except(l2, &l1).map(Transformed::yes);
1859                    }
1860                    // Matched previously once
1861                    _ => unreachable!(),
1862                }
1863            }
1864
1865            Expr::BinaryExpr(BinaryExpr {
1866                left,
1867                op: Or,
1868                right,
1869            }) if are_inlist_and_eq_and_match_neg(
1870                left.as_ref(),
1871                right.as_ref(),
1872                true,
1873                true,
1874            ) =>
1875            {
1876                match (*left, *right) {
1877                    (Expr::InList(l1), Expr::InList(l2)) => {
1878                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1879                    }
1880                    // Matched previously once
1881                    _ => unreachable!(),
1882                }
1883            }
1884
1885            // =======================================
1886            // unwrap_cast_in_comparison
1887            // =======================================
1888            //
1889            // For case:
1890            // try_cast/cast(expr as data_type) op literal
1891            Expr::BinaryExpr(BinaryExpr { left, op, right })
1892                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1893                    info, &left, op, &right,
1894                ) && op.supports_propagation() =>
1895            {
1896                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1897            }
1898            // literal op try_cast/cast(expr as data_type)
1899            // -->
1900            // try_cast/cast(expr as data_type) op_swap literal
1901            Expr::BinaryExpr(BinaryExpr { left, op, right })
1902                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1903                    info, &right, op, &left,
1904                ) && op.supports_propagation()
1905                    && op.swap().is_some() =>
1906            {
1907                unwrap_cast_in_comparison_for_binary(
1908                    info,
1909                    *right,
1910                    *left,
1911                    op.swap().unwrap(),
1912                )?
1913            }
1914            // For case:
1915            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1916            Expr::InList(InList {
1917                expr: mut left,
1918                list,
1919                negated,
1920            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1921                info, &left, &list,
1922            ) =>
1923            {
1924                let (Expr::TryCast(TryCast {
1925                    expr: left_expr, ..
1926                })
1927                | Expr::Cast(Cast {
1928                    expr: left_expr, ..
1929                })) = left.as_mut()
1930                else {
1931                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1932                };
1933
1934                let expr_type = info.get_data_type(left_expr)?;
1935                let right_exprs = list
1936                    .into_iter()
1937                    .map(|right| {
1938                        match right {
1939                            Expr::Literal(right_lit_value, _) => {
1940                                // if the right_lit_value can be casted to the type of internal_left_expr
1941                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1942                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1943                                    internal_err!(
1944                                        "Can't cast the list expr {:?} to type {}",
1945                                        right_lit_value, &expr_type
1946                                    )?
1947                                };
1948                                Ok(lit(value))
1949                            }
1950                            other_expr => internal_err!(
1951                                "Only support literal expr to optimize, but the expr is {:?}",
1952                                &other_expr
1953                            ),
1954                        }
1955                    })
1956                    .collect::<Result<Vec<_>>>()?;
1957
1958                Transformed::yes(Expr::InList(InList {
1959                    expr: std::mem::take(left_expr),
1960                    list: right_exprs,
1961                    negated,
1962                }))
1963            }
1964
1965            // no additional rewrites possible
1966            expr => Transformed::no(expr),
1967        })
1968    }
1969}
1970
1971fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1972    match expr {
1973        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1974        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1975        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1976        _ => None,
1977    }
1978}
1979
1980fn to_string_scalar(data_type: &DataType, value: Option<String>) -> Expr {
1981    match data_type {
1982        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1983        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1984        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1985        _ => unreachable!(),
1986    }
1987}
1988
1989fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1990    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1991    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1992}
1993
1994// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1995fn are_inlist_and_eq_and_match_neg(
1996    left: &Expr,
1997    right: &Expr,
1998    is_left_neg: bool,
1999    is_right_neg: bool,
2000) -> bool {
2001    match (left, right) {
2002        (Expr::InList(l), Expr::InList(r)) => {
2003            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
2004        }
2005        _ => false,
2006    }
2007}
2008
2009// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
2010fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
2011    let left = as_inlist(left);
2012    let right = as_inlist(right);
2013    if let (Some(lhs), Some(rhs)) = (left, right) {
2014        matches!(lhs.expr.as_ref(), Expr::Column(_))
2015            && matches!(rhs.expr.as_ref(), Expr::Column(_))
2016            && lhs.expr == rhs.expr
2017            && !lhs.negated
2018            && !rhs.negated
2019    } else {
2020        false
2021    }
2022}
2023
2024/// Try to convert an expression to an in-list expression
2025fn as_inlist(expr: &'_ Expr) -> Option<Cow<'_, InList>> {
2026    match expr {
2027        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
2028        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
2029            match (left.as_ref(), right.as_ref()) {
2030                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
2031                    expr: left.clone(),
2032                    list: vec![*right.clone()],
2033                    negated: false,
2034                })),
2035                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
2036                    expr: right.clone(),
2037                    list: vec![*left.clone()],
2038                    negated: false,
2039                })),
2040                _ => None,
2041            }
2042        }
2043        _ => None,
2044    }
2045}
2046
2047fn to_inlist(expr: Expr) -> Option<InList> {
2048    match expr {
2049        Expr::InList(inlist) => Some(inlist),
2050        Expr::BinaryExpr(BinaryExpr {
2051            left,
2052            op: Operator::Eq,
2053            right,
2054        }) => match (left.as_ref(), right.as_ref()) {
2055            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
2056                expr: left,
2057                list: vec![*right],
2058                negated: false,
2059            }),
2060            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
2061                expr: right,
2062                list: vec![*left],
2063                negated: false,
2064            }),
2065            _ => None,
2066        },
2067        _ => None,
2068    }
2069}
2070
2071/// Return the union of two inlist expressions
2072/// maintaining the order of the elements in the two lists
2073fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
2074    // extend the list in l1 with the elements in l2 that are not already in l1
2075    let l1_items: HashSet<_> = l1.list.iter().collect();
2076
2077    // keep all l2 items that do not also appear in l1
2078    let keep_l2: Vec<_> = l2
2079        .list
2080        .into_iter()
2081        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
2082        .collect();
2083
2084    l1.list.extend(keep_l2);
2085    l1.negated = negated;
2086    Ok(Expr::InList(l1))
2087}
2088
2089/// Return the intersection of two inlist expressions
2090/// maintaining the order of the elements in the two lists
2091fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
2092    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2093
2094    // remove all items from l1 that are not in l2
2095    l1.list.retain(|e| l2_items.contains(e));
2096
2097    // e in () is always false
2098    // e not in () is always true
2099    if l1.list.is_empty() {
2100        return Ok(lit(negated));
2101    }
2102    Ok(Expr::InList(l1))
2103}
2104
2105/// Return the all items in l1 that are not in l2
2106/// maintaining the order of the elements in the two lists
2107fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
2108    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2109
2110    // keep only items from l1 that are not in l2
2111    l1.list.retain(|e| !l2_items.contains(e));
2112
2113    if l1.list.is_empty() {
2114        return Ok(lit(false));
2115    }
2116    Ok(Expr::InList(l1))
2117}
2118
2119/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
2120fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
2121    if !info.nullable(&expr)? {
2122        Ok(expr)
2123    } else {
2124        Ok(Expr::BinaryExpr(BinaryExpr {
2125            left: Box::new(expr),
2126            op: Operator::IsNotDistinctFrom,
2127            right: Box::new(lit(true)),
2128        }))
2129    }
2130}
2131
2132// A * 1 -> A
2133// A / 1 -> A
2134//
2135// Move this function body out of the large match branch avoid stack overflow
2136fn simplify_right_is_one_case<S: SimplifyInfo>(
2137    info: &S,
2138    left: Box<Expr>,
2139    op: &Operator,
2140    right: &Expr,
2141) -> Result<Transformed<Expr>> {
2142    // Check if resulting type would be different due to coercion
2143    let left_type = info.get_data_type(&left)?;
2144    let right_type = info.get_data_type(right)?;
2145    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2146        Ok(result_type) => {
2147            // Only cast if the types differ
2148            if left_type != result_type {
2149                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2150            } else {
2151                Ok(Transformed::yes(*left))
2152            }
2153        }
2154        Err(_) => Ok(Transformed::yes(*left)),
2155    }
2156}
2157
2158#[cfg(test)]
2159mod tests {
2160    use super::*;
2161    use crate::simplify_expressions::SimplifyContext;
2162    use crate::test::test_table_scan_with_name;
2163    use arrow::datatypes::FieldRef;
2164    use datafusion_common::{DFSchemaRef, ToDFSchema, assert_contains};
2165    use datafusion_expr::{
2166        expr::WindowFunction,
2167        function::{
2168            AccumulatorArgs, AggregateFunctionSimplification,
2169            WindowFunctionSimplification,
2170        },
2171        interval_arithmetic::Interval,
2172        *,
2173    };
2174    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2175    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2176    use datafusion_physical_expr::PhysicalExpr;
2177    use std::hash::Hash;
2178    use std::sync::LazyLock;
2179    use std::{
2180        collections::HashMap,
2181        ops::{BitAnd, BitOr, BitXor},
2182        sync::Arc,
2183    };
2184
2185    // ------------------------------
2186    // --- ExprSimplifier tests -----
2187    // ------------------------------
2188    #[test]
2189    fn api_basic() {
2190        let props = ExecutionProps::new();
2191        let simplifier =
2192            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2193
2194        let expr = lit(1) + lit(2);
2195        let expected = lit(3);
2196        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2197    }
2198
2199    #[test]
2200    fn basic_coercion() {
2201        let schema = test_schema();
2202        let props = ExecutionProps::new();
2203        let simplifier = ExprSimplifier::new(
2204            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2205        );
2206
2207        // Note expr type is int32 (not int64)
2208        // (1i64 + 2i32) < i
2209        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2210        // should fully simplify to 3 < i (though i has been coerced to i64)
2211        let expected = lit(3i64).lt(col("i"));
2212
2213        let expr = simplifier.coerce(expr, &schema).unwrap();
2214
2215        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2216    }
2217
2218    fn test_schema() -> DFSchemaRef {
2219        static TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
2220            Schema::new(vec![
2221                Field::new("i", DataType::Int64, false),
2222                Field::new("b", DataType::Boolean, true),
2223            ])
2224            .to_dfschema_ref()
2225            .unwrap()
2226        });
2227        Arc::clone(&TEST_SCHEMA)
2228    }
2229
2230    #[test]
2231    fn simplify_and_constant_prop() {
2232        let props = ExecutionProps::new();
2233        let simplifier =
2234            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2235
2236        // should be able to simplify to false
2237        // (i * (1 - 2)) > 0
2238        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2239        let expected = lit(false);
2240        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2241    }
2242
2243    #[test]
2244    fn simplify_and_constant_prop_with_case() {
2245        let props = ExecutionProps::new();
2246        let simplifier =
2247            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2248
2249        //   CASE
2250        //     WHEN i>5 AND false THEN i > 5
2251        //     WHEN i<5 AND true THEN i < 5
2252        //     ELSE false
2253        //   END
2254        //
2255        // Can be simplified to `i < 5`
2256        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2257            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2258            .otherwise(lit(false))
2259            .unwrap();
2260        let expected = col("i").lt(lit(5));
2261        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2262    }
2263
2264    // ------------------------------
2265    // --- Simplifier tests -----
2266    // ------------------------------
2267
2268    #[test]
2269    fn test_simplify_canonicalize() {
2270        {
2271            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2272            let expected = col("c2").gt(lit(1));
2273            assert_eq!(simplify(expr), expected);
2274        }
2275        {
2276            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2277            let expected = col("c2").gt(col("c1"));
2278            assert_eq!(simplify(expr), expected);
2279        }
2280        {
2281            let expr = col("c1")
2282                .eq(lit(1))
2283                .and(lit(1).eq(col("c1")))
2284                .and(col("c1").eq(lit(3)));
2285            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2286            assert_eq!(simplify(expr), expected);
2287        }
2288        {
2289            let expr = col("c1")
2290                .eq(col("c2"))
2291                .and(col("c1").gt(lit(5)))
2292                .and(col("c2").eq(col("c1")));
2293            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2294            assert_eq!(simplify(expr), expected);
2295        }
2296        {
2297            let expr = col("c1")
2298                .eq(lit(1))
2299                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2300            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2301            assert_eq!(simplify(expr), expected);
2302        }
2303        {
2304            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2305            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2306            assert_eq!(simplify(expr), expected);
2307        }
2308        {
2309            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2310            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2311            assert_eq!(simplify(expr), expected);
2312        }
2313        {
2314            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2315            let expected = col("c2").lt(col("c1"));
2316            assert_eq!(simplify(expr), expected);
2317        }
2318    }
2319
2320    #[test]
2321    fn test_simplify_eq_not_self() {
2322        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2323        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2324        let expr_a = col("c2").eq(col("c2"));
2325        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2326
2327        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2328        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2329        let expected_b = lit(true);
2330
2331        assert_eq!(simplify(expr_a), expected_a);
2332        assert_eq!(simplify(expr_b), expected_b);
2333    }
2334
2335    #[test]
2336    fn test_simplify_or_true() {
2337        let expr_a = col("c2").or(lit(true));
2338        let expr_b = lit(true).or(col("c2"));
2339        let expected = lit(true);
2340
2341        assert_eq!(simplify(expr_a), expected);
2342        assert_eq!(simplify(expr_b), expected);
2343    }
2344
2345    #[test]
2346    fn test_simplify_or_false() {
2347        let expr_a = lit(false).or(col("c2"));
2348        let expr_b = col("c2").or(lit(false));
2349        let expected = col("c2");
2350
2351        assert_eq!(simplify(expr_a), expected);
2352        assert_eq!(simplify(expr_b), expected);
2353    }
2354
2355    #[test]
2356    fn test_simplify_or_same() {
2357        let expr = col("c2").or(col("c2"));
2358        let expected = col("c2");
2359
2360        assert_eq!(simplify(expr), expected);
2361    }
2362
2363    #[test]
2364    fn test_simplify_or_not_self() {
2365        // A OR !A if A is not nullable --> true
2366        // !A OR A if A is not nullable --> true
2367        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2368        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2369        let expected = lit(true);
2370
2371        assert_eq!(simplify(expr_a), expected);
2372        assert_eq!(simplify(expr_b), expected);
2373    }
2374
2375    #[test]
2376    fn test_simplify_and_false() {
2377        let expr_a = lit(false).and(col("c2"));
2378        let expr_b = col("c2").and(lit(false));
2379        let expected = lit(false);
2380
2381        assert_eq!(simplify(expr_a), expected);
2382        assert_eq!(simplify(expr_b), expected);
2383    }
2384
2385    #[test]
2386    fn test_simplify_and_same() {
2387        let expr = col("c2").and(col("c2"));
2388        let expected = col("c2");
2389
2390        assert_eq!(simplify(expr), expected);
2391    }
2392
2393    #[test]
2394    fn test_simplify_and_true() {
2395        let expr_a = lit(true).and(col("c2"));
2396        let expr_b = col("c2").and(lit(true));
2397        let expected = col("c2");
2398
2399        assert_eq!(simplify(expr_a), expected);
2400        assert_eq!(simplify(expr_b), expected);
2401    }
2402
2403    #[test]
2404    fn test_simplify_and_not_self() {
2405        // A AND !A if A is not nullable --> false
2406        // !A AND A if A is not nullable --> false
2407        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2408        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2409        let expected = lit(false);
2410
2411        assert_eq!(simplify(expr_a), expected);
2412        assert_eq!(simplify(expr_b), expected);
2413    }
2414
2415    #[test]
2416    fn test_simplify_multiply_by_one() {
2417        let expr_a = col("c2") * lit(1);
2418        let expr_b = lit(1) * col("c2");
2419        let expected = col("c2");
2420
2421        assert_eq!(simplify(expr_a), expected);
2422        assert_eq!(simplify(expr_b), expected);
2423
2424        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2425        assert_eq!(simplify(expr), expected);
2426
2427        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2428        assert_eq!(simplify(expr), expected);
2429    }
2430
2431    #[test]
2432    fn test_simplify_multiply_by_null() {
2433        let null = lit(ScalarValue::Int64(None));
2434        // A * null --> null
2435        {
2436            let expr = col("c3") * null.clone();
2437            assert_eq!(simplify(expr), null);
2438        }
2439        // null * A --> null
2440        {
2441            let expr = null.clone() * col("c3");
2442            assert_eq!(simplify(expr), null);
2443        }
2444    }
2445
2446    #[test]
2447    fn test_simplify_multiply_by_zero() {
2448        // cannot optimize A * null (null * A) if A is nullable
2449        {
2450            let expr_a = col("c2") * lit(0);
2451            let expr_b = lit(0) * col("c2");
2452
2453            assert_eq!(simplify(expr_a.clone()), expr_a);
2454            assert_eq!(simplify(expr_b.clone()), expr_b);
2455        }
2456        // 0 * A --> 0 if A is not nullable
2457        {
2458            let expr = lit(0) * col("c2_non_null");
2459            assert_eq!(simplify(expr), lit(0));
2460        }
2461        // A * 0 --> 0 if A is not nullable
2462        {
2463            let expr = col("c2_non_null") * lit(0);
2464            assert_eq!(simplify(expr), lit(0));
2465        }
2466        // A * Decimal128(0) --> 0 if A is not nullable
2467        {
2468            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2469            assert_eq!(
2470                simplify(expr),
2471                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2472            );
2473            let expr = binary_expr(
2474                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2475                Operator::Multiply,
2476                col("c2_non_null"),
2477            );
2478            assert_eq!(
2479                simplify(expr),
2480                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2481            );
2482        }
2483    }
2484
2485    #[test]
2486    fn test_simplify_divide_by_one() {
2487        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2488        let expected = col("c2");
2489        assert_eq!(simplify(expr), expected);
2490        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2491        assert_eq!(simplify(expr), expected);
2492    }
2493
2494    #[test]
2495    fn test_simplify_divide_null() {
2496        // A / null --> null
2497        let null = lit(ScalarValue::Int64(None));
2498        {
2499            let expr = col("c3") / null.clone();
2500            assert_eq!(simplify(expr), null);
2501        }
2502        // null / A --> null
2503        {
2504            let expr = null.clone() / col("c3");
2505            assert_eq!(simplify(expr), null);
2506        }
2507    }
2508
2509    #[test]
2510    fn test_simplify_divide_by_same() {
2511        let expr = col("c2") / col("c2");
2512        // if c2 is null, c2 / c2 = null, so can't simplify
2513        let expected = expr.clone();
2514
2515        assert_eq!(simplify(expr), expected);
2516    }
2517
2518    #[test]
2519    fn test_simplify_modulo_by_null() {
2520        let null = lit(ScalarValue::Int64(None));
2521        // A % null --> null
2522        {
2523            let expr = col("c3") % null.clone();
2524            assert_eq!(simplify(expr), null);
2525        }
2526        // null % A --> null
2527        {
2528            let expr = null.clone() % col("c3");
2529            assert_eq!(simplify(expr), null);
2530        }
2531    }
2532
2533    #[test]
2534    fn test_simplify_modulo_by_one() {
2535        let expr = col("c2") % lit(1);
2536        // if c2 is null, c2 % 1 = null, so can't simplify
2537        let expected = expr.clone();
2538
2539        assert_eq!(simplify(expr), expected);
2540    }
2541
2542    #[test]
2543    fn test_simplify_divide_zero_by_zero() {
2544        // because divide by 0 maybe occur in short-circuit expression
2545        // so we should not simplify this, and throw error in runtime
2546        let expr = lit(0) / lit(0);
2547        let expected = expr.clone();
2548
2549        assert_eq!(simplify(expr), expected);
2550    }
2551
2552    #[test]
2553    fn test_simplify_divide_by_zero() {
2554        // because divide by 0 maybe occur in short-circuit expression
2555        // so we should not simplify this, and throw error in runtime
2556        let expr = col("c2_non_null") / lit(0);
2557        let expected = expr.clone();
2558
2559        assert_eq!(simplify(expr), expected);
2560    }
2561
2562    #[test]
2563    fn test_simplify_modulo_by_one_non_null() {
2564        let expr = col("c3_non_null") % lit(1);
2565        let expected = lit(0_i64);
2566        assert_eq!(simplify(expr), expected);
2567        let expr =
2568            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2569        assert_eq!(simplify(expr), expected);
2570    }
2571
2572    #[test]
2573    fn test_simplify_bitwise_xor_by_null() {
2574        let null = lit(ScalarValue::Int64(None));
2575        // A ^ null --> null
2576        {
2577            let expr = col("c3") ^ null.clone();
2578            assert_eq!(simplify(expr), null);
2579        }
2580        // null ^ A --> null
2581        {
2582            let expr = null.clone() ^ col("c3");
2583            assert_eq!(simplify(expr), null);
2584        }
2585    }
2586
2587    #[test]
2588    fn test_simplify_bitwise_shift_right_by_null() {
2589        let null = lit(ScalarValue::Int64(None));
2590        // A >> null --> null
2591        {
2592            let expr = col("c3") >> null.clone();
2593            assert_eq!(simplify(expr), null);
2594        }
2595        // null >> A --> null
2596        {
2597            let expr = null.clone() >> col("c3");
2598            assert_eq!(simplify(expr), null);
2599        }
2600    }
2601
2602    #[test]
2603    fn test_simplify_bitwise_shift_left_by_null() {
2604        let null = lit(ScalarValue::Int64(None));
2605        // A << null --> null
2606        {
2607            let expr = col("c3") << null.clone();
2608            assert_eq!(simplify(expr), null);
2609        }
2610        // null << A --> null
2611        {
2612            let expr = null.clone() << col("c3");
2613            assert_eq!(simplify(expr), null);
2614        }
2615    }
2616
2617    #[test]
2618    fn test_simplify_bitwise_and_by_zero() {
2619        // A & 0 --> 0
2620        {
2621            let expr = col("c2_non_null") & lit(0);
2622            assert_eq!(simplify(expr), lit(0));
2623        }
2624        // 0 & A --> 0
2625        {
2626            let expr = lit(0) & col("c2_non_null");
2627            assert_eq!(simplify(expr), lit(0));
2628        }
2629    }
2630
2631    #[test]
2632    fn test_simplify_bitwise_or_by_zero() {
2633        // A | 0 --> A
2634        {
2635            let expr = col("c2_non_null") | lit(0);
2636            assert_eq!(simplify(expr), col("c2_non_null"));
2637        }
2638        // 0 | A --> A
2639        {
2640            let expr = lit(0) | col("c2_non_null");
2641            assert_eq!(simplify(expr), col("c2_non_null"));
2642        }
2643    }
2644
2645    #[test]
2646    fn test_simplify_bitwise_xor_by_zero() {
2647        // A ^ 0 --> A
2648        {
2649            let expr = col("c2_non_null") ^ lit(0);
2650            assert_eq!(simplify(expr), col("c2_non_null"));
2651        }
2652        // 0 ^ A --> A
2653        {
2654            let expr = lit(0) ^ col("c2_non_null");
2655            assert_eq!(simplify(expr), col("c2_non_null"));
2656        }
2657    }
2658
2659    #[test]
2660    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2661        // A >> 0 --> A
2662        {
2663            let expr = col("c2_non_null") >> lit(0);
2664            assert_eq!(simplify(expr), col("c2_non_null"));
2665        }
2666    }
2667
2668    #[test]
2669    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2670        // A << 0 --> A
2671        {
2672            let expr = col("c2_non_null") << lit(0);
2673            assert_eq!(simplify(expr), col("c2_non_null"));
2674        }
2675    }
2676
2677    #[test]
2678    fn test_simplify_bitwise_and_by_null() {
2679        let null = Expr::Literal(ScalarValue::Int64(None), None);
2680        // A & null --> null
2681        {
2682            let expr = col("c3") & null.clone();
2683            assert_eq!(simplify(expr), null);
2684        }
2685        // null & A --> null
2686        {
2687            let expr = null.clone() & col("c3");
2688            assert_eq!(simplify(expr), null);
2689        }
2690    }
2691
2692    #[test]
2693    fn test_simplify_composed_bitwise_and() {
2694        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2695
2696        let expr = bitwise_and(
2697            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2698            col("c2").gt(lit(5)),
2699        );
2700        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2701
2702        assert_eq!(simplify(expr), expected);
2703
2704        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2705
2706        let expr = bitwise_and(
2707            col("c2").gt(lit(5)),
2708            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2709        );
2710        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2711        assert_eq!(simplify(expr), expected);
2712    }
2713
2714    #[test]
2715    fn test_simplify_composed_bitwise_or() {
2716        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2717
2718        let expr = bitwise_or(
2719            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2720            col("c2").gt(lit(5)),
2721        );
2722        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2723
2724        assert_eq!(simplify(expr), expected);
2725
2726        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2727
2728        let expr = bitwise_or(
2729            col("c2").gt(lit(5)),
2730            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2731        );
2732        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2733
2734        assert_eq!(simplify(expr), expected);
2735    }
2736
2737    #[test]
2738    fn test_simplify_composed_bitwise_xor() {
2739        // with an even number of the column "c2"
2740        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2741
2742        let expr = bitwise_xor(
2743            col("c2"),
2744            bitwise_xor(
2745                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2746                bitwise_and(col("c1"), col("c2")),
2747            ),
2748        );
2749
2750        let expected = bitwise_xor(
2751            bitwise_or(col("c2"), col("c1")),
2752            bitwise_and(col("c1"), col("c2")),
2753        );
2754
2755        assert_eq!(simplify(expr), expected);
2756
2757        // with an odd number of the column "c2"
2758        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2759
2760        let expr = bitwise_xor(
2761            col("c2"),
2762            bitwise_xor(
2763                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2764                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2765            ),
2766        );
2767
2768        let expected = bitwise_xor(
2769            col("c2"),
2770            bitwise_xor(
2771                bitwise_or(col("c2"), col("c1")),
2772                bitwise_and(col("c1"), col("c2")),
2773            ),
2774        );
2775
2776        assert_eq!(simplify(expr), expected);
2777
2778        // with an even number of the column "c2"
2779        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2780
2781        let expr = bitwise_xor(
2782            bitwise_xor(
2783                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2784                bitwise_and(col("c1"), col("c2")),
2785            ),
2786            col("c2"),
2787        );
2788
2789        let expected = bitwise_xor(
2790            bitwise_or(col("c2"), col("c1")),
2791            bitwise_and(col("c1"), col("c2")),
2792        );
2793
2794        assert_eq!(simplify(expr), expected);
2795
2796        // with an odd number of the column "c2"
2797        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2798
2799        let expr = bitwise_xor(
2800            bitwise_xor(
2801                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2802                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2803            ),
2804            col("c2"),
2805        );
2806
2807        let expected = bitwise_xor(
2808            bitwise_xor(
2809                bitwise_or(col("c2"), col("c1")),
2810                bitwise_and(col("c1"), col("c2")),
2811            ),
2812            col("c2"),
2813        );
2814
2815        assert_eq!(simplify(expr), expected);
2816    }
2817
2818    #[test]
2819    fn test_simplify_negated_bitwise_and() {
2820        // !c4 & c4 --> 0
2821        let expr = (-col("c4_non_null")) & col("c4_non_null");
2822        let expected = lit(0u32);
2823
2824        assert_eq!(simplify(expr), expected);
2825        // c4 & !c4 --> 0
2826        let expr = col("c4_non_null") & (-col("c4_non_null"));
2827        let expected = lit(0u32);
2828
2829        assert_eq!(simplify(expr), expected);
2830
2831        // !c3 & c3 --> 0
2832        let expr = (-col("c3_non_null")) & col("c3_non_null");
2833        let expected = lit(0i64);
2834
2835        assert_eq!(simplify(expr), expected);
2836        // c3 & !c3 --> 0
2837        let expr = col("c3_non_null") & (-col("c3_non_null"));
2838        let expected = lit(0i64);
2839
2840        assert_eq!(simplify(expr), expected);
2841    }
2842
2843    #[test]
2844    fn test_simplify_negated_bitwise_or() {
2845        // !c4 | c4 --> -1
2846        let expr = (-col("c4_non_null")) | col("c4_non_null");
2847        let expected = lit(-1i32);
2848
2849        assert_eq!(simplify(expr), expected);
2850
2851        // c4 | !c4 --> -1
2852        let expr = col("c4_non_null") | (-col("c4_non_null"));
2853        let expected = lit(-1i32);
2854
2855        assert_eq!(simplify(expr), expected);
2856
2857        // !c3 | c3 --> -1
2858        let expr = (-col("c3_non_null")) | col("c3_non_null");
2859        let expected = lit(-1i64);
2860
2861        assert_eq!(simplify(expr), expected);
2862
2863        // c3 | !c3 --> -1
2864        let expr = col("c3_non_null") | (-col("c3_non_null"));
2865        let expected = lit(-1i64);
2866
2867        assert_eq!(simplify(expr), expected);
2868    }
2869
2870    #[test]
2871    fn test_simplify_negated_bitwise_xor() {
2872        // !c4 ^ c4 --> -1
2873        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2874        let expected = lit(-1i32);
2875
2876        assert_eq!(simplify(expr), expected);
2877
2878        // c4 ^ !c4 --> -1
2879        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2880        let expected = lit(-1i32);
2881
2882        assert_eq!(simplify(expr), expected);
2883
2884        // !c3 ^ c3 --> -1
2885        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2886        let expected = lit(-1i64);
2887
2888        assert_eq!(simplify(expr), expected);
2889
2890        // c3 ^ !c3 --> -1
2891        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2892        let expected = lit(-1i64);
2893
2894        assert_eq!(simplify(expr), expected);
2895    }
2896
2897    #[test]
2898    fn test_simplify_bitwise_and_or() {
2899        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2900        let expr = bitwise_and(
2901            col("c2_non_null").lt(lit(3)),
2902            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2903        );
2904        let expected = col("c2_non_null").lt(lit(3));
2905
2906        assert_eq!(simplify(expr), expected);
2907    }
2908
2909    #[test]
2910    fn test_simplify_bitwise_or_and() {
2911        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2912        let expr = bitwise_or(
2913            col("c2_non_null").lt(lit(3)),
2914            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2915        );
2916        let expected = col("c2_non_null").lt(lit(3));
2917
2918        assert_eq!(simplify(expr), expected);
2919    }
2920
2921    #[test]
2922    fn test_simplify_simple_bitwise_and() {
2923        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2924        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2925        let expected = col("c2").gt(lit(5));
2926
2927        assert_eq!(simplify(expr), expected);
2928    }
2929
2930    #[test]
2931    fn test_simplify_simple_bitwise_or() {
2932        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2933        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2934        let expected = col("c2").gt(lit(5));
2935
2936        assert_eq!(simplify(expr), expected);
2937    }
2938
2939    #[test]
2940    fn test_simplify_simple_bitwise_xor() {
2941        // c4 ^ c4 -> 0
2942        let expr = (col("c4")).bitxor(col("c4"));
2943        let expected = lit(0u32);
2944
2945        assert_eq!(simplify(expr), expected);
2946
2947        // c3 ^ c3 -> 0
2948        let expr = col("c3").bitxor(col("c3"));
2949        let expected = lit(0i64);
2950
2951        assert_eq!(simplify(expr), expected);
2952    }
2953
2954    #[test]
2955    fn test_simplify_modulo_by_zero_non_null() {
2956        // because modulo by 0 maybe occur in short-circuit expression
2957        // so we should not simplify this, and throw error in runtime.
2958        let expr = col("c2_non_null") % lit(0);
2959        let expected = expr.clone();
2960
2961        assert_eq!(simplify(expr), expected);
2962    }
2963
2964    #[test]
2965    fn test_simplify_simple_and() {
2966        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2967        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2968        let expected = col("c2").gt(lit(5));
2969
2970        assert_eq!(simplify(expr), expected);
2971    }
2972
2973    #[test]
2974    fn test_simplify_composed_and() {
2975        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2976        let expr = and(
2977            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2978            col("c2").gt(lit(5)),
2979        );
2980        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2981
2982        assert_eq!(simplify(expr), expected);
2983    }
2984
2985    #[test]
2986    fn test_simplify_negated_and() {
2987        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2988        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2989        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2990
2991        assert_eq!(simplify(expr), expected);
2992    }
2993
2994    #[test]
2995    fn test_simplify_or_and() {
2996        let l = col("c2").gt(lit(5));
2997        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2998
2999        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
3000        let expr = or(l.clone(), r.clone());
3001
3002        let expected = l.clone();
3003        assert_eq!(simplify(expr), expected);
3004
3005        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
3006        let expr = or(r, l);
3007        assert_eq!(simplify(expr), expected);
3008    }
3009
3010    #[test]
3011    fn test_simplify_or_and_non_null() {
3012        let l = col("c2_non_null").gt(lit(5));
3013        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3014
3015        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
3016        let expr = or(l.clone(), r.clone());
3017
3018        // This is only true if `c1 < 6` is not nullable / can not be null.
3019        let expected = col("c2_non_null").gt(lit(5));
3020
3021        assert_eq!(simplify(expr), expected);
3022
3023        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
3024        let expr = or(l, r);
3025
3026        assert_eq!(simplify(expr), expected);
3027    }
3028
3029    #[test]
3030    fn test_simplify_and_or() {
3031        let l = col("c2").gt(lit(5));
3032        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3033
3034        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3035        let expr = and(l.clone(), r.clone());
3036
3037        let expected = l.clone();
3038        assert_eq!(simplify(expr), expected);
3039
3040        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3041        let expr = and(r, l);
3042        assert_eq!(simplify(expr), expected);
3043    }
3044
3045    #[test]
3046    fn test_simplify_and_or_non_null() {
3047        let l = col("c2_non_null").gt(lit(5));
3048        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3049
3050        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3051        let expr = and(l.clone(), r.clone());
3052
3053        // This is only true if `c1 < 6` is not nullable / can not be null.
3054        let expected = col("c2_non_null").gt(lit(5));
3055
3056        assert_eq!(simplify(expr), expected);
3057
3058        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3059        let expr = and(l, r);
3060
3061        assert_eq!(simplify(expr), expected);
3062    }
3063
3064    #[test]
3065    fn test_simplify_by_de_morgan_laws() {
3066        // Laws with logical operations
3067        // !(c3 AND c4) --> !c3 OR !c4
3068        let expr = and(col("c3"), col("c4")).not();
3069        let expected = or(col("c3").not(), col("c4").not());
3070        assert_eq!(simplify(expr), expected);
3071        // !(c3 OR c4) --> !c3 AND !c4
3072        let expr = or(col("c3"), col("c4")).not();
3073        let expected = and(col("c3").not(), col("c4").not());
3074        assert_eq!(simplify(expr), expected);
3075        // !(!c3) --> c3
3076        let expr = col("c3").not().not();
3077        let expected = col("c3");
3078        assert_eq!(simplify(expr), expected);
3079
3080        // Laws with bitwise operations
3081        // !(c3 & c4) --> !c3 | !c4
3082        let expr = -bitwise_and(col("c3"), col("c4"));
3083        let expected = bitwise_or(-col("c3"), -col("c4"));
3084        assert_eq!(simplify(expr), expected);
3085        // !(c3 | c4) --> !c3 & !c4
3086        let expr = -bitwise_or(col("c3"), col("c4"));
3087        let expected = bitwise_and(-col("c3"), -col("c4"));
3088        assert_eq!(simplify(expr), expected);
3089        // !(!c3) --> c3
3090        let expr = -(-col("c3"));
3091        let expected = col("c3");
3092        assert_eq!(simplify(expr), expected);
3093    }
3094
3095    #[test]
3096    fn test_simplify_null_and_false() {
3097        let expr = and(lit_bool_null(), lit(false));
3098        let expr_eq = lit(false);
3099
3100        assert_eq!(simplify(expr), expr_eq);
3101    }
3102
3103    #[test]
3104    fn test_simplify_divide_null_by_null() {
3105        let null = lit(ScalarValue::Int32(None));
3106        let expr_plus = null.clone() / null.clone();
3107        let expr_eq = null;
3108
3109        assert_eq!(simplify(expr_plus), expr_eq);
3110    }
3111
3112    #[test]
3113    fn test_simplify_simplify_arithmetic_expr() {
3114        let expr_plus = lit(1) + lit(1);
3115
3116        assert_eq!(simplify(expr_plus), lit(2));
3117    }
3118
3119    #[test]
3120    fn test_simplify_simplify_eq_expr() {
3121        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
3122
3123        assert_eq!(simplify(expr_eq), lit(true));
3124    }
3125
3126    #[test]
3127    fn test_simplify_regex() {
3128        // malformed regex
3129        assert_contains!(
3130            try_simplify(regex_match(col("c1"), lit("foo{")))
3131                .unwrap_err()
3132                .to_string(),
3133            "regex parse error"
3134        );
3135
3136        // unsupported cases
3137        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3138        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3139        assert_no_change(regex_match(col("c1"), lit("%")));
3140        assert_no_change(regex_match(col("c1"), lit("_")));
3141        assert_no_change(regex_match(col("c1"), lit("f%o")));
3142        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3143        assert_no_change(regex_match(col("c1"), lit("f_o")));
3144
3145        // empty cases
3146        assert_change(
3147            regex_match(col("c1"), lit("")),
3148            if_not_null(col("c1"), true),
3149        );
3150        assert_change(
3151            regex_not_match(col("c1"), lit("")),
3152            if_not_null(col("c1"), false),
3153        );
3154        assert_change(
3155            regex_imatch(col("c1"), lit("")),
3156            if_not_null(col("c1"), true),
3157        );
3158        assert_change(
3159            regex_not_imatch(col("c1"), lit("")),
3160            if_not_null(col("c1"), false),
3161        );
3162
3163        // single character
3164        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3165
3166        // single word
3167        assert_change(
3168            regex_match(col("c1"), lit("foo")),
3169            col("c1").like(lit("%foo%")),
3170        );
3171
3172        // regular expressions that match an exact literal
3173        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3174        assert_change(
3175            regex_not_match(col("c1"), lit("^$")),
3176            col("c1").not_eq(lit("")),
3177        );
3178        assert_change(
3179            regex_match(col("c1"), lit("^foo$")),
3180            col("c1").eq(lit("foo")),
3181        );
3182        assert_change(
3183            regex_not_match(col("c1"), lit("^foo$")),
3184            col("c1").not_eq(lit("foo")),
3185        );
3186
3187        // regular expressions that match exact captured literals
3188        assert_change(
3189            regex_match(col("c1"), lit("^(foo|bar)$")),
3190            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3191        );
3192        assert_change(
3193            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3194            col("c1")
3195                .not_eq(lit("foo"))
3196                .and(col("c1").not_eq(lit("bar"))),
3197        );
3198        assert_change(
3199            regex_match(col("c1"), lit("^(foo)$")),
3200            col("c1").eq(lit("foo")),
3201        );
3202        assert_change(
3203            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3204            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3205                .or(col("c1").eq(lit("baz"))),
3206        );
3207        assert_change(
3208            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3209            col("c1")
3210                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3211        );
3212        assert_change(
3213            regex_match(col("c1"), lit("^(fo_o)$")),
3214            col("c1").eq(lit("fo_o")),
3215        );
3216        assert_change(
3217            regex_match(col("c1"), lit("^(fo_o)$")),
3218            col("c1").eq(lit("fo_o")),
3219        );
3220        assert_change(
3221            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3222            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3223        );
3224        assert_change(
3225            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3226            col("c1")
3227                .not_eq(lit("fo_o"))
3228                .and(col("c1").not_eq(lit("ba_r"))),
3229        );
3230        assert_change(
3231            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3232            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3233                .or(col("c1").eq(lit("ba_z"))),
3234        );
3235        assert_change(
3236            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3237            col("c1").in_list(
3238                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3239                false,
3240            ),
3241        );
3242
3243        // regular expressions that mismatch captured literals
3244        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3245        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3246        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3247        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3248        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3249        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3250        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3251        assert_no_change(regex_match(col("c1"), lit("^")));
3252        assert_no_change(regex_match(col("c1"), lit("$")));
3253        assert_no_change(regex_match(col("c1"), lit("$^")));
3254        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3255
3256        // regular expressions that match a partial literal
3257        assert_change(
3258            regex_match(col("c1"), lit("^foo")),
3259            col("c1").like(lit("foo%")),
3260        );
3261        assert_change(
3262            regex_match(col("c1"), lit("foo$")),
3263            col("c1").like(lit("%foo")),
3264        );
3265        assert_change(
3266            regex_match(col("c1"), lit("^foo|bar$")),
3267            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3268        );
3269
3270        // OR-chain
3271        assert_change(
3272            regex_match(col("c1"), lit("foo|bar|baz")),
3273            col("c1")
3274                .like(lit("%foo%"))
3275                .or(col("c1").like(lit("%bar%")))
3276                .or(col("c1").like(lit("%baz%"))),
3277        );
3278        assert_change(
3279            regex_match(col("c1"), lit("foo|x|baz")),
3280            col("c1")
3281                .like(lit("%foo%"))
3282                .or(col("c1").like(lit("%x%")))
3283                .or(col("c1").like(lit("%baz%"))),
3284        );
3285        assert_change(
3286            regex_not_match(col("c1"), lit("foo|bar|baz")),
3287            col("c1")
3288                .not_like(lit("%foo%"))
3289                .and(col("c1").not_like(lit("%bar%")))
3290                .and(col("c1").not_like(lit("%baz%"))),
3291        );
3292        // both anchored expressions (translated to equality) and unanchored
3293        assert_change(
3294            regex_match(col("c1"), lit("foo|^x$|baz")),
3295            col("c1")
3296                .like(lit("%foo%"))
3297                .or(col("c1").eq(lit("x")))
3298                .or(col("c1").like(lit("%baz%"))),
3299        );
3300        assert_change(
3301            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3302            col("c1")
3303                .not_like(lit("%foo%"))
3304                .and(col("c1").not_eq(lit("bar")))
3305                .and(col("c1").not_like(lit("%baz%"))),
3306        );
3307        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3308        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3309    }
3310
3311    #[track_caller]
3312    fn assert_no_change(expr: Expr) {
3313        let optimized = simplify(expr.clone());
3314        assert_eq!(expr, optimized);
3315    }
3316
3317    #[track_caller]
3318    fn assert_change(expr: Expr, expected: Expr) {
3319        let optimized = simplify(expr);
3320        assert_eq!(optimized, expected);
3321    }
3322
3323    fn regex_match(left: Expr, right: Expr) -> Expr {
3324        Expr::BinaryExpr(BinaryExpr {
3325            left: Box::new(left),
3326            op: Operator::RegexMatch,
3327            right: Box::new(right),
3328        })
3329    }
3330
3331    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3332        Expr::BinaryExpr(BinaryExpr {
3333            left: Box::new(left),
3334            op: Operator::RegexNotMatch,
3335            right: Box::new(right),
3336        })
3337    }
3338
3339    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3340        Expr::BinaryExpr(BinaryExpr {
3341            left: Box::new(left),
3342            op: Operator::RegexIMatch,
3343            right: Box::new(right),
3344        })
3345    }
3346
3347    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3348        Expr::BinaryExpr(BinaryExpr {
3349            left: Box::new(left),
3350            op: Operator::RegexNotIMatch,
3351            right: Box::new(right),
3352        })
3353    }
3354
3355    // ------------------------------
3356    // ----- Simplifier tests -------
3357    // ------------------------------
3358
3359    fn try_simplify(expr: Expr) -> Result<Expr> {
3360        let schema = expr_test_schema();
3361        let execution_props = ExecutionProps::new();
3362        let simplifier = ExprSimplifier::new(
3363            SimplifyContext::new(&execution_props).with_schema(schema),
3364        );
3365        simplifier.simplify(expr)
3366    }
3367
3368    fn coerce(expr: Expr) -> Expr {
3369        let schema = expr_test_schema();
3370        let execution_props = ExecutionProps::new();
3371        let simplifier = ExprSimplifier::new(
3372            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3373        );
3374        simplifier.coerce(expr, schema.as_ref()).unwrap()
3375    }
3376
3377    fn simplify(expr: Expr) -> Expr {
3378        try_simplify(expr).unwrap()
3379    }
3380
3381    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3382        let schema = expr_test_schema();
3383        let execution_props = ExecutionProps::new();
3384        let simplifier = ExprSimplifier::new(
3385            SimplifyContext::new(&execution_props).with_schema(schema),
3386        );
3387        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3388        Ok((expr.data, count))
3389    }
3390
3391    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3392        try_simplify_with_cycle_count(expr).unwrap()
3393    }
3394
3395    fn simplify_with_guarantee(
3396        expr: Expr,
3397        guarantees: Vec<(Expr, NullableInterval)>,
3398    ) -> Expr {
3399        let schema = expr_test_schema();
3400        let execution_props = ExecutionProps::new();
3401        let simplifier = ExprSimplifier::new(
3402            SimplifyContext::new(&execution_props).with_schema(schema),
3403        )
3404        .with_guarantees(guarantees);
3405        simplifier.simplify(expr).unwrap()
3406    }
3407
3408    fn expr_test_schema() -> DFSchemaRef {
3409        static EXPR_TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
3410            Arc::new(
3411                DFSchema::from_unqualified_fields(
3412                    vec![
3413                        Field::new("c1", DataType::Utf8, true),
3414                        Field::new("c2", DataType::Boolean, true),
3415                        Field::new("c3", DataType::Int64, true),
3416                        Field::new("c4", DataType::UInt32, true),
3417                        Field::new("c1_non_null", DataType::Utf8, false),
3418                        Field::new("c2_non_null", DataType::Boolean, false),
3419                        Field::new("c3_non_null", DataType::Int64, false),
3420                        Field::new("c4_non_null", DataType::UInt32, false),
3421                        Field::new("c5", DataType::FixedSizeBinary(3), true),
3422                    ]
3423                    .into(),
3424                    HashMap::new(),
3425                )
3426                .unwrap(),
3427            )
3428        });
3429        Arc::clone(&EXPR_TEST_SCHEMA)
3430    }
3431
3432    #[test]
3433    fn simplify_expr_null_comparison() {
3434        // x = null is always null
3435        assert_eq!(
3436            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3437            lit(ScalarValue::Boolean(None)),
3438        );
3439
3440        // null != null is always null
3441        assert_eq!(
3442            simplify(
3443                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3444            ),
3445            lit(ScalarValue::Boolean(None)),
3446        );
3447
3448        // x != null is always null
3449        assert_eq!(
3450            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3451            lit(ScalarValue::Boolean(None)),
3452        );
3453
3454        // null = x is always null
3455        assert_eq!(
3456            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3457            lit(ScalarValue::Boolean(None)),
3458        );
3459    }
3460
3461    #[test]
3462    fn simplify_expr_is_not_null() {
3463        assert_eq!(
3464            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3465            Expr::IsNotNull(Box::new(col("c1")))
3466        );
3467
3468        // 'c1_non_null IS NOT NULL' is always true
3469        assert_eq!(
3470            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3471            lit(true)
3472        );
3473    }
3474
3475    #[test]
3476    fn simplify_expr_is_null() {
3477        assert_eq!(
3478            simplify(Expr::IsNull(Box::new(col("c1")))),
3479            Expr::IsNull(Box::new(col("c1")))
3480        );
3481
3482        // 'c1_non_null IS NULL' is always false
3483        assert_eq!(
3484            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3485            lit(false)
3486        );
3487    }
3488
3489    #[test]
3490    fn simplify_expr_is_unknown() {
3491        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3492
3493        // 'c2_non_null is unknown' is always false
3494        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3495    }
3496
3497    #[test]
3498    fn simplify_expr_is_not_known() {
3499        assert_eq!(
3500            simplify(col("c2").is_not_unknown()),
3501            col("c2").is_not_unknown()
3502        );
3503
3504        // 'c2_non_null is not unknown' is always true
3505        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3506    }
3507
3508    #[test]
3509    fn simplify_expr_eq() {
3510        let schema = expr_test_schema();
3511        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3512
3513        // true = true -> true
3514        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3515
3516        // true = false -> false
3517        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3518
3519        // c2 = true -> c2
3520        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3521
3522        // c2 = false => !c2
3523        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3524    }
3525
3526    #[test]
3527    fn simplify_expr_eq_skip_nonboolean_type() {
3528        let schema = expr_test_schema();
3529
3530        // When one of the operand is not of boolean type, folding the
3531        // other boolean constant will change return type of
3532        // expression to non-boolean.
3533        //
3534        // Make sure c1 column to be used in tests is not boolean type
3535        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3536
3537        // don't fold c1 = foo
3538        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3539    }
3540
3541    #[test]
3542    fn simplify_expr_not_eq() {
3543        let schema = expr_test_schema();
3544
3545        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3546
3547        // c2 != true -> !c2
3548        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3549
3550        // c2 != false -> c2
3551        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3552
3553        // test constant
3554        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3555
3556        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3557    }
3558
3559    #[test]
3560    fn simplify_expr_not_eq_skip_nonboolean_type() {
3561        let schema = expr_test_schema();
3562
3563        // when one of the operand is not of boolean type, folding the
3564        // other boolean constant will change return type of
3565        // expression to non-boolean.
3566        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3567
3568        assert_eq!(
3569            simplify(col("c1").not_eq(lit("foo"))),
3570            col("c1").not_eq(lit("foo")),
3571        );
3572    }
3573
3574    #[test]
3575    fn simplify_literal_case_equality() {
3576        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
3577        let simple_case = Expr::Case(Case::new(
3578            None,
3579            vec![(
3580                Box::new(col("c2_non_null").not_eq(lit(false))),
3581                Box::new(lit("ok")),
3582            )],
3583            Some(Box::new(lit("not_ok"))),
3584        ));
3585
3586        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
3587        // -->
3588        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
3589        // -->
3590        // CASE WHEN c2 != false THEN true ELSE false
3591        // -->
3592        // c2
3593        assert_eq!(
3594            simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)),
3595            col("c2_non_null"),
3596        );
3597
3598        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
3599        // -->
3600        // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
3601        // -->
3602        // NOT(CASE WHEN c2 != false THEN true ELSE false)
3603        // -->
3604        // NOT(c2)
3605        assert_eq!(
3606            simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)),
3607            not(col("c2_non_null")),
3608        );
3609
3610        let complex_case = Expr::Case(Case::new(
3611            None,
3612            vec![
3613                (
3614                    Box::new(col("c1").eq(lit("inboxed"))),
3615                    Box::new(lit("pending")),
3616                ),
3617                (
3618                    Box::new(col("c1").eq(lit("scheduled"))),
3619                    Box::new(lit("pending")),
3620                ),
3621                (
3622                    Box::new(col("c1").eq(lit("completed"))),
3623                    Box::new(lit("completed")),
3624                ),
3625                (
3626                    Box::new(col("c1").eq(lit("paused"))),
3627                    Box::new(lit("paused")),
3628                ),
3629                (Box::new(col("c2")), Box::new(lit("running"))),
3630                (
3631                    Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))),
3632                    Box::new(lit("backing-off")),
3633                ),
3634            ],
3635            Some(Box::new(lit("ready"))),
3636        ));
3637
3638        assert_eq!(
3639            simplify(binary_expr(
3640                complex_case.clone(),
3641                Operator::Eq,
3642                lit("completed"),
3643            )),
3644            not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and(
3645                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3646                    .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3647            )
3648        );
3649
3650        assert_eq!(
3651            simplify(binary_expr(
3652                complex_case.clone(),
3653                Operator::NotEq,
3654                lit("completed"),
3655            )),
3656            distinct_from(col("c1").eq(lit("completed")), lit(true))
3657                .or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3658                    .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))))
3659        );
3660
3661        assert_eq!(
3662            simplify(binary_expr(
3663                complex_case.clone(),
3664                Operator::Eq,
3665                lit("running"),
3666            )),
3667            not_distinct_from(col("c2"), lit(true)).and(
3668                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3669                    .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3670                    .and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3671                    .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3672            )
3673        );
3674
3675        assert_eq!(
3676            simplify(binary_expr(
3677                complex_case.clone(),
3678                Operator::Eq,
3679                lit("ready"),
3680            )),
3681            distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3682                .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3683                .and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3684                .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3685                .and(distinct_from(col("c2"), lit(true)))
3686                .and(distinct_from(
3687                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3688                    lit(true)
3689                ))
3690        );
3691
3692        assert_eq!(
3693            simplify(binary_expr(
3694                complex_case.clone(),
3695                Operator::NotEq,
3696                lit("ready"),
3697            )),
3698            not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3699                .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3700                .or(not_distinct_from(col("c1").eq(lit("completed")), lit(true)))
3701                .or(not_distinct_from(col("c1").eq(lit("paused")), lit(true)))
3702                .or(not_distinct_from(col("c2"), lit(true)))
3703                .or(not_distinct_from(
3704                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3705                    lit(true)
3706                ))
3707        );
3708    }
3709
3710    #[test]
3711    fn simplify_expr_case_when_then_else() {
3712        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3713        // -->
3714        // CASE WHEN c2 THEN false ELSE c2
3715        // -->
3716        // false
3717        assert_eq!(
3718            simplify(Expr::Case(Case::new(
3719                None,
3720                vec![(
3721                    Box::new(col("c2_non_null").not_eq(lit(false))),
3722                    Box::new(lit("ok").eq(lit("not_ok"))),
3723                )],
3724                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3725            ))),
3726            lit(false) // #1716
3727        );
3728
3729        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3730        // -->
3731        // CASE WHEN c2 THEN true ELSE c2
3732        // -->
3733        // c2
3734        //
3735        // Need to call simplify 2x due to
3736        // https://github.com/apache/datafusion/issues/1160
3737        assert_eq!(
3738            simplify(simplify(Expr::Case(Case::new(
3739                None,
3740                vec![(
3741                    Box::new(col("c2_non_null").not_eq(lit(false))),
3742                    Box::new(lit("ok").eq(lit("ok"))),
3743                )],
3744                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3745            )))),
3746            col("c2_non_null")
3747        );
3748
3749        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3750        // -->
3751        // ISNULL(c2) OR c2
3752        //
3753        // Need to call simplify 2x due to
3754        // https://github.com/apache/datafusion/issues/1160
3755        assert_eq!(
3756            simplify(simplify(Expr::Case(Case::new(
3757                None,
3758                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3759                Some(Box::new(col("c2"))),
3760            )))),
3761            col("c2")
3762                .is_null()
3763                .or(col("c2").is_not_null().and(col("c2")))
3764        );
3765
3766        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3767        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3768        // --> c1 OR (NOT(c1) AND NOT(c2))
3769        // --> c1 OR NOT(c2)
3770        //
3771        // Need to call simplify 2x due to
3772        // https://github.com/apache/datafusion/issues/1160
3773        assert_eq!(
3774            simplify(simplify(Expr::Case(Case::new(
3775                None,
3776                vec![
3777                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3778                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3779                ],
3780                Some(Box::new(lit(true))),
3781            )))),
3782            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3783        );
3784
3785        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3786        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3787        // --> c1 OR (NOT(c1) AND c2)
3788        // --> c1 OR c2
3789        //
3790        // Need to call simplify 2x due to
3791        // https://github.com/apache/datafusion/issues/1160
3792        assert_eq!(
3793            simplify(simplify(Expr::Case(Case::new(
3794                None,
3795                vec![
3796                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3797                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3798                ],
3799                Some(Box::new(lit(true))),
3800            )))),
3801            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3802        );
3803
3804        // CASE WHEN c > 0 THEN true END AS c1
3805        assert_eq!(
3806            simplify(simplify(Expr::Case(Case::new(
3807                None,
3808                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3809                None,
3810            )))),
3811            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3812                col("c3").gt(lit(0_i64)),
3813                lit(true)
3814            )
3815            .and(lit_bool_null()))
3816        );
3817
3818        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3819        assert_eq!(
3820            simplify(simplify(Expr::Case(Case::new(
3821                None,
3822                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3823                Some(Box::new(lit(false))),
3824            )))),
3825            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3826        );
3827    }
3828
3829    #[test]
3830    fn simplify_expr_case_when_first_true() {
3831        // CASE WHEN true THEN 1 ELSE c1 END --> 1
3832        assert_eq!(
3833            simplify(Expr::Case(Case::new(
3834                None,
3835                vec![(Box::new(lit(true)), Box::new(lit(1)),)],
3836                Some(Box::new(col("c1"))),
3837            ))),
3838            lit(1)
3839        );
3840
3841        // CASE WHEN true THEN col('a') ELSE col('b') END --> col('a')
3842        assert_eq!(
3843            simplify(Expr::Case(Case::new(
3844                None,
3845                vec![(Box::new(lit(true)), Box::new(lit("a")),)],
3846                Some(Box::new(lit("b"))),
3847            ))),
3848            lit("a")
3849        );
3850
3851        // CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a')
3852        assert_eq!(
3853            simplify(Expr::Case(Case::new(
3854                None,
3855                vec![
3856                    (Box::new(lit(true)), Box::new(lit("a"))),
3857                    (Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))),
3858                ],
3859                Some(Box::new(lit("c"))),
3860            ))),
3861            lit("a")
3862        );
3863
3864        // CASE WHEN true THEN col('a') END --> col('a') (no else clause)
3865        assert_eq!(
3866            simplify(Expr::Case(Case::new(
3867                None,
3868                vec![(Box::new(lit(true)), Box::new(lit("a")),)],
3869                None,
3870            ))),
3871            lit("a")
3872        );
3873
3874        // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
3875        let expr = Expr::Case(Case::new(
3876            None,
3877            vec![(Box::new(col("c2")), Box::new(lit(1)))],
3878            Some(Box::new(lit(2))),
3879        ));
3880        assert_eq!(simplify(expr.clone()), expr);
3881
3882        // Negative test: CASE WHEN false THEN 1 ELSE 2 END should not use this rule
3883        let expr = Expr::Case(Case::new(
3884            None,
3885            vec![(Box::new(lit(false)), Box::new(lit(1)))],
3886            Some(Box::new(lit(2))),
3887        ));
3888        assert_ne!(simplify(expr), lit(1));
3889
3890        // Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified
3891        let expr = Expr::Case(Case::new(
3892            None,
3893            vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))],
3894            Some(Box::new(lit(2))),
3895        ));
3896        assert_eq!(simplify(expr.clone()), expr);
3897    }
3898
3899    #[test]
3900    fn simplify_expr_case_when_any_true() {
3901        // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
3902        assert_eq!(
3903            simplify(Expr::Case(Case::new(
3904                None,
3905                vec![
3906                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3907                    (Box::new(lit(true)), Box::new(lit("b"))),
3908                ],
3909                Some(Box::new(lit("c"))),
3910            ))),
3911            Expr::Case(Case::new(
3912                None,
3913                vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))],
3914                Some(Box::new(lit("b"))),
3915            ))
3916        );
3917
3918        // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
3919        // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
3920        assert_eq!(
3921            simplify(Expr::Case(Case::new(
3922                None,
3923                vec![
3924                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3925                    (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
3926                    (Box::new(lit(true)), Box::new(lit("c"))),
3927                    (Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))),
3928                ],
3929                Some(Box::new(lit("e"))),
3930            ))),
3931            Expr::Case(Case::new(
3932                None,
3933                vec![
3934                    (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
3935                    (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
3936                ],
3937                Some(Box::new(lit("c"))),
3938            ))
3939        );
3940
3941        // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
3942        // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
3943        assert_eq!(
3944            simplify(Expr::Case(Case::new(
3945                None,
3946                vec![
3947                    (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3948                    (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3949                    (Box::new(lit(true)), Box::new(lit(3))),
3950                ],
3951                None,
3952            ))),
3953            Expr::Case(Case::new(
3954                None,
3955                vec![
3956                    (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
3957                    (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3958                ],
3959                Some(Box::new(lit(3))),
3960            ))
3961        );
3962
3963        // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
3964        let expr = Expr::Case(Case::new(
3965            None,
3966            vec![
3967                (Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))),
3968                (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
3969            ],
3970            Some(Box::new(lit(3))),
3971        ));
3972        assert_eq!(simplify(expr.clone()), expr);
3973    }
3974
3975    #[test]
3976    fn simplify_expr_case_when_any_false() {
3977        // CASE WHEN false THEN 'a' END --> NULL
3978        assert_eq!(
3979            simplify(Expr::Case(Case::new(
3980                None,
3981                vec![(Box::new(lit(false)), Box::new(lit("a")))],
3982                None,
3983            ))),
3984            Expr::Literal(ScalarValue::Utf8(None), None)
3985        );
3986
3987        // CASE WHEN false THEN 2 ELSE 1 END --> 1
3988        assert_eq!(
3989            simplify(Expr::Case(Case::new(
3990                None,
3991                vec![(Box::new(lit(false)), Box::new(lit(2)))],
3992                Some(Box::new(lit(1))),
3993            ))),
3994            lit(1),
3995        );
3996
3997        // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
3998        assert_eq!(
3999            simplify(Expr::Case(Case::new(
4000                None,
4001                vec![
4002                    (Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))),
4003                    (Box::new(lit(false)), Box::new(col("c3"))),
4004                ],
4005                Some(Box::new(col("c4"))),
4006            ))),
4007            Expr::Case(Case::new(
4008                None,
4009                vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))],
4010                Some(Box::new(col("c4"))),
4011            ))
4012        );
4013
4014        // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
4015        let expr = Expr::Case(Case::new(
4016            None,
4017            vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))],
4018            Some(Box::new(lit(2))),
4019        ));
4020        assert_eq!(simplify(expr.clone()), expr);
4021    }
4022
4023    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
4024        Expr::BinaryExpr(BinaryExpr {
4025            left: Box::new(left.into()),
4026            op: Operator::IsDistinctFrom,
4027            right: Box::new(right.into()),
4028        })
4029    }
4030
4031    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
4032        Expr::BinaryExpr(BinaryExpr {
4033            left: Box::new(left.into()),
4034            op: Operator::IsNotDistinctFrom,
4035            right: Box::new(right.into()),
4036        })
4037    }
4038
4039    #[test]
4040    fn simplify_expr_bool_or() {
4041        // col || true is always true
4042        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
4043
4044        // col || false is always col
4045        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
4046
4047        // true || null is always true
4048        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
4049
4050        // null || true is always true
4051        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
4052
4053        // false || null is always null
4054        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
4055
4056        // null || false is always null
4057        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
4058
4059        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
4060        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
4061        // and should not be rewritten
4062        let expr = col("c1").between(lit(0), lit(10));
4063        let expr = expr.or(lit_bool_null());
4064        let result = simplify(expr);
4065
4066        let expected_expr = or(
4067            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
4068            lit_bool_null(),
4069        );
4070        assert_eq!(expected_expr, result);
4071    }
4072
4073    #[test]
4074    fn simplify_inlist() {
4075        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
4076        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
4077
4078        // null in (...)  --> null
4079        assert_eq!(
4080            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
4081            lit_bool_null()
4082        );
4083
4084        // null not in (...)  --> null
4085        assert_eq!(
4086            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
4087            lit_bool_null()
4088        );
4089
4090        assert_eq!(
4091            simplify(in_list(col("c1"), vec![lit(1)], false)),
4092            col("c1").eq(lit(1))
4093        );
4094        assert_eq!(
4095            simplify(in_list(col("c1"), vec![lit(1)], true)),
4096            col("c1").not_eq(lit(1))
4097        );
4098
4099        // more complex expressions can be simplified if list contains
4100        // one element only
4101        assert_eq!(
4102            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
4103            (col("c1") * lit(10)).eq(lit(2))
4104        );
4105
4106        assert_eq!(
4107            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
4108            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
4109        );
4110        assert_eq!(
4111            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
4112            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
4113        );
4114
4115        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
4116        assert_eq!(
4117            simplify(in_list(
4118                col("c1"),
4119                vec![scalar_subquery(Arc::clone(&subquery))],
4120                false
4121            )),
4122            in_subquery(col("c1"), Arc::clone(&subquery))
4123        );
4124        assert_eq!(
4125            simplify(in_list(
4126                col("c1"),
4127                vec![scalar_subquery(Arc::clone(&subquery))],
4128                true
4129            )),
4130            not_in_subquery(col("c1"), subquery)
4131        );
4132
4133        let subquery1 =
4134            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
4135        let subquery2 =
4136            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
4137
4138        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
4139        assert_eq!(
4140            simplify(in_list(
4141                col("c1"),
4142                vec![subquery1.clone(), subquery2.clone()],
4143                true
4144            )),
4145            col("c1")
4146                .not_eq(subquery1.clone())
4147                .and(col("c1").not_eq(subquery2.clone()))
4148        );
4149
4150        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
4151        assert_eq!(
4152            simplify(in_list(
4153                col("c1"),
4154                vec![subquery1.clone(), subquery2.clone()],
4155                false
4156            )),
4157            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
4158        );
4159
4160        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
4161        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4162            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
4163        );
4164        assert_eq!(simplify(expr), lit(false));
4165
4166        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
4167        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4168            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
4169        );
4170        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
4171
4172        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
4173        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
4174            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
4175        );
4176        assert_eq!(simplify(expr), lit(true));
4177
4178        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
4179        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
4180            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
4181        );
4182        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
4183
4184        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
4185        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
4186            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
4187        );
4188        assert_eq!(
4189            simplify(expr),
4190            in_list(
4191                col("c1"),
4192                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
4193                true
4194            )
4195        );
4196
4197        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
4198        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
4199            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
4200        );
4201        assert_eq!(
4202            simplify(expr),
4203            in_list(
4204                col("c1"),
4205                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4206                false
4207            )
4208        );
4209
4210        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
4211        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
4212            col("c1"),
4213            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4214            true,
4215        ));
4216        assert_eq!(simplify(expr), lit(false));
4217
4218        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
4219        let expr =
4220            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
4221                col("c1"),
4222                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
4223                false,
4224            ));
4225        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
4226
4227        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
4228        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
4229            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
4230        );
4231        assert_eq!(
4232            simplify(expr),
4233            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
4234        );
4235
4236        // inlist with more than two expressions
4237        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
4238        let expr = in_list(
4239            col("c1"),
4240            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
4241            false,
4242        )
4243        .and(in_list(
4244            col("c1"),
4245            vec![lit(1), lit(3), lit(5), lit(6)],
4246            false,
4247        ))
4248        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
4249        assert_eq!(
4250            simplify(expr),
4251            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
4252        );
4253
4254        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
4255        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
4256            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
4257                .and(in_list(
4258                    col("c1"),
4259                    vec![lit(3), lit(4), lit(5), lit(6)],
4260                    true,
4261                ))
4262                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
4263        );
4264        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
4265
4266        // Contains non-InList expression
4267        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
4268        let expr =
4269            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
4270                .not_eq(lit(5))
4271                .or(in_list(
4272                    col("c1"),
4273                    vec![lit(6), lit(7), lit(8), lit(9)],
4274                    true,
4275                )));
4276        // TODO: Further simplify this expression
4277        // https://github.com/apache/datafusion/issues/8970
4278        // assert_eq!(simplify(expr.clone()), lit(true));
4279        assert_eq!(simplify(expr.clone()), expr);
4280    }
4281
4282    #[test]
4283    fn simplify_null_in_empty_inlist() {
4284        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
4285        let expr = in_list(lit_bool_null(), vec![], false);
4286        assert_eq!(simplify(expr), lit(false));
4287
4288        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
4289        let expr = in_list(lit_bool_null(), vec![], true);
4290        assert_eq!(simplify(expr), lit(true));
4291
4292        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
4293        let null_null = || Expr::Literal(ScalarValue::Null, None);
4294        let expr = in_list(null_null(), vec![], false);
4295        assert_eq!(simplify(expr), lit(false));
4296
4297        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
4298        let expr = in_list(null_null(), vec![], true);
4299        assert_eq!(simplify(expr), lit(true));
4300    }
4301
4302    #[test]
4303    fn just_simplifier_simplify_null_in_empty_inlist() {
4304        let simplify = |expr: Expr| -> Expr {
4305            let schema = expr_test_schema();
4306            let execution_props = ExecutionProps::new();
4307            let info = SimplifyContext::new(&execution_props).with_schema(schema);
4308            let simplifier = &mut Simplifier::new(&info);
4309            expr.rewrite(simplifier)
4310                .expect("Failed to simplify expression")
4311                .data
4312        };
4313
4314        // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false
4315        let expr = in_list(lit_bool_null(), vec![], false);
4316        assert_eq!(simplify(expr), lit(false));
4317
4318        // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true
4319        let expr = in_list(lit_bool_null(), vec![], true);
4320        assert_eq!(simplify(expr), lit(true));
4321
4322        // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false
4323        let null_null = || Expr::Literal(ScalarValue::Null, None);
4324        let expr = in_list(null_null(), vec![], false);
4325        assert_eq!(simplify(expr), lit(false));
4326
4327        // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true
4328        let expr = in_list(null_null(), vec![], true);
4329        assert_eq!(simplify(expr), lit(true));
4330    }
4331
4332    #[test]
4333    fn simplify_large_or() {
4334        let expr = (0..5)
4335            .map(|i| col("c1").eq(lit(i)))
4336            .fold(lit(false), |acc, e| acc.or(e));
4337        assert_eq!(
4338            simplify(expr),
4339            in_list(col("c1"), (0..5).map(lit).collect(), false),
4340        );
4341    }
4342
4343    #[test]
4344    fn simplify_expr_bool_and() {
4345        // col & true is always col
4346        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
4347        // col & false is always false
4348        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
4349
4350        // true && null is always null
4351        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
4352
4353        // null && true is always null
4354        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
4355
4356        // false && null is always false
4357        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
4358
4359        // null && false is always false
4360        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
4361
4362        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
4363        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
4364        // and the Boolean(NULL) should remain
4365        let expr = col("c1").between(lit(0), lit(10));
4366        let expr = expr.and(lit_bool_null());
4367        let result = simplify(expr);
4368
4369        let expected_expr = and(
4370            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
4371            lit_bool_null(),
4372        );
4373        assert_eq!(expected_expr, result);
4374    }
4375
4376    #[test]
4377    fn simplify_expr_between() {
4378        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
4379        let expr = col("c2").between(lit(3), lit(4));
4380        assert_eq!(
4381            simplify(expr),
4382            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
4383        );
4384
4385        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
4386        let expr = col("c2").not_between(lit(3), lit(4));
4387        assert_eq!(
4388            simplify(expr),
4389            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
4390        );
4391    }
4392
4393    #[test]
4394    fn test_like_and_ilike() {
4395        let null = lit(ScalarValue::Utf8(None));
4396
4397        // expr [NOT] [I]LIKE NULL
4398        let expr = col("c1").like(null.clone());
4399        assert_eq!(simplify(expr), lit_bool_null());
4400
4401        let expr = col("c1").not_like(null.clone());
4402        assert_eq!(simplify(expr), lit_bool_null());
4403
4404        let expr = col("c1").ilike(null.clone());
4405        assert_eq!(simplify(expr), lit_bool_null());
4406
4407        let expr = col("c1").not_ilike(null.clone());
4408        assert_eq!(simplify(expr), lit_bool_null());
4409
4410        // expr [NOT] [I]LIKE '%'
4411        let expr = col("c1").like(lit("%"));
4412        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4413
4414        let expr = col("c1").not_like(lit("%"));
4415        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4416
4417        let expr = col("c1").ilike(lit("%"));
4418        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4419
4420        let expr = col("c1").not_ilike(lit("%"));
4421        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4422
4423        // expr [NOT] [I]LIKE '%%'
4424        let expr = col("c1").like(lit("%%"));
4425        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4426
4427        let expr = col("c1").not_like(lit("%%"));
4428        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4429
4430        let expr = col("c1").ilike(lit("%%"));
4431        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4432
4433        let expr = col("c1").not_ilike(lit("%%"));
4434        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4435
4436        // not_null_expr [NOT] [I]LIKE '%'
4437        let expr = col("c1_non_null").like(lit("%"));
4438        assert_eq!(simplify(expr), lit(true));
4439
4440        let expr = col("c1_non_null").not_like(lit("%"));
4441        assert_eq!(simplify(expr), lit(false));
4442
4443        let expr = col("c1_non_null").ilike(lit("%"));
4444        assert_eq!(simplify(expr), lit(true));
4445
4446        let expr = col("c1_non_null").not_ilike(lit("%"));
4447        assert_eq!(simplify(expr), lit(false));
4448
4449        // not_null_expr [NOT] [I]LIKE '%%'
4450        let expr = col("c1_non_null").like(lit("%%"));
4451        assert_eq!(simplify(expr), lit(true));
4452
4453        let expr = col("c1_non_null").not_like(lit("%%"));
4454        assert_eq!(simplify(expr), lit(false));
4455
4456        let expr = col("c1_non_null").ilike(lit("%%"));
4457        assert_eq!(simplify(expr), lit(true));
4458
4459        let expr = col("c1_non_null").not_ilike(lit("%%"));
4460        assert_eq!(simplify(expr), lit(false));
4461
4462        // null_constant [NOT] [I]LIKE '%'
4463        let expr = null.clone().like(lit("%"));
4464        assert_eq!(simplify(expr), lit_bool_null());
4465
4466        let expr = null.clone().not_like(lit("%"));
4467        assert_eq!(simplify(expr), lit_bool_null());
4468
4469        let expr = null.clone().ilike(lit("%"));
4470        assert_eq!(simplify(expr), lit_bool_null());
4471
4472        let expr = null.clone().not_ilike(lit("%"));
4473        assert_eq!(simplify(expr), lit_bool_null());
4474
4475        // null_constant [NOT] [I]LIKE '%%'
4476        let expr = null.clone().like(lit("%%"));
4477        assert_eq!(simplify(expr), lit_bool_null());
4478
4479        let expr = null.clone().not_like(lit("%%"));
4480        assert_eq!(simplify(expr), lit_bool_null());
4481
4482        let expr = null.clone().ilike(lit("%%"));
4483        assert_eq!(simplify(expr), lit_bool_null());
4484
4485        let expr = null.clone().not_ilike(lit("%%"));
4486        assert_eq!(simplify(expr), lit_bool_null());
4487
4488        // null_constant [NOT] [I]LIKE 'a%'
4489        let expr = null.clone().like(lit("a%"));
4490        assert_eq!(simplify(expr), lit_bool_null());
4491
4492        let expr = null.clone().not_like(lit("a%"));
4493        assert_eq!(simplify(expr), lit_bool_null());
4494
4495        let expr = null.clone().ilike(lit("a%"));
4496        assert_eq!(simplify(expr), lit_bool_null());
4497
4498        let expr = null.clone().not_ilike(lit("a%"));
4499        assert_eq!(simplify(expr), lit_bool_null());
4500
4501        // expr [NOT] [I]LIKE with pattern without wildcards
4502        let expr = col("c1").like(lit("a"));
4503        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4504        let expr = col("c1").not_like(lit("a"));
4505        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4506        let expr = col("c1").like(lit("a_"));
4507        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4508        let expr = col("c1").not_like(lit("a_"));
4509        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4510
4511        let expr = col("c1").ilike(lit("a"));
4512        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4513        let expr = col("c1").not_ilike(lit("a"));
4514        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4515    }
4516
4517    #[test]
4518    fn test_simplify_with_guarantee() {
4519        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4520        let expr_x = col("c3").gt(lit(3_i64));
4521        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4522        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4523        let expr = expr_x.clone().and(expr_y.or(expr_z));
4524
4525        // All guaranteed null
4526        let guarantees = vec![
4527            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4528            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4529            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4530        ];
4531
4532        let output = simplify_with_guarantee(expr.clone(), guarantees);
4533        assert_eq!(output, lit_bool_null());
4534
4535        // All guaranteed false
4536        let guarantees = vec![
4537            (
4538                col("c3"),
4539                NullableInterval::NotNull {
4540                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4541                },
4542            ),
4543            (
4544                col("c4"),
4545                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4546            ),
4547            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4548        ];
4549        let output = simplify_with_guarantee(expr.clone(), guarantees);
4550        assert_eq!(output, lit(false));
4551
4552        // Guaranteed false or null -> no change.
4553        let guarantees = vec![
4554            (
4555                col("c3"),
4556                NullableInterval::MaybeNull {
4557                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4558                },
4559            ),
4560            (
4561                col("c4"),
4562                NullableInterval::MaybeNull {
4563                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4564                },
4565            ),
4566            (
4567                col("c1"),
4568                NullableInterval::NotNull {
4569                    values: Interval::try_new(
4570                        ScalarValue::from("d"),
4571                        ScalarValue::from("f"),
4572                    )
4573                    .unwrap(),
4574                },
4575            ),
4576        ];
4577        let output = simplify_with_guarantee(expr.clone(), guarantees);
4578        assert_eq!(&output, &expr_x);
4579
4580        // Sufficient true guarantees
4581        let guarantees = vec![
4582            (
4583                col("c3"),
4584                NullableInterval::from(ScalarValue::Int64(Some(9))),
4585            ),
4586            (
4587                col("c4"),
4588                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4589            ),
4590        ];
4591        let output = simplify_with_guarantee(expr.clone(), guarantees);
4592        assert_eq!(output, lit(true));
4593
4594        // Only partially simplify
4595        let guarantees = vec![(
4596            col("c4"),
4597            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4598        )];
4599        let output = simplify_with_guarantee(expr, guarantees);
4600        assert_eq!(&output, &expr_x);
4601    }
4602
4603    #[test]
4604    fn test_expression_partial_simplify_1() {
4605        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4606        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4607        let expected = (lit(3)) + (lit(4) / lit(0));
4608
4609        assert_eq!(simplify(expr), expected);
4610    }
4611
4612    #[test]
4613    fn test_expression_partial_simplify_2() {
4614        // (1 > 2) and (4 / 0) -> false
4615        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4616        let expected = lit(false);
4617
4618        assert_eq!(simplify(expr), expected);
4619    }
4620
4621    #[test]
4622    fn test_simplify_cycles() {
4623        // TRUE
4624        let expr = lit(true);
4625        let expected = lit(true);
4626        let (expr, num_iter) = simplify_with_cycle_count(expr);
4627        assert_eq!(expr, expected);
4628        assert_eq!(num_iter, 1);
4629
4630        // (true != NULL) OR (5 > 10)
4631        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4632        let expected = lit_bool_null();
4633        let (expr, num_iter) = simplify_with_cycle_count(expr);
4634        assert_eq!(expr, expected);
4635        assert_eq!(num_iter, 2);
4636
4637        // NOTE: this currently does not simplify
4638        // (((c4 - 10) + 10) *100) / 100
4639        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4640        let expected = expr.clone();
4641        let (expr, num_iter) = simplify_with_cycle_count(expr);
4642        assert_eq!(expr, expected);
4643        assert_eq!(num_iter, 1);
4644
4645        // ((c4<1 or c3<2) and c3_non_null<3) and false
4646        let expr = col("c4")
4647            .lt(lit(1))
4648            .or(col("c3").lt(lit(2)))
4649            .and(col("c3_non_null").lt(lit(3)))
4650            .and(lit(false));
4651        let expected = lit(false);
4652        let (expr, num_iter) = simplify_with_cycle_count(expr);
4653        assert_eq!(expr, expected);
4654        assert_eq!(num_iter, 2);
4655    }
4656
4657    fn boolean_test_schema() -> DFSchemaRef {
4658        static BOOLEAN_TEST_SCHEMA: LazyLock<DFSchemaRef> = LazyLock::new(|| {
4659            Schema::new(vec![
4660                Field::new("A", DataType::Boolean, false),
4661                Field::new("B", DataType::Boolean, false),
4662                Field::new("C", DataType::Boolean, false),
4663                Field::new("D", DataType::Boolean, false),
4664            ])
4665            .to_dfschema_ref()
4666            .unwrap()
4667        });
4668        Arc::clone(&BOOLEAN_TEST_SCHEMA)
4669    }
4670
4671    #[test]
4672    fn simplify_common_factor_conjunction_in_disjunction() {
4673        let props = ExecutionProps::new();
4674        let schema = boolean_test_schema();
4675        let simplifier =
4676            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4677
4678        let a = || col("A");
4679        let b = || col("B");
4680        let c = || col("C");
4681        let d = || col("D");
4682
4683        // (A AND B) OR (A AND C) -> A AND (B OR C)
4684        let expr = a().and(b()).or(a().and(c()));
4685        let expected = a().and(b().or(c()));
4686
4687        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4688
4689        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4690        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4691        let expected = a().and(b().or(c()).or(d()));
4692        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4693
4694        // A OR (B AND C AND A) -> A
4695        let expr = a().or(b().and(c().and(a())));
4696        let expected = a();
4697        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4698    }
4699
4700    #[test]
4701    fn test_simplify_udaf() {
4702        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4703        let aggregate_function_expr =
4704            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4705                udaf.into(),
4706                vec![],
4707                false,
4708                None,
4709                vec![],
4710                None,
4711            ));
4712
4713        let expected = col("result_column");
4714        assert_eq!(simplify(aggregate_function_expr), expected);
4715
4716        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4717        let aggregate_function_expr =
4718            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4719                udaf.into(),
4720                vec![],
4721                false,
4722                None,
4723                vec![],
4724                None,
4725            ));
4726
4727        let expected = aggregate_function_expr.clone();
4728        assert_eq!(simplify(aggregate_function_expr), expected);
4729    }
4730
4731    /// A Mock UDAF which defines `simplify` to be used in tests
4732    /// related to UDAF simplification
4733    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4734    struct SimplifyMockUdaf {
4735        simplify: bool,
4736    }
4737
4738    impl SimplifyMockUdaf {
4739        /// make simplify method return new expression
4740        fn new_with_simplify() -> Self {
4741            Self { simplify: true }
4742        }
4743        /// make simplify method return no change
4744        fn new_without_simplify() -> Self {
4745            Self { simplify: false }
4746        }
4747    }
4748
4749    impl AggregateUDFImpl for SimplifyMockUdaf {
4750        fn as_any(&self) -> &dyn std::any::Any {
4751            self
4752        }
4753
4754        fn name(&self) -> &str {
4755            "mock_simplify"
4756        }
4757
4758        fn signature(&self) -> &Signature {
4759            unimplemented!()
4760        }
4761
4762        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4763            unimplemented!("not needed for tests")
4764        }
4765
4766        fn accumulator(
4767            &self,
4768            _acc_args: AccumulatorArgs,
4769        ) -> Result<Box<dyn Accumulator>> {
4770            unimplemented!("not needed for tests")
4771        }
4772
4773        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4774            unimplemented!("not needed for testing")
4775        }
4776
4777        fn create_groups_accumulator(
4778            &self,
4779            _args: AccumulatorArgs,
4780        ) -> Result<Box<dyn GroupsAccumulator>> {
4781            unimplemented!("not needed for testing")
4782        }
4783
4784        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4785            if self.simplify {
4786                Some(Box::new(|_, _| Ok(col("result_column"))))
4787            } else {
4788                None
4789            }
4790        }
4791    }
4792
4793    #[test]
4794    fn test_simplify_udwf() {
4795        let udwf = WindowFunctionDefinition::WindowUDF(
4796            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4797        );
4798        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4799
4800        let expected = col("result_column");
4801        assert_eq!(simplify(window_function_expr), expected);
4802
4803        let udwf = WindowFunctionDefinition::WindowUDF(
4804            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4805        );
4806        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4807
4808        let expected = window_function_expr.clone();
4809        assert_eq!(simplify(window_function_expr), expected);
4810    }
4811
4812    /// A Mock UDWF which defines `simplify` to be used in tests
4813    /// related to UDWF simplification
4814    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
4815    struct SimplifyMockUdwf {
4816        simplify: bool,
4817    }
4818
4819    impl SimplifyMockUdwf {
4820        /// make simplify method return new expression
4821        fn new_with_simplify() -> Self {
4822            Self { simplify: true }
4823        }
4824        /// make simplify method return no change
4825        fn new_without_simplify() -> Self {
4826            Self { simplify: false }
4827        }
4828    }
4829
4830    impl WindowUDFImpl for SimplifyMockUdwf {
4831        fn as_any(&self) -> &dyn std::any::Any {
4832            self
4833        }
4834
4835        fn name(&self) -> &str {
4836            "mock_simplify"
4837        }
4838
4839        fn signature(&self) -> &Signature {
4840            unimplemented!()
4841        }
4842
4843        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4844            if self.simplify {
4845                Some(Box::new(|_, _| Ok(col("result_column"))))
4846            } else {
4847                None
4848            }
4849        }
4850
4851        fn partition_evaluator(
4852            &self,
4853            _partition_evaluator_args: PartitionEvaluatorArgs,
4854        ) -> Result<Box<dyn PartitionEvaluator>> {
4855            unimplemented!("not needed for tests")
4856        }
4857
4858        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4859            unimplemented!("not needed for tests")
4860        }
4861
4862        fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
4863            LimitEffect::Unknown
4864        }
4865    }
4866    #[derive(Debug, PartialEq, Eq, Hash)]
4867    struct VolatileUdf {
4868        signature: Signature,
4869    }
4870
4871    impl VolatileUdf {
4872        pub fn new() -> Self {
4873            Self {
4874                signature: Signature::exact(vec![], Volatility::Volatile),
4875            }
4876        }
4877    }
4878    impl ScalarUDFImpl for VolatileUdf {
4879        fn as_any(&self) -> &dyn std::any::Any {
4880            self
4881        }
4882
4883        fn name(&self) -> &str {
4884            "VolatileUdf"
4885        }
4886
4887        fn signature(&self) -> &Signature {
4888            &self.signature
4889        }
4890
4891        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4892            Ok(DataType::Int16)
4893        }
4894
4895        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4896            panic!("dummy - not implemented")
4897        }
4898    }
4899
4900    #[test]
4901    fn test_optimize_volatile_conditions() {
4902        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4903        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4904        {
4905            let expr = rand
4906                .clone()
4907                .eq(lit(0))
4908                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4909
4910            assert_eq!(simplify(expr.clone()), expr);
4911        }
4912
4913        {
4914            let expr = col("column1")
4915                .eq(lit(2))
4916                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4917
4918            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4919        }
4920
4921        {
4922            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4923                "column1",
4924            )
4925            .eq(lit(2))
4926            .and(rand.clone().eq(lit(0))));
4927
4928            assert_eq!(
4929                simplify(expr),
4930                col("column1")
4931                    .eq(lit(2))
4932                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4933            );
4934        }
4935    }
4936
4937    #[test]
4938    fn simplify_fixed_size_binary_eq_lit() {
4939        let bytes = [1u8, 2, 3].as_slice();
4940
4941        // The expression starts simple.
4942        let expr = col("c5").eq(lit(bytes));
4943
4944        // The type coercer introduces a cast.
4945        let coerced = coerce(expr.clone());
4946        let schema = expr_test_schema();
4947        assert_eq!(
4948            coerced,
4949            col("c5")
4950                .cast_to(&DataType::Binary, schema.as_ref())
4951                .unwrap()
4952                .eq(lit(bytes))
4953        );
4954
4955        // The simplifier removes the cast.
4956        assert_eq!(
4957            simplify(coerced),
4958            col("c5").eq(Expr::Literal(
4959                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4960                None
4961            ))
4962        );
4963    }
4964
4965    #[test]
4966    fn simplify_cast_literal() {
4967        // Test that CAST(literal) expressions are evaluated at plan time
4968
4969        // CAST(123 AS Int64) should become 123i64
4970        let expr = Expr::Cast(Cast::new(Box::new(lit(123i32)), DataType::Int64));
4971        let expected = lit(123i64);
4972        assert_eq!(simplify(expr), expected);
4973
4974        // CAST(1761630189642 AS Timestamp(Nanosecond, Some("+00:00")))
4975        // Integer to timestamp cast
4976        let expr = Expr::Cast(Cast::new(
4977            Box::new(lit(1761630189642i64)),
4978            DataType::Timestamp(
4979                arrow::datatypes::TimeUnit::Nanosecond,
4980                Some("+00:00".into()),
4981            ),
4982        ));
4983        // Should evaluate to a timestamp literal
4984        let result = simplify(expr);
4985        match result {
4986            Expr::Literal(ScalarValue::TimestampNanosecond(Some(val), tz), _) => {
4987                assert_eq!(val, 1761630189642i64);
4988                assert_eq!(tz.as_deref(), Some("+00:00"));
4989            }
4990            other => panic!("Expected TimestampNanosecond literal, got: {other:?}"),
4991        }
4992
4993        // Test CAST of invalid string to timestamp - should return an error at plan time
4994        // This represents the case from the issue: CAST(Utf8("1761630189642") AS Timestamp)
4995        // "1761630189642" is NOT a valid timestamp string format
4996        let expr = Expr::Cast(Cast::new(
4997            Box::new(lit("1761630189642")),
4998            DataType::Timestamp(
4999                arrow::datatypes::TimeUnit::Nanosecond,
5000                Some("+00:00".into()),
5001            ),
5002        ));
5003
5004        // The simplification should now fail with an error at plan time
5005        let schema = test_schema();
5006        let props = ExecutionProps::new();
5007        let simplifier =
5008            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
5009        let result = simplifier.simplify(expr);
5010        assert!(result.is_err(), "Expected error for invalid cast");
5011        let err_msg = result.unwrap_err().to_string();
5012        assert_contains!(err_msg, "Error parsing timestamp");
5013    }
5014
5015    fn if_not_null(expr: Expr, then: bool) -> Expr {
5016        Expr::Case(Case {
5017            expr: Some(expr.is_not_null().into()),
5018            when_then_expr: vec![(lit(true).into(), lit(then).into())],
5019            else_expr: None,
5020        })
5021    }
5022}