datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use std::borrow::Cow;
21use std::collections::HashSet;
22use std::ops::Not;
23
24use arrow::{
25    array::{new_null_array, AsArray},
26    datatypes::{DataType, Field, Schema},
27    record_batch::RecordBatch,
28};
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
33};
34use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35use datafusion_expr::{
36    and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
37    Operator, Volatility,
38};
39use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
40use datafusion_expr::{
41    expr::{InList, InSubquery},
42    utils::{iter_conjunction, iter_conjunction_owned},
43};
44use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
45use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
46
47use super::inlist_simplifier::ShortenInListSimplifier;
48use super::utils::*;
49use crate::analyzer::type_coercion::TypeCoercionRewriter;
50use crate::simplify_expressions::guarantees::GuaranteeRewriter;
51use crate::simplify_expressions::regex::simplify_regex_expr;
52use crate::simplify_expressions::unwrap_cast::{
53    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
54    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
55    unwrap_cast_in_comparison_for_binary,
56};
57use crate::simplify_expressions::SimplifyInfo;
58use datafusion_expr::expr::FieldMetadata;
59use datafusion_expr_common::casts::try_cast_literal_to_type;
60use indexmap::IndexSet;
61use regex::Regex;
62
63/// This structure handles API for expression simplification
64///
65/// Provides simplification information based on DFSchema and
66/// [`ExecutionProps`]. This is the default implementation used by DataFusion
67///
68/// For example:
69/// ```
70/// use arrow::datatypes::{Schema, Field, DataType};
71/// use datafusion_expr::{col, lit};
72/// use datafusion_common::{DataFusionError, ToDFSchema};
73/// use datafusion_expr::execution_props::ExecutionProps;
74/// use datafusion_expr::simplify::SimplifyContext;
75/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
76///
77/// // Create the schema
78/// let schema = Schema::new(vec![
79///     Field::new("i", DataType::Int64, false),
80///   ])
81///   .to_dfschema_ref().unwrap();
82///
83/// // Create the simplifier
84/// let props = ExecutionProps::new();
85/// let context = SimplifyContext::new(&props)
86///    .with_schema(schema);
87/// let simplifier = ExprSimplifier::new(context);
88///
89/// // Use the simplifier
90///
91/// // b < 2 or (1 > 3)
92/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
93///
94/// // b < 2
95/// let simplified = simplifier.simplify(expr).unwrap();
96/// assert_eq!(simplified, col("b").lt(lit(2)));
97/// ```
98pub struct ExprSimplifier<S> {
99    info: S,
100    /// Guarantees about the values of columns. This is provided by the user
101    /// in [ExprSimplifier::with_guarantees()].
102    guarantees: Vec<(Expr, NullableInterval)>,
103    /// Should expressions be canonicalized before simplification? Defaults to
104    /// true
105    canonicalize: bool,
106    /// Maximum number of simplifier cycles
107    max_simplifier_cycles: u32,
108}
109
110pub const THRESHOLD_INLINE_INLIST: usize = 3;
111pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
112
113impl<S: SimplifyInfo> ExprSimplifier<S> {
114    /// Create a new `ExprSimplifier` with the given `info` such as an
115    /// instance of [`SimplifyContext`]. See
116    /// [`simplify`](Self::simplify) for an example.
117    ///
118    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
119    pub fn new(info: S) -> Self {
120        Self {
121            info,
122            guarantees: vec![],
123            canonicalize: true,
124            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
125        }
126    }
127
128    /// Simplifies this [`Expr`] as much as possible, evaluating
129    /// constants and applying algebraic simplifications.
130    ///
131    /// The types of the expression must match what operators expect,
132    /// or else an error may occur trying to evaluate. See
133    /// [`coerce`](Self::coerce) for a function to help.
134    ///
135    /// # Example:
136    ///
137    /// `b > 2 AND b > 2`
138    ///
139    /// can be written to
140    ///
141    /// `b > 2`
142    ///
143    /// ```
144    /// use arrow::datatypes::DataType;
145    /// use datafusion_expr::{col, lit, Expr};
146    /// use datafusion_common::Result;
147    /// use datafusion_expr::execution_props::ExecutionProps;
148    /// use datafusion_expr::simplify::SimplifyContext;
149    /// use datafusion_expr::simplify::SimplifyInfo;
150    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
151    /// use datafusion_common::DFSchema;
152    /// use std::sync::Arc;
153    ///
154    /// /// Simple implementation that provides `Simplifier` the information it needs
155    /// /// See SimplifyContext for a structure that does this.
156    /// #[derive(Default)]
157    /// struct Info {
158    ///   execution_props: ExecutionProps,
159    /// };
160    ///
161    /// impl SimplifyInfo for Info {
162    ///   fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
163    ///     Ok(false)
164    ///   }
165    ///   fn nullable(&self, expr: &Expr) -> Result<bool> {
166    ///     Ok(true)
167    ///   }
168    ///   fn execution_props(&self) -> &ExecutionProps {
169    ///     &self.execution_props
170    ///   }
171    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
172    ///     Ok(DataType::Int32)
173    ///   }
174    /// }
175    ///
176    /// // Create the simplifier
177    /// let simplifier = ExprSimplifier::new(Info::default());
178    ///
179    /// // b < 2
180    /// let b_lt_2 = col("b").gt(lit(2));
181    ///
182    /// // (b < 2) OR (b < 2)
183    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
184    ///
185    /// // (b < 2) OR (b < 2) --> (b < 2)
186    /// let expr = simplifier.simplify(expr).unwrap();
187    /// assert_eq!(expr, b_lt_2);
188    /// ```
189    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
190        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
191    }
192
193    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
194    /// constants and applying algebraic simplifications. Additionally returns a `u32`
195    /// representing the number of simplification cycles performed, which can be useful for testing
196    /// optimizations.
197    ///
198    /// See [Self::simplify] for details and usage examples.
199    ///
200    #[deprecated(
201        since = "48.0.0",
202        note = "Use `simplify_with_cycle_count_transformed` instead"
203    )]
204    #[allow(unused_mut)]
205    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
206        let (transformed, cycle_count) =
207            self.simplify_with_cycle_count_transformed(expr)?;
208        Ok((transformed.data, cycle_count))
209    }
210
211    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
212    /// constants and applying algebraic simplifications. Additionally returns a `u32`
213    /// representing the number of simplification cycles performed, which can be useful for testing
214    /// optimizations.
215    ///
216    /// # Returns
217    ///
218    /// A tuple containing:
219    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
220    /// - The number of simplification cycles that were performed
221    ///
222    /// See [Self::simplify] for details and usage examples.
223    ///
224    pub fn simplify_with_cycle_count_transformed(
225        &self,
226        mut expr: Expr,
227    ) -> Result<(Transformed<Expr>, u32)> {
228        let mut simplifier = Simplifier::new(&self.info);
229        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
230        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
231        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
232
233        if self.canonicalize {
234            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
235        }
236
237        // Evaluating constants can enable new simplifications and
238        // simplifications can enable new constant evaluation
239        // see `Self::with_max_cycles`
240        let mut num_cycles = 0;
241        let mut has_transformed = false;
242        loop {
243            let Transformed {
244                data, transformed, ..
245            } = expr
246                .rewrite(&mut const_evaluator)?
247                .transform_data(|expr| expr.rewrite(&mut simplifier))?
248                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
249            expr = data;
250            num_cycles += 1;
251            // Track if any transformation occurred
252            has_transformed = has_transformed || transformed;
253            if !transformed || num_cycles >= self.max_simplifier_cycles {
254                break;
255            }
256        }
257        // shorten inlist should be started after other inlist rules are applied
258        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
259        Ok((
260            Transformed::new_transformed(expr, has_transformed),
261            num_cycles,
262        ))
263    }
264
265    /// Apply type coercion to an [`Expr`] so that it can be
266    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
267    ///
268    /// See the [type coercion module](datafusion_expr::type_coercion)
269    /// documentation for more details on type coercion
270    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
271        let mut expr_rewrite = TypeCoercionRewriter { schema };
272        expr.rewrite(&mut expr_rewrite).data()
273    }
274
275    /// Input guarantees about the values of columns.
276    ///
277    /// The guarantees can simplify expressions. For example, if a column `x` is
278    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
279    /// literal `true`.
280    ///
281    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
282    /// where the [Expr] is a column reference and the [NullableInterval]
283    /// is an interval representing the known possible values of that column.
284    ///
285    /// ```rust
286    /// use arrow::datatypes::{DataType, Field, Schema};
287    /// use datafusion_expr::{col, lit, Expr};
288    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
289    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
290    /// use datafusion_expr::execution_props::ExecutionProps;
291    /// use datafusion_expr::simplify::SimplifyContext;
292    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
293    ///
294    /// let schema = Schema::new(vec![
295    ///   Field::new("x", DataType::Int64, false),
296    ///   Field::new("y", DataType::UInt32, false),
297    ///   Field::new("z", DataType::Int64, false),
298    ///   ])
299    ///   .to_dfschema_ref().unwrap();
300    ///
301    /// // Create the simplifier
302    /// let props = ExecutionProps::new();
303    /// let context = SimplifyContext::new(&props)
304    ///    .with_schema(schema);
305    ///
306    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
307    /// let expr_x = col("x").gt_eq(lit(3_i64));
308    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
309    /// let expr_z = col("z").gt(lit(5_i64));
310    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
311    ///
312    /// let guarantees = vec![
313    ///    // x ∈ [3, 5]
314    ///    (
315    ///        col("x"),
316    ///        NullableInterval::NotNull {
317    ///            values: Interval::make(Some(3_i64), Some(5_i64)).unwrap()
318    ///        }
319    ///    ),
320    ///    // y = 3
321    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
322    /// ];
323    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
324    /// let output = simplifier.simplify(expr).unwrap();
325    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
326    /// // z > 5.
327    /// assert_eq!(output, expr_z);
328    /// ```
329    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
330        self.guarantees = guarantees;
331        self
332    }
333
334    /// Should `Canonicalizer` be applied before simplification?
335    ///
336    /// If true (the default), the expression will be rewritten to canonical
337    /// form before simplification. This is useful to ensure that the simplifier
338    /// can apply all possible simplifications.
339    ///
340    /// Some expressions, such as those in some Joins, can not be canonicalized
341    /// without changing their meaning. In these cases, canonicalization should
342    /// be disabled.
343    ///
344    /// ```rust
345    /// use arrow::datatypes::{DataType, Field, Schema};
346    /// use datafusion_expr::{col, lit, Expr};
347    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
348    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
349    /// use datafusion_expr::execution_props::ExecutionProps;
350    /// use datafusion_expr::simplify::SimplifyContext;
351    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
352    ///
353    /// let schema = Schema::new(vec![
354    ///   Field::new("a", DataType::Int64, false),
355    ///   Field::new("b", DataType::Int64, false),
356    ///   Field::new("c", DataType::Int64, false),
357    ///   ])
358    ///   .to_dfschema_ref().unwrap();
359    ///
360    /// // Create the simplifier
361    /// let props = ExecutionProps::new();
362    /// let context = SimplifyContext::new(&props)
363    ///    .with_schema(schema);
364    /// let simplifier = ExprSimplifier::new(context);
365    ///
366    /// // Expression: a = c AND 1 = b
367    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
368    ///
369    /// // With canonicalization, the expression is rewritten to canonical form
370    /// // (though it is no simpler in this case):
371    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
372    /// // Expression has been rewritten to: (c = a AND b = 1)
373    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
374    ///
375    /// // If canonicalization is disabled, the expression is not changed
376    /// let non_canonicalized = simplifier
377    ///   .with_canonicalize(false)
378    ///   .simplify(expr.clone())
379    ///   .unwrap();
380    ///
381    /// assert_eq!(non_canonicalized, expr);
382    /// ```
383    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
384        self.canonicalize = canonicalize;
385        self
386    }
387
388    /// Specifies the maximum number of simplification cycles to run.
389    ///
390    /// The simplifier can perform multiple passes of simplification. This is
391    /// because the output of one simplification step can allow more optimizations
392    /// in another simplification step. For example, constant evaluation can allow more
393    /// expression simplifications, and expression simplifications can allow more constant
394    /// evaluations.
395    ///
396    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
397    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
398    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
399    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
400    ///
401    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
402    /// instead.
403    ///
404    /// ```rust
405    /// use arrow::datatypes::{DataType, Field, Schema};
406    /// use datafusion_expr::{col, lit, Expr};
407    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
408    /// use datafusion_expr::execution_props::ExecutionProps;
409    /// use datafusion_expr::simplify::SimplifyContext;
410    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
411    ///
412    /// let schema = Schema::new(vec![
413    ///   Field::new("a", DataType::Int64, false),
414    ///   ])
415    ///   .to_dfschema_ref().unwrap();
416    ///
417    /// // Create the simplifier
418    /// let props = ExecutionProps::new();
419    /// let context = SimplifyContext::new(&props)
420    ///    .with_schema(schema);
421    /// let simplifier = ExprSimplifier::new(context);
422    ///
423    /// // Expression: a IS NOT NULL
424    /// let expr = col("a").is_not_null();
425    ///
426    /// // When using default maximum cycles, 2 cycles will be performed.
427    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
428    /// assert_eq!(simplified_expr.data, lit(true));
429    /// // 2 cycles were executed, but only 1 was needed
430    /// assert_eq!(count, 2);
431    ///
432    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
433    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
434    /// // Expression has been rewritten to: (c = a AND b = 1)
435    /// assert_eq!(simplified_expr.data, lit(true));
436    /// // Only 1 cycle was executed
437    /// assert_eq!(count, 1);
438    ///
439    /// ```
440    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
441        self.max_simplifier_cycles = max_simplifier_cycles;
442        self
443    }
444}
445
446/// Canonicalize any BinaryExprs that are not in canonical form
447///
448/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
449///
450/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
451/// than `col2` (`a > b` would be canonicalized to `b < a`)
452struct Canonicalizer {}
453
454impl Canonicalizer {
455    fn new() -> Self {
456        Self {}
457    }
458}
459
460impl TreeNodeRewriter for Canonicalizer {
461    type Node = Expr;
462
463    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
464        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
465            return Ok(Transformed::no(expr));
466        };
467        match (left.as_ref(), right.as_ref(), op.swap()) {
468            // <col1> <op> <col2>
469            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
470                if right_col > left_col =>
471            {
472                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
473                    left: right,
474                    op: swapped_op,
475                    right: left,
476                })))
477            }
478            // <literal> <op> <col>
479            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
480                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
481                    left: right,
482                    op: swapped_op,
483                    right: left,
484                })))
485            }
486            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
487                left,
488                op,
489                right,
490            }))),
491        }
492    }
493}
494
495#[allow(rustdoc::private_intra_doc_links)]
496/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
497///
498/// Note it does not handle algebraic rewrites such as `(a or false)`
499/// --> `a`, which is handled by [`Simplifier`]
500struct ConstEvaluator<'a> {
501    /// `can_evaluate` is used during the depth-first-search of the
502    /// `Expr` tree to track if any siblings (or their descendants) were
503    /// non evaluatable (e.g. had a column reference or volatile
504    /// function)
505    ///
506    /// Specifically, `can_evaluate[N]` represents the state of
507    /// traversal when we are N levels deep in the tree, one entry for
508    /// this Expr and each of its parents.
509    ///
510    /// After visiting all siblings if `can_evaluate.top()` is true, that
511    /// means there were no non evaluatable siblings (or their
512    /// descendants) so this `Expr` can be evaluated
513    can_evaluate: Vec<bool>,
514
515    execution_props: &'a ExecutionProps,
516    input_schema: DFSchema,
517    input_batch: RecordBatch,
518}
519
520#[allow(dead_code)]
521/// The simplify result of ConstEvaluator
522#[allow(clippy::large_enum_variant)]
523enum ConstSimplifyResult {
524    // Expr was simplified and contains the new expression
525    Simplified(ScalarValue, Option<FieldMetadata>),
526    // Expr was not simplified and original value is returned
527    NotSimplified(ScalarValue, Option<FieldMetadata>),
528    // Evaluation encountered an error, contains the original expression
529    SimplifyRuntimeError(DataFusionError, Expr),
530}
531
532impl TreeNodeRewriter for ConstEvaluator<'_> {
533    type Node = Expr;
534
535    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
536        // Default to being able to evaluate this node
537        self.can_evaluate.push(true);
538
539        // if this expr is not ok to evaluate, mark entire parent
540        // stack as not ok (as all parents have at least one child or
541        // descendant that can not be evaluated
542
543        if !Self::can_evaluate(&expr) {
544            // walk back up stack, marking first parent that is not mutable
545            let parent_iter = self.can_evaluate.iter_mut().rev();
546            for p in parent_iter {
547                if !*p {
548                    // optimization: if we find an element on the
549                    // stack already marked, know all elements above are also marked
550                    break;
551                }
552                *p = false;
553            }
554        }
555
556        // NB: do not short circuit recursion even if we find a non
557        // evaluatable node (so we can fold other children, args to
558        // functions, etc.)
559        Ok(Transformed::no(expr))
560    }
561
562    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
563        match self.can_evaluate.pop() {
564            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
565            // and may not evaluate all their sub expressions. Thus, if
566            // any error is countered during simplification, return the original
567            // so that normal evaluation can occur
568            Some(true) => match self.evaluate_to_scalar(expr) {
569                ConstSimplifyResult::Simplified(s, m) => {
570                    Ok(Transformed::yes(Expr::Literal(s, m)))
571                }
572                ConstSimplifyResult::NotSimplified(s, m) => {
573                    Ok(Transformed::no(Expr::Literal(s, m)))
574                }
575                ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
576                    Ok(Transformed::yes(expr))
577                }
578            },
579            Some(false) => Ok(Transformed::no(expr)),
580            _ => internal_err!("Failed to pop can_evaluate"),
581        }
582    }
583}
584
585impl<'a> ConstEvaluator<'a> {
586    /// Create a new `ConstantEvaluator`. Session constants (such as
587    /// the time for `now()` are taken from the passed
588    /// `execution_props`.
589    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
590        // The dummy column name is unused and doesn't matter as only
591        // expressions without column references can be evaluated
592        static DUMMY_COL_NAME: &str = ".";
593        let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
594        let input_schema = DFSchema::try_from(schema.clone())?;
595        // Need a single "input" row to produce a single output row
596        let col = new_null_array(&DataType::Null, 1);
597        let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?;
598
599        Ok(Self {
600            can_evaluate: vec![],
601            execution_props,
602            input_schema,
603            input_batch,
604        })
605    }
606
607    /// Can a function of the specified volatility be evaluated?
608    fn volatility_ok(volatility: Volatility) -> bool {
609        match volatility {
610            Volatility::Immutable => true,
611            // Values for functions such as now() are taken from ExecutionProps
612            Volatility::Stable => true,
613            Volatility::Volatile => false,
614        }
615    }
616
617    /// Can the expression be evaluated at plan time, (assuming all of
618    /// its children can also be evaluated)?
619    fn can_evaluate(expr: &Expr) -> bool {
620        // check for reasons we can't evaluate this node
621        //
622        // NOTE all expr types are listed here so when new ones are
623        // added they can be checked for their ability to be evaluated
624        // at plan time
625        match expr {
626            // TODO: remove the next line after `Expr::Wildcard` is removed
627            #[expect(deprecated)]
628            Expr::AggregateFunction { .. }
629            | Expr::ScalarVariable(_, _)
630            | Expr::Column(_)
631            | Expr::OuterReferenceColumn(_, _)
632            | Expr::Exists { .. }
633            | Expr::InSubquery(_)
634            | Expr::ScalarSubquery(_)
635            | Expr::WindowFunction { .. }
636            | Expr::GroupingSet(_)
637            | Expr::Wildcard { .. }
638            | Expr::Placeholder(_) => false,
639            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
640                Self::volatility_ok(func.signature().volatility)
641            }
642            Expr::Literal(_, _)
643            | Expr::Alias(..)
644            | Expr::Unnest(_)
645            | Expr::BinaryExpr { .. }
646            | Expr::Not(_)
647            | Expr::IsNotNull(_)
648            | Expr::IsNull(_)
649            | Expr::IsTrue(_)
650            | Expr::IsFalse(_)
651            | Expr::IsUnknown(_)
652            | Expr::IsNotTrue(_)
653            | Expr::IsNotFalse(_)
654            | Expr::IsNotUnknown(_)
655            | Expr::Negative(_)
656            | Expr::Between { .. }
657            | Expr::Like { .. }
658            | Expr::SimilarTo { .. }
659            | Expr::Case(_)
660            | Expr::Cast { .. }
661            | Expr::TryCast { .. }
662            | Expr::InList { .. } => true,
663        }
664    }
665
666    /// Internal helper to evaluates an Expr
667    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
668        if let Expr::Literal(s, m) = expr {
669            return ConstSimplifyResult::NotSimplified(s, m);
670        }
671
672        let phys_expr =
673            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
674                Ok(e) => e,
675                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
676            };
677        let metadata = phys_expr
678            .return_field(self.input_batch.schema_ref())
679            .ok()
680            .and_then(|f| {
681                let m = f.metadata();
682                match m.is_empty() {
683                    true => None,
684                    false => Some(FieldMetadata::from(m)),
685                }
686            });
687        let col_val = match phys_expr.evaluate(&self.input_batch) {
688            Ok(v) => v,
689            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
690        };
691        match col_val {
692            ColumnarValue::Array(a) => {
693                if a.len() != 1 {
694                    ConstSimplifyResult::SimplifyRuntimeError(
695                        DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
696                        expr,
697                    )
698                } else if as_list_array(&a).is_ok() {
699                    ConstSimplifyResult::Simplified(
700                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
701                        metadata,
702                    )
703                } else if as_large_list_array(&a).is_ok() {
704                    ConstSimplifyResult::Simplified(
705                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
706                        metadata,
707                    )
708                } else {
709                    // Non-ListArray
710                    match ScalarValue::try_from_array(&a, 0) {
711                        Ok(s) => {
712                            // TODO: support the optimization for `Map` type after support impl hash for it
713                            if matches!(&s, ScalarValue::Map(_)) {
714                                ConstSimplifyResult::SimplifyRuntimeError(
715                                    DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()),
716                                    expr,
717                                )
718                            } else {
719                                ConstSimplifyResult::Simplified(s, metadata)
720                            }
721                        }
722                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
723                    }
724                }
725            }
726            ColumnarValue::Scalar(s) => {
727                // TODO: support the optimization for `Map` type after support impl hash for it
728                if matches!(&s, ScalarValue::Map(_)) {
729                    ConstSimplifyResult::SimplifyRuntimeError(
730                        DataFusionError::NotImplemented(
731                            "Const evaluate for Map type is still not supported"
732                                .to_string(),
733                        ),
734                        expr,
735                    )
736                } else {
737                    ConstSimplifyResult::Simplified(s, metadata)
738                }
739            }
740        }
741    }
742}
743
744/// Simplifies [`Expr`]s by applying algebraic transformation rules
745///
746/// Example transformations that are applied:
747/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
748/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
749/// * `true = true` and `false = false` to `true`
750/// * `false = true` and `true = false` to `false`
751/// * `!!expr` to `expr`
752/// * `expr = null` and `expr != null` to `null`
753struct Simplifier<'a, S> {
754    info: &'a S,
755}
756
757impl<'a, S> Simplifier<'a, S> {
758    pub fn new(info: &'a S) -> Self {
759        Self { info }
760    }
761}
762
763impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
764    type Node = Expr;
765
766    /// rewrite the expression simplifying any constant expressions
767    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
768        use datafusion_expr::Operator::{
769            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
770            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
771            RegexNotIMatch, RegexNotMatch,
772        };
773
774        let info = self.info;
775        Ok(match expr {
776            //
777            // Rules for Eq
778            //
779
780            // true = A  --> A
781            // false = A --> !A
782            // null = A --> null
783            Expr::BinaryExpr(BinaryExpr {
784                left,
785                op: Eq,
786                right,
787            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
788                Transformed::yes(match as_bool_lit(&left)? {
789                    Some(true) => *right,
790                    Some(false) => Expr::Not(right),
791                    None => lit_bool_null(),
792                })
793            }
794            // A = true  --> A
795            // A = false --> !A
796            // A = null --> null
797            Expr::BinaryExpr(BinaryExpr {
798                left,
799                op: Eq,
800                right,
801            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
802                Transformed::yes(match as_bool_lit(&right)? {
803                    Some(true) => *left,
804                    Some(false) => Expr::Not(left),
805                    None => lit_bool_null(),
806                })
807            }
808            // According to SQL's null semantics, NULL = NULL evaluates to NULL
809            // Both sides are the same expression (A = A) and A is non-volatile expression
810            // A = A --> A IS NOT NULL OR NULL
811            // A = A --> true (if A not nullable)
812            Expr::BinaryExpr(BinaryExpr {
813                left,
814                op: Eq,
815                right,
816            }) if (left == right) & !left.is_volatile() => {
817                Transformed::yes(match !info.nullable(&left)? {
818                    true => lit(true),
819                    false => Expr::BinaryExpr(BinaryExpr {
820                        left: Box::new(Expr::IsNotNull(left)),
821                        op: Or,
822                        right: Box::new(lit_bool_null()),
823                    }),
824                })
825            }
826
827            // Rules for NotEq
828            //
829
830            // true != A  --> !A
831            // false != A --> A
832            // null != A --> null
833            Expr::BinaryExpr(BinaryExpr {
834                left,
835                op: NotEq,
836                right,
837            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
838                Transformed::yes(match as_bool_lit(&left)? {
839                    Some(true) => Expr::Not(right),
840                    Some(false) => *right,
841                    None => lit_bool_null(),
842                })
843            }
844            // A != true  --> !A
845            // A != false --> A
846            // A != null --> null,
847            Expr::BinaryExpr(BinaryExpr {
848                left,
849                op: NotEq,
850                right,
851            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
852                Transformed::yes(match as_bool_lit(&right)? {
853                    Some(true) => Expr::Not(left),
854                    Some(false) => *left,
855                    None => lit_bool_null(),
856                })
857            }
858
859            //
860            // Rules for OR
861            //
862
863            // true OR A --> true (even if A is null)
864            Expr::BinaryExpr(BinaryExpr {
865                left,
866                op: Or,
867                right: _,
868            }) if is_true(&left) => Transformed::yes(*left),
869            // false OR A --> A
870            Expr::BinaryExpr(BinaryExpr {
871                left,
872                op: Or,
873                right,
874            }) if is_false(&left) => Transformed::yes(*right),
875            // A OR true --> true (even if A is null)
876            Expr::BinaryExpr(BinaryExpr {
877                left: _,
878                op: Or,
879                right,
880            }) if is_true(&right) => Transformed::yes(*right),
881            // A OR false --> A
882            Expr::BinaryExpr(BinaryExpr {
883                left,
884                op: Or,
885                right,
886            }) if is_false(&right) => Transformed::yes(*left),
887            // A OR !A ---> true (if A not nullable)
888            Expr::BinaryExpr(BinaryExpr {
889                left,
890                op: Or,
891                right,
892            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
893                Transformed::yes(lit(true))
894            }
895            // !A OR A ---> true (if A not nullable)
896            Expr::BinaryExpr(BinaryExpr {
897                left,
898                op: Or,
899                right,
900            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
901                Transformed::yes(lit(true))
902            }
903            // (..A..) OR A --> (..A..)
904            Expr::BinaryExpr(BinaryExpr {
905                left,
906                op: Or,
907                right,
908            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
909            // A OR (..A..) --> (..A..)
910            Expr::BinaryExpr(BinaryExpr {
911                left,
912                op: Or,
913                right,
914            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
915            // A OR (A AND B) --> A
916            Expr::BinaryExpr(BinaryExpr {
917                left,
918                op: Or,
919                right,
920            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
921            // (A AND B) OR A --> A
922            Expr::BinaryExpr(BinaryExpr {
923                left,
924                op: Or,
925                right,
926            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
927            // Eliminate common factors in conjunctions e.g
928            // (A AND B) OR (A AND C) -> A AND (B OR C)
929            Expr::BinaryExpr(BinaryExpr {
930                left,
931                op: Or,
932                right,
933            }) if has_common_conjunction(&left, &right) => {
934                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
935                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
936                    .partition(|e| lhs.contains(e) && !e.is_volatile());
937
938                let new_rhs = rhs.into_iter().reduce(and);
939                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
940                let common_conjunction = common.into_iter().reduce(and).unwrap();
941
942                let new_expr = match (new_lhs, new_rhs) {
943                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
944                    (_, _) => common_conjunction,
945                };
946                Transformed::yes(new_expr)
947            }
948
949            //
950            // Rules for AND
951            //
952
953            // true AND A --> A
954            Expr::BinaryExpr(BinaryExpr {
955                left,
956                op: And,
957                right,
958            }) if is_true(&left) => Transformed::yes(*right),
959            // false AND A --> false (even if A is null)
960            Expr::BinaryExpr(BinaryExpr {
961                left,
962                op: And,
963                right: _,
964            }) if is_false(&left) => Transformed::yes(*left),
965            // A AND true --> A
966            Expr::BinaryExpr(BinaryExpr {
967                left,
968                op: And,
969                right,
970            }) if is_true(&right) => Transformed::yes(*left),
971            // A AND false --> false (even if A is null)
972            Expr::BinaryExpr(BinaryExpr {
973                left: _,
974                op: And,
975                right,
976            }) if is_false(&right) => Transformed::yes(*right),
977            // A AND !A ---> false (if A not nullable)
978            Expr::BinaryExpr(BinaryExpr {
979                left,
980                op: And,
981                right,
982            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
983                Transformed::yes(lit(false))
984            }
985            // !A AND A ---> false (if A not nullable)
986            Expr::BinaryExpr(BinaryExpr {
987                left,
988                op: And,
989                right,
990            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
991                Transformed::yes(lit(false))
992            }
993            // (..A..) AND A --> (..A..)
994            Expr::BinaryExpr(BinaryExpr {
995                left,
996                op: And,
997                right,
998            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
999            // A AND (..A..) --> (..A..)
1000            Expr::BinaryExpr(BinaryExpr {
1001                left,
1002                op: And,
1003                right,
1004            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1005            // A AND (A OR B) --> A
1006            Expr::BinaryExpr(BinaryExpr {
1007                left,
1008                op: And,
1009                right,
1010            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1011            // (A OR B) AND A --> A
1012            Expr::BinaryExpr(BinaryExpr {
1013                left,
1014                op: And,
1015                right,
1016            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1017            // A >= constant AND constant <= A --> A = constant
1018            Expr::BinaryExpr(BinaryExpr {
1019                left,
1020                op: And,
1021                right,
1022            }) if can_reduce_to_equal_statement(&left, &right) => {
1023                if let Expr::BinaryExpr(BinaryExpr {
1024                    left: left_left,
1025                    right: left_right,
1026                    ..
1027                }) = *left
1028                {
1029                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1030                        left: left_left,
1031                        op: Eq,
1032                        right: left_right,
1033                    }))
1034                } else {
1035                    return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr");
1036                }
1037            }
1038
1039            //
1040            // Rules for Multiply
1041            //
1042
1043            // A * 1 --> A (with type coercion if needed)
1044            Expr::BinaryExpr(BinaryExpr {
1045                left,
1046                op: Multiply,
1047                right,
1048            }) if is_one(&right) => {
1049                simplify_right_is_one_case(info, left, &Multiply, &right)?
1050            }
1051            // A * null --> null
1052            Expr::BinaryExpr(BinaryExpr {
1053                left,
1054                op: Multiply,
1055                right,
1056            }) if is_null(&right) => {
1057                simplify_right_is_null_case(info, &left, &Multiply, right)?
1058            }
1059            // 1 * A --> A
1060            Expr::BinaryExpr(BinaryExpr {
1061                left,
1062                op: Multiply,
1063                right,
1064            }) if is_one(&left) => {
1065                // 1 * A is equivalent to A * 1
1066                simplify_right_is_one_case(info, right, &Multiply, &left)?
1067            }
1068            // null * A --> null
1069            Expr::BinaryExpr(BinaryExpr {
1070                left,
1071                op: Multiply,
1072                right,
1073            }) if is_null(&left) => {
1074                simplify_right_is_null_case(info, &right, &Multiply, left)?
1075            }
1076
1077            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1078            Expr::BinaryExpr(BinaryExpr {
1079                left,
1080                op: Multiply,
1081                right,
1082            }) if !info.nullable(&left)?
1083                && !info.get_data_type(&left)?.is_floating()
1084                && is_zero(&right) =>
1085            {
1086                Transformed::yes(*right)
1087            }
1088            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1089            Expr::BinaryExpr(BinaryExpr {
1090                left,
1091                op: Multiply,
1092                right,
1093            }) if !info.nullable(&right)?
1094                && !info.get_data_type(&right)?.is_floating()
1095                && is_zero(&left) =>
1096            {
1097                Transformed::yes(*left)
1098            }
1099
1100            //
1101            // Rules for Divide
1102            //
1103
1104            // A / 1 --> A
1105            Expr::BinaryExpr(BinaryExpr {
1106                left,
1107                op: Divide,
1108                right,
1109            }) if is_one(&right) => {
1110                simplify_right_is_one_case(info, left, &Divide, &right)?
1111            }
1112            // A / null --> null
1113            Expr::BinaryExpr(BinaryExpr {
1114                left,
1115                op: Divide,
1116                right,
1117            }) if is_null(&right) => {
1118                simplify_right_is_null_case(info, &left, &Divide, right)?
1119            }
1120            // null / A --> null
1121            Expr::BinaryExpr(BinaryExpr {
1122                left,
1123                op: Divide,
1124                right,
1125            }) if is_null(&left) => simplify_null_div_other_case(info, left, &right)?,
1126
1127            //
1128            // Rules for Modulo
1129            //
1130
1131            // A % null --> null
1132            Expr::BinaryExpr(BinaryExpr {
1133                left: _,
1134                op: Modulo,
1135                right,
1136            }) if is_null(&right) => Transformed::yes(*right),
1137            // null % A --> null
1138            Expr::BinaryExpr(BinaryExpr {
1139                left,
1140                op: Modulo,
1141                right: _,
1142            }) if is_null(&left) => Transformed::yes(*left),
1143            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1144            Expr::BinaryExpr(BinaryExpr {
1145                left,
1146                op: Modulo,
1147                right,
1148            }) if !info.nullable(&left)?
1149                && !info.get_data_type(&left)?.is_floating()
1150                && is_one(&right) =>
1151            {
1152                Transformed::yes(Expr::Literal(
1153                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1154                    None,
1155                ))
1156            }
1157
1158            //
1159            // Rules for BitwiseAnd
1160            //
1161
1162            // A & null -> null
1163            Expr::BinaryExpr(BinaryExpr {
1164                left: _,
1165                op: BitwiseAnd,
1166                right,
1167            }) if is_null(&right) => Transformed::yes(*right),
1168
1169            // null & A -> null
1170            Expr::BinaryExpr(BinaryExpr {
1171                left,
1172                op: BitwiseAnd,
1173                right: _,
1174            }) if is_null(&left) => Transformed::yes(*left),
1175
1176            // A & 0 -> 0 (if A not nullable)
1177            Expr::BinaryExpr(BinaryExpr {
1178                left,
1179                op: BitwiseAnd,
1180                right,
1181            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1182
1183            // 0 & A -> 0 (if A not nullable)
1184            Expr::BinaryExpr(BinaryExpr {
1185                left,
1186                op: BitwiseAnd,
1187                right,
1188            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1189
1190            // !A & A -> 0 (if A not nullable)
1191            Expr::BinaryExpr(BinaryExpr {
1192                left,
1193                op: BitwiseAnd,
1194                right,
1195            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1196                Transformed::yes(Expr::Literal(
1197                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1198                    None,
1199                ))
1200            }
1201
1202            // A & !A -> 0 (if A not nullable)
1203            Expr::BinaryExpr(BinaryExpr {
1204                left,
1205                op: BitwiseAnd,
1206                right,
1207            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1208                Transformed::yes(Expr::Literal(
1209                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1210                    None,
1211                ))
1212            }
1213
1214            // (..A..) & A --> (..A..)
1215            Expr::BinaryExpr(BinaryExpr {
1216                left,
1217                op: BitwiseAnd,
1218                right,
1219            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1220
1221            // A & (..A..) --> (..A..)
1222            Expr::BinaryExpr(BinaryExpr {
1223                left,
1224                op: BitwiseAnd,
1225                right,
1226            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1227
1228            // A & (A | B) --> A (if B not null)
1229            Expr::BinaryExpr(BinaryExpr {
1230                left,
1231                op: BitwiseAnd,
1232                right,
1233            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1234                Transformed::yes(*left)
1235            }
1236
1237            // (A | B) & A --> A (if B not null)
1238            Expr::BinaryExpr(BinaryExpr {
1239                left,
1240                op: BitwiseAnd,
1241                right,
1242            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1243                Transformed::yes(*right)
1244            }
1245
1246            //
1247            // Rules for BitwiseOr
1248            //
1249
1250            // A | null -> null
1251            Expr::BinaryExpr(BinaryExpr {
1252                left: _,
1253                op: BitwiseOr,
1254                right,
1255            }) if is_null(&right) => Transformed::yes(*right),
1256
1257            // null | A -> null
1258            Expr::BinaryExpr(BinaryExpr {
1259                left,
1260                op: BitwiseOr,
1261                right: _,
1262            }) if is_null(&left) => Transformed::yes(*left),
1263
1264            // A | 0 -> A (even if A is null)
1265            Expr::BinaryExpr(BinaryExpr {
1266                left,
1267                op: BitwiseOr,
1268                right,
1269            }) if is_zero(&right) => Transformed::yes(*left),
1270
1271            // 0 | A -> A (even if A is null)
1272            Expr::BinaryExpr(BinaryExpr {
1273                left,
1274                op: BitwiseOr,
1275                right,
1276            }) if is_zero(&left) => Transformed::yes(*right),
1277
1278            // !A | A -> -1 (if A not nullable)
1279            Expr::BinaryExpr(BinaryExpr {
1280                left,
1281                op: BitwiseOr,
1282                right,
1283            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1284                Transformed::yes(Expr::Literal(
1285                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1286                    None,
1287                ))
1288            }
1289
1290            // A | !A -> -1 (if A not nullable)
1291            Expr::BinaryExpr(BinaryExpr {
1292                left,
1293                op: BitwiseOr,
1294                right,
1295            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1296                Transformed::yes(Expr::Literal(
1297                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1298                    None,
1299                ))
1300            }
1301
1302            // (..A..) | A --> (..A..)
1303            Expr::BinaryExpr(BinaryExpr {
1304                left,
1305                op: BitwiseOr,
1306                right,
1307            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1308
1309            // A | (..A..) --> (..A..)
1310            Expr::BinaryExpr(BinaryExpr {
1311                left,
1312                op: BitwiseOr,
1313                right,
1314            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1315
1316            // A | (A & B) --> A (if B not null)
1317            Expr::BinaryExpr(BinaryExpr {
1318                left,
1319                op: BitwiseOr,
1320                right,
1321            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1322                Transformed::yes(*left)
1323            }
1324
1325            // (A & B) | A --> A (if B not null)
1326            Expr::BinaryExpr(BinaryExpr {
1327                left,
1328                op: BitwiseOr,
1329                right,
1330            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1331                Transformed::yes(*right)
1332            }
1333
1334            //
1335            // Rules for BitwiseXor
1336            //
1337
1338            // A ^ null -> null
1339            Expr::BinaryExpr(BinaryExpr {
1340                left: _,
1341                op: BitwiseXor,
1342                right,
1343            }) if is_null(&right) => Transformed::yes(*right),
1344
1345            // null ^ A -> null
1346            Expr::BinaryExpr(BinaryExpr {
1347                left,
1348                op: BitwiseXor,
1349                right: _,
1350            }) if is_null(&left) => Transformed::yes(*left),
1351
1352            // A ^ 0 -> A (if A not nullable)
1353            Expr::BinaryExpr(BinaryExpr {
1354                left,
1355                op: BitwiseXor,
1356                right,
1357            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1358
1359            // 0 ^ A -> A (if A not nullable)
1360            Expr::BinaryExpr(BinaryExpr {
1361                left,
1362                op: BitwiseXor,
1363                right,
1364            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1365
1366            // !A ^ A -> -1 (if A not nullable)
1367            Expr::BinaryExpr(BinaryExpr {
1368                left,
1369                op: BitwiseXor,
1370                right,
1371            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1372                Transformed::yes(Expr::Literal(
1373                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1374                    None,
1375                ))
1376            }
1377
1378            // A ^ !A -> -1 (if A not nullable)
1379            Expr::BinaryExpr(BinaryExpr {
1380                left,
1381                op: BitwiseXor,
1382                right,
1383            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1384                Transformed::yes(Expr::Literal(
1385                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1386                    None,
1387                ))
1388            }
1389
1390            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1391            Expr::BinaryExpr(BinaryExpr {
1392                left,
1393                op: BitwiseXor,
1394                right,
1395            }) if expr_contains(&left, &right, BitwiseXor) => {
1396                let expr = delete_xor_in_complex_expr(&left, &right, false);
1397                Transformed::yes(if expr == *right {
1398                    Expr::Literal(
1399                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1400                        None,
1401                    )
1402                } else {
1403                    expr
1404                })
1405            }
1406
1407            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1408            Expr::BinaryExpr(BinaryExpr {
1409                left,
1410                op: BitwiseXor,
1411                right,
1412            }) if expr_contains(&right, &left, BitwiseXor) => {
1413                let expr = delete_xor_in_complex_expr(&right, &left, true);
1414                Transformed::yes(if expr == *left {
1415                    Expr::Literal(
1416                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1417                        None,
1418                    )
1419                } else {
1420                    expr
1421                })
1422            }
1423
1424            //
1425            // Rules for BitwiseShiftRight
1426            //
1427
1428            // A >> null -> null
1429            Expr::BinaryExpr(BinaryExpr {
1430                left: _,
1431                op: BitwiseShiftRight,
1432                right,
1433            }) if is_null(&right) => Transformed::yes(*right),
1434
1435            // null >> A -> null
1436            Expr::BinaryExpr(BinaryExpr {
1437                left,
1438                op: BitwiseShiftRight,
1439                right: _,
1440            }) if is_null(&left) => Transformed::yes(*left),
1441
1442            // A >> 0 -> A (even if A is null)
1443            Expr::BinaryExpr(BinaryExpr {
1444                left,
1445                op: BitwiseShiftRight,
1446                right,
1447            }) if is_zero(&right) => Transformed::yes(*left),
1448
1449            //
1450            // Rules for BitwiseShiftRight
1451            //
1452
1453            // A << null -> null
1454            Expr::BinaryExpr(BinaryExpr {
1455                left: _,
1456                op: BitwiseShiftLeft,
1457                right,
1458            }) if is_null(&right) => Transformed::yes(*right),
1459
1460            // null << A -> null
1461            Expr::BinaryExpr(BinaryExpr {
1462                left,
1463                op: BitwiseShiftLeft,
1464                right: _,
1465            }) if is_null(&left) => Transformed::yes(*left),
1466
1467            // A << 0 -> A (even if A is null)
1468            Expr::BinaryExpr(BinaryExpr {
1469                left,
1470                op: BitwiseShiftLeft,
1471                right,
1472            }) if is_zero(&right) => Transformed::yes(*left),
1473
1474            //
1475            // Rules for Not
1476            //
1477            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1478
1479            //
1480            // Rules for Negative
1481            //
1482            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1483
1484            //
1485            // Rules for Case
1486            //
1487
1488            // CASE
1489            //   WHEN X THEN A
1490            //   WHEN Y THEN B
1491            //   ...
1492            //   ELSE Q
1493            // END
1494            //
1495            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1496            //
1497            // Note: the rationale for this rewrite is that the expr can then be further
1498            // simplified using the existing rules for AND/OR
1499            Expr::Case(Case {
1500                expr: None,
1501                when_then_expr,
1502                else_expr,
1503            }) if !when_then_expr.is_empty()
1504                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1505                && info.is_boolean_type(&when_then_expr[0].1)? =>
1506            {
1507                // String disjunction of all the when predicates encountered so far. Not nullable.
1508                let mut filter_expr = lit(false);
1509                // The disjunction of all the cases
1510                let mut out_expr = lit(false);
1511
1512                for (when, then) in when_then_expr {
1513                    let when = is_exactly_true(*when, info)?;
1514                    let case_expr =
1515                        when.clone().and(filter_expr.clone().not()).and(*then);
1516
1517                    out_expr = out_expr.or(case_expr);
1518                    filter_expr = filter_expr.or(when);
1519                }
1520
1521                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1522                let case_expr = filter_expr.not().and(else_expr);
1523                out_expr = out_expr.or(case_expr);
1524
1525                // Do a first pass at simplification
1526                out_expr.rewrite(self)?
1527            }
1528            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1529                match udf.simplify(args, info)? {
1530                    ExprSimplifyResult::Original(args) => {
1531                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1532                            func: udf,
1533                            args,
1534                        }))
1535                    }
1536                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1537                }
1538            }
1539
1540            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1541                ref func,
1542                ..
1543            }) => match (func.simplify(), expr) {
1544                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1545                    Transformed::yes(simplify_function(af, info)?)
1546                }
1547                (_, expr) => Transformed::no(expr),
1548            },
1549
1550            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1551                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1552                    Transformed::yes(simplify_function(*wf, info)?)
1553                }
1554                (_, expr) => Transformed::no(expr),
1555            },
1556
1557            //
1558            // Rules for Between
1559            //
1560
1561            // a between 3 and 5  -->  a >= 3 AND a <=5
1562            // a not between 3 and 5  -->  a < 3 OR a > 5
1563            Expr::Between(between) => Transformed::yes(if between.negated {
1564                let l = *between.expr.clone();
1565                let r = *between.expr;
1566                or(l.lt(*between.low), r.gt(*between.high))
1567            } else {
1568                and(
1569                    between.expr.clone().gt_eq(*between.low),
1570                    between.expr.lt_eq(*between.high),
1571                )
1572            }),
1573
1574            //
1575            // Rules for regexes
1576            //
1577            Expr::BinaryExpr(BinaryExpr {
1578                left,
1579                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1580                right,
1581            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1582
1583            // Rules for Like
1584            Expr::Like(like) => {
1585                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1586                let escape_char = like.escape_char.unwrap_or('\\');
1587                match as_string_scalar(&like.pattern) {
1588                    Some((data_type, pattern_str)) => {
1589                        match pattern_str {
1590                            None => return Ok(Transformed::yes(lit_bool_null())),
1591                            Some(pattern_str) if pattern_str == "%" => {
1592                                // exp LIKE '%' is
1593                                //   - when exp is not NULL, it's true
1594                                //   - when exp is NULL, it's NULL
1595                                // exp NOT LIKE '%' is
1596                                //   - when exp is not NULL, it's false
1597                                //   - when exp is NULL, it's NULL
1598                                let result_for_non_null = lit(!like.negated);
1599                                Transformed::yes(if !info.nullable(&like.expr)? {
1600                                    result_for_non_null
1601                                } else {
1602                                    Expr::Case(Case {
1603                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1604                                        when_then_expr: vec![(
1605                                            Box::new(lit(true)),
1606                                            Box::new(result_for_non_null),
1607                                        )],
1608                                        else_expr: None,
1609                                    })
1610                                })
1611                            }
1612                            Some(pattern_str)
1613                                if pattern_str.contains("%%")
1614                                    && !pattern_str.contains(escape_char) =>
1615                            {
1616                                // Repeated occurrences of wildcard are redundant so remove them
1617                                // exp LIKE '%%'  --> exp LIKE '%'
1618                                let simplified_pattern = Regex::new("%%+")
1619                                    .unwrap()
1620                                    .replace_all(pattern_str, "%")
1621                                    .to_string();
1622                                Transformed::yes(Expr::Like(Like {
1623                                    pattern: Box::new(to_string_scalar(
1624                                        data_type,
1625                                        Some(simplified_pattern),
1626                                    )),
1627                                    ..like
1628                                }))
1629                            }
1630                            Some(pattern_str)
1631                                if !like.case_insensitive
1632                                    && !pattern_str
1633                                        .contains(['%', '_', escape_char].as_ref()) =>
1634                            {
1635                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1636                                // TODO: handle escape characters
1637                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1638                                    left: like.expr.clone(),
1639                                    op: if like.negated { NotEq } else { Eq },
1640                                    right: like.pattern.clone(),
1641                                }))
1642                            }
1643
1644                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1645                        }
1646                    }
1647                    None => Transformed::no(Expr::Like(like)),
1648                }
1649            }
1650
1651            // a is not null/unknown --> true (if a is not nullable)
1652            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1653                if !info.nullable(&expr)? =>
1654            {
1655                Transformed::yes(lit(true))
1656            }
1657
1658            // a is null/unknown --> false (if a is not nullable)
1659            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1660                Transformed::yes(lit(false))
1661            }
1662
1663            // expr IN () --> false
1664            // expr NOT IN () --> true
1665            Expr::InList(InList {
1666                expr,
1667                list,
1668                negated,
1669            }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null, None) => {
1670                Transformed::yes(lit(negated))
1671            }
1672
1673            // null in (x, y, z) --> null
1674            // null not in (x, y, z) --> null
1675            Expr::InList(InList {
1676                expr,
1677                list: _,
1678                negated: _,
1679            }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()),
1680
1681            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1682            Expr::InList(InList {
1683                expr,
1684                mut list,
1685                negated,
1686            }) if list.len() == 1
1687                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1688            {
1689                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1690                    unreachable!()
1691                };
1692
1693                Transformed::yes(Expr::InSubquery(InSubquery::new(
1694                    expr, subquery, negated,
1695                )))
1696            }
1697
1698            // Combine multiple OR expressions into a single IN list expression if possible
1699            //
1700            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1701            Expr::BinaryExpr(BinaryExpr {
1702                left,
1703                op: Or,
1704                right,
1705            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1706                let lhs = to_inlist(*left).unwrap();
1707                let rhs = to_inlist(*right).unwrap();
1708                let mut seen: HashSet<Expr> = HashSet::new();
1709                let list = lhs
1710                    .list
1711                    .into_iter()
1712                    .chain(rhs.list)
1713                    .filter(|e| seen.insert(e.to_owned()))
1714                    .collect::<Vec<_>>();
1715
1716                let merged_inlist = InList {
1717                    expr: lhs.expr,
1718                    list,
1719                    negated: false,
1720                };
1721
1722                Transformed::yes(Expr::InList(merged_inlist))
1723            }
1724
1725            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1726            //
1727            // Rules:
1728            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1729            //   Intersection:
1730            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1731            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1732            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1733            //   Union:
1734            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1735            //     # This rule is handled by `or_in_list_simplifier.rs`
1736            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1737            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1738            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1739            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1740            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1741            Expr::BinaryExpr(BinaryExpr {
1742                left,
1743                op: And,
1744                right,
1745            }) if are_inlist_and_eq_and_match_neg(
1746                left.as_ref(),
1747                right.as_ref(),
1748                false,
1749                false,
1750            ) =>
1751            {
1752                match (*left, *right) {
1753                    (Expr::InList(l1), Expr::InList(l2)) => {
1754                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1755                    }
1756                    // Matched previously once
1757                    _ => unreachable!(),
1758                }
1759            }
1760
1761            Expr::BinaryExpr(BinaryExpr {
1762                left,
1763                op: And,
1764                right,
1765            }) if are_inlist_and_eq_and_match_neg(
1766                left.as_ref(),
1767                right.as_ref(),
1768                true,
1769                true,
1770            ) =>
1771            {
1772                match (*left, *right) {
1773                    (Expr::InList(l1), Expr::InList(l2)) => {
1774                        return inlist_union(l1, l2, true).map(Transformed::yes);
1775                    }
1776                    // Matched previously once
1777                    _ => unreachable!(),
1778                }
1779            }
1780
1781            Expr::BinaryExpr(BinaryExpr {
1782                left,
1783                op: And,
1784                right,
1785            }) if are_inlist_and_eq_and_match_neg(
1786                left.as_ref(),
1787                right.as_ref(),
1788                false,
1789                true,
1790            ) =>
1791            {
1792                match (*left, *right) {
1793                    (Expr::InList(l1), Expr::InList(l2)) => {
1794                        return inlist_except(l1, &l2).map(Transformed::yes);
1795                    }
1796                    // Matched previously once
1797                    _ => unreachable!(),
1798                }
1799            }
1800
1801            Expr::BinaryExpr(BinaryExpr {
1802                left,
1803                op: And,
1804                right,
1805            }) if are_inlist_and_eq_and_match_neg(
1806                left.as_ref(),
1807                right.as_ref(),
1808                true,
1809                false,
1810            ) =>
1811            {
1812                match (*left, *right) {
1813                    (Expr::InList(l1), Expr::InList(l2)) => {
1814                        return inlist_except(l2, &l1).map(Transformed::yes);
1815                    }
1816                    // Matched previously once
1817                    _ => unreachable!(),
1818                }
1819            }
1820
1821            Expr::BinaryExpr(BinaryExpr {
1822                left,
1823                op: Or,
1824                right,
1825            }) if are_inlist_and_eq_and_match_neg(
1826                left.as_ref(),
1827                right.as_ref(),
1828                true,
1829                true,
1830            ) =>
1831            {
1832                match (*left, *right) {
1833                    (Expr::InList(l1), Expr::InList(l2)) => {
1834                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1835                    }
1836                    // Matched previously once
1837                    _ => unreachable!(),
1838                }
1839            }
1840
1841            // =======================================
1842            // unwrap_cast_in_comparison
1843            // =======================================
1844            //
1845            // For case:
1846            // try_cast/cast(expr as data_type) op literal
1847            Expr::BinaryExpr(BinaryExpr { left, op, right })
1848                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1849                    info, &left, op, &right,
1850                ) && op.supports_propagation() =>
1851            {
1852                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1853            }
1854            // literal op try_cast/cast(expr as data_type)
1855            // -->
1856            // try_cast/cast(expr as data_type) op_swap literal
1857            Expr::BinaryExpr(BinaryExpr { left, op, right })
1858                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1859                    info, &right, op, &left,
1860                ) && op.supports_propagation()
1861                    && op.swap().is_some() =>
1862            {
1863                unwrap_cast_in_comparison_for_binary(
1864                    info,
1865                    *right,
1866                    *left,
1867                    op.swap().unwrap(),
1868                )?
1869            }
1870            // For case:
1871            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1872            Expr::InList(InList {
1873                expr: mut left,
1874                list,
1875                negated,
1876            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1877                info, &left, &list,
1878            ) =>
1879            {
1880                let (Expr::TryCast(TryCast {
1881                    expr: left_expr, ..
1882                })
1883                | Expr::Cast(Cast {
1884                    expr: left_expr, ..
1885                })) = left.as_mut()
1886                else {
1887                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1888                };
1889
1890                let expr_type = info.get_data_type(left_expr)?;
1891                let right_exprs = list
1892                    .into_iter()
1893                    .map(|right| {
1894                        match right {
1895                            Expr::Literal(right_lit_value, _) => {
1896                                // if the right_lit_value can be casted to the type of internal_left_expr
1897                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1898                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1899                                    internal_err!(
1900                                        "Can't cast the list expr {:?} to type {:?}",
1901                                        right_lit_value, &expr_type
1902                                    )?
1903                                };
1904                                Ok(lit(value))
1905                            }
1906                            other_expr => internal_err!(
1907                                "Only support literal expr to optimize, but the expr is {:?}",
1908                                &other_expr
1909                            ),
1910                        }
1911                    })
1912                    .collect::<Result<Vec<_>>>()?;
1913
1914                Transformed::yes(Expr::InList(InList {
1915                    expr: std::mem::take(left_expr),
1916                    list: right_exprs,
1917                    negated,
1918                }))
1919            }
1920
1921            // no additional rewrites possible
1922            expr => Transformed::no(expr),
1923        })
1924    }
1925}
1926
1927fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1928    match expr {
1929        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1930        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1931        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1932        _ => None,
1933    }
1934}
1935
1936fn to_string_scalar(data_type: DataType, value: Option<String>) -> Expr {
1937    match data_type {
1938        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1939        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1940        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1941        _ => unreachable!(),
1942    }
1943}
1944
1945fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1946    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1947    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1948}
1949
1950// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1951fn are_inlist_and_eq_and_match_neg(
1952    left: &Expr,
1953    right: &Expr,
1954    is_left_neg: bool,
1955    is_right_neg: bool,
1956) -> bool {
1957    match (left, right) {
1958        (Expr::InList(l), Expr::InList(r)) => {
1959            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1960        }
1961        _ => false,
1962    }
1963}
1964
1965// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1966fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
1967    let left = as_inlist(left);
1968    let right = as_inlist(right);
1969    if let (Some(lhs), Some(rhs)) = (left, right) {
1970        matches!(lhs.expr.as_ref(), Expr::Column(_))
1971            && matches!(rhs.expr.as_ref(), Expr::Column(_))
1972            && lhs.expr == rhs.expr
1973            && !lhs.negated
1974            && !rhs.negated
1975    } else {
1976        false
1977    }
1978}
1979
1980/// Try to convert an expression to an in-list expression
1981fn as_inlist(expr: &Expr) -> Option<Cow<InList>> {
1982    match expr {
1983        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
1984        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
1985            match (left.as_ref(), right.as_ref()) {
1986                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
1987                    expr: left.clone(),
1988                    list: vec![*right.clone()],
1989                    negated: false,
1990                })),
1991                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
1992                    expr: right.clone(),
1993                    list: vec![*left.clone()],
1994                    negated: false,
1995                })),
1996                _ => None,
1997            }
1998        }
1999        _ => None,
2000    }
2001}
2002
2003fn to_inlist(expr: Expr) -> Option<InList> {
2004    match expr {
2005        Expr::InList(inlist) => Some(inlist),
2006        Expr::BinaryExpr(BinaryExpr {
2007            left,
2008            op: Operator::Eq,
2009            right,
2010        }) => match (left.as_ref(), right.as_ref()) {
2011            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
2012                expr: left,
2013                list: vec![*right],
2014                negated: false,
2015            }),
2016            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
2017                expr: right,
2018                list: vec![*left],
2019                negated: false,
2020            }),
2021            _ => None,
2022        },
2023        _ => None,
2024    }
2025}
2026
2027/// Return the union of two inlist expressions
2028/// maintaining the order of the elements in the two lists
2029fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
2030    // extend the list in l1 with the elements in l2 that are not already in l1
2031    let l1_items: HashSet<_> = l1.list.iter().collect();
2032
2033    // keep all l2 items that do not also appear in l1
2034    let keep_l2: Vec<_> = l2
2035        .list
2036        .into_iter()
2037        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
2038        .collect();
2039
2040    l1.list.extend(keep_l2);
2041    l1.negated = negated;
2042    Ok(Expr::InList(l1))
2043}
2044
2045/// Return the intersection of two inlist expressions
2046/// maintaining the order of the elements in the two lists
2047fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
2048    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2049
2050    // remove all items from l1 that are not in l2
2051    l1.list.retain(|e| l2_items.contains(e));
2052
2053    // e in () is always false
2054    // e not in () is always true
2055    if l1.list.is_empty() {
2056        return Ok(lit(negated));
2057    }
2058    Ok(Expr::InList(l1))
2059}
2060
2061/// Return the all items in l1 that are not in l2
2062/// maintaining the order of the elements in the two lists
2063fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
2064    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2065
2066    // keep only items from l1 that are not in l2
2067    l1.list.retain(|e| !l2_items.contains(e));
2068
2069    if l1.list.is_empty() {
2070        return Ok(lit(false));
2071    }
2072    Ok(Expr::InList(l1))
2073}
2074
2075/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
2076fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
2077    if !info.nullable(&expr)? {
2078        Ok(expr)
2079    } else {
2080        Ok(Expr::BinaryExpr(BinaryExpr {
2081            left: Box::new(expr),
2082            op: Operator::IsNotDistinctFrom,
2083            right: Box::new(lit(true)),
2084        }))
2085    }
2086}
2087
2088// A * 1 -> A
2089// A / 1 -> A
2090//
2091// Move this function body out of the large match branch avoid stack overflow
2092fn simplify_right_is_one_case<S: SimplifyInfo>(
2093    info: &S,
2094    left: Box<Expr>,
2095    op: &Operator,
2096    right: &Expr,
2097) -> Result<Transformed<Expr>> {
2098    // Check if resulting type would be different due to coercion
2099    let left_type = info.get_data_type(&left)?;
2100    let right_type = info.get_data_type(right)?;
2101    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2102        Ok(result_type) => {
2103            // Only cast if the types differ
2104            if left_type != result_type {
2105                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2106            } else {
2107                Ok(Transformed::yes(*left))
2108            }
2109        }
2110        Err(_) => Ok(Transformed::yes(*left)),
2111    }
2112}
2113
2114// A * null -> null
2115// A / null -> null
2116//
2117// Move this function body out of the large match branch avoid stack overflow
2118fn simplify_right_is_null_case<S: SimplifyInfo>(
2119    info: &S,
2120    left: &Expr,
2121    op: &Operator,
2122    right: Box<Expr>,
2123) -> Result<Transformed<Expr>> {
2124    // Check if resulting type would be different due to coercion
2125    let left_type = info.get_data_type(left)?;
2126    let right_type = info.get_data_type(&right)?;
2127    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2128        Ok(result_type) => {
2129            // Only cast if the types differ
2130            if right_type != result_type {
2131                Ok(Transformed::yes(Expr::Cast(Cast::new(right, result_type))))
2132            } else {
2133                Ok(Transformed::yes(*right))
2134            }
2135        }
2136        Err(_) => Ok(Transformed::yes(*right)),
2137    }
2138}
2139
2140// null / A --> null
2141//
2142// Move this function body out of the large match branch avoid stack overflow
2143fn simplify_null_div_other_case<S: SimplifyInfo>(
2144    info: &S,
2145    left: Box<Expr>,
2146    right: &Expr,
2147) -> Result<Transformed<Expr>> {
2148    // Check if resulting type would be different due to coercion
2149    let left_type = info.get_data_type(&left)?;
2150    let right_type = info.get_data_type(right)?;
2151    match BinaryTypeCoercer::new(&left_type, &Operator::Divide, &right_type)
2152        .get_result_type()
2153    {
2154        Ok(result_type) => {
2155            // Only cast if the types differ
2156            if left_type != result_type {
2157                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2158            } else {
2159                Ok(Transformed::yes(*left))
2160            }
2161        }
2162        Err(_) => Ok(Transformed::yes(*left)),
2163    }
2164}
2165
2166#[cfg(test)]
2167mod tests {
2168    use super::*;
2169    use crate::simplify_expressions::SimplifyContext;
2170    use crate::test::test_table_scan_with_name;
2171    use arrow::datatypes::FieldRef;
2172    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
2173    use datafusion_expr::{
2174        expr::WindowFunction,
2175        function::{
2176            AccumulatorArgs, AggregateFunctionSimplification,
2177            WindowFunctionSimplification,
2178        },
2179        interval_arithmetic::Interval,
2180        *,
2181    };
2182    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2183    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2184    use std::hash::{DefaultHasher, Hash, Hasher};
2185    use std::{
2186        collections::HashMap,
2187        ops::{BitAnd, BitOr, BitXor},
2188        sync::Arc,
2189    };
2190
2191    // ------------------------------
2192    // --- ExprSimplifier tests -----
2193    // ------------------------------
2194    #[test]
2195    fn api_basic() {
2196        let props = ExecutionProps::new();
2197        let simplifier =
2198            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2199
2200        let expr = lit(1) + lit(2);
2201        let expected = lit(3);
2202        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2203    }
2204
2205    #[test]
2206    fn basic_coercion() {
2207        let schema = test_schema();
2208        let props = ExecutionProps::new();
2209        let simplifier = ExprSimplifier::new(
2210            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2211        );
2212
2213        // Note expr type is int32 (not int64)
2214        // (1i64 + 2i32) < i
2215        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2216        // should fully simplify to 3 < i (though i has been coerced to i64)
2217        let expected = lit(3i64).lt(col("i"));
2218
2219        let expr = simplifier.coerce(expr, &schema).unwrap();
2220
2221        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2222    }
2223
2224    fn test_schema() -> DFSchemaRef {
2225        Schema::new(vec![
2226            Field::new("i", DataType::Int64, false),
2227            Field::new("b", DataType::Boolean, true),
2228        ])
2229        .to_dfschema_ref()
2230        .unwrap()
2231    }
2232
2233    #[test]
2234    fn simplify_and_constant_prop() {
2235        let props = ExecutionProps::new();
2236        let simplifier =
2237            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2238
2239        // should be able to simplify to false
2240        // (i * (1 - 2)) > 0
2241        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2242        let expected = lit(false);
2243        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2244    }
2245
2246    #[test]
2247    fn simplify_and_constant_prop_with_case() {
2248        let props = ExecutionProps::new();
2249        let simplifier =
2250            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2251
2252        //   CASE
2253        //     WHEN i>5 AND false THEN i > 5
2254        //     WHEN i<5 AND true THEN i < 5
2255        //     ELSE false
2256        //   END
2257        //
2258        // Can be simplified to `i < 5`
2259        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2260            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2261            .otherwise(lit(false))
2262            .unwrap();
2263        let expected = col("i").lt(lit(5));
2264        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2265    }
2266
2267    // ------------------------------
2268    // --- Simplifier tests -----
2269    // ------------------------------
2270
2271    #[test]
2272    fn test_simplify_canonicalize() {
2273        {
2274            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2275            let expected = col("c2").gt(lit(1));
2276            assert_eq!(simplify(expr), expected);
2277        }
2278        {
2279            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2280            let expected = col("c2").gt(col("c1"));
2281            assert_eq!(simplify(expr), expected);
2282        }
2283        {
2284            let expr = col("c1")
2285                .eq(lit(1))
2286                .and(lit(1).eq(col("c1")))
2287                .and(col("c1").eq(lit(3)));
2288            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2289            assert_eq!(simplify(expr), expected);
2290        }
2291        {
2292            let expr = col("c1")
2293                .eq(col("c2"))
2294                .and(col("c1").gt(lit(5)))
2295                .and(col("c2").eq(col("c1")));
2296            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2297            assert_eq!(simplify(expr), expected);
2298        }
2299        {
2300            let expr = col("c1")
2301                .eq(lit(1))
2302                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2303            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2304            assert_eq!(simplify(expr), expected);
2305        }
2306        {
2307            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2308            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2309            assert_eq!(simplify(expr), expected);
2310        }
2311        {
2312            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2313            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2314            assert_eq!(simplify(expr), expected);
2315        }
2316        {
2317            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2318            let expected = col("c2").lt(col("c1"));
2319            assert_eq!(simplify(expr), expected);
2320        }
2321    }
2322
2323    #[test]
2324    fn test_simplify_eq_not_self() {
2325        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2326        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2327        let expr_a = col("c2").eq(col("c2"));
2328        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2329
2330        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2331        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2332        let expected_b = lit(true);
2333
2334        assert_eq!(simplify(expr_a), expected_a);
2335        assert_eq!(simplify(expr_b), expected_b);
2336    }
2337
2338    #[test]
2339    fn test_simplify_or_true() {
2340        let expr_a = col("c2").or(lit(true));
2341        let expr_b = lit(true).or(col("c2"));
2342        let expected = lit(true);
2343
2344        assert_eq!(simplify(expr_a), expected);
2345        assert_eq!(simplify(expr_b), expected);
2346    }
2347
2348    #[test]
2349    fn test_simplify_or_false() {
2350        let expr_a = lit(false).or(col("c2"));
2351        let expr_b = col("c2").or(lit(false));
2352        let expected = col("c2");
2353
2354        assert_eq!(simplify(expr_a), expected);
2355        assert_eq!(simplify(expr_b), expected);
2356    }
2357
2358    #[test]
2359    fn test_simplify_or_same() {
2360        let expr = col("c2").or(col("c2"));
2361        let expected = col("c2");
2362
2363        assert_eq!(simplify(expr), expected);
2364    }
2365
2366    #[test]
2367    fn test_simplify_or_not_self() {
2368        // A OR !A if A is not nullable --> true
2369        // !A OR A if A is not nullable --> true
2370        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2371        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2372        let expected = lit(true);
2373
2374        assert_eq!(simplify(expr_a), expected);
2375        assert_eq!(simplify(expr_b), expected);
2376    }
2377
2378    #[test]
2379    fn test_simplify_and_false() {
2380        let expr_a = lit(false).and(col("c2"));
2381        let expr_b = col("c2").and(lit(false));
2382        let expected = lit(false);
2383
2384        assert_eq!(simplify(expr_a), expected);
2385        assert_eq!(simplify(expr_b), expected);
2386    }
2387
2388    #[test]
2389    fn test_simplify_and_same() {
2390        let expr = col("c2").and(col("c2"));
2391        let expected = col("c2");
2392
2393        assert_eq!(simplify(expr), expected);
2394    }
2395
2396    #[test]
2397    fn test_simplify_and_true() {
2398        let expr_a = lit(true).and(col("c2"));
2399        let expr_b = col("c2").and(lit(true));
2400        let expected = col("c2");
2401
2402        assert_eq!(simplify(expr_a), expected);
2403        assert_eq!(simplify(expr_b), expected);
2404    }
2405
2406    #[test]
2407    fn test_simplify_and_not_self() {
2408        // A AND !A if A is not nullable --> false
2409        // !A AND A if A is not nullable --> false
2410        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2411        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2412        let expected = lit(false);
2413
2414        assert_eq!(simplify(expr_a), expected);
2415        assert_eq!(simplify(expr_b), expected);
2416    }
2417
2418    #[test]
2419    fn test_simplify_multiply_by_one() {
2420        let expr_a = col("c2") * lit(1);
2421        let expr_b = lit(1) * col("c2");
2422        let expected = col("c2");
2423
2424        assert_eq!(simplify(expr_a), expected);
2425        assert_eq!(simplify(expr_b), expected);
2426
2427        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2428        assert_eq!(simplify(expr), expected);
2429
2430        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2431        assert_eq!(simplify(expr), expected);
2432    }
2433
2434    #[test]
2435    fn test_simplify_multiply_by_null() {
2436        let null = lit(ScalarValue::Null);
2437        // A * null --> null
2438        {
2439            let expr = col("c2") * null.clone();
2440            assert_eq!(simplify(expr), null);
2441        }
2442        // null * A --> null
2443        {
2444            let expr = null.clone() * col("c2");
2445            assert_eq!(simplify(expr), null);
2446        }
2447    }
2448
2449    #[test]
2450    fn test_simplify_multiply_by_zero() {
2451        // cannot optimize A * null (null * A) if A is nullable
2452        {
2453            let expr_a = col("c2") * lit(0);
2454            let expr_b = lit(0) * col("c2");
2455
2456            assert_eq!(simplify(expr_a.clone()), expr_a);
2457            assert_eq!(simplify(expr_b.clone()), expr_b);
2458        }
2459        // 0 * A --> 0 if A is not nullable
2460        {
2461            let expr = lit(0) * col("c2_non_null");
2462            assert_eq!(simplify(expr), lit(0));
2463        }
2464        // A * 0 --> 0 if A is not nullable
2465        {
2466            let expr = col("c2_non_null") * lit(0);
2467            assert_eq!(simplify(expr), lit(0));
2468        }
2469        // A * Decimal128(0) --> 0 if A is not nullable
2470        {
2471            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2472            assert_eq!(
2473                simplify(expr),
2474                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2475            );
2476            let expr = binary_expr(
2477                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2478                Operator::Multiply,
2479                col("c2_non_null"),
2480            );
2481            assert_eq!(
2482                simplify(expr),
2483                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2484            );
2485        }
2486    }
2487
2488    #[test]
2489    fn test_simplify_divide_by_one() {
2490        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2491        let expected = col("c2");
2492        assert_eq!(simplify(expr), expected);
2493        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2494        assert_eq!(simplify(expr), expected);
2495    }
2496
2497    #[test]
2498    fn test_simplify_divide_null() {
2499        // A / null --> null
2500        let null = lit(ScalarValue::Null);
2501        {
2502            let expr = col("c1") / null.clone();
2503            assert_eq!(simplify(expr), null);
2504        }
2505        // null / A --> null
2506        {
2507            let expr = null.clone() / col("c1");
2508            assert_eq!(simplify(expr), null);
2509        }
2510    }
2511
2512    #[test]
2513    fn test_simplify_divide_by_same() {
2514        let expr = col("c2") / col("c2");
2515        // if c2 is null, c2 / c2 = null, so can't simplify
2516        let expected = expr.clone();
2517
2518        assert_eq!(simplify(expr), expected);
2519    }
2520
2521    #[test]
2522    fn test_simplify_modulo_by_null() {
2523        let null = lit(ScalarValue::Null);
2524        // A % null --> null
2525        {
2526            let expr = col("c2") % null.clone();
2527            assert_eq!(simplify(expr), null);
2528        }
2529        // null % A --> null
2530        {
2531            let expr = null.clone() % col("c2");
2532            assert_eq!(simplify(expr), null);
2533        }
2534    }
2535
2536    #[test]
2537    fn test_simplify_modulo_by_one() {
2538        let expr = col("c2") % lit(1);
2539        // if c2 is null, c2 % 1 = null, so can't simplify
2540        let expected = expr.clone();
2541
2542        assert_eq!(simplify(expr), expected);
2543    }
2544
2545    #[test]
2546    fn test_simplify_divide_zero_by_zero() {
2547        // because divide by 0 maybe occur in short-circuit expression
2548        // so we should not simplify this, and throw error in runtime
2549        let expr = lit(0) / lit(0);
2550        let expected = expr.clone();
2551
2552        assert_eq!(simplify(expr), expected);
2553    }
2554
2555    #[test]
2556    fn test_simplify_divide_by_zero() {
2557        // because divide by 0 maybe occur in short-circuit expression
2558        // so we should not simplify this, and throw error in runtime
2559        let expr = col("c2_non_null") / lit(0);
2560        let expected = expr.clone();
2561
2562        assert_eq!(simplify(expr), expected);
2563    }
2564
2565    #[test]
2566    fn test_simplify_modulo_by_one_non_null() {
2567        let expr = col("c3_non_null") % lit(1);
2568        let expected = lit(0_i64);
2569        assert_eq!(simplify(expr), expected);
2570        let expr =
2571            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2572        assert_eq!(simplify(expr), expected);
2573    }
2574
2575    #[test]
2576    fn test_simplify_bitwise_xor_by_null() {
2577        let null = lit(ScalarValue::Null);
2578        // A ^ null --> null
2579        {
2580            let expr = col("c2") ^ null.clone();
2581            assert_eq!(simplify(expr), null);
2582        }
2583        // null ^ A --> null
2584        {
2585            let expr = null.clone() ^ col("c2");
2586            assert_eq!(simplify(expr), null);
2587        }
2588    }
2589
2590    #[test]
2591    fn test_simplify_bitwise_shift_right_by_null() {
2592        let null = lit(ScalarValue::Null);
2593        // A >> null --> null
2594        {
2595            let expr = col("c2") >> null.clone();
2596            assert_eq!(simplify(expr), null);
2597        }
2598        // null >> A --> null
2599        {
2600            let expr = null.clone() >> col("c2");
2601            assert_eq!(simplify(expr), null);
2602        }
2603    }
2604
2605    #[test]
2606    fn test_simplify_bitwise_shift_left_by_null() {
2607        let null = lit(ScalarValue::Null);
2608        // A << null --> null
2609        {
2610            let expr = col("c2") << null.clone();
2611            assert_eq!(simplify(expr), null);
2612        }
2613        // null << A --> null
2614        {
2615            let expr = null.clone() << col("c2");
2616            assert_eq!(simplify(expr), null);
2617        }
2618    }
2619
2620    #[test]
2621    fn test_simplify_bitwise_and_by_zero() {
2622        // A & 0 --> 0
2623        {
2624            let expr = col("c2_non_null") & lit(0);
2625            assert_eq!(simplify(expr), lit(0));
2626        }
2627        // 0 & A --> 0
2628        {
2629            let expr = lit(0) & col("c2_non_null");
2630            assert_eq!(simplify(expr), lit(0));
2631        }
2632    }
2633
2634    #[test]
2635    fn test_simplify_bitwise_or_by_zero() {
2636        // A | 0 --> A
2637        {
2638            let expr = col("c2_non_null") | lit(0);
2639            assert_eq!(simplify(expr), col("c2_non_null"));
2640        }
2641        // 0 | A --> A
2642        {
2643            let expr = lit(0) | col("c2_non_null");
2644            assert_eq!(simplify(expr), col("c2_non_null"));
2645        }
2646    }
2647
2648    #[test]
2649    fn test_simplify_bitwise_xor_by_zero() {
2650        // A ^ 0 --> A
2651        {
2652            let expr = col("c2_non_null") ^ lit(0);
2653            assert_eq!(simplify(expr), col("c2_non_null"));
2654        }
2655        // 0 ^ A --> A
2656        {
2657            let expr = lit(0) ^ col("c2_non_null");
2658            assert_eq!(simplify(expr), col("c2_non_null"));
2659        }
2660    }
2661
2662    #[test]
2663    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2664        // A >> 0 --> A
2665        {
2666            let expr = col("c2_non_null") >> lit(0);
2667            assert_eq!(simplify(expr), col("c2_non_null"));
2668        }
2669    }
2670
2671    #[test]
2672    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2673        // A << 0 --> A
2674        {
2675            let expr = col("c2_non_null") << lit(0);
2676            assert_eq!(simplify(expr), col("c2_non_null"));
2677        }
2678    }
2679
2680    #[test]
2681    fn test_simplify_bitwise_and_by_null() {
2682        let null = lit(ScalarValue::Null);
2683        // A & null --> null
2684        {
2685            let expr = col("c2") & null.clone();
2686            assert_eq!(simplify(expr), null);
2687        }
2688        // null & A --> null
2689        {
2690            let expr = null.clone() & col("c2");
2691            assert_eq!(simplify(expr), null);
2692        }
2693    }
2694
2695    #[test]
2696    fn test_simplify_composed_bitwise_and() {
2697        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2698
2699        let expr = bitwise_and(
2700            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2701            col("c2").gt(lit(5)),
2702        );
2703        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2704
2705        assert_eq!(simplify(expr), expected);
2706
2707        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2708
2709        let expr = bitwise_and(
2710            col("c2").gt(lit(5)),
2711            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2712        );
2713        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2714        assert_eq!(simplify(expr), expected);
2715    }
2716
2717    #[test]
2718    fn test_simplify_composed_bitwise_or() {
2719        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2720
2721        let expr = bitwise_or(
2722            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2723            col("c2").gt(lit(5)),
2724        );
2725        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2726
2727        assert_eq!(simplify(expr), expected);
2728
2729        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2730
2731        let expr = bitwise_or(
2732            col("c2").gt(lit(5)),
2733            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2734        );
2735        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2736
2737        assert_eq!(simplify(expr), expected);
2738    }
2739
2740    #[test]
2741    fn test_simplify_composed_bitwise_xor() {
2742        // with an even number of the column "c2"
2743        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2744
2745        let expr = bitwise_xor(
2746            col("c2"),
2747            bitwise_xor(
2748                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2749                bitwise_and(col("c1"), col("c2")),
2750            ),
2751        );
2752
2753        let expected = bitwise_xor(
2754            bitwise_or(col("c2"), col("c1")),
2755            bitwise_and(col("c1"), col("c2")),
2756        );
2757
2758        assert_eq!(simplify(expr), expected);
2759
2760        // with an odd number of the column "c2"
2761        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2762
2763        let expr = bitwise_xor(
2764            col("c2"),
2765            bitwise_xor(
2766                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2767                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2768            ),
2769        );
2770
2771        let expected = bitwise_xor(
2772            col("c2"),
2773            bitwise_xor(
2774                bitwise_or(col("c2"), col("c1")),
2775                bitwise_and(col("c1"), col("c2")),
2776            ),
2777        );
2778
2779        assert_eq!(simplify(expr), expected);
2780
2781        // with an even number of the column "c2"
2782        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2783
2784        let expr = bitwise_xor(
2785            bitwise_xor(
2786                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2787                bitwise_and(col("c1"), col("c2")),
2788            ),
2789            col("c2"),
2790        );
2791
2792        let expected = bitwise_xor(
2793            bitwise_or(col("c2"), col("c1")),
2794            bitwise_and(col("c1"), col("c2")),
2795        );
2796
2797        assert_eq!(simplify(expr), expected);
2798
2799        // with an odd number of the column "c2"
2800        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2801
2802        let expr = bitwise_xor(
2803            bitwise_xor(
2804                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2805                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2806            ),
2807            col("c2"),
2808        );
2809
2810        let expected = bitwise_xor(
2811            bitwise_xor(
2812                bitwise_or(col("c2"), col("c1")),
2813                bitwise_and(col("c1"), col("c2")),
2814            ),
2815            col("c2"),
2816        );
2817
2818        assert_eq!(simplify(expr), expected);
2819    }
2820
2821    #[test]
2822    fn test_simplify_negated_bitwise_and() {
2823        // !c4 & c4 --> 0
2824        let expr = (-col("c4_non_null")) & col("c4_non_null");
2825        let expected = lit(0u32);
2826
2827        assert_eq!(simplify(expr), expected);
2828        // c4 & !c4 --> 0
2829        let expr = col("c4_non_null") & (-col("c4_non_null"));
2830        let expected = lit(0u32);
2831
2832        assert_eq!(simplify(expr), expected);
2833
2834        // !c3 & c3 --> 0
2835        let expr = (-col("c3_non_null")) & col("c3_non_null");
2836        let expected = lit(0i64);
2837
2838        assert_eq!(simplify(expr), expected);
2839        // c3 & !c3 --> 0
2840        let expr = col("c3_non_null") & (-col("c3_non_null"));
2841        let expected = lit(0i64);
2842
2843        assert_eq!(simplify(expr), expected);
2844    }
2845
2846    #[test]
2847    fn test_simplify_negated_bitwise_or() {
2848        // !c4 | c4 --> -1
2849        let expr = (-col("c4_non_null")) | col("c4_non_null");
2850        let expected = lit(-1i32);
2851
2852        assert_eq!(simplify(expr), expected);
2853
2854        // c4 | !c4 --> -1
2855        let expr = col("c4_non_null") | (-col("c4_non_null"));
2856        let expected = lit(-1i32);
2857
2858        assert_eq!(simplify(expr), expected);
2859
2860        // !c3 | c3 --> -1
2861        let expr = (-col("c3_non_null")) | col("c3_non_null");
2862        let expected = lit(-1i64);
2863
2864        assert_eq!(simplify(expr), expected);
2865
2866        // c3 | !c3 --> -1
2867        let expr = col("c3_non_null") | (-col("c3_non_null"));
2868        let expected = lit(-1i64);
2869
2870        assert_eq!(simplify(expr), expected);
2871    }
2872
2873    #[test]
2874    fn test_simplify_negated_bitwise_xor() {
2875        // !c4 ^ c4 --> -1
2876        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2877        let expected = lit(-1i32);
2878
2879        assert_eq!(simplify(expr), expected);
2880
2881        // c4 ^ !c4 --> -1
2882        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2883        let expected = lit(-1i32);
2884
2885        assert_eq!(simplify(expr), expected);
2886
2887        // !c3 ^ c3 --> -1
2888        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2889        let expected = lit(-1i64);
2890
2891        assert_eq!(simplify(expr), expected);
2892
2893        // c3 ^ !c3 --> -1
2894        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2895        let expected = lit(-1i64);
2896
2897        assert_eq!(simplify(expr), expected);
2898    }
2899
2900    #[test]
2901    fn test_simplify_bitwise_and_or() {
2902        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2903        let expr = bitwise_and(
2904            col("c2_non_null").lt(lit(3)),
2905            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2906        );
2907        let expected = col("c2_non_null").lt(lit(3));
2908
2909        assert_eq!(simplify(expr), expected);
2910    }
2911
2912    #[test]
2913    fn test_simplify_bitwise_or_and() {
2914        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2915        let expr = bitwise_or(
2916            col("c2_non_null").lt(lit(3)),
2917            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2918        );
2919        let expected = col("c2_non_null").lt(lit(3));
2920
2921        assert_eq!(simplify(expr), expected);
2922    }
2923
2924    #[test]
2925    fn test_simplify_simple_bitwise_and() {
2926        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2927        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2928        let expected = col("c2").gt(lit(5));
2929
2930        assert_eq!(simplify(expr), expected);
2931    }
2932
2933    #[test]
2934    fn test_simplify_simple_bitwise_or() {
2935        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2936        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2937        let expected = col("c2").gt(lit(5));
2938
2939        assert_eq!(simplify(expr), expected);
2940    }
2941
2942    #[test]
2943    fn test_simplify_simple_bitwise_xor() {
2944        // c4 ^ c4 -> 0
2945        let expr = (col("c4")).bitxor(col("c4"));
2946        let expected = lit(0u32);
2947
2948        assert_eq!(simplify(expr), expected);
2949
2950        // c3 ^ c3 -> 0
2951        let expr = col("c3").bitxor(col("c3"));
2952        let expected = lit(0i64);
2953
2954        assert_eq!(simplify(expr), expected);
2955    }
2956
2957    #[test]
2958    fn test_simplify_modulo_by_zero_non_null() {
2959        // because modulo by 0 maybe occur in short-circuit expression
2960        // so we should not simplify this, and throw error in runtime.
2961        let expr = col("c2_non_null") % lit(0);
2962        let expected = expr.clone();
2963
2964        assert_eq!(simplify(expr), expected);
2965    }
2966
2967    #[test]
2968    fn test_simplify_simple_and() {
2969        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2970        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2971        let expected = col("c2").gt(lit(5));
2972
2973        assert_eq!(simplify(expr), expected);
2974    }
2975
2976    #[test]
2977    fn test_simplify_composed_and() {
2978        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2979        let expr = and(
2980            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2981            col("c2").gt(lit(5)),
2982        );
2983        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2984
2985        assert_eq!(simplify(expr), expected);
2986    }
2987
2988    #[test]
2989    fn test_simplify_negated_and() {
2990        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2991        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2992        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2993
2994        assert_eq!(simplify(expr), expected);
2995    }
2996
2997    #[test]
2998    fn test_simplify_or_and() {
2999        let l = col("c2").gt(lit(5));
3000        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3001
3002        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
3003        let expr = or(l.clone(), r.clone());
3004
3005        let expected = l.clone();
3006        assert_eq!(simplify(expr), expected);
3007
3008        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
3009        let expr = or(r, l);
3010        assert_eq!(simplify(expr), expected);
3011    }
3012
3013    #[test]
3014    fn test_simplify_or_and_non_null() {
3015        let l = col("c2_non_null").gt(lit(5));
3016        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3017
3018        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
3019        let expr = or(l.clone(), r.clone());
3020
3021        // This is only true if `c1 < 6` is not nullable / can not be null.
3022        let expected = col("c2_non_null").gt(lit(5));
3023
3024        assert_eq!(simplify(expr), expected);
3025
3026        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
3027        let expr = or(l, r);
3028
3029        assert_eq!(simplify(expr), expected);
3030    }
3031
3032    #[test]
3033    fn test_simplify_and_or() {
3034        let l = col("c2").gt(lit(5));
3035        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3036
3037        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3038        let expr = and(l.clone(), r.clone());
3039
3040        let expected = l.clone();
3041        assert_eq!(simplify(expr), expected);
3042
3043        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3044        let expr = and(r, l);
3045        assert_eq!(simplify(expr), expected);
3046    }
3047
3048    #[test]
3049    fn test_simplify_and_or_non_null() {
3050        let l = col("c2_non_null").gt(lit(5));
3051        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3052
3053        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3054        let expr = and(l.clone(), r.clone());
3055
3056        // This is only true if `c1 < 6` is not nullable / can not be null.
3057        let expected = col("c2_non_null").gt(lit(5));
3058
3059        assert_eq!(simplify(expr), expected);
3060
3061        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3062        let expr = and(l, r);
3063
3064        assert_eq!(simplify(expr), expected);
3065    }
3066
3067    #[test]
3068    fn test_simplify_by_de_morgan_laws() {
3069        // Laws with logical operations
3070        // !(c3 AND c4) --> !c3 OR !c4
3071        let expr = and(col("c3"), col("c4")).not();
3072        let expected = or(col("c3").not(), col("c4").not());
3073        assert_eq!(simplify(expr), expected);
3074        // !(c3 OR c4) --> !c3 AND !c4
3075        let expr = or(col("c3"), col("c4")).not();
3076        let expected = and(col("c3").not(), col("c4").not());
3077        assert_eq!(simplify(expr), expected);
3078        // !(!c3) --> c3
3079        let expr = col("c3").not().not();
3080        let expected = col("c3");
3081        assert_eq!(simplify(expr), expected);
3082
3083        // Laws with bitwise operations
3084        // !(c3 & c4) --> !c3 | !c4
3085        let expr = -bitwise_and(col("c3"), col("c4"));
3086        let expected = bitwise_or(-col("c3"), -col("c4"));
3087        assert_eq!(simplify(expr), expected);
3088        // !(c3 | c4) --> !c3 & !c4
3089        let expr = -bitwise_or(col("c3"), col("c4"));
3090        let expected = bitwise_and(-col("c3"), -col("c4"));
3091        assert_eq!(simplify(expr), expected);
3092        // !(!c3) --> c3
3093        let expr = -(-col("c3"));
3094        let expected = col("c3");
3095        assert_eq!(simplify(expr), expected);
3096    }
3097
3098    #[test]
3099    fn test_simplify_null_and_false() {
3100        let expr = and(lit_bool_null(), lit(false));
3101        let expr_eq = lit(false);
3102
3103        assert_eq!(simplify(expr), expr_eq);
3104    }
3105
3106    #[test]
3107    fn test_simplify_divide_null_by_null() {
3108        let null = lit(ScalarValue::Int32(None));
3109        let expr_plus = null.clone() / null.clone();
3110        let expr_eq = null;
3111
3112        assert_eq!(simplify(expr_plus), expr_eq);
3113    }
3114
3115    #[test]
3116    fn test_simplify_simplify_arithmetic_expr() {
3117        let expr_plus = lit(1) + lit(1);
3118
3119        assert_eq!(simplify(expr_plus), lit(2));
3120    }
3121
3122    #[test]
3123    fn test_simplify_simplify_eq_expr() {
3124        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
3125
3126        assert_eq!(simplify(expr_eq), lit(true));
3127    }
3128
3129    #[test]
3130    fn test_simplify_regex() {
3131        // malformed regex
3132        assert_contains!(
3133            try_simplify(regex_match(col("c1"), lit("foo{")))
3134                .unwrap_err()
3135                .to_string(),
3136            "regex parse error"
3137        );
3138
3139        // unsupported cases
3140        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3141        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3142        assert_no_change(regex_match(col("c1"), lit("%")));
3143        assert_no_change(regex_match(col("c1"), lit("_")));
3144        assert_no_change(regex_match(col("c1"), lit("f%o")));
3145        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3146        assert_no_change(regex_match(col("c1"), lit("f_o")));
3147
3148        // empty cases
3149        assert_change(
3150            regex_match(col("c1"), lit("")),
3151            if_not_null(col("c1"), true),
3152        );
3153        assert_change(
3154            regex_not_match(col("c1"), lit("")),
3155            if_not_null(col("c1"), false),
3156        );
3157        assert_change(
3158            regex_imatch(col("c1"), lit("")),
3159            if_not_null(col("c1"), true),
3160        );
3161        assert_change(
3162            regex_not_imatch(col("c1"), lit("")),
3163            if_not_null(col("c1"), false),
3164        );
3165
3166        // single character
3167        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3168
3169        // single word
3170        assert_change(
3171            regex_match(col("c1"), lit("foo")),
3172            col("c1").like(lit("%foo%")),
3173        );
3174
3175        // regular expressions that match an exact literal
3176        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3177        assert_change(
3178            regex_not_match(col("c1"), lit("^$")),
3179            col("c1").not_eq(lit("")),
3180        );
3181        assert_change(
3182            regex_match(col("c1"), lit("^foo$")),
3183            col("c1").eq(lit("foo")),
3184        );
3185        assert_change(
3186            regex_not_match(col("c1"), lit("^foo$")),
3187            col("c1").not_eq(lit("foo")),
3188        );
3189
3190        // regular expressions that match exact captured literals
3191        assert_change(
3192            regex_match(col("c1"), lit("^(foo|bar)$")),
3193            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3194        );
3195        assert_change(
3196            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3197            col("c1")
3198                .not_eq(lit("foo"))
3199                .and(col("c1").not_eq(lit("bar"))),
3200        );
3201        assert_change(
3202            regex_match(col("c1"), lit("^(foo)$")),
3203            col("c1").eq(lit("foo")),
3204        );
3205        assert_change(
3206            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3207            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3208                .or(col("c1").eq(lit("baz"))),
3209        );
3210        assert_change(
3211            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3212            col("c1")
3213                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3214        );
3215        assert_change(
3216            regex_match(col("c1"), lit("^(fo_o)$")),
3217            col("c1").eq(lit("fo_o")),
3218        );
3219        assert_change(
3220            regex_match(col("c1"), lit("^(fo_o)$")),
3221            col("c1").eq(lit("fo_o")),
3222        );
3223        assert_change(
3224            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3225            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3226        );
3227        assert_change(
3228            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3229            col("c1")
3230                .not_eq(lit("fo_o"))
3231                .and(col("c1").not_eq(lit("ba_r"))),
3232        );
3233        assert_change(
3234            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3235            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3236                .or(col("c1").eq(lit("ba_z"))),
3237        );
3238        assert_change(
3239            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3240            col("c1").in_list(
3241                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3242                false,
3243            ),
3244        );
3245
3246        // regular expressions that mismatch captured literals
3247        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3248        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3249        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3250        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3251        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3252        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3253        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3254        assert_no_change(regex_match(col("c1"), lit("^")));
3255        assert_no_change(regex_match(col("c1"), lit("$")));
3256        assert_no_change(regex_match(col("c1"), lit("$^")));
3257        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3258
3259        // regular expressions that match a partial literal
3260        assert_change(
3261            regex_match(col("c1"), lit("^foo")),
3262            col("c1").like(lit("foo%")),
3263        );
3264        assert_change(
3265            regex_match(col("c1"), lit("foo$")),
3266            col("c1").like(lit("%foo")),
3267        );
3268        assert_change(
3269            regex_match(col("c1"), lit("^foo|bar$")),
3270            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3271        );
3272
3273        // OR-chain
3274        assert_change(
3275            regex_match(col("c1"), lit("foo|bar|baz")),
3276            col("c1")
3277                .like(lit("%foo%"))
3278                .or(col("c1").like(lit("%bar%")))
3279                .or(col("c1").like(lit("%baz%"))),
3280        );
3281        assert_change(
3282            regex_match(col("c1"), lit("foo|x|baz")),
3283            col("c1")
3284                .like(lit("%foo%"))
3285                .or(col("c1").like(lit("%x%")))
3286                .or(col("c1").like(lit("%baz%"))),
3287        );
3288        assert_change(
3289            regex_not_match(col("c1"), lit("foo|bar|baz")),
3290            col("c1")
3291                .not_like(lit("%foo%"))
3292                .and(col("c1").not_like(lit("%bar%")))
3293                .and(col("c1").not_like(lit("%baz%"))),
3294        );
3295        // both anchored expressions (translated to equality) and unanchored
3296        assert_change(
3297            regex_match(col("c1"), lit("foo|^x$|baz")),
3298            col("c1")
3299                .like(lit("%foo%"))
3300                .or(col("c1").eq(lit("x")))
3301                .or(col("c1").like(lit("%baz%"))),
3302        );
3303        assert_change(
3304            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3305            col("c1")
3306                .not_like(lit("%foo%"))
3307                .and(col("c1").not_eq(lit("bar")))
3308                .and(col("c1").not_like(lit("%baz%"))),
3309        );
3310        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3311        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3312    }
3313
3314    #[track_caller]
3315    fn assert_no_change(expr: Expr) {
3316        let optimized = simplify(expr.clone());
3317        assert_eq!(expr, optimized);
3318    }
3319
3320    #[track_caller]
3321    fn assert_change(expr: Expr, expected: Expr) {
3322        let optimized = simplify(expr);
3323        assert_eq!(optimized, expected);
3324    }
3325
3326    fn regex_match(left: Expr, right: Expr) -> Expr {
3327        Expr::BinaryExpr(BinaryExpr {
3328            left: Box::new(left),
3329            op: Operator::RegexMatch,
3330            right: Box::new(right),
3331        })
3332    }
3333
3334    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3335        Expr::BinaryExpr(BinaryExpr {
3336            left: Box::new(left),
3337            op: Operator::RegexNotMatch,
3338            right: Box::new(right),
3339        })
3340    }
3341
3342    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3343        Expr::BinaryExpr(BinaryExpr {
3344            left: Box::new(left),
3345            op: Operator::RegexIMatch,
3346            right: Box::new(right),
3347        })
3348    }
3349
3350    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3351        Expr::BinaryExpr(BinaryExpr {
3352            left: Box::new(left),
3353            op: Operator::RegexNotIMatch,
3354            right: Box::new(right),
3355        })
3356    }
3357
3358    // ------------------------------
3359    // ----- Simplifier tests -------
3360    // ------------------------------
3361
3362    fn try_simplify(expr: Expr) -> Result<Expr> {
3363        let schema = expr_test_schema();
3364        let execution_props = ExecutionProps::new();
3365        let simplifier = ExprSimplifier::new(
3366            SimplifyContext::new(&execution_props).with_schema(schema),
3367        );
3368        simplifier.simplify(expr)
3369    }
3370
3371    fn coerce(expr: Expr) -> Expr {
3372        let schema = expr_test_schema();
3373        let execution_props = ExecutionProps::new();
3374        let simplifier = ExprSimplifier::new(
3375            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3376        );
3377        simplifier.coerce(expr, schema.as_ref()).unwrap()
3378    }
3379
3380    fn simplify(expr: Expr) -> Expr {
3381        try_simplify(expr).unwrap()
3382    }
3383
3384    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3385        let schema = expr_test_schema();
3386        let execution_props = ExecutionProps::new();
3387        let simplifier = ExprSimplifier::new(
3388            SimplifyContext::new(&execution_props).with_schema(schema),
3389        );
3390        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3391        Ok((expr.data, count))
3392    }
3393
3394    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3395        try_simplify_with_cycle_count(expr).unwrap()
3396    }
3397
3398    fn simplify_with_guarantee(
3399        expr: Expr,
3400        guarantees: Vec<(Expr, NullableInterval)>,
3401    ) -> Expr {
3402        let schema = expr_test_schema();
3403        let execution_props = ExecutionProps::new();
3404        let simplifier = ExprSimplifier::new(
3405            SimplifyContext::new(&execution_props).with_schema(schema),
3406        )
3407        .with_guarantees(guarantees);
3408        simplifier.simplify(expr).unwrap()
3409    }
3410
3411    fn expr_test_schema() -> DFSchemaRef {
3412        Arc::new(
3413            DFSchema::from_unqualified_fields(
3414                vec![
3415                    Field::new("c1", DataType::Utf8, true),
3416                    Field::new("c2", DataType::Boolean, true),
3417                    Field::new("c3", DataType::Int64, true),
3418                    Field::new("c4", DataType::UInt32, true),
3419                    Field::new("c1_non_null", DataType::Utf8, false),
3420                    Field::new("c2_non_null", DataType::Boolean, false),
3421                    Field::new("c3_non_null", DataType::Int64, false),
3422                    Field::new("c4_non_null", DataType::UInt32, false),
3423                    Field::new("c5", DataType::FixedSizeBinary(3), true),
3424                ]
3425                .into(),
3426                HashMap::new(),
3427            )
3428            .unwrap(),
3429        )
3430    }
3431
3432    #[test]
3433    fn simplify_expr_null_comparison() {
3434        // x = null is always null
3435        assert_eq!(
3436            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3437            lit(ScalarValue::Boolean(None)),
3438        );
3439
3440        // null != null is always null
3441        assert_eq!(
3442            simplify(
3443                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3444            ),
3445            lit(ScalarValue::Boolean(None)),
3446        );
3447
3448        // x != null is always null
3449        assert_eq!(
3450            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3451            lit(ScalarValue::Boolean(None)),
3452        );
3453
3454        // null = x is always null
3455        assert_eq!(
3456            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3457            lit(ScalarValue::Boolean(None)),
3458        );
3459    }
3460
3461    #[test]
3462    fn simplify_expr_is_not_null() {
3463        assert_eq!(
3464            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3465            Expr::IsNotNull(Box::new(col("c1")))
3466        );
3467
3468        // 'c1_non_null IS NOT NULL' is always true
3469        assert_eq!(
3470            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3471            lit(true)
3472        );
3473    }
3474
3475    #[test]
3476    fn simplify_expr_is_null() {
3477        assert_eq!(
3478            simplify(Expr::IsNull(Box::new(col("c1")))),
3479            Expr::IsNull(Box::new(col("c1")))
3480        );
3481
3482        // 'c1_non_null IS NULL' is always false
3483        assert_eq!(
3484            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3485            lit(false)
3486        );
3487    }
3488
3489    #[test]
3490    fn simplify_expr_is_unknown() {
3491        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3492
3493        // 'c2_non_null is unknown' is always false
3494        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3495    }
3496
3497    #[test]
3498    fn simplify_expr_is_not_known() {
3499        assert_eq!(
3500            simplify(col("c2").is_not_unknown()),
3501            col("c2").is_not_unknown()
3502        );
3503
3504        // 'c2_non_null is not unknown' is always true
3505        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3506    }
3507
3508    #[test]
3509    fn simplify_expr_eq() {
3510        let schema = expr_test_schema();
3511        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3512
3513        // true = true -> true
3514        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3515
3516        // true = false -> false
3517        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3518
3519        // c2 = true -> c2
3520        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3521
3522        // c2 = false => !c2
3523        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3524    }
3525
3526    #[test]
3527    fn simplify_expr_eq_skip_nonboolean_type() {
3528        let schema = expr_test_schema();
3529
3530        // When one of the operand is not of boolean type, folding the
3531        // other boolean constant will change return type of
3532        // expression to non-boolean.
3533        //
3534        // Make sure c1 column to be used in tests is not boolean type
3535        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3536
3537        // don't fold c1 = foo
3538        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3539    }
3540
3541    #[test]
3542    fn simplify_expr_not_eq() {
3543        let schema = expr_test_schema();
3544
3545        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3546
3547        // c2 != true -> !c2
3548        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3549
3550        // c2 != false -> c2
3551        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3552
3553        // test constant
3554        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3555
3556        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3557    }
3558
3559    #[test]
3560    fn simplify_expr_not_eq_skip_nonboolean_type() {
3561        let schema = expr_test_schema();
3562
3563        // when one of the operand is not of boolean type, folding the
3564        // other boolean constant will change return type of
3565        // expression to non-boolean.
3566        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3567
3568        assert_eq!(
3569            simplify(col("c1").not_eq(lit("foo"))),
3570            col("c1").not_eq(lit("foo")),
3571        );
3572    }
3573
3574    #[test]
3575    fn simplify_expr_case_when_then_else() {
3576        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3577        // -->
3578        // CASE WHEN c2 THEN false ELSE c2
3579        // -->
3580        // false
3581        assert_eq!(
3582            simplify(Expr::Case(Case::new(
3583                None,
3584                vec![(
3585                    Box::new(col("c2_non_null").not_eq(lit(false))),
3586                    Box::new(lit("ok").eq(lit("not_ok"))),
3587                )],
3588                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3589            ))),
3590            lit(false) // #1716
3591        );
3592
3593        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3594        // -->
3595        // CASE WHEN c2 THEN true ELSE c2
3596        // -->
3597        // c2
3598        //
3599        // Need to call simplify 2x due to
3600        // https://github.com/apache/datafusion/issues/1160
3601        assert_eq!(
3602            simplify(simplify(Expr::Case(Case::new(
3603                None,
3604                vec![(
3605                    Box::new(col("c2_non_null").not_eq(lit(false))),
3606                    Box::new(lit("ok").eq(lit("ok"))),
3607                )],
3608                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3609            )))),
3610            col("c2_non_null")
3611        );
3612
3613        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3614        // -->
3615        // ISNULL(c2) OR c2
3616        //
3617        // Need to call simplify 2x due to
3618        // https://github.com/apache/datafusion/issues/1160
3619        assert_eq!(
3620            simplify(simplify(Expr::Case(Case::new(
3621                None,
3622                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3623                Some(Box::new(col("c2"))),
3624            )))),
3625            col("c2")
3626                .is_null()
3627                .or(col("c2").is_not_null().and(col("c2")))
3628        );
3629
3630        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3631        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3632        // --> c1 OR (NOT(c1) AND NOT(c2))
3633        // --> c1 OR NOT(c2)
3634        //
3635        // Need to call simplify 2x due to
3636        // https://github.com/apache/datafusion/issues/1160
3637        assert_eq!(
3638            simplify(simplify(Expr::Case(Case::new(
3639                None,
3640                vec![
3641                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3642                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3643                ],
3644                Some(Box::new(lit(true))),
3645            )))),
3646            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3647        );
3648
3649        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3650        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3651        // --> c1 OR (NOT(c1) AND c2)
3652        // --> c1 OR c2
3653        //
3654        // Need to call simplify 2x due to
3655        // https://github.com/apache/datafusion/issues/1160
3656        assert_eq!(
3657            simplify(simplify(Expr::Case(Case::new(
3658                None,
3659                vec![
3660                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3661                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3662                ],
3663                Some(Box::new(lit(true))),
3664            )))),
3665            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3666        );
3667
3668        // CASE WHEN c > 0 THEN true END AS c1
3669        assert_eq!(
3670            simplify(simplify(Expr::Case(Case::new(
3671                None,
3672                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3673                None,
3674            )))),
3675            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3676                col("c3").gt(lit(0_i64)),
3677                lit(true)
3678            )
3679            .and(lit_bool_null()))
3680        );
3681
3682        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3683        assert_eq!(
3684            simplify(simplify(Expr::Case(Case::new(
3685                None,
3686                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3687                Some(Box::new(lit(false))),
3688            )))),
3689            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3690        );
3691    }
3692
3693    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3694        Expr::BinaryExpr(BinaryExpr {
3695            left: Box::new(left.into()),
3696            op: Operator::IsDistinctFrom,
3697            right: Box::new(right.into()),
3698        })
3699    }
3700
3701    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3702        Expr::BinaryExpr(BinaryExpr {
3703            left: Box::new(left.into()),
3704            op: Operator::IsNotDistinctFrom,
3705            right: Box::new(right.into()),
3706        })
3707    }
3708
3709    #[test]
3710    fn simplify_expr_bool_or() {
3711        // col || true is always true
3712        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
3713
3714        // col || false is always col
3715        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
3716
3717        // true || null is always true
3718        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
3719
3720        // null || true is always true
3721        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
3722
3723        // false || null is always null
3724        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
3725
3726        // null || false is always null
3727        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
3728
3729        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
3730        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3731        // and should not be rewritten
3732        let expr = col("c1").between(lit(0), lit(10));
3733        let expr = expr.or(lit_bool_null());
3734        let result = simplify(expr);
3735
3736        let expected_expr = or(
3737            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3738            lit_bool_null(),
3739        );
3740        assert_eq!(expected_expr, result);
3741    }
3742
3743    #[test]
3744    fn simplify_inlist() {
3745        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
3746        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
3747
3748        // null in (...)  --> null
3749        assert_eq!(
3750            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
3751            lit_bool_null()
3752        );
3753
3754        // null not in (...)  --> null
3755        assert_eq!(
3756            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
3757            lit_bool_null()
3758        );
3759
3760        assert_eq!(
3761            simplify(in_list(col("c1"), vec![lit(1)], false)),
3762            col("c1").eq(lit(1))
3763        );
3764        assert_eq!(
3765            simplify(in_list(col("c1"), vec![lit(1)], true)),
3766            col("c1").not_eq(lit(1))
3767        );
3768
3769        // more complex expressions can be simplified if list contains
3770        // one element only
3771        assert_eq!(
3772            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
3773            (col("c1") * lit(10)).eq(lit(2))
3774        );
3775
3776        assert_eq!(
3777            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
3778            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
3779        );
3780        assert_eq!(
3781            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
3782            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
3783        );
3784
3785        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
3786        assert_eq!(
3787            simplify(in_list(
3788                col("c1"),
3789                vec![scalar_subquery(Arc::clone(&subquery))],
3790                false
3791            )),
3792            in_subquery(col("c1"), Arc::clone(&subquery))
3793        );
3794        assert_eq!(
3795            simplify(in_list(
3796                col("c1"),
3797                vec![scalar_subquery(Arc::clone(&subquery))],
3798                true
3799            )),
3800            not_in_subquery(col("c1"), subquery)
3801        );
3802
3803        let subquery1 =
3804            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
3805        let subquery2 =
3806            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
3807
3808        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
3809        assert_eq!(
3810            simplify(in_list(
3811                col("c1"),
3812                vec![subquery1.clone(), subquery2.clone()],
3813                true
3814            )),
3815            col("c1")
3816                .not_eq(subquery1.clone())
3817                .and(col("c1").not_eq(subquery2.clone()))
3818        );
3819
3820        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
3821        assert_eq!(
3822            simplify(in_list(
3823                col("c1"),
3824                vec![subquery1.clone(), subquery2.clone()],
3825                false
3826            )),
3827            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
3828        );
3829
3830        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
3831        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3832            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
3833        );
3834        assert_eq!(simplify(expr), lit(false));
3835
3836        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
3837        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3838            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
3839        );
3840        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
3841
3842        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
3843        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3844            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3845        );
3846        assert_eq!(simplify(expr), lit(true));
3847
3848        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
3849        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3850            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3851        );
3852        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
3853
3854        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
3855        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3856            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3857        );
3858        assert_eq!(
3859            simplify(expr),
3860            in_list(
3861                col("c1"),
3862                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
3863                true
3864            )
3865        );
3866
3867        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
3868        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
3869            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
3870        );
3871        assert_eq!(
3872            simplify(expr),
3873            in_list(
3874                col("c1"),
3875                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3876                false
3877            )
3878        );
3879
3880        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
3881        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
3882            col("c1"),
3883            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3884            true,
3885        ));
3886        assert_eq!(simplify(expr), lit(false));
3887
3888        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
3889        let expr =
3890            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
3891                col("c1"),
3892                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3893                false,
3894            ));
3895        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
3896
3897        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
3898        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3899            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3900        );
3901        assert_eq!(
3902            simplify(expr),
3903            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
3904        );
3905
3906        // inlist with more than two expressions
3907        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
3908        let expr = in_list(
3909            col("c1"),
3910            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
3911            false,
3912        )
3913        .and(in_list(
3914            col("c1"),
3915            vec![lit(1), lit(3), lit(5), lit(6)],
3916            false,
3917        ))
3918        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
3919        assert_eq!(
3920            simplify(expr),
3921            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
3922        );
3923
3924        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
3925        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3926            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
3927                .and(in_list(
3928                    col("c1"),
3929                    vec![lit(3), lit(4), lit(5), lit(6)],
3930                    true,
3931                ))
3932                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
3933        );
3934        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
3935
3936        // Contains non-InList expression
3937        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
3938        let expr =
3939            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
3940                .not_eq(lit(5))
3941                .or(in_list(
3942                    col("c1"),
3943                    vec![lit(6), lit(7), lit(8), lit(9)],
3944                    true,
3945                )));
3946        // TODO: Further simplify this expression
3947        // https://github.com/apache/datafusion/issues/8970
3948        // assert_eq!(simplify(expr.clone()), lit(true));
3949        assert_eq!(simplify(expr.clone()), expr);
3950    }
3951
3952    #[test]
3953    fn simplify_large_or() {
3954        let expr = (0..5)
3955            .map(|i| col("c1").eq(lit(i)))
3956            .fold(lit(false), |acc, e| acc.or(e));
3957        assert_eq!(
3958            simplify(expr),
3959            in_list(col("c1"), (0..5).map(lit).collect(), false),
3960        );
3961    }
3962
3963    #[test]
3964    fn simplify_expr_bool_and() {
3965        // col & true is always col
3966        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
3967        // col & false is always false
3968        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
3969
3970        // true && null is always null
3971        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
3972
3973        // null && true is always null
3974        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
3975
3976        // false && null is always false
3977        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
3978
3979        // null && false is always false
3980        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
3981
3982        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
3983        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3984        // and the Boolean(NULL) should remain
3985        let expr = col("c1").between(lit(0), lit(10));
3986        let expr = expr.and(lit_bool_null());
3987        let result = simplify(expr);
3988
3989        let expected_expr = and(
3990            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3991            lit_bool_null(),
3992        );
3993        assert_eq!(expected_expr, result);
3994    }
3995
3996    #[test]
3997    fn simplify_expr_between() {
3998        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
3999        let expr = col("c2").between(lit(3), lit(4));
4000        assert_eq!(
4001            simplify(expr),
4002            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
4003        );
4004
4005        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
4006        let expr = col("c2").not_between(lit(3), lit(4));
4007        assert_eq!(
4008            simplify(expr),
4009            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
4010        );
4011    }
4012
4013    #[test]
4014    fn test_like_and_ilike() {
4015        let null = lit(ScalarValue::Utf8(None));
4016
4017        // expr [NOT] [I]LIKE NULL
4018        let expr = col("c1").like(null.clone());
4019        assert_eq!(simplify(expr), lit_bool_null());
4020
4021        let expr = col("c1").not_like(null.clone());
4022        assert_eq!(simplify(expr), lit_bool_null());
4023
4024        let expr = col("c1").ilike(null.clone());
4025        assert_eq!(simplify(expr), lit_bool_null());
4026
4027        let expr = col("c1").not_ilike(null.clone());
4028        assert_eq!(simplify(expr), lit_bool_null());
4029
4030        // expr [NOT] [I]LIKE '%'
4031        let expr = col("c1").like(lit("%"));
4032        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4033
4034        let expr = col("c1").not_like(lit("%"));
4035        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4036
4037        let expr = col("c1").ilike(lit("%"));
4038        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4039
4040        let expr = col("c1").not_ilike(lit("%"));
4041        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4042
4043        // expr [NOT] [I]LIKE '%%'
4044        let expr = col("c1").like(lit("%%"));
4045        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4046
4047        let expr = col("c1").not_like(lit("%%"));
4048        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4049
4050        let expr = col("c1").ilike(lit("%%"));
4051        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4052
4053        let expr = col("c1").not_ilike(lit("%%"));
4054        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4055
4056        // not_null_expr [NOT] [I]LIKE '%'
4057        let expr = col("c1_non_null").like(lit("%"));
4058        assert_eq!(simplify(expr), lit(true));
4059
4060        let expr = col("c1_non_null").not_like(lit("%"));
4061        assert_eq!(simplify(expr), lit(false));
4062
4063        let expr = col("c1_non_null").ilike(lit("%"));
4064        assert_eq!(simplify(expr), lit(true));
4065
4066        let expr = col("c1_non_null").not_ilike(lit("%"));
4067        assert_eq!(simplify(expr), lit(false));
4068
4069        // not_null_expr [NOT] [I]LIKE '%%'
4070        let expr = col("c1_non_null").like(lit("%%"));
4071        assert_eq!(simplify(expr), lit(true));
4072
4073        let expr = col("c1_non_null").not_like(lit("%%"));
4074        assert_eq!(simplify(expr), lit(false));
4075
4076        let expr = col("c1_non_null").ilike(lit("%%"));
4077        assert_eq!(simplify(expr), lit(true));
4078
4079        let expr = col("c1_non_null").not_ilike(lit("%%"));
4080        assert_eq!(simplify(expr), lit(false));
4081
4082        // null_constant [NOT] [I]LIKE '%'
4083        let expr = null.clone().like(lit("%"));
4084        assert_eq!(simplify(expr), lit_bool_null());
4085
4086        let expr = null.clone().not_like(lit("%"));
4087        assert_eq!(simplify(expr), lit_bool_null());
4088
4089        let expr = null.clone().ilike(lit("%"));
4090        assert_eq!(simplify(expr), lit_bool_null());
4091
4092        let expr = null.clone().not_ilike(lit("%"));
4093        assert_eq!(simplify(expr), lit_bool_null());
4094
4095        // null_constant [NOT] [I]LIKE '%%'
4096        let expr = null.clone().like(lit("%%"));
4097        assert_eq!(simplify(expr), lit_bool_null());
4098
4099        let expr = null.clone().not_like(lit("%%"));
4100        assert_eq!(simplify(expr), lit_bool_null());
4101
4102        let expr = null.clone().ilike(lit("%%"));
4103        assert_eq!(simplify(expr), lit_bool_null());
4104
4105        let expr = null.clone().not_ilike(lit("%%"));
4106        assert_eq!(simplify(expr), lit_bool_null());
4107
4108        // null_constant [NOT] [I]LIKE 'a%'
4109        let expr = null.clone().like(lit("a%"));
4110        assert_eq!(simplify(expr), lit_bool_null());
4111
4112        let expr = null.clone().not_like(lit("a%"));
4113        assert_eq!(simplify(expr), lit_bool_null());
4114
4115        let expr = null.clone().ilike(lit("a%"));
4116        assert_eq!(simplify(expr), lit_bool_null());
4117
4118        let expr = null.clone().not_ilike(lit("a%"));
4119        assert_eq!(simplify(expr), lit_bool_null());
4120
4121        // expr [NOT] [I]LIKE with pattern without wildcards
4122        let expr = col("c1").like(lit("a"));
4123        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4124        let expr = col("c1").not_like(lit("a"));
4125        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4126        let expr = col("c1").like(lit("a_"));
4127        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4128        let expr = col("c1").not_like(lit("a_"));
4129        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4130
4131        let expr = col("c1").ilike(lit("a"));
4132        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4133        let expr = col("c1").not_ilike(lit("a"));
4134        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4135    }
4136
4137    #[test]
4138    fn test_simplify_with_guarantee() {
4139        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4140        let expr_x = col("c3").gt(lit(3_i64));
4141        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4142        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4143        let expr = expr_x.clone().and(expr_y.or(expr_z));
4144
4145        // All guaranteed null
4146        let guarantees = vec![
4147            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4148            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4149            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4150        ];
4151
4152        let output = simplify_with_guarantee(expr.clone(), guarantees);
4153        assert_eq!(output, lit_bool_null());
4154
4155        // All guaranteed false
4156        let guarantees = vec![
4157            (
4158                col("c3"),
4159                NullableInterval::NotNull {
4160                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4161                },
4162            ),
4163            (
4164                col("c4"),
4165                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4166            ),
4167            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4168        ];
4169        let output = simplify_with_guarantee(expr.clone(), guarantees);
4170        assert_eq!(output, lit(false));
4171
4172        // Guaranteed false or null -> no change.
4173        let guarantees = vec![
4174            (
4175                col("c3"),
4176                NullableInterval::MaybeNull {
4177                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4178                },
4179            ),
4180            (
4181                col("c4"),
4182                NullableInterval::MaybeNull {
4183                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4184                },
4185            ),
4186            (
4187                col("c1"),
4188                NullableInterval::NotNull {
4189                    values: Interval::try_new(
4190                        ScalarValue::from("d"),
4191                        ScalarValue::from("f"),
4192                    )
4193                    .unwrap(),
4194                },
4195            ),
4196        ];
4197        let output = simplify_with_guarantee(expr.clone(), guarantees);
4198        assert_eq!(&output, &expr_x);
4199
4200        // Sufficient true guarantees
4201        let guarantees = vec![
4202            (
4203                col("c3"),
4204                NullableInterval::from(ScalarValue::Int64(Some(9))),
4205            ),
4206            (
4207                col("c4"),
4208                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4209            ),
4210        ];
4211        let output = simplify_with_guarantee(expr.clone(), guarantees);
4212        assert_eq!(output, lit(true));
4213
4214        // Only partially simplify
4215        let guarantees = vec![(
4216            col("c4"),
4217            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4218        )];
4219        let output = simplify_with_guarantee(expr, guarantees);
4220        assert_eq!(&output, &expr_x);
4221    }
4222
4223    #[test]
4224    fn test_expression_partial_simplify_1() {
4225        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4226        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4227        let expected = (lit(3)) + (lit(4) / lit(0));
4228
4229        assert_eq!(simplify(expr), expected);
4230    }
4231
4232    #[test]
4233    fn test_expression_partial_simplify_2() {
4234        // (1 > 2) and (4 / 0) -> false
4235        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4236        let expected = lit(false);
4237
4238        assert_eq!(simplify(expr), expected);
4239    }
4240
4241    #[test]
4242    fn test_simplify_cycles() {
4243        // TRUE
4244        let expr = lit(true);
4245        let expected = lit(true);
4246        let (expr, num_iter) = simplify_with_cycle_count(expr);
4247        assert_eq!(expr, expected);
4248        assert_eq!(num_iter, 1);
4249
4250        // (true != NULL) OR (5 > 10)
4251        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4252        let expected = lit_bool_null();
4253        let (expr, num_iter) = simplify_with_cycle_count(expr);
4254        assert_eq!(expr, expected);
4255        assert_eq!(num_iter, 2);
4256
4257        // NOTE: this currently does not simplify
4258        // (((c4 - 10) + 10) *100) / 100
4259        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4260        let expected = expr.clone();
4261        let (expr, num_iter) = simplify_with_cycle_count(expr);
4262        assert_eq!(expr, expected);
4263        assert_eq!(num_iter, 1);
4264
4265        // ((c4<1 or c3<2) and c3_non_null<3) and false
4266        let expr = col("c4")
4267            .lt(lit(1))
4268            .or(col("c3").lt(lit(2)))
4269            .and(col("c3_non_null").lt(lit(3)))
4270            .and(lit(false));
4271        let expected = lit(false);
4272        let (expr, num_iter) = simplify_with_cycle_count(expr);
4273        assert_eq!(expr, expected);
4274        assert_eq!(num_iter, 2);
4275    }
4276
4277    fn boolean_test_schema() -> DFSchemaRef {
4278        Schema::new(vec![
4279            Field::new("A", DataType::Boolean, false),
4280            Field::new("B", DataType::Boolean, false),
4281            Field::new("C", DataType::Boolean, false),
4282            Field::new("D", DataType::Boolean, false),
4283        ])
4284        .to_dfschema_ref()
4285        .unwrap()
4286    }
4287
4288    #[test]
4289    fn simplify_common_factor_conjunction_in_disjunction() {
4290        let props = ExecutionProps::new();
4291        let schema = boolean_test_schema();
4292        let simplifier =
4293            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4294
4295        let a = || col("A");
4296        let b = || col("B");
4297        let c = || col("C");
4298        let d = || col("D");
4299
4300        // (A AND B) OR (A AND C) -> A AND (B OR C)
4301        let expr = a().and(b()).or(a().and(c()));
4302        let expected = a().and(b().or(c()));
4303
4304        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4305
4306        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4307        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4308        let expected = a().and(b().or(c()).or(d()));
4309        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4310
4311        // A OR (B AND C AND A) -> A
4312        let expr = a().or(b().and(c().and(a())));
4313        let expected = a();
4314        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4315    }
4316
4317    #[test]
4318    fn test_simplify_udaf() {
4319        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4320        let aggregate_function_expr =
4321            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4322                udaf.into(),
4323                vec![],
4324                false,
4325                None,
4326                vec![],
4327                None,
4328            ));
4329
4330        let expected = col("result_column");
4331        assert_eq!(simplify(aggregate_function_expr), expected);
4332
4333        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4334        let aggregate_function_expr =
4335            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4336                udaf.into(),
4337                vec![],
4338                false,
4339                None,
4340                vec![],
4341                None,
4342            ));
4343
4344        let expected = aggregate_function_expr.clone();
4345        assert_eq!(simplify(aggregate_function_expr), expected);
4346    }
4347
4348    /// A Mock UDAF which defines `simplify` to be used in tests
4349    /// related to UDAF simplification
4350    #[derive(Debug, Clone)]
4351    struct SimplifyMockUdaf {
4352        simplify: bool,
4353    }
4354
4355    impl SimplifyMockUdaf {
4356        /// make simplify method return new expression
4357        fn new_with_simplify() -> Self {
4358            Self { simplify: true }
4359        }
4360        /// make simplify method return no change
4361        fn new_without_simplify() -> Self {
4362            Self { simplify: false }
4363        }
4364    }
4365
4366    impl AggregateUDFImpl for SimplifyMockUdaf {
4367        fn as_any(&self) -> &dyn std::any::Any {
4368            self
4369        }
4370
4371        fn name(&self) -> &str {
4372            "mock_simplify"
4373        }
4374
4375        fn signature(&self) -> &Signature {
4376            unimplemented!()
4377        }
4378
4379        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4380            unimplemented!("not needed for tests")
4381        }
4382
4383        fn accumulator(
4384            &self,
4385            _acc_args: AccumulatorArgs,
4386        ) -> Result<Box<dyn Accumulator>> {
4387            unimplemented!("not needed for tests")
4388        }
4389
4390        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4391            unimplemented!("not needed for testing")
4392        }
4393
4394        fn create_groups_accumulator(
4395            &self,
4396            _args: AccumulatorArgs,
4397        ) -> Result<Box<dyn GroupsAccumulator>> {
4398            unimplemented!("not needed for testing")
4399        }
4400
4401        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4402            if self.simplify {
4403                Some(Box::new(|_, _| Ok(col("result_column"))))
4404            } else {
4405                None
4406            }
4407        }
4408
4409        fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
4410            let Some(other) = other.as_any().downcast_ref::<Self>() else {
4411                return false;
4412            };
4413            let Self { simplify } = self;
4414            simplify == &other.simplify
4415        }
4416
4417        fn hash_value(&self) -> u64 {
4418            let Self { simplify } = self;
4419            let mut hasher = DefaultHasher::new();
4420            simplify.hash(&mut hasher);
4421            hasher.finish()
4422        }
4423    }
4424
4425    #[test]
4426    fn test_simplify_udwf() {
4427        let udwf = WindowFunctionDefinition::WindowUDF(
4428            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4429        );
4430        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4431
4432        let expected = col("result_column");
4433        assert_eq!(simplify(window_function_expr), expected);
4434
4435        let udwf = WindowFunctionDefinition::WindowUDF(
4436            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4437        );
4438        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4439
4440        let expected = window_function_expr.clone();
4441        assert_eq!(simplify(window_function_expr), expected);
4442    }
4443
4444    /// A Mock UDWF which defines `simplify` to be used in tests
4445    /// related to UDWF simplification
4446    #[derive(Debug, Clone)]
4447    struct SimplifyMockUdwf {
4448        simplify: bool,
4449    }
4450
4451    impl SimplifyMockUdwf {
4452        /// make simplify method return new expression
4453        fn new_with_simplify() -> Self {
4454            Self { simplify: true }
4455        }
4456        /// make simplify method return no change
4457        fn new_without_simplify() -> Self {
4458            Self { simplify: false }
4459        }
4460    }
4461
4462    impl WindowUDFImpl for SimplifyMockUdwf {
4463        fn as_any(&self) -> &dyn std::any::Any {
4464            self
4465        }
4466
4467        fn name(&self) -> &str {
4468            "mock_simplify"
4469        }
4470
4471        fn signature(&self) -> &Signature {
4472            unimplemented!()
4473        }
4474
4475        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4476            if self.simplify {
4477                Some(Box::new(|_, _| Ok(col("result_column"))))
4478            } else {
4479                None
4480            }
4481        }
4482
4483        fn partition_evaluator(
4484            &self,
4485            _partition_evaluator_args: PartitionEvaluatorArgs,
4486        ) -> Result<Box<dyn PartitionEvaluator>> {
4487            unimplemented!("not needed for tests")
4488        }
4489
4490        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4491            unimplemented!("not needed for tests")
4492        }
4493
4494        fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
4495            let Some(other) = other.as_any().downcast_ref::<Self>() else {
4496                return false;
4497            };
4498            let Self { simplify } = self;
4499            simplify == &other.simplify
4500        }
4501
4502        fn hash_value(&self) -> u64 {
4503            let Self { simplify } = self;
4504            let mut hasher = DefaultHasher::new();
4505            std::any::type_name::<Self>().hash(&mut hasher);
4506            simplify.hash(&mut hasher);
4507            hasher.finish()
4508        }
4509    }
4510    #[derive(Debug)]
4511    struct VolatileUdf {
4512        signature: Signature,
4513    }
4514
4515    impl VolatileUdf {
4516        pub fn new() -> Self {
4517            Self {
4518                signature: Signature::exact(vec![], Volatility::Volatile),
4519            }
4520        }
4521    }
4522    impl ScalarUDFImpl for VolatileUdf {
4523        fn as_any(&self) -> &dyn std::any::Any {
4524            self
4525        }
4526
4527        fn name(&self) -> &str {
4528            "VolatileUdf"
4529        }
4530
4531        fn signature(&self) -> &Signature {
4532            &self.signature
4533        }
4534
4535        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4536            Ok(DataType::Int16)
4537        }
4538
4539        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4540            panic!("dummy - not implemented")
4541        }
4542    }
4543
4544    #[test]
4545    fn test_optimize_volatile_conditions() {
4546        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4547        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4548        {
4549            let expr = rand
4550                .clone()
4551                .eq(lit(0))
4552                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4553
4554            assert_eq!(simplify(expr.clone()), expr);
4555        }
4556
4557        {
4558            let expr = col("column1")
4559                .eq(lit(2))
4560                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4561
4562            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4563        }
4564
4565        {
4566            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4567                "column1",
4568            )
4569            .eq(lit(2))
4570            .and(rand.clone().eq(lit(0))));
4571
4572            assert_eq!(
4573                simplify(expr),
4574                col("column1")
4575                    .eq(lit(2))
4576                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4577            );
4578        }
4579    }
4580
4581    #[test]
4582    fn simplify_fixed_size_binary_eq_lit() {
4583        let bytes = [1u8, 2, 3].as_slice();
4584
4585        // The expression starts simple.
4586        let expr = col("c5").eq(lit(bytes));
4587
4588        // The type coercer introduces a cast.
4589        let coerced = coerce(expr.clone());
4590        let schema = expr_test_schema();
4591        assert_eq!(
4592            coerced,
4593            col("c5")
4594                .cast_to(&DataType::Binary, schema.as_ref())
4595                .unwrap()
4596                .eq(lit(bytes))
4597        );
4598
4599        // The simplifier removes the cast.
4600        assert_eq!(
4601            simplify(coerced),
4602            col("c5").eq(Expr::Literal(
4603                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4604                None
4605            ))
4606        );
4607    }
4608
4609    fn if_not_null(expr: Expr, then: bool) -> Expr {
4610        Expr::Case(Case {
4611            expr: Some(expr.is_not_null().into()),
4612            when_then_expr: vec![(lit(true).into(), lit(then).into())],
4613            else_expr: None,
4614        })
4615    }
4616}