datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use std::borrow::Cow;
21use std::collections::{BTreeMap, HashSet};
22use std::ops::Not;
23
24use arrow::{
25    array::{new_null_array, AsArray},
26    datatypes::{DataType, Field, Schema},
27    record_batch::RecordBatch,
28};
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
33};
34use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35use datafusion_expr::{
36    and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
37    Operator, Volatility,
38};
39use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
40use datafusion_expr::{
41    expr::{InList, InSubquery},
42    utils::{iter_conjunction, iter_conjunction_owned},
43};
44use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
45use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
46
47use super::inlist_simplifier::ShortenInListSimplifier;
48use super::utils::*;
49use crate::simplify_expressions::guarantees::GuaranteeRewriter;
50use crate::simplify_expressions::regex::simplify_regex_expr;
51use crate::simplify_expressions::unwrap_cast::{
52    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
53    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
54    unwrap_cast_in_comparison_for_binary,
55};
56use crate::simplify_expressions::SimplifyInfo;
57use crate::{
58    analyzer::type_coercion::TypeCoercionRewriter,
59    simplify_expressions::unwrap_cast::try_cast_literal_to_type,
60};
61use indexmap::IndexSet;
62use regex::Regex;
63
64/// This structure handles API for expression simplification
65///
66/// Provides simplification information based on DFSchema and
67/// [`ExecutionProps`]. This is the default implementation used by DataFusion
68///
69/// For example:
70/// ```
71/// use arrow::datatypes::{Schema, Field, DataType};
72/// use datafusion_expr::{col, lit};
73/// use datafusion_common::{DataFusionError, ToDFSchema};
74/// use datafusion_expr::execution_props::ExecutionProps;
75/// use datafusion_expr::simplify::SimplifyContext;
76/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
77///
78/// // Create the schema
79/// let schema = Schema::new(vec![
80///     Field::new("i", DataType::Int64, false),
81///   ])
82///   .to_dfschema_ref().unwrap();
83///
84/// // Create the simplifier
85/// let props = ExecutionProps::new();
86/// let context = SimplifyContext::new(&props)
87///    .with_schema(schema);
88/// let simplifier = ExprSimplifier::new(context);
89///
90/// // Use the simplifier
91///
92/// // b < 2 or (1 > 3)
93/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
94///
95/// // b < 2
96/// let simplified = simplifier.simplify(expr).unwrap();
97/// assert_eq!(simplified, col("b").lt(lit(2)));
98/// ```
99pub struct ExprSimplifier<S> {
100    info: S,
101    /// Guarantees about the values of columns. This is provided by the user
102    /// in [ExprSimplifier::with_guarantees()].
103    guarantees: Vec<(Expr, NullableInterval)>,
104    /// Should expressions be canonicalized before simplification? Defaults to
105    /// true
106    canonicalize: bool,
107    /// Maximum number of simplifier cycles
108    max_simplifier_cycles: u32,
109}
110
111pub const THRESHOLD_INLINE_INLIST: usize = 3;
112pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
113
114impl<S: SimplifyInfo> ExprSimplifier<S> {
115    /// Create a new `ExprSimplifier` with the given `info` such as an
116    /// instance of [`SimplifyContext`]. See
117    /// [`simplify`](Self::simplify) for an example.
118    ///
119    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
120    pub fn new(info: S) -> Self {
121        Self {
122            info,
123            guarantees: vec![],
124            canonicalize: true,
125            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
126        }
127    }
128
129    /// Simplifies this [`Expr`] as much as possible, evaluating
130    /// constants and applying algebraic simplifications.
131    ///
132    /// The types of the expression must match what operators expect,
133    /// or else an error may occur trying to evaluate. See
134    /// [`coerce`](Self::coerce) for a function to help.
135    ///
136    /// # Example:
137    ///
138    /// `b > 2 AND b > 2`
139    ///
140    /// can be written to
141    ///
142    /// `b > 2`
143    ///
144    /// ```
145    /// use arrow::datatypes::DataType;
146    /// use datafusion_expr::{col, lit, Expr};
147    /// use datafusion_common::Result;
148    /// use datafusion_expr::execution_props::ExecutionProps;
149    /// use datafusion_expr::simplify::SimplifyContext;
150    /// use datafusion_expr::simplify::SimplifyInfo;
151    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
152    /// use datafusion_common::DFSchema;
153    /// use std::sync::Arc;
154    ///
155    /// /// Simple implementation that provides `Simplifier` the information it needs
156    /// /// See SimplifyContext for a structure that does this.
157    /// #[derive(Default)]
158    /// struct Info {
159    ///   execution_props: ExecutionProps,
160    /// };
161    ///
162    /// impl SimplifyInfo for Info {
163    ///   fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
164    ///     Ok(false)
165    ///   }
166    ///   fn nullable(&self, expr: &Expr) -> Result<bool> {
167    ///     Ok(true)
168    ///   }
169    ///   fn execution_props(&self) -> &ExecutionProps {
170    ///     &self.execution_props
171    ///   }
172    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
173    ///     Ok(DataType::Int32)
174    ///   }
175    /// }
176    ///
177    /// // Create the simplifier
178    /// let simplifier = ExprSimplifier::new(Info::default());
179    ///
180    /// // b < 2
181    /// let b_lt_2 = col("b").gt(lit(2));
182    ///
183    /// // (b < 2) OR (b < 2)
184    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
185    ///
186    /// // (b < 2) OR (b < 2) --> (b < 2)
187    /// let expr = simplifier.simplify(expr).unwrap();
188    /// assert_eq!(expr, b_lt_2);
189    /// ```
190    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
191        Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
192    }
193
194    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
195    /// constants and applying algebraic simplifications. Additionally returns a `u32`
196    /// representing the number of simplification cycles performed, which can be useful for testing
197    /// optimizations.
198    ///
199    /// See [Self::simplify] for details and usage examples.
200    ///
201    #[deprecated(
202        since = "48.0.0",
203        note = "Use `simplify_with_cycle_count_transformed` instead"
204    )]
205    #[allow(unused_mut)]
206    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
207        let (transformed, cycle_count) =
208            self.simplify_with_cycle_count_transformed(expr)?;
209        Ok((transformed.data, cycle_count))
210    }
211
212    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
213    /// constants and applying algebraic simplifications. Additionally returns a `u32`
214    /// representing the number of simplification cycles performed, which can be useful for testing
215    /// optimizations.
216    ///
217    /// # Returns
218    ///
219    /// A tuple containing:
220    /// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
221    /// - The number of simplification cycles that were performed
222    ///
223    /// See [Self::simplify] for details and usage examples.
224    ///
225    pub fn simplify_with_cycle_count_transformed(
226        &self,
227        mut expr: Expr,
228    ) -> Result<(Transformed<Expr>, u32)> {
229        let mut simplifier = Simplifier::new(&self.info);
230        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
231        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
232        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
233
234        if self.canonicalize {
235            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
236        }
237
238        // Evaluating constants can enable new simplifications and
239        // simplifications can enable new constant evaluation
240        // see `Self::with_max_cycles`
241        let mut num_cycles = 0;
242        let mut has_transformed = false;
243        loop {
244            let Transformed {
245                data, transformed, ..
246            } = expr
247                .rewrite(&mut const_evaluator)?
248                .transform_data(|expr| expr.rewrite(&mut simplifier))?
249                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
250            expr = data;
251            num_cycles += 1;
252            // Track if any transformation occurred
253            has_transformed = has_transformed || transformed;
254            if !transformed || num_cycles >= self.max_simplifier_cycles {
255                break;
256            }
257        }
258        // shorten inlist should be started after other inlist rules are applied
259        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
260        Ok((
261            Transformed::new_transformed(expr, has_transformed),
262            num_cycles,
263        ))
264    }
265
266    /// Apply type coercion to an [`Expr`] so that it can be
267    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
268    ///
269    /// See the [type coercion module](datafusion_expr::type_coercion)
270    /// documentation for more details on type coercion
271    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
272        let mut expr_rewrite = TypeCoercionRewriter { schema };
273        expr.rewrite(&mut expr_rewrite).data()
274    }
275
276    /// Input guarantees about the values of columns.
277    ///
278    /// The guarantees can simplify expressions. For example, if a column `x` is
279    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
280    /// literal `true`.
281    ///
282    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
283    /// where the [Expr] is a column reference and the [NullableInterval]
284    /// is an interval representing the known possible values of that column.
285    ///
286    /// ```rust
287    /// use arrow::datatypes::{DataType, Field, Schema};
288    /// use datafusion_expr::{col, lit, Expr};
289    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
290    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
291    /// use datafusion_expr::execution_props::ExecutionProps;
292    /// use datafusion_expr::simplify::SimplifyContext;
293    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
294    ///
295    /// let schema = Schema::new(vec![
296    ///   Field::new("x", DataType::Int64, false),
297    ///   Field::new("y", DataType::UInt32, false),
298    ///   Field::new("z", DataType::Int64, false),
299    ///   ])
300    ///   .to_dfschema_ref().unwrap();
301    ///
302    /// // Create the simplifier
303    /// let props = ExecutionProps::new();
304    /// let context = SimplifyContext::new(&props)
305    ///    .with_schema(schema);
306    ///
307    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
308    /// let expr_x = col("x").gt_eq(lit(3_i64));
309    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
310    /// let expr_z = col("z").gt(lit(5_i64));
311    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
312    ///
313    /// let guarantees = vec![
314    ///    // x ∈ [3, 5]
315    ///    (
316    ///        col("x"),
317    ///        NullableInterval::NotNull {
318    ///            values: Interval::make(Some(3_i64), Some(5_i64)).unwrap()
319    ///        }
320    ///    ),
321    ///    // y = 3
322    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
323    /// ];
324    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
325    /// let output = simplifier.simplify(expr).unwrap();
326    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
327    /// // z > 5.
328    /// assert_eq!(output, expr_z);
329    /// ```
330    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
331        self.guarantees = guarantees;
332        self
333    }
334
335    /// Should `Canonicalizer` be applied before simplification?
336    ///
337    /// If true (the default), the expression will be rewritten to canonical
338    /// form before simplification. This is useful to ensure that the simplifier
339    /// can apply all possible simplifications.
340    ///
341    /// Some expressions, such as those in some Joins, can not be canonicalized
342    /// without changing their meaning. In these cases, canonicalization should
343    /// be disabled.
344    ///
345    /// ```rust
346    /// use arrow::datatypes::{DataType, Field, Schema};
347    /// use datafusion_expr::{col, lit, Expr};
348    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
349    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
350    /// use datafusion_expr::execution_props::ExecutionProps;
351    /// use datafusion_expr::simplify::SimplifyContext;
352    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
353    ///
354    /// let schema = Schema::new(vec![
355    ///   Field::new("a", DataType::Int64, false),
356    ///   Field::new("b", DataType::Int64, false),
357    ///   Field::new("c", DataType::Int64, false),
358    ///   ])
359    ///   .to_dfschema_ref().unwrap();
360    ///
361    /// // Create the simplifier
362    /// let props = ExecutionProps::new();
363    /// let context = SimplifyContext::new(&props)
364    ///    .with_schema(schema);
365    /// let simplifier = ExprSimplifier::new(context);
366    ///
367    /// // Expression: a = c AND 1 = b
368    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
369    ///
370    /// // With canonicalization, the expression is rewritten to canonical form
371    /// // (though it is no simpler in this case):
372    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
373    /// // Expression has been rewritten to: (c = a AND b = 1)
374    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
375    ///
376    /// // If canonicalization is disabled, the expression is not changed
377    /// let non_canonicalized = simplifier
378    ///   .with_canonicalize(false)
379    ///   .simplify(expr.clone())
380    ///   .unwrap();
381    ///
382    /// assert_eq!(non_canonicalized, expr);
383    /// ```
384    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
385        self.canonicalize = canonicalize;
386        self
387    }
388
389    /// Specifies the maximum number of simplification cycles to run.
390    ///
391    /// The simplifier can perform multiple passes of simplification. This is
392    /// because the output of one simplification step can allow more optimizations
393    /// in another simplification step. For example, constant evaluation can allow more
394    /// expression simplifications, and expression simplifications can allow more constant
395    /// evaluations.
396    ///
397    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
398    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
399    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
400    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
401    ///
402    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
403    /// instead.
404    ///
405    /// ```rust
406    /// use arrow::datatypes::{DataType, Field, Schema};
407    /// use datafusion_expr::{col, lit, Expr};
408    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
409    /// use datafusion_expr::execution_props::ExecutionProps;
410    /// use datafusion_expr::simplify::SimplifyContext;
411    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
412    ///
413    /// let schema = Schema::new(vec![
414    ///   Field::new("a", DataType::Int64, false),
415    ///   ])
416    ///   .to_dfschema_ref().unwrap();
417    ///
418    /// // Create the simplifier
419    /// let props = ExecutionProps::new();
420    /// let context = SimplifyContext::new(&props)
421    ///    .with_schema(schema);
422    /// let simplifier = ExprSimplifier::new(context);
423    ///
424    /// // Expression: a IS NOT NULL
425    /// let expr = col("a").is_not_null();
426    ///
427    /// // When using default maximum cycles, 2 cycles will be performed.
428    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
429    /// assert_eq!(simplified_expr.data, lit(true));
430    /// // 2 cycles were executed, but only 1 was needed
431    /// assert_eq!(count, 2);
432    ///
433    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
434    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
435    /// // Expression has been rewritten to: (c = a AND b = 1)
436    /// assert_eq!(simplified_expr.data, lit(true));
437    /// // Only 1 cycle was executed
438    /// assert_eq!(count, 1);
439    ///
440    /// ```
441    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
442        self.max_simplifier_cycles = max_simplifier_cycles;
443        self
444    }
445}
446
447/// Canonicalize any BinaryExprs that are not in canonical form
448///
449/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
450///
451/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
452/// than `col2` (`a > b` would be canonicalized to `b < a`)
453struct Canonicalizer {}
454
455impl Canonicalizer {
456    fn new() -> Self {
457        Self {}
458    }
459}
460
461impl TreeNodeRewriter for Canonicalizer {
462    type Node = Expr;
463
464    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
465        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
466            return Ok(Transformed::no(expr));
467        };
468        match (left.as_ref(), right.as_ref(), op.swap()) {
469            // <col1> <op> <col2>
470            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
471                if right_col > left_col =>
472            {
473                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
474                    left: right,
475                    op: swapped_op,
476                    right: left,
477                })))
478            }
479            // <literal> <op> <col>
480            (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => {
481                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
482                    left: right,
483                    op: swapped_op,
484                    right: left,
485                })))
486            }
487            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
488                left,
489                op,
490                right,
491            }))),
492        }
493    }
494}
495
496#[allow(rustdoc::private_intra_doc_links)]
497/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
498///
499/// Note it does not handle algebraic rewrites such as `(a or false)`
500/// --> `a`, which is handled by [`Simplifier`]
501struct ConstEvaluator<'a> {
502    /// `can_evaluate` is used during the depth-first-search of the
503    /// `Expr` tree to track if any siblings (or their descendants) were
504    /// non evaluatable (e.g. had a column reference or volatile
505    /// function)
506    ///
507    /// Specifically, `can_evaluate[N]` represents the state of
508    /// traversal when we are N levels deep in the tree, one entry for
509    /// this Expr and each of its parents.
510    ///
511    /// After visiting all siblings if `can_evaluate.top()` is true, that
512    /// means there were no non evaluatable siblings (or their
513    /// descendants) so this `Expr` can be evaluated
514    can_evaluate: Vec<bool>,
515
516    execution_props: &'a ExecutionProps,
517    input_schema: DFSchema,
518    input_batch: RecordBatch,
519}
520
521#[allow(dead_code)]
522/// The simplify result of ConstEvaluator
523#[allow(clippy::large_enum_variant)]
524enum ConstSimplifyResult {
525    // Expr was simplified and contains the new expression
526    Simplified(ScalarValue, Option<BTreeMap<String, String>>),
527    // Expr was not simplified and original value is returned
528    NotSimplified(ScalarValue, Option<BTreeMap<String, String>>),
529    // Evaluation encountered an error, contains the original expression
530    SimplifyRuntimeError(DataFusionError, Expr),
531}
532
533impl TreeNodeRewriter for ConstEvaluator<'_> {
534    type Node = Expr;
535
536    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
537        // Default to being able to evaluate this node
538        self.can_evaluate.push(true);
539
540        // if this expr is not ok to evaluate, mark entire parent
541        // stack as not ok (as all parents have at least one child or
542        // descendant that can not be evaluated
543
544        if !Self::can_evaluate(&expr) {
545            // walk back up stack, marking first parent that is not mutable
546            let parent_iter = self.can_evaluate.iter_mut().rev();
547            for p in parent_iter {
548                if !*p {
549                    // optimization: if we find an element on the
550                    // stack already marked, know all elements above are also marked
551                    break;
552                }
553                *p = false;
554            }
555        }
556
557        // NB: do not short circuit recursion even if we find a non
558        // evaluatable node (so we can fold other children, args to
559        // functions, etc.)
560        Ok(Transformed::no(expr))
561    }
562
563    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
564        match self.can_evaluate.pop() {
565            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
566            // and may not evaluate all their sub expressions. Thus, if
567            // any error is countered during simplification, return the original
568            // so that normal evaluation can occur
569            Some(true) => match self.evaluate_to_scalar(expr) {
570                ConstSimplifyResult::Simplified(s, m) => {
571                    Ok(Transformed::yes(Expr::Literal(s, m)))
572                }
573                ConstSimplifyResult::NotSimplified(s, m) => {
574                    Ok(Transformed::no(Expr::Literal(s, m)))
575                }
576                ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
577                    Ok(Transformed::yes(expr))
578                }
579            },
580            Some(false) => Ok(Transformed::no(expr)),
581            _ => internal_err!("Failed to pop can_evaluate"),
582        }
583    }
584}
585
586impl<'a> ConstEvaluator<'a> {
587    /// Create a new `ConstantEvaluator`. Session constants (such as
588    /// the time for `now()` are taken from the passed
589    /// `execution_props`.
590    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
591        // The dummy column name is unused and doesn't matter as only
592        // expressions without column references can be evaluated
593        static DUMMY_COL_NAME: &str = ".";
594        let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
595        let input_schema = DFSchema::try_from(schema.clone())?;
596        // Need a single "input" row to produce a single output row
597        let col = new_null_array(&DataType::Null, 1);
598        let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?;
599
600        Ok(Self {
601            can_evaluate: vec![],
602            execution_props,
603            input_schema,
604            input_batch,
605        })
606    }
607
608    /// Can a function of the specified volatility be evaluated?
609    fn volatility_ok(volatility: Volatility) -> bool {
610        match volatility {
611            Volatility::Immutable => true,
612            // Values for functions such as now() are taken from ExecutionProps
613            Volatility::Stable => true,
614            Volatility::Volatile => false,
615        }
616    }
617
618    /// Can the expression be evaluated at plan time, (assuming all of
619    /// its children can also be evaluated)?
620    fn can_evaluate(expr: &Expr) -> bool {
621        // check for reasons we can't evaluate this node
622        //
623        // NOTE all expr types are listed here so when new ones are
624        // added they can be checked for their ability to be evaluated
625        // at plan time
626        match expr {
627            // TODO: remove the next line after `Expr::Wildcard` is removed
628            #[expect(deprecated)]
629            Expr::AggregateFunction { .. }
630            | Expr::ScalarVariable(_, _)
631            | Expr::Column(_)
632            | Expr::OuterReferenceColumn(_, _)
633            | Expr::Exists { .. }
634            | Expr::InSubquery(_)
635            | Expr::ScalarSubquery(_)
636            | Expr::WindowFunction { .. }
637            | Expr::GroupingSet(_)
638            | Expr::Wildcard { .. }
639            | Expr::Placeholder(_) => false,
640            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
641                Self::volatility_ok(func.signature().volatility)
642            }
643            Expr::Literal(_, _)
644            | Expr::Alias(..)
645            | Expr::Unnest(_)
646            | Expr::BinaryExpr { .. }
647            | Expr::Not(_)
648            | Expr::IsNotNull(_)
649            | Expr::IsNull(_)
650            | Expr::IsTrue(_)
651            | Expr::IsFalse(_)
652            | Expr::IsUnknown(_)
653            | Expr::IsNotTrue(_)
654            | Expr::IsNotFalse(_)
655            | Expr::IsNotUnknown(_)
656            | Expr::Negative(_)
657            | Expr::Between { .. }
658            | Expr::Like { .. }
659            | Expr::SimilarTo { .. }
660            | Expr::Case(_)
661            | Expr::Cast { .. }
662            | Expr::TryCast { .. }
663            | Expr::InList { .. } => true,
664        }
665    }
666
667    /// Internal helper to evaluates an Expr
668    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
669        if let Expr::Literal(s, m) = expr {
670            return ConstSimplifyResult::NotSimplified(s, m);
671        }
672
673        let phys_expr =
674            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
675                Ok(e) => e,
676                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
677            };
678        let metadata = phys_expr
679            .return_field(self.input_batch.schema_ref())
680            .ok()
681            .and_then(|f| {
682                let m = f.metadata();
683                match m.is_empty() {
684                    true => None,
685                    false => {
686                        Some(m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
687                    }
688                }
689            });
690        let col_val = match phys_expr.evaluate(&self.input_batch) {
691            Ok(v) => v,
692            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
693        };
694        match col_val {
695            ColumnarValue::Array(a) => {
696                if a.len() != 1 {
697                    ConstSimplifyResult::SimplifyRuntimeError(
698                        DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
699                        expr,
700                    )
701                } else if as_list_array(&a).is_ok() {
702                    ConstSimplifyResult::Simplified(
703                        ScalarValue::List(a.as_list::<i32>().to_owned().into()),
704                        metadata,
705                    )
706                } else if as_large_list_array(&a).is_ok() {
707                    ConstSimplifyResult::Simplified(
708                        ScalarValue::LargeList(a.as_list::<i64>().to_owned().into()),
709                        metadata,
710                    )
711                } else {
712                    // Non-ListArray
713                    match ScalarValue::try_from_array(&a, 0) {
714                        Ok(s) => {
715                            // TODO: support the optimization for `Map` type after support impl hash for it
716                            if matches!(&s, ScalarValue::Map(_)) {
717                                ConstSimplifyResult::SimplifyRuntimeError(
718                                    DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()),
719                                    expr,
720                                )
721                            } else {
722                                ConstSimplifyResult::Simplified(s, metadata)
723                            }
724                        }
725                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
726                    }
727                }
728            }
729            ColumnarValue::Scalar(s) => {
730                // TODO: support the optimization for `Map` type after support impl hash for it
731                if matches!(&s, ScalarValue::Map(_)) {
732                    ConstSimplifyResult::SimplifyRuntimeError(
733                        DataFusionError::NotImplemented(
734                            "Const evaluate for Map type is still not supported"
735                                .to_string(),
736                        ),
737                        expr,
738                    )
739                } else {
740                    ConstSimplifyResult::Simplified(s, metadata)
741                }
742            }
743        }
744    }
745}
746
747/// Simplifies [`Expr`]s by applying algebraic transformation rules
748///
749/// Example transformations that are applied:
750/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
751/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
752/// * `true = true` and `false = false` to `true`
753/// * `false = true` and `true = false` to `false`
754/// * `!!expr` to `expr`
755/// * `expr = null` and `expr != null` to `null`
756struct Simplifier<'a, S> {
757    info: &'a S,
758}
759
760impl<'a, S> Simplifier<'a, S> {
761    pub fn new(info: &'a S) -> Self {
762        Self { info }
763    }
764}
765
766impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
767    type Node = Expr;
768
769    /// rewrite the expression simplifying any constant expressions
770    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
771        use datafusion_expr::Operator::{
772            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
773            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
774            RegexNotIMatch, RegexNotMatch,
775        };
776
777        let info = self.info;
778        Ok(match expr {
779            //
780            // Rules for Eq
781            //
782
783            // true = A  --> A
784            // false = A --> !A
785            // null = A --> null
786            Expr::BinaryExpr(BinaryExpr {
787                left,
788                op: Eq,
789                right,
790            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
791                Transformed::yes(match as_bool_lit(&left)? {
792                    Some(true) => *right,
793                    Some(false) => Expr::Not(right),
794                    None => lit_bool_null(),
795                })
796            }
797            // A = true  --> A
798            // A = false --> !A
799            // A = null --> null
800            Expr::BinaryExpr(BinaryExpr {
801                left,
802                op: Eq,
803                right,
804            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
805                Transformed::yes(match as_bool_lit(&right)? {
806                    Some(true) => *left,
807                    Some(false) => Expr::Not(left),
808                    None => lit_bool_null(),
809                })
810            }
811            // According to SQL's null semantics, NULL = NULL evaluates to NULL
812            // Both sides are the same expression (A = A) and A is non-volatile expression
813            // A = A --> A IS NOT NULL OR NULL
814            // A = A --> true (if A not nullable)
815            Expr::BinaryExpr(BinaryExpr {
816                left,
817                op: Eq,
818                right,
819            }) if (left == right) & !left.is_volatile() => {
820                Transformed::yes(match !info.nullable(&left)? {
821                    true => lit(true),
822                    false => Expr::BinaryExpr(BinaryExpr {
823                        left: Box::new(Expr::IsNotNull(left)),
824                        op: Or,
825                        right: Box::new(lit_bool_null()),
826                    }),
827                })
828            }
829
830            // Rules for NotEq
831            //
832
833            // true != A  --> !A
834            // false != A --> A
835            // null != A --> null
836            Expr::BinaryExpr(BinaryExpr {
837                left,
838                op: NotEq,
839                right,
840            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
841                Transformed::yes(match as_bool_lit(&left)? {
842                    Some(true) => Expr::Not(right),
843                    Some(false) => *right,
844                    None => lit_bool_null(),
845                })
846            }
847            // A != true  --> !A
848            // A != false --> A
849            // A != null --> null,
850            Expr::BinaryExpr(BinaryExpr {
851                left,
852                op: NotEq,
853                right,
854            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
855                Transformed::yes(match as_bool_lit(&right)? {
856                    Some(true) => Expr::Not(left),
857                    Some(false) => *left,
858                    None => lit_bool_null(),
859                })
860            }
861
862            //
863            // Rules for OR
864            //
865
866            // true OR A --> true (even if A is null)
867            Expr::BinaryExpr(BinaryExpr {
868                left,
869                op: Or,
870                right: _,
871            }) if is_true(&left) => Transformed::yes(*left),
872            // false OR A --> A
873            Expr::BinaryExpr(BinaryExpr {
874                left,
875                op: Or,
876                right,
877            }) if is_false(&left) => Transformed::yes(*right),
878            // A OR true --> true (even if A is null)
879            Expr::BinaryExpr(BinaryExpr {
880                left: _,
881                op: Or,
882                right,
883            }) if is_true(&right) => Transformed::yes(*right),
884            // A OR false --> A
885            Expr::BinaryExpr(BinaryExpr {
886                left,
887                op: Or,
888                right,
889            }) if is_false(&right) => Transformed::yes(*left),
890            // A OR !A ---> true (if A not nullable)
891            Expr::BinaryExpr(BinaryExpr {
892                left,
893                op: Or,
894                right,
895            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
896                Transformed::yes(lit(true))
897            }
898            // !A OR A ---> true (if A not nullable)
899            Expr::BinaryExpr(BinaryExpr {
900                left,
901                op: Or,
902                right,
903            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
904                Transformed::yes(lit(true))
905            }
906            // (..A..) OR A --> (..A..)
907            Expr::BinaryExpr(BinaryExpr {
908                left,
909                op: Or,
910                right,
911            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
912            // A OR (..A..) --> (..A..)
913            Expr::BinaryExpr(BinaryExpr {
914                left,
915                op: Or,
916                right,
917            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
918            // A OR (A AND B) --> A
919            Expr::BinaryExpr(BinaryExpr {
920                left,
921                op: Or,
922                right,
923            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
924            // (A AND B) OR A --> A
925            Expr::BinaryExpr(BinaryExpr {
926                left,
927                op: Or,
928                right,
929            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
930            // Eliminate common factors in conjunctions e.g
931            // (A AND B) OR (A AND C) -> A AND (B OR C)
932            Expr::BinaryExpr(BinaryExpr {
933                left,
934                op: Or,
935                right,
936            }) if has_common_conjunction(&left, &right) => {
937                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
938                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
939                    .partition(|e| lhs.contains(e) && !e.is_volatile());
940
941                let new_rhs = rhs.into_iter().reduce(and);
942                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
943                let common_conjunction = common.into_iter().reduce(and).unwrap();
944
945                let new_expr = match (new_lhs, new_rhs) {
946                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
947                    (_, _) => common_conjunction,
948                };
949                Transformed::yes(new_expr)
950            }
951
952            //
953            // Rules for AND
954            //
955
956            // true AND A --> A
957            Expr::BinaryExpr(BinaryExpr {
958                left,
959                op: And,
960                right,
961            }) if is_true(&left) => Transformed::yes(*right),
962            // false AND A --> false (even if A is null)
963            Expr::BinaryExpr(BinaryExpr {
964                left,
965                op: And,
966                right: _,
967            }) if is_false(&left) => Transformed::yes(*left),
968            // A AND true --> A
969            Expr::BinaryExpr(BinaryExpr {
970                left,
971                op: And,
972                right,
973            }) if is_true(&right) => Transformed::yes(*left),
974            // A AND false --> false (even if A is null)
975            Expr::BinaryExpr(BinaryExpr {
976                left: _,
977                op: And,
978                right,
979            }) if is_false(&right) => Transformed::yes(*right),
980            // A AND !A ---> false (if A not nullable)
981            Expr::BinaryExpr(BinaryExpr {
982                left,
983                op: And,
984                right,
985            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
986                Transformed::yes(lit(false))
987            }
988            // !A AND A ---> false (if A not nullable)
989            Expr::BinaryExpr(BinaryExpr {
990                left,
991                op: And,
992                right,
993            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
994                Transformed::yes(lit(false))
995            }
996            // (..A..) AND A --> (..A..)
997            Expr::BinaryExpr(BinaryExpr {
998                left,
999                op: And,
1000                right,
1001            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
1002            // A AND (..A..) --> (..A..)
1003            Expr::BinaryExpr(BinaryExpr {
1004                left,
1005                op: And,
1006                right,
1007            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
1008            // A AND (A OR B) --> A
1009            Expr::BinaryExpr(BinaryExpr {
1010                left,
1011                op: And,
1012                right,
1013            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
1014            // (A OR B) AND A --> A
1015            Expr::BinaryExpr(BinaryExpr {
1016                left,
1017                op: And,
1018                right,
1019            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
1020            // A >= constant AND constant <= A --> A = constant
1021            Expr::BinaryExpr(BinaryExpr {
1022                left,
1023                op: And,
1024                right,
1025            }) if can_reduce_to_equal_statement(&left, &right) => {
1026                if let Expr::BinaryExpr(BinaryExpr {
1027                    left: left_left,
1028                    right: left_right,
1029                    ..
1030                }) = *left
1031                {
1032                    Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1033                        left: left_left,
1034                        op: Eq,
1035                        right: left_right,
1036                    }))
1037                } else {
1038                    return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr");
1039                }
1040            }
1041
1042            //
1043            // Rules for Multiply
1044            //
1045
1046            // A * 1 --> A (with type coercion if needed)
1047            Expr::BinaryExpr(BinaryExpr {
1048                left,
1049                op: Multiply,
1050                right,
1051            }) if is_one(&right) => {
1052                simplify_right_is_one_case(info, left, &Multiply, &right)?
1053            }
1054            // A * null --> null
1055            Expr::BinaryExpr(BinaryExpr {
1056                left,
1057                op: Multiply,
1058                right,
1059            }) if is_null(&right) => {
1060                simplify_right_is_null_case(info, &left, &Multiply, right)?
1061            }
1062            // 1 * A --> A
1063            Expr::BinaryExpr(BinaryExpr {
1064                left,
1065                op: Multiply,
1066                right,
1067            }) if is_one(&left) => {
1068                // 1 * A is equivalent to A * 1
1069                simplify_right_is_one_case(info, right, &Multiply, &left)?
1070            }
1071            // null * A --> null
1072            Expr::BinaryExpr(BinaryExpr {
1073                left,
1074                op: Multiply,
1075                right,
1076            }) if is_null(&left) => {
1077                simplify_right_is_null_case(info, &right, &Multiply, left)?
1078            }
1079
1080            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1081            Expr::BinaryExpr(BinaryExpr {
1082                left,
1083                op: Multiply,
1084                right,
1085            }) if !info.nullable(&left)?
1086                && !info.get_data_type(&left)?.is_floating()
1087                && is_zero(&right) =>
1088            {
1089                Transformed::yes(*right)
1090            }
1091            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
1092            Expr::BinaryExpr(BinaryExpr {
1093                left,
1094                op: Multiply,
1095                right,
1096            }) if !info.nullable(&right)?
1097                && !info.get_data_type(&right)?.is_floating()
1098                && is_zero(&left) =>
1099            {
1100                Transformed::yes(*left)
1101            }
1102
1103            //
1104            // Rules for Divide
1105            //
1106
1107            // A / 1 --> A
1108            Expr::BinaryExpr(BinaryExpr {
1109                left,
1110                op: Divide,
1111                right,
1112            }) if is_one(&right) => {
1113                simplify_right_is_one_case(info, left, &Divide, &right)?
1114            }
1115            // A / null --> null
1116            Expr::BinaryExpr(BinaryExpr {
1117                left,
1118                op: Divide,
1119                right,
1120            }) if is_null(&right) => {
1121                simplify_right_is_null_case(info, &left, &Divide, right)?
1122            }
1123            // null / A --> null
1124            Expr::BinaryExpr(BinaryExpr {
1125                left,
1126                op: Divide,
1127                right,
1128            }) if is_null(&left) => simplify_null_div_other_case(info, left, &right)?,
1129
1130            //
1131            // Rules for Modulo
1132            //
1133
1134            // A % null --> null
1135            Expr::BinaryExpr(BinaryExpr {
1136                left: _,
1137                op: Modulo,
1138                right,
1139            }) if is_null(&right) => Transformed::yes(*right),
1140            // null % A --> null
1141            Expr::BinaryExpr(BinaryExpr {
1142                left,
1143                op: Modulo,
1144                right: _,
1145            }) if is_null(&left) => Transformed::yes(*left),
1146            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1147            Expr::BinaryExpr(BinaryExpr {
1148                left,
1149                op: Modulo,
1150                right,
1151            }) if !info.nullable(&left)?
1152                && !info.get_data_type(&left)?.is_floating()
1153                && is_one(&right) =>
1154            {
1155                Transformed::yes(Expr::Literal(
1156                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1157                    None,
1158                ))
1159            }
1160
1161            //
1162            // Rules for BitwiseAnd
1163            //
1164
1165            // A & null -> null
1166            Expr::BinaryExpr(BinaryExpr {
1167                left: _,
1168                op: BitwiseAnd,
1169                right,
1170            }) if is_null(&right) => Transformed::yes(*right),
1171
1172            // null & A -> null
1173            Expr::BinaryExpr(BinaryExpr {
1174                left,
1175                op: BitwiseAnd,
1176                right: _,
1177            }) if is_null(&left) => Transformed::yes(*left),
1178
1179            // A & 0 -> 0 (if A not nullable)
1180            Expr::BinaryExpr(BinaryExpr {
1181                left,
1182                op: BitwiseAnd,
1183                right,
1184            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1185
1186            // 0 & A -> 0 (if A not nullable)
1187            Expr::BinaryExpr(BinaryExpr {
1188                left,
1189                op: BitwiseAnd,
1190                right,
1191            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1192
1193            // !A & A -> 0 (if A not nullable)
1194            Expr::BinaryExpr(BinaryExpr {
1195                left,
1196                op: BitwiseAnd,
1197                right,
1198            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1199                Transformed::yes(Expr::Literal(
1200                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1201                    None,
1202                ))
1203            }
1204
1205            // A & !A -> 0 (if A not nullable)
1206            Expr::BinaryExpr(BinaryExpr {
1207                left,
1208                op: BitwiseAnd,
1209                right,
1210            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1211                Transformed::yes(Expr::Literal(
1212                    ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1213                    None,
1214                ))
1215            }
1216
1217            // (..A..) & A --> (..A..)
1218            Expr::BinaryExpr(BinaryExpr {
1219                left,
1220                op: BitwiseAnd,
1221                right,
1222            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1223
1224            // A & (..A..) --> (..A..)
1225            Expr::BinaryExpr(BinaryExpr {
1226                left,
1227                op: BitwiseAnd,
1228                right,
1229            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1230
1231            // A & (A | B) --> A (if B not null)
1232            Expr::BinaryExpr(BinaryExpr {
1233                left,
1234                op: BitwiseAnd,
1235                right,
1236            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1237                Transformed::yes(*left)
1238            }
1239
1240            // (A | B) & A --> A (if B not null)
1241            Expr::BinaryExpr(BinaryExpr {
1242                left,
1243                op: BitwiseAnd,
1244                right,
1245            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1246                Transformed::yes(*right)
1247            }
1248
1249            //
1250            // Rules for BitwiseOr
1251            //
1252
1253            // A | null -> null
1254            Expr::BinaryExpr(BinaryExpr {
1255                left: _,
1256                op: BitwiseOr,
1257                right,
1258            }) if is_null(&right) => Transformed::yes(*right),
1259
1260            // null | A -> null
1261            Expr::BinaryExpr(BinaryExpr {
1262                left,
1263                op: BitwiseOr,
1264                right: _,
1265            }) if is_null(&left) => Transformed::yes(*left),
1266
1267            // A | 0 -> A (even if A is null)
1268            Expr::BinaryExpr(BinaryExpr {
1269                left,
1270                op: BitwiseOr,
1271                right,
1272            }) if is_zero(&right) => Transformed::yes(*left),
1273
1274            // 0 | A -> A (even if A is null)
1275            Expr::BinaryExpr(BinaryExpr {
1276                left,
1277                op: BitwiseOr,
1278                right,
1279            }) if is_zero(&left) => Transformed::yes(*right),
1280
1281            // !A | A -> -1 (if A not nullable)
1282            Expr::BinaryExpr(BinaryExpr {
1283                left,
1284                op: BitwiseOr,
1285                right,
1286            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1287                Transformed::yes(Expr::Literal(
1288                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1289                    None,
1290                ))
1291            }
1292
1293            // A | !A -> -1 (if A not nullable)
1294            Expr::BinaryExpr(BinaryExpr {
1295                left,
1296                op: BitwiseOr,
1297                right,
1298            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1299                Transformed::yes(Expr::Literal(
1300                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1301                    None,
1302                ))
1303            }
1304
1305            // (..A..) | A --> (..A..)
1306            Expr::BinaryExpr(BinaryExpr {
1307                left,
1308                op: BitwiseOr,
1309                right,
1310            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1311
1312            // A | (..A..) --> (..A..)
1313            Expr::BinaryExpr(BinaryExpr {
1314                left,
1315                op: BitwiseOr,
1316                right,
1317            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1318
1319            // A | (A & B) --> A (if B not null)
1320            Expr::BinaryExpr(BinaryExpr {
1321                left,
1322                op: BitwiseOr,
1323                right,
1324            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1325                Transformed::yes(*left)
1326            }
1327
1328            // (A & B) | A --> A (if B not null)
1329            Expr::BinaryExpr(BinaryExpr {
1330                left,
1331                op: BitwiseOr,
1332                right,
1333            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1334                Transformed::yes(*right)
1335            }
1336
1337            //
1338            // Rules for BitwiseXor
1339            //
1340
1341            // A ^ null -> null
1342            Expr::BinaryExpr(BinaryExpr {
1343                left: _,
1344                op: BitwiseXor,
1345                right,
1346            }) if is_null(&right) => Transformed::yes(*right),
1347
1348            // null ^ A -> null
1349            Expr::BinaryExpr(BinaryExpr {
1350                left,
1351                op: BitwiseXor,
1352                right: _,
1353            }) if is_null(&left) => Transformed::yes(*left),
1354
1355            // A ^ 0 -> A (if A not nullable)
1356            Expr::BinaryExpr(BinaryExpr {
1357                left,
1358                op: BitwiseXor,
1359                right,
1360            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1361
1362            // 0 ^ A -> A (if A not nullable)
1363            Expr::BinaryExpr(BinaryExpr {
1364                left,
1365                op: BitwiseXor,
1366                right,
1367            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1368
1369            // !A ^ A -> -1 (if A not nullable)
1370            Expr::BinaryExpr(BinaryExpr {
1371                left,
1372                op: BitwiseXor,
1373                right,
1374            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1375                Transformed::yes(Expr::Literal(
1376                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1377                    None,
1378                ))
1379            }
1380
1381            // A ^ !A -> -1 (if A not nullable)
1382            Expr::BinaryExpr(BinaryExpr {
1383                left,
1384                op: BitwiseXor,
1385                right,
1386            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1387                Transformed::yes(Expr::Literal(
1388                    ScalarValue::new_negative_one(&info.get_data_type(&left)?)?,
1389                    None,
1390                ))
1391            }
1392
1393            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1394            Expr::BinaryExpr(BinaryExpr {
1395                left,
1396                op: BitwiseXor,
1397                right,
1398            }) if expr_contains(&left, &right, BitwiseXor) => {
1399                let expr = delete_xor_in_complex_expr(&left, &right, false);
1400                Transformed::yes(if expr == *right {
1401                    Expr::Literal(
1402                        ScalarValue::new_zero(&info.get_data_type(&right)?)?,
1403                        None,
1404                    )
1405                } else {
1406                    expr
1407                })
1408            }
1409
1410            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1411            Expr::BinaryExpr(BinaryExpr {
1412                left,
1413                op: BitwiseXor,
1414                right,
1415            }) if expr_contains(&right, &left, BitwiseXor) => {
1416                let expr = delete_xor_in_complex_expr(&right, &left, true);
1417                Transformed::yes(if expr == *left {
1418                    Expr::Literal(
1419                        ScalarValue::new_zero(&info.get_data_type(&left)?)?,
1420                        None,
1421                    )
1422                } else {
1423                    expr
1424                })
1425            }
1426
1427            //
1428            // Rules for BitwiseShiftRight
1429            //
1430
1431            // A >> null -> null
1432            Expr::BinaryExpr(BinaryExpr {
1433                left: _,
1434                op: BitwiseShiftRight,
1435                right,
1436            }) if is_null(&right) => Transformed::yes(*right),
1437
1438            // null >> A -> null
1439            Expr::BinaryExpr(BinaryExpr {
1440                left,
1441                op: BitwiseShiftRight,
1442                right: _,
1443            }) if is_null(&left) => Transformed::yes(*left),
1444
1445            // A >> 0 -> A (even if A is null)
1446            Expr::BinaryExpr(BinaryExpr {
1447                left,
1448                op: BitwiseShiftRight,
1449                right,
1450            }) if is_zero(&right) => Transformed::yes(*left),
1451
1452            //
1453            // Rules for BitwiseShiftRight
1454            //
1455
1456            // A << null -> null
1457            Expr::BinaryExpr(BinaryExpr {
1458                left: _,
1459                op: BitwiseShiftLeft,
1460                right,
1461            }) if is_null(&right) => Transformed::yes(*right),
1462
1463            // null << A -> null
1464            Expr::BinaryExpr(BinaryExpr {
1465                left,
1466                op: BitwiseShiftLeft,
1467                right: _,
1468            }) if is_null(&left) => Transformed::yes(*left),
1469
1470            // A << 0 -> A (even if A is null)
1471            Expr::BinaryExpr(BinaryExpr {
1472                left,
1473                op: BitwiseShiftLeft,
1474                right,
1475            }) if is_zero(&right) => Transformed::yes(*left),
1476
1477            //
1478            // Rules for Not
1479            //
1480            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1481
1482            //
1483            // Rules for Negative
1484            //
1485            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1486
1487            //
1488            // Rules for Case
1489            //
1490
1491            // CASE
1492            //   WHEN X THEN A
1493            //   WHEN Y THEN B
1494            //   ...
1495            //   ELSE Q
1496            // END
1497            //
1498            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1499            //
1500            // Note: the rationale for this rewrite is that the expr can then be further
1501            // simplified using the existing rules for AND/OR
1502            Expr::Case(Case {
1503                expr: None,
1504                when_then_expr,
1505                else_expr,
1506            }) if !when_then_expr.is_empty()
1507                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1508                && info.is_boolean_type(&when_then_expr[0].1)? =>
1509            {
1510                // String disjunction of all the when predicates encountered so far. Not nullable.
1511                let mut filter_expr = lit(false);
1512                // The disjunction of all the cases
1513                let mut out_expr = lit(false);
1514
1515                for (when, then) in when_then_expr {
1516                    let when = is_exactly_true(*when, info)?;
1517                    let case_expr =
1518                        when.clone().and(filter_expr.clone().not()).and(*then);
1519
1520                    out_expr = out_expr.or(case_expr);
1521                    filter_expr = filter_expr.or(when);
1522                }
1523
1524                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1525                let case_expr = filter_expr.not().and(else_expr);
1526                out_expr = out_expr.or(case_expr);
1527
1528                // Do a first pass at simplification
1529                out_expr.rewrite(self)?
1530            }
1531            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1532                match udf.simplify(args, info)? {
1533                    ExprSimplifyResult::Original(args) => {
1534                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1535                            func: udf,
1536                            args,
1537                        }))
1538                    }
1539                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1540                }
1541            }
1542
1543            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1544                ref func,
1545                ..
1546            }) => match (func.simplify(), expr) {
1547                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1548                    Transformed::yes(simplify_function(af, info)?)
1549                }
1550                (_, expr) => Transformed::no(expr),
1551            },
1552
1553            Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) {
1554                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1555                    Transformed::yes(simplify_function(*wf, info)?)
1556                }
1557                (_, expr) => Transformed::no(expr),
1558            },
1559
1560            //
1561            // Rules for Between
1562            //
1563
1564            // a between 3 and 5  -->  a >= 3 AND a <=5
1565            // a not between 3 and 5  -->  a < 3 OR a > 5
1566            Expr::Between(between) => Transformed::yes(if between.negated {
1567                let l = *between.expr.clone();
1568                let r = *between.expr;
1569                or(l.lt(*between.low), r.gt(*between.high))
1570            } else {
1571                and(
1572                    between.expr.clone().gt_eq(*between.low),
1573                    between.expr.lt_eq(*between.high),
1574                )
1575            }),
1576
1577            //
1578            // Rules for regexes
1579            //
1580            Expr::BinaryExpr(BinaryExpr {
1581                left,
1582                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1583                right,
1584            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1585
1586            // Rules for Like
1587            Expr::Like(like) => {
1588                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1589                let escape_char = like.escape_char.unwrap_or('\\');
1590                match as_string_scalar(&like.pattern) {
1591                    Some((data_type, pattern_str)) => {
1592                        match pattern_str {
1593                            None => return Ok(Transformed::yes(lit_bool_null())),
1594                            Some(pattern_str) if pattern_str == "%" => {
1595                                // exp LIKE '%' is
1596                                //   - when exp is not NULL, it's true
1597                                //   - when exp is NULL, it's NULL
1598                                // exp NOT LIKE '%' is
1599                                //   - when exp is not NULL, it's false
1600                                //   - when exp is NULL, it's NULL
1601                                let result_for_non_null = lit(!like.negated);
1602                                Transformed::yes(if !info.nullable(&like.expr)? {
1603                                    result_for_non_null
1604                                } else {
1605                                    Expr::Case(Case {
1606                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1607                                        when_then_expr: vec![(
1608                                            Box::new(lit(true)),
1609                                            Box::new(result_for_non_null),
1610                                        )],
1611                                        else_expr: None,
1612                                    })
1613                                })
1614                            }
1615                            Some(pattern_str)
1616                                if pattern_str.contains("%%")
1617                                    && !pattern_str.contains(escape_char) =>
1618                            {
1619                                // Repeated occurrences of wildcard are redundant so remove them
1620                                // exp LIKE '%%'  --> exp LIKE '%'
1621                                let simplified_pattern = Regex::new("%%+")
1622                                    .unwrap()
1623                                    .replace_all(pattern_str, "%")
1624                                    .to_string();
1625                                Transformed::yes(Expr::Like(Like {
1626                                    pattern: Box::new(to_string_scalar(
1627                                        data_type,
1628                                        Some(simplified_pattern),
1629                                    )),
1630                                    ..like
1631                                }))
1632                            }
1633                            Some(pattern_str)
1634                                if !like.case_insensitive
1635                                    && !pattern_str
1636                                        .contains(['%', '_', escape_char].as_ref()) =>
1637                            {
1638                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1639                                // TODO: handle escape characters
1640                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1641                                    left: like.expr.clone(),
1642                                    op: if like.negated { NotEq } else { Eq },
1643                                    right: like.pattern.clone(),
1644                                }))
1645                            }
1646
1647                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1648                        }
1649                    }
1650                    None => Transformed::no(Expr::Like(like)),
1651                }
1652            }
1653
1654            // a is not null/unknown --> true (if a is not nullable)
1655            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1656                if !info.nullable(&expr)? =>
1657            {
1658                Transformed::yes(lit(true))
1659            }
1660
1661            // a is null/unknown --> false (if a is not nullable)
1662            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1663                Transformed::yes(lit(false))
1664            }
1665
1666            // expr IN () --> false
1667            // expr NOT IN () --> true
1668            Expr::InList(InList {
1669                expr,
1670                list,
1671                negated,
1672            }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null, None) => {
1673                Transformed::yes(lit(negated))
1674            }
1675
1676            // null in (x, y, z) --> null
1677            // null not in (x, y, z) --> null
1678            Expr::InList(InList {
1679                expr,
1680                list: _,
1681                negated: _,
1682            }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()),
1683
1684            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1685            Expr::InList(InList {
1686                expr,
1687                mut list,
1688                negated,
1689            }) if list.len() == 1
1690                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1691            {
1692                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1693                    unreachable!()
1694                };
1695
1696                Transformed::yes(Expr::InSubquery(InSubquery::new(
1697                    expr, subquery, negated,
1698                )))
1699            }
1700
1701            // Combine multiple OR expressions into a single IN list expression if possible
1702            //
1703            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1704            Expr::BinaryExpr(BinaryExpr {
1705                left,
1706                op: Or,
1707                right,
1708            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1709                let lhs = to_inlist(*left).unwrap();
1710                let rhs = to_inlist(*right).unwrap();
1711                let mut seen: HashSet<Expr> = HashSet::new();
1712                let list = lhs
1713                    .list
1714                    .into_iter()
1715                    .chain(rhs.list)
1716                    .filter(|e| seen.insert(e.to_owned()))
1717                    .collect::<Vec<_>>();
1718
1719                let merged_inlist = InList {
1720                    expr: lhs.expr,
1721                    list,
1722                    negated: false,
1723                };
1724
1725                Transformed::yes(Expr::InList(merged_inlist))
1726            }
1727
1728            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1729            //
1730            // Rules:
1731            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1732            //   Intersection:
1733            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1734            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1735            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1736            //   Union:
1737            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1738            //     # This rule is handled by `or_in_list_simplifier.rs`
1739            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1740            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1741            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1742            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1743            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1744            Expr::BinaryExpr(BinaryExpr {
1745                left,
1746                op: And,
1747                right,
1748            }) if are_inlist_and_eq_and_match_neg(
1749                left.as_ref(),
1750                right.as_ref(),
1751                false,
1752                false,
1753            ) =>
1754            {
1755                match (*left, *right) {
1756                    (Expr::InList(l1), Expr::InList(l2)) => {
1757                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1758                    }
1759                    // Matched previously once
1760                    _ => unreachable!(),
1761                }
1762            }
1763
1764            Expr::BinaryExpr(BinaryExpr {
1765                left,
1766                op: And,
1767                right,
1768            }) if are_inlist_and_eq_and_match_neg(
1769                left.as_ref(),
1770                right.as_ref(),
1771                true,
1772                true,
1773            ) =>
1774            {
1775                match (*left, *right) {
1776                    (Expr::InList(l1), Expr::InList(l2)) => {
1777                        return inlist_union(l1, l2, true).map(Transformed::yes);
1778                    }
1779                    // Matched previously once
1780                    _ => unreachable!(),
1781                }
1782            }
1783
1784            Expr::BinaryExpr(BinaryExpr {
1785                left,
1786                op: And,
1787                right,
1788            }) if are_inlist_and_eq_and_match_neg(
1789                left.as_ref(),
1790                right.as_ref(),
1791                false,
1792                true,
1793            ) =>
1794            {
1795                match (*left, *right) {
1796                    (Expr::InList(l1), Expr::InList(l2)) => {
1797                        return inlist_except(l1, &l2).map(Transformed::yes);
1798                    }
1799                    // Matched previously once
1800                    _ => unreachable!(),
1801                }
1802            }
1803
1804            Expr::BinaryExpr(BinaryExpr {
1805                left,
1806                op: And,
1807                right,
1808            }) if are_inlist_and_eq_and_match_neg(
1809                left.as_ref(),
1810                right.as_ref(),
1811                true,
1812                false,
1813            ) =>
1814            {
1815                match (*left, *right) {
1816                    (Expr::InList(l1), Expr::InList(l2)) => {
1817                        return inlist_except(l2, &l1).map(Transformed::yes);
1818                    }
1819                    // Matched previously once
1820                    _ => unreachable!(),
1821                }
1822            }
1823
1824            Expr::BinaryExpr(BinaryExpr {
1825                left,
1826                op: Or,
1827                right,
1828            }) if are_inlist_and_eq_and_match_neg(
1829                left.as_ref(),
1830                right.as_ref(),
1831                true,
1832                true,
1833            ) =>
1834            {
1835                match (*left, *right) {
1836                    (Expr::InList(l1), Expr::InList(l2)) => {
1837                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1838                    }
1839                    // Matched previously once
1840                    _ => unreachable!(),
1841                }
1842            }
1843
1844            // =======================================
1845            // unwrap_cast_in_comparison
1846            // =======================================
1847            //
1848            // For case:
1849            // try_cast/cast(expr as data_type) op literal
1850            Expr::BinaryExpr(BinaryExpr { left, op, right })
1851                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1852                    info, &left, op, &right,
1853                ) && op.supports_propagation() =>
1854            {
1855                unwrap_cast_in_comparison_for_binary(info, *left, *right, op)?
1856            }
1857            // literal op try_cast/cast(expr as data_type)
1858            // -->
1859            // try_cast/cast(expr as data_type) op_swap literal
1860            Expr::BinaryExpr(BinaryExpr { left, op, right })
1861                if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1862                    info, &right, op, &left,
1863                ) && op.supports_propagation()
1864                    && op.swap().is_some() =>
1865            {
1866                unwrap_cast_in_comparison_for_binary(
1867                    info,
1868                    *right,
1869                    *left,
1870                    op.swap().unwrap(),
1871                )?
1872            }
1873            // For case:
1874            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1875            Expr::InList(InList {
1876                expr: mut left,
1877                list,
1878                negated,
1879            }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1880                info, &left, &list,
1881            ) =>
1882            {
1883                let (Expr::TryCast(TryCast {
1884                    expr: left_expr, ..
1885                })
1886                | Expr::Cast(Cast {
1887                    expr: left_expr, ..
1888                })) = left.as_mut()
1889                else {
1890                    return internal_err!("Expect cast expr, but got {:?}", left)?;
1891                };
1892
1893                let expr_type = info.get_data_type(left_expr)?;
1894                let right_exprs = list
1895                    .into_iter()
1896                    .map(|right| {
1897                        match right {
1898                            Expr::Literal(right_lit_value, _) => {
1899                                // if the right_lit_value can be casted to the type of internal_left_expr
1900                                // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1901                                let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1902                                    internal_err!(
1903                                        "Can't cast the list expr {:?} to type {:?}",
1904                                        right_lit_value, &expr_type
1905                                    )?
1906                                };
1907                                Ok(lit(value))
1908                            }
1909                            other_expr => internal_err!(
1910                                "Only support literal expr to optimize, but the expr is {:?}",
1911                                &other_expr
1912                            ),
1913                        }
1914                    })
1915                    .collect::<Result<Vec<_>>>()?;
1916
1917                Transformed::yes(Expr::InList(InList {
1918                    expr: std::mem::take(left_expr),
1919                    list: right_exprs,
1920                    negated,
1921                }))
1922            }
1923
1924            // no additional rewrites possible
1925            expr => Transformed::no(expr),
1926        })
1927    }
1928}
1929
1930fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1931    match expr {
1932        Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)),
1933        Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)),
1934        Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)),
1935        _ => None,
1936    }
1937}
1938
1939fn to_string_scalar(data_type: DataType, value: Option<String>) -> Expr {
1940    match data_type {
1941        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None),
1942        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None),
1943        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None),
1944        _ => unreachable!(),
1945    }
1946}
1947
1948fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1949    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1950    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1951}
1952
1953// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1954fn are_inlist_and_eq_and_match_neg(
1955    left: &Expr,
1956    right: &Expr,
1957    is_left_neg: bool,
1958    is_right_neg: bool,
1959) -> bool {
1960    match (left, right) {
1961        (Expr::InList(l), Expr::InList(r)) => {
1962            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1963        }
1964        _ => false,
1965    }
1966}
1967
1968// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1969fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
1970    let left = as_inlist(left);
1971    let right = as_inlist(right);
1972    if let (Some(lhs), Some(rhs)) = (left, right) {
1973        matches!(lhs.expr.as_ref(), Expr::Column(_))
1974            && matches!(rhs.expr.as_ref(), Expr::Column(_))
1975            && lhs.expr == rhs.expr
1976            && !lhs.negated
1977            && !rhs.negated
1978    } else {
1979        false
1980    }
1981}
1982
1983/// Try to convert an expression to an in-list expression
1984fn as_inlist(expr: &Expr) -> Option<Cow<InList>> {
1985    match expr {
1986        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
1987        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
1988            match (left.as_ref(), right.as_ref()) {
1989                (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList {
1990                    expr: left.clone(),
1991                    list: vec![*right.clone()],
1992                    negated: false,
1993                })),
1994                (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList {
1995                    expr: right.clone(),
1996                    list: vec![*left.clone()],
1997                    negated: false,
1998                })),
1999                _ => None,
2000            }
2001        }
2002        _ => None,
2003    }
2004}
2005
2006fn to_inlist(expr: Expr) -> Option<InList> {
2007    match expr {
2008        Expr::InList(inlist) => Some(inlist),
2009        Expr::BinaryExpr(BinaryExpr {
2010            left,
2011            op: Operator::Eq,
2012            right,
2013        }) => match (left.as_ref(), right.as_ref()) {
2014            (Expr::Column(_), Expr::Literal(_, _)) => Some(InList {
2015                expr: left,
2016                list: vec![*right],
2017                negated: false,
2018            }),
2019            (Expr::Literal(_, _), Expr::Column(_)) => Some(InList {
2020                expr: right,
2021                list: vec![*left],
2022                negated: false,
2023            }),
2024            _ => None,
2025        },
2026        _ => None,
2027    }
2028}
2029
2030/// Return the union of two inlist expressions
2031/// maintaining the order of the elements in the two lists
2032fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
2033    // extend the list in l1 with the elements in l2 that are not already in l1
2034    let l1_items: HashSet<_> = l1.list.iter().collect();
2035
2036    // keep all l2 items that do not also appear in l1
2037    let keep_l2: Vec<_> = l2
2038        .list
2039        .into_iter()
2040        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
2041        .collect();
2042
2043    l1.list.extend(keep_l2);
2044    l1.negated = negated;
2045    Ok(Expr::InList(l1))
2046}
2047
2048/// Return the intersection of two inlist expressions
2049/// maintaining the order of the elements in the two lists
2050fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
2051    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2052
2053    // remove all items from l1 that are not in l2
2054    l1.list.retain(|e| l2_items.contains(e));
2055
2056    // e in () is always false
2057    // e not in () is always true
2058    if l1.list.is_empty() {
2059        return Ok(lit(negated));
2060    }
2061    Ok(Expr::InList(l1))
2062}
2063
2064/// Return the all items in l1 that are not in l2
2065/// maintaining the order of the elements in the two lists
2066fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
2067    let l2_items = l2.list.iter().collect::<HashSet<_>>();
2068
2069    // keep only items from l1 that are not in l2
2070    l1.list.retain(|e| !l2_items.contains(e));
2071
2072    if l1.list.is_empty() {
2073        return Ok(lit(false));
2074    }
2075    Ok(Expr::InList(l1))
2076}
2077
2078/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
2079fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
2080    if !info.nullable(&expr)? {
2081        Ok(expr)
2082    } else {
2083        Ok(Expr::BinaryExpr(BinaryExpr {
2084            left: Box::new(expr),
2085            op: Operator::IsNotDistinctFrom,
2086            right: Box::new(lit(true)),
2087        }))
2088    }
2089}
2090
2091// A * 1 -> A
2092// A / 1 -> A
2093//
2094// Move this function body out of the large match branch avoid stack overflow
2095fn simplify_right_is_one_case<S: SimplifyInfo>(
2096    info: &S,
2097    left: Box<Expr>,
2098    op: &Operator,
2099    right: &Expr,
2100) -> Result<Transformed<Expr>> {
2101    // Check if resulting type would be different due to coercion
2102    let left_type = info.get_data_type(&left)?;
2103    let right_type = info.get_data_type(right)?;
2104    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2105        Ok(result_type) => {
2106            // Only cast if the types differ
2107            if left_type != result_type {
2108                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2109            } else {
2110                Ok(Transformed::yes(*left))
2111            }
2112        }
2113        Err(_) => Ok(Transformed::yes(*left)),
2114    }
2115}
2116
2117// A * null -> null
2118// A / null -> null
2119//
2120// Move this function body out of the large match branch avoid stack overflow
2121fn simplify_right_is_null_case<S: SimplifyInfo>(
2122    info: &S,
2123    left: &Expr,
2124    op: &Operator,
2125    right: Box<Expr>,
2126) -> Result<Transformed<Expr>> {
2127    // Check if resulting type would be different due to coercion
2128    let left_type = info.get_data_type(left)?;
2129    let right_type = info.get_data_type(&right)?;
2130    match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2131        Ok(result_type) => {
2132            // Only cast if the types differ
2133            if right_type != result_type {
2134                Ok(Transformed::yes(Expr::Cast(Cast::new(right, result_type))))
2135            } else {
2136                Ok(Transformed::yes(*right))
2137            }
2138        }
2139        Err(_) => Ok(Transformed::yes(*right)),
2140    }
2141}
2142
2143// null / A --> null
2144//
2145// Move this function body out of the large match branch avoid stack overflow
2146fn simplify_null_div_other_case<S: SimplifyInfo>(
2147    info: &S,
2148    left: Box<Expr>,
2149    right: &Expr,
2150) -> Result<Transformed<Expr>> {
2151    // Check if resulting type would be different due to coercion
2152    let left_type = info.get_data_type(&left)?;
2153    let right_type = info.get_data_type(right)?;
2154    match BinaryTypeCoercer::new(&left_type, &Operator::Divide, &right_type)
2155        .get_result_type()
2156    {
2157        Ok(result_type) => {
2158            // Only cast if the types differ
2159            if left_type != result_type {
2160                Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2161            } else {
2162                Ok(Transformed::yes(*left))
2163            }
2164        }
2165        Err(_) => Ok(Transformed::yes(*left)),
2166    }
2167}
2168
2169#[cfg(test)]
2170mod tests {
2171    use super::*;
2172    use crate::simplify_expressions::SimplifyContext;
2173    use crate::test::test_table_scan_with_name;
2174    use arrow::datatypes::FieldRef;
2175    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
2176    use datafusion_expr::{
2177        expr::WindowFunction,
2178        function::{
2179            AccumulatorArgs, AggregateFunctionSimplification,
2180            WindowFunctionSimplification,
2181        },
2182        interval_arithmetic::Interval,
2183        *,
2184    };
2185    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
2186    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2187    use std::{
2188        collections::HashMap,
2189        ops::{BitAnd, BitOr, BitXor},
2190        sync::Arc,
2191    };
2192
2193    // ------------------------------
2194    // --- ExprSimplifier tests -----
2195    // ------------------------------
2196    #[test]
2197    fn api_basic() {
2198        let props = ExecutionProps::new();
2199        let simplifier =
2200            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2201
2202        let expr = lit(1) + lit(2);
2203        let expected = lit(3);
2204        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2205    }
2206
2207    #[test]
2208    fn basic_coercion() {
2209        let schema = test_schema();
2210        let props = ExecutionProps::new();
2211        let simplifier = ExprSimplifier::new(
2212            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
2213        );
2214
2215        // Note expr type is int32 (not int64)
2216        // (1i64 + 2i32) < i
2217        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
2218        // should fully simplify to 3 < i (though i has been coerced to i64)
2219        let expected = lit(3i64).lt(col("i"));
2220
2221        let expr = simplifier.coerce(expr, &schema).unwrap();
2222
2223        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2224    }
2225
2226    fn test_schema() -> DFSchemaRef {
2227        Schema::new(vec![
2228            Field::new("i", DataType::Int64, false),
2229            Field::new("b", DataType::Boolean, true),
2230        ])
2231        .to_dfschema_ref()
2232        .unwrap()
2233    }
2234
2235    #[test]
2236    fn simplify_and_constant_prop() {
2237        let props = ExecutionProps::new();
2238        let simplifier =
2239            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2240
2241        // should be able to simplify to false
2242        // (i * (1 - 2)) > 0
2243        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
2244        let expected = lit(false);
2245        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2246    }
2247
2248    #[test]
2249    fn simplify_and_constant_prop_with_case() {
2250        let props = ExecutionProps::new();
2251        let simplifier =
2252            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
2253
2254        //   CASE
2255        //     WHEN i>5 AND false THEN i > 5
2256        //     WHEN i<5 AND true THEN i < 5
2257        //     ELSE false
2258        //   END
2259        //
2260        // Can be simplified to `i < 5`
2261        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
2262            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
2263            .otherwise(lit(false))
2264            .unwrap();
2265        let expected = col("i").lt(lit(5));
2266        assert_eq!(expected, simplifier.simplify(expr).unwrap());
2267    }
2268
2269    // ------------------------------
2270    // --- Simplifier tests -----
2271    // ------------------------------
2272
2273    #[test]
2274    fn test_simplify_canonicalize() {
2275        {
2276            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
2277            let expected = col("c2").gt(lit(1));
2278            assert_eq!(simplify(expr), expected);
2279        }
2280        {
2281            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2282            let expected = col("c2").gt(col("c1"));
2283            assert_eq!(simplify(expr), expected);
2284        }
2285        {
2286            let expr = col("c1")
2287                .eq(lit(1))
2288                .and(lit(1).eq(col("c1")))
2289                .and(col("c1").eq(lit(3)));
2290            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2291            assert_eq!(simplify(expr), expected);
2292        }
2293        {
2294            let expr = col("c1")
2295                .eq(col("c2"))
2296                .and(col("c1").gt(lit(5)))
2297                .and(col("c2").eq(col("c1")));
2298            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2299            assert_eq!(simplify(expr), expected);
2300        }
2301        {
2302            let expr = col("c1")
2303                .eq(lit(1))
2304                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2305            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2306            assert_eq!(simplify(expr), expected);
2307        }
2308        {
2309            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2310            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2311            assert_eq!(simplify(expr), expected);
2312        }
2313        {
2314            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2315            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2316            assert_eq!(simplify(expr), expected);
2317        }
2318        {
2319            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2320            let expected = col("c2").lt(col("c1"));
2321            assert_eq!(simplify(expr), expected);
2322        }
2323    }
2324
2325    #[test]
2326    fn test_simplify_eq_not_self() {
2327        // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL`
2328        // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics.
2329        let expr_a = col("c2").eq(col("c2"));
2330        let expected_a = col("c2").is_not_null().or(lit_bool_null());
2331
2332        // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true
2333        let expr_b = col("c2_non_null").eq(col("c2_non_null"));
2334        let expected_b = lit(true);
2335
2336        assert_eq!(simplify(expr_a), expected_a);
2337        assert_eq!(simplify(expr_b), expected_b);
2338    }
2339
2340    #[test]
2341    fn test_simplify_or_true() {
2342        let expr_a = col("c2").or(lit(true));
2343        let expr_b = lit(true).or(col("c2"));
2344        let expected = lit(true);
2345
2346        assert_eq!(simplify(expr_a), expected);
2347        assert_eq!(simplify(expr_b), expected);
2348    }
2349
2350    #[test]
2351    fn test_simplify_or_false() {
2352        let expr_a = lit(false).or(col("c2"));
2353        let expr_b = col("c2").or(lit(false));
2354        let expected = col("c2");
2355
2356        assert_eq!(simplify(expr_a), expected);
2357        assert_eq!(simplify(expr_b), expected);
2358    }
2359
2360    #[test]
2361    fn test_simplify_or_same() {
2362        let expr = col("c2").or(col("c2"));
2363        let expected = col("c2");
2364
2365        assert_eq!(simplify(expr), expected);
2366    }
2367
2368    #[test]
2369    fn test_simplify_or_not_self() {
2370        // A OR !A if A is not nullable --> true
2371        // !A OR A if A is not nullable --> true
2372        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2373        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2374        let expected = lit(true);
2375
2376        assert_eq!(simplify(expr_a), expected);
2377        assert_eq!(simplify(expr_b), expected);
2378    }
2379
2380    #[test]
2381    fn test_simplify_and_false() {
2382        let expr_a = lit(false).and(col("c2"));
2383        let expr_b = col("c2").and(lit(false));
2384        let expected = lit(false);
2385
2386        assert_eq!(simplify(expr_a), expected);
2387        assert_eq!(simplify(expr_b), expected);
2388    }
2389
2390    #[test]
2391    fn test_simplify_and_same() {
2392        let expr = col("c2").and(col("c2"));
2393        let expected = col("c2");
2394
2395        assert_eq!(simplify(expr), expected);
2396    }
2397
2398    #[test]
2399    fn test_simplify_and_true() {
2400        let expr_a = lit(true).and(col("c2"));
2401        let expr_b = col("c2").and(lit(true));
2402        let expected = col("c2");
2403
2404        assert_eq!(simplify(expr_a), expected);
2405        assert_eq!(simplify(expr_b), expected);
2406    }
2407
2408    #[test]
2409    fn test_simplify_and_not_self() {
2410        // A AND !A if A is not nullable --> false
2411        // !A AND A if A is not nullable --> false
2412        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2413        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2414        let expected = lit(false);
2415
2416        assert_eq!(simplify(expr_a), expected);
2417        assert_eq!(simplify(expr_b), expected);
2418    }
2419
2420    #[test]
2421    fn test_simplify_multiply_by_one() {
2422        let expr_a = col("c2") * lit(1);
2423        let expr_b = lit(1) * col("c2");
2424        let expected = col("c2");
2425
2426        assert_eq!(simplify(expr_a), expected);
2427        assert_eq!(simplify(expr_b), expected);
2428
2429        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2430        assert_eq!(simplify(expr), expected);
2431
2432        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2433        assert_eq!(simplify(expr), expected);
2434    }
2435
2436    #[test]
2437    fn test_simplify_multiply_by_null() {
2438        let null = Expr::Literal(ScalarValue::Null, None);
2439        // A * null --> null
2440        {
2441            let expr = col("c2") * null.clone();
2442            assert_eq!(simplify(expr), null);
2443        }
2444        // null * A --> null
2445        {
2446            let expr = null.clone() * col("c2");
2447            assert_eq!(simplify(expr), null);
2448        }
2449    }
2450
2451    #[test]
2452    fn test_simplify_multiply_by_zero() {
2453        // cannot optimize A * null (null * A) if A is nullable
2454        {
2455            let expr_a = col("c2") * lit(0);
2456            let expr_b = lit(0) * col("c2");
2457
2458            assert_eq!(simplify(expr_a.clone()), expr_a);
2459            assert_eq!(simplify(expr_b.clone()), expr_b);
2460        }
2461        // 0 * A --> 0 if A is not nullable
2462        {
2463            let expr = lit(0) * col("c2_non_null");
2464            assert_eq!(simplify(expr), lit(0));
2465        }
2466        // A * 0 --> 0 if A is not nullable
2467        {
2468            let expr = col("c2_non_null") * lit(0);
2469            assert_eq!(simplify(expr), lit(0));
2470        }
2471        // A * Decimal128(0) --> 0 if A is not nullable
2472        {
2473            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2474            assert_eq!(
2475                simplify(expr),
2476                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2477            );
2478            let expr = binary_expr(
2479                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2480                Operator::Multiply,
2481                col("c2_non_null"),
2482            );
2483            assert_eq!(
2484                simplify(expr),
2485                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2486            );
2487        }
2488    }
2489
2490    #[test]
2491    fn test_simplify_divide_by_one() {
2492        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2493        let expected = col("c2");
2494        assert_eq!(simplify(expr), expected);
2495        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2496        assert_eq!(simplify(expr), expected);
2497    }
2498
2499    #[test]
2500    fn test_simplify_divide_null() {
2501        // A / null --> null
2502        let null = lit(ScalarValue::Null);
2503        {
2504            let expr = col("c1") / null.clone();
2505            assert_eq!(simplify(expr), null);
2506        }
2507        // null / A --> null
2508        {
2509            let expr = null.clone() / col("c1");
2510            assert_eq!(simplify(expr), null);
2511        }
2512    }
2513
2514    #[test]
2515    fn test_simplify_divide_by_same() {
2516        let expr = col("c2") / col("c2");
2517        // if c2 is null, c2 / c2 = null, so can't simplify
2518        let expected = expr.clone();
2519
2520        assert_eq!(simplify(expr), expected);
2521    }
2522
2523    #[test]
2524    fn test_simplify_modulo_by_null() {
2525        let null = lit(ScalarValue::Null);
2526        // A % null --> null
2527        {
2528            let expr = col("c2") % null.clone();
2529            assert_eq!(simplify(expr), null);
2530        }
2531        // null % A --> null
2532        {
2533            let expr = null.clone() % col("c2");
2534            assert_eq!(simplify(expr), null);
2535        }
2536    }
2537
2538    #[test]
2539    fn test_simplify_modulo_by_one() {
2540        let expr = col("c2") % lit(1);
2541        // if c2 is null, c2 % 1 = null, so can't simplify
2542        let expected = expr.clone();
2543
2544        assert_eq!(simplify(expr), expected);
2545    }
2546
2547    #[test]
2548    fn test_simplify_divide_zero_by_zero() {
2549        // because divide by 0 maybe occur in short-circuit expression
2550        // so we should not simplify this, and throw error in runtime
2551        let expr = lit(0) / lit(0);
2552        let expected = expr.clone();
2553
2554        assert_eq!(simplify(expr), expected);
2555    }
2556
2557    #[test]
2558    fn test_simplify_divide_by_zero() {
2559        // because divide by 0 maybe occur in short-circuit expression
2560        // so we should not simplify this, and throw error in runtime
2561        let expr = col("c2_non_null") / lit(0);
2562        let expected = expr.clone();
2563
2564        assert_eq!(simplify(expr), expected);
2565    }
2566
2567    #[test]
2568    fn test_simplify_modulo_by_one_non_null() {
2569        let expr = col("c3_non_null") % lit(1);
2570        let expected = lit(0_i64);
2571        assert_eq!(simplify(expr), expected);
2572        let expr =
2573            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2574        assert_eq!(simplify(expr), expected);
2575    }
2576
2577    #[test]
2578    fn test_simplify_bitwise_xor_by_null() {
2579        let null = lit(ScalarValue::Null);
2580        // A ^ null --> null
2581        {
2582            let expr = col("c2") ^ null.clone();
2583            assert_eq!(simplify(expr), null);
2584        }
2585        // null ^ A --> null
2586        {
2587            let expr = null.clone() ^ col("c2");
2588            assert_eq!(simplify(expr), null);
2589        }
2590    }
2591
2592    #[test]
2593    fn test_simplify_bitwise_shift_right_by_null() {
2594        let null = lit(ScalarValue::Null);
2595        // A >> null --> null
2596        {
2597            let expr = col("c2") >> null.clone();
2598            assert_eq!(simplify(expr), null);
2599        }
2600        // null >> A --> null
2601        {
2602            let expr = null.clone() >> col("c2");
2603            assert_eq!(simplify(expr), null);
2604        }
2605    }
2606
2607    #[test]
2608    fn test_simplify_bitwise_shift_left_by_null() {
2609        let null = lit(ScalarValue::Null);
2610        // A << null --> null
2611        {
2612            let expr = col("c2") << null.clone();
2613            assert_eq!(simplify(expr), null);
2614        }
2615        // null << A --> null
2616        {
2617            let expr = null.clone() << col("c2");
2618            assert_eq!(simplify(expr), null);
2619        }
2620    }
2621
2622    #[test]
2623    fn test_simplify_bitwise_and_by_zero() {
2624        // A & 0 --> 0
2625        {
2626            let expr = col("c2_non_null") & lit(0);
2627            assert_eq!(simplify(expr), lit(0));
2628        }
2629        // 0 & A --> 0
2630        {
2631            let expr = lit(0) & col("c2_non_null");
2632            assert_eq!(simplify(expr), lit(0));
2633        }
2634    }
2635
2636    #[test]
2637    fn test_simplify_bitwise_or_by_zero() {
2638        // A | 0 --> A
2639        {
2640            let expr = col("c2_non_null") | lit(0);
2641            assert_eq!(simplify(expr), col("c2_non_null"));
2642        }
2643        // 0 | A --> A
2644        {
2645            let expr = lit(0) | col("c2_non_null");
2646            assert_eq!(simplify(expr), col("c2_non_null"));
2647        }
2648    }
2649
2650    #[test]
2651    fn test_simplify_bitwise_xor_by_zero() {
2652        // A ^ 0 --> A
2653        {
2654            let expr = col("c2_non_null") ^ lit(0);
2655            assert_eq!(simplify(expr), col("c2_non_null"));
2656        }
2657        // 0 ^ A --> A
2658        {
2659            let expr = lit(0) ^ col("c2_non_null");
2660            assert_eq!(simplify(expr), col("c2_non_null"));
2661        }
2662    }
2663
2664    #[test]
2665    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2666        // A >> 0 --> A
2667        {
2668            let expr = col("c2_non_null") >> lit(0);
2669            assert_eq!(simplify(expr), col("c2_non_null"));
2670        }
2671    }
2672
2673    #[test]
2674    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2675        // A << 0 --> A
2676        {
2677            let expr = col("c2_non_null") << lit(0);
2678            assert_eq!(simplify(expr), col("c2_non_null"));
2679        }
2680    }
2681
2682    #[test]
2683    fn test_simplify_bitwise_and_by_null() {
2684        let null = lit(ScalarValue::Null);
2685        // A & null --> null
2686        {
2687            let expr = col("c2") & null.clone();
2688            assert_eq!(simplify(expr), null);
2689        }
2690        // null & A --> null
2691        {
2692            let expr = null.clone() & col("c2");
2693            assert_eq!(simplify(expr), null);
2694        }
2695    }
2696
2697    #[test]
2698    fn test_simplify_composed_bitwise_and() {
2699        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2700
2701        let expr = bitwise_and(
2702            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2703            col("c2").gt(lit(5)),
2704        );
2705        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2706
2707        assert_eq!(simplify(expr), expected);
2708
2709        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2710
2711        let expr = bitwise_and(
2712            col("c2").gt(lit(5)),
2713            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2714        );
2715        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2716        assert_eq!(simplify(expr), expected);
2717    }
2718
2719    #[test]
2720    fn test_simplify_composed_bitwise_or() {
2721        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2722
2723        let expr = bitwise_or(
2724            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2725            col("c2").gt(lit(5)),
2726        );
2727        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2728
2729        assert_eq!(simplify(expr), expected);
2730
2731        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2732
2733        let expr = bitwise_or(
2734            col("c2").gt(lit(5)),
2735            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2736        );
2737        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2738
2739        assert_eq!(simplify(expr), expected);
2740    }
2741
2742    #[test]
2743    fn test_simplify_composed_bitwise_xor() {
2744        // with an even number of the column "c2"
2745        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2746
2747        let expr = bitwise_xor(
2748            col("c2"),
2749            bitwise_xor(
2750                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2751                bitwise_and(col("c1"), col("c2")),
2752            ),
2753        );
2754
2755        let expected = bitwise_xor(
2756            bitwise_or(col("c2"), col("c1")),
2757            bitwise_and(col("c1"), col("c2")),
2758        );
2759
2760        assert_eq!(simplify(expr), expected);
2761
2762        // with an odd number of the column "c2"
2763        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2764
2765        let expr = bitwise_xor(
2766            col("c2"),
2767            bitwise_xor(
2768                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2769                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2770            ),
2771        );
2772
2773        let expected = bitwise_xor(
2774            col("c2"),
2775            bitwise_xor(
2776                bitwise_or(col("c2"), col("c1")),
2777                bitwise_and(col("c1"), col("c2")),
2778            ),
2779        );
2780
2781        assert_eq!(simplify(expr), expected);
2782
2783        // with an even number of the column "c2"
2784        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2785
2786        let expr = bitwise_xor(
2787            bitwise_xor(
2788                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2789                bitwise_and(col("c1"), col("c2")),
2790            ),
2791            col("c2"),
2792        );
2793
2794        let expected = bitwise_xor(
2795            bitwise_or(col("c2"), col("c1")),
2796            bitwise_and(col("c1"), col("c2")),
2797        );
2798
2799        assert_eq!(simplify(expr), expected);
2800
2801        // with an odd number of the column "c2"
2802        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2803
2804        let expr = bitwise_xor(
2805            bitwise_xor(
2806                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2807                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2808            ),
2809            col("c2"),
2810        );
2811
2812        let expected = bitwise_xor(
2813            bitwise_xor(
2814                bitwise_or(col("c2"), col("c1")),
2815                bitwise_and(col("c1"), col("c2")),
2816            ),
2817            col("c2"),
2818        );
2819
2820        assert_eq!(simplify(expr), expected);
2821    }
2822
2823    #[test]
2824    fn test_simplify_negated_bitwise_and() {
2825        // !c4 & c4 --> 0
2826        let expr = (-col("c4_non_null")) & col("c4_non_null");
2827        let expected = lit(0u32);
2828
2829        assert_eq!(simplify(expr), expected);
2830        // c4 & !c4 --> 0
2831        let expr = col("c4_non_null") & (-col("c4_non_null"));
2832        let expected = lit(0u32);
2833
2834        assert_eq!(simplify(expr), expected);
2835
2836        // !c3 & c3 --> 0
2837        let expr = (-col("c3_non_null")) & col("c3_non_null");
2838        let expected = lit(0i64);
2839
2840        assert_eq!(simplify(expr), expected);
2841        // c3 & !c3 --> 0
2842        let expr = col("c3_non_null") & (-col("c3_non_null"));
2843        let expected = lit(0i64);
2844
2845        assert_eq!(simplify(expr), expected);
2846    }
2847
2848    #[test]
2849    fn test_simplify_negated_bitwise_or() {
2850        // !c4 | c4 --> -1
2851        let expr = (-col("c4_non_null")) | col("c4_non_null");
2852        let expected = lit(-1i32);
2853
2854        assert_eq!(simplify(expr), expected);
2855
2856        // c4 | !c4 --> -1
2857        let expr = col("c4_non_null") | (-col("c4_non_null"));
2858        let expected = lit(-1i32);
2859
2860        assert_eq!(simplify(expr), expected);
2861
2862        // !c3 | c3 --> -1
2863        let expr = (-col("c3_non_null")) | col("c3_non_null");
2864        let expected = lit(-1i64);
2865
2866        assert_eq!(simplify(expr), expected);
2867
2868        // c3 | !c3 --> -1
2869        let expr = col("c3_non_null") | (-col("c3_non_null"));
2870        let expected = lit(-1i64);
2871
2872        assert_eq!(simplify(expr), expected);
2873    }
2874
2875    #[test]
2876    fn test_simplify_negated_bitwise_xor() {
2877        // !c4 ^ c4 --> -1
2878        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2879        let expected = lit(-1i32);
2880
2881        assert_eq!(simplify(expr), expected);
2882
2883        // c4 ^ !c4 --> -1
2884        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2885        let expected = lit(-1i32);
2886
2887        assert_eq!(simplify(expr), expected);
2888
2889        // !c3 ^ c3 --> -1
2890        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2891        let expected = lit(-1i64);
2892
2893        assert_eq!(simplify(expr), expected);
2894
2895        // c3 ^ !c3 --> -1
2896        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2897        let expected = lit(-1i64);
2898
2899        assert_eq!(simplify(expr), expected);
2900    }
2901
2902    #[test]
2903    fn test_simplify_bitwise_and_or() {
2904        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2905        let expr = bitwise_and(
2906            col("c2_non_null").lt(lit(3)),
2907            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2908        );
2909        let expected = col("c2_non_null").lt(lit(3));
2910
2911        assert_eq!(simplify(expr), expected);
2912    }
2913
2914    #[test]
2915    fn test_simplify_bitwise_or_and() {
2916        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2917        let expr = bitwise_or(
2918            col("c2_non_null").lt(lit(3)),
2919            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2920        );
2921        let expected = col("c2_non_null").lt(lit(3));
2922
2923        assert_eq!(simplify(expr), expected);
2924    }
2925
2926    #[test]
2927    fn test_simplify_simple_bitwise_and() {
2928        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2929        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2930        let expected = col("c2").gt(lit(5));
2931
2932        assert_eq!(simplify(expr), expected);
2933    }
2934
2935    #[test]
2936    fn test_simplify_simple_bitwise_or() {
2937        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2938        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2939        let expected = col("c2").gt(lit(5));
2940
2941        assert_eq!(simplify(expr), expected);
2942    }
2943
2944    #[test]
2945    fn test_simplify_simple_bitwise_xor() {
2946        // c4 ^ c4 -> 0
2947        let expr = (col("c4")).bitxor(col("c4"));
2948        let expected = lit(0u32);
2949
2950        assert_eq!(simplify(expr), expected);
2951
2952        // c3 ^ c3 -> 0
2953        let expr = col("c3").bitxor(col("c3"));
2954        let expected = lit(0i64);
2955
2956        assert_eq!(simplify(expr), expected);
2957    }
2958
2959    #[test]
2960    fn test_simplify_modulo_by_zero_non_null() {
2961        // because modulo by 0 maybe occur in short-circuit expression
2962        // so we should not simplify this, and throw error in runtime.
2963        let expr = col("c2_non_null") % lit(0);
2964        let expected = expr.clone();
2965
2966        assert_eq!(simplify(expr), expected);
2967    }
2968
2969    #[test]
2970    fn test_simplify_simple_and() {
2971        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2972        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2973        let expected = col("c2").gt(lit(5));
2974
2975        assert_eq!(simplify(expr), expected);
2976    }
2977
2978    #[test]
2979    fn test_simplify_composed_and() {
2980        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2981        let expr = and(
2982            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2983            col("c2").gt(lit(5)),
2984        );
2985        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2986
2987        assert_eq!(simplify(expr), expected);
2988    }
2989
2990    #[test]
2991    fn test_simplify_negated_and() {
2992        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2993        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2994        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2995
2996        assert_eq!(simplify(expr), expected);
2997    }
2998
2999    #[test]
3000    fn test_simplify_or_and() {
3001        let l = col("c2").gt(lit(5));
3002        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3003
3004        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
3005        let expr = or(l.clone(), r.clone());
3006
3007        let expected = l.clone();
3008        assert_eq!(simplify(expr), expected);
3009
3010        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
3011        let expr = or(r, l);
3012        assert_eq!(simplify(expr), expected);
3013    }
3014
3015    #[test]
3016    fn test_simplify_or_and_non_null() {
3017        let l = col("c2_non_null").gt(lit(5));
3018        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3019
3020        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
3021        let expr = or(l.clone(), r.clone());
3022
3023        // This is only true if `c1 < 6` is not nullable / can not be null.
3024        let expected = col("c2_non_null").gt(lit(5));
3025
3026        assert_eq!(simplify(expr), expected);
3027
3028        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
3029        let expr = or(l, r);
3030
3031        assert_eq!(simplify(expr), expected);
3032    }
3033
3034    #[test]
3035    fn test_simplify_and_or() {
3036        let l = col("c2").gt(lit(5));
3037        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
3038
3039        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3040        let expr = and(l.clone(), r.clone());
3041
3042        let expected = l.clone();
3043        assert_eq!(simplify(expr), expected);
3044
3045        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3046        let expr = and(r, l);
3047        assert_eq!(simplify(expr), expected);
3048    }
3049
3050    #[test]
3051    fn test_simplify_and_or_non_null() {
3052        let l = col("c2_non_null").gt(lit(5));
3053        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
3054
3055        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
3056        let expr = and(l.clone(), r.clone());
3057
3058        // This is only true if `c1 < 6` is not nullable / can not be null.
3059        let expected = col("c2_non_null").gt(lit(5));
3060
3061        assert_eq!(simplify(expr), expected);
3062
3063        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
3064        let expr = and(l, r);
3065
3066        assert_eq!(simplify(expr), expected);
3067    }
3068
3069    #[test]
3070    fn test_simplify_by_de_morgan_laws() {
3071        // Laws with logical operations
3072        // !(c3 AND c4) --> !c3 OR !c4
3073        let expr = and(col("c3"), col("c4")).not();
3074        let expected = or(col("c3").not(), col("c4").not());
3075        assert_eq!(simplify(expr), expected);
3076        // !(c3 OR c4) --> !c3 AND !c4
3077        let expr = or(col("c3"), col("c4")).not();
3078        let expected = and(col("c3").not(), col("c4").not());
3079        assert_eq!(simplify(expr), expected);
3080        // !(!c3) --> c3
3081        let expr = col("c3").not().not();
3082        let expected = col("c3");
3083        assert_eq!(simplify(expr), expected);
3084
3085        // Laws with bitwise operations
3086        // !(c3 & c4) --> !c3 | !c4
3087        let expr = -bitwise_and(col("c3"), col("c4"));
3088        let expected = bitwise_or(-col("c3"), -col("c4"));
3089        assert_eq!(simplify(expr), expected);
3090        // !(c3 | c4) --> !c3 & !c4
3091        let expr = -bitwise_or(col("c3"), col("c4"));
3092        let expected = bitwise_and(-col("c3"), -col("c4"));
3093        assert_eq!(simplify(expr), expected);
3094        // !(!c3) --> c3
3095        let expr = -(-col("c3"));
3096        let expected = col("c3");
3097        assert_eq!(simplify(expr), expected);
3098    }
3099
3100    #[test]
3101    fn test_simplify_null_and_false() {
3102        let expr = and(lit_bool_null(), lit(false));
3103        let expr_eq = lit(false);
3104
3105        assert_eq!(simplify(expr), expr_eq);
3106    }
3107
3108    #[test]
3109    fn test_simplify_divide_null_by_null() {
3110        let null = lit(ScalarValue::Int32(None));
3111        let expr_plus = null.clone() / null.clone();
3112        let expr_eq = null;
3113
3114        assert_eq!(simplify(expr_plus), expr_eq);
3115    }
3116
3117    #[test]
3118    fn test_simplify_simplify_arithmetic_expr() {
3119        let expr_plus = lit(1) + lit(1);
3120
3121        assert_eq!(simplify(expr_plus), lit(2));
3122    }
3123
3124    #[test]
3125    fn test_simplify_simplify_eq_expr() {
3126        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
3127
3128        assert_eq!(simplify(expr_eq), lit(true));
3129    }
3130
3131    #[test]
3132    fn test_simplify_regex() {
3133        // malformed regex
3134        assert_contains!(
3135            try_simplify(regex_match(col("c1"), lit("foo{")))
3136                .unwrap_err()
3137                .to_string(),
3138            "regex parse error"
3139        );
3140
3141        // unsupported cases
3142        assert_no_change(regex_match(col("c1"), lit("foo.*")));
3143        assert_no_change(regex_match(col("c1"), lit("(foo)")));
3144        assert_no_change(regex_match(col("c1"), lit("%")));
3145        assert_no_change(regex_match(col("c1"), lit("_")));
3146        assert_no_change(regex_match(col("c1"), lit("f%o")));
3147        assert_no_change(regex_match(col("c1"), lit("^f%o")));
3148        assert_no_change(regex_match(col("c1"), lit("f_o")));
3149
3150        // empty cases
3151        assert_change(
3152            regex_match(col("c1"), lit("")),
3153            if_not_null(col("c1"), true),
3154        );
3155        assert_change(
3156            regex_not_match(col("c1"), lit("")),
3157            if_not_null(col("c1"), false),
3158        );
3159        assert_change(
3160            regex_imatch(col("c1"), lit("")),
3161            if_not_null(col("c1"), true),
3162        );
3163        assert_change(
3164            regex_not_imatch(col("c1"), lit("")),
3165            if_not_null(col("c1"), false),
3166        );
3167
3168        // single character
3169        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
3170
3171        // single word
3172        assert_change(
3173            regex_match(col("c1"), lit("foo")),
3174            col("c1").like(lit("%foo%")),
3175        );
3176
3177        // regular expressions that match an exact literal
3178        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
3179        assert_change(
3180            regex_not_match(col("c1"), lit("^$")),
3181            col("c1").not_eq(lit("")),
3182        );
3183        assert_change(
3184            regex_match(col("c1"), lit("^foo$")),
3185            col("c1").eq(lit("foo")),
3186        );
3187        assert_change(
3188            regex_not_match(col("c1"), lit("^foo$")),
3189            col("c1").not_eq(lit("foo")),
3190        );
3191
3192        // regular expressions that match exact captured literals
3193        assert_change(
3194            regex_match(col("c1"), lit("^(foo|bar)$")),
3195            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
3196        );
3197        assert_change(
3198            regex_not_match(col("c1"), lit("^(foo|bar)$")),
3199            col("c1")
3200                .not_eq(lit("foo"))
3201                .and(col("c1").not_eq(lit("bar"))),
3202        );
3203        assert_change(
3204            regex_match(col("c1"), lit("^(foo)$")),
3205            col("c1").eq(lit("foo")),
3206        );
3207        assert_change(
3208            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
3209            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
3210                .or(col("c1").eq(lit("baz"))),
3211        );
3212        assert_change(
3213            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
3214            col("c1")
3215                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
3216        );
3217        assert_change(
3218            regex_match(col("c1"), lit("^(fo_o)$")),
3219            col("c1").eq(lit("fo_o")),
3220        );
3221        assert_change(
3222            regex_match(col("c1"), lit("^(fo_o)$")),
3223            col("c1").eq(lit("fo_o")),
3224        );
3225        assert_change(
3226            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
3227            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
3228        );
3229        assert_change(
3230            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
3231            col("c1")
3232                .not_eq(lit("fo_o"))
3233                .and(col("c1").not_eq(lit("ba_r"))),
3234        );
3235        assert_change(
3236            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
3237            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
3238                .or(col("c1").eq(lit("ba_z"))),
3239        );
3240        assert_change(
3241            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
3242            col("c1").in_list(
3243                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
3244                false,
3245            ),
3246        );
3247
3248        // regular expressions that mismatch captured literals
3249        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
3250        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
3251        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
3252        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
3253        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
3254        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
3255        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
3256        assert_no_change(regex_match(col("c1"), lit("^")));
3257        assert_no_change(regex_match(col("c1"), lit("$")));
3258        assert_no_change(regex_match(col("c1"), lit("$^")));
3259        assert_no_change(regex_match(col("c1"), lit("$foo^")));
3260
3261        // regular expressions that match a partial literal
3262        assert_change(
3263            regex_match(col("c1"), lit("^foo")),
3264            col("c1").like(lit("foo%")),
3265        );
3266        assert_change(
3267            regex_match(col("c1"), lit("foo$")),
3268            col("c1").like(lit("%foo")),
3269        );
3270        assert_change(
3271            regex_match(col("c1"), lit("^foo|bar$")),
3272            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
3273        );
3274
3275        // OR-chain
3276        assert_change(
3277            regex_match(col("c1"), lit("foo|bar|baz")),
3278            col("c1")
3279                .like(lit("%foo%"))
3280                .or(col("c1").like(lit("%bar%")))
3281                .or(col("c1").like(lit("%baz%"))),
3282        );
3283        assert_change(
3284            regex_match(col("c1"), lit("foo|x|baz")),
3285            col("c1")
3286                .like(lit("%foo%"))
3287                .or(col("c1").like(lit("%x%")))
3288                .or(col("c1").like(lit("%baz%"))),
3289        );
3290        assert_change(
3291            regex_not_match(col("c1"), lit("foo|bar|baz")),
3292            col("c1")
3293                .not_like(lit("%foo%"))
3294                .and(col("c1").not_like(lit("%bar%")))
3295                .and(col("c1").not_like(lit("%baz%"))),
3296        );
3297        // both anchored expressions (translated to equality) and unanchored
3298        assert_change(
3299            regex_match(col("c1"), lit("foo|^x$|baz")),
3300            col("c1")
3301                .like(lit("%foo%"))
3302                .or(col("c1").eq(lit("x")))
3303                .or(col("c1").like(lit("%baz%"))),
3304        );
3305        assert_change(
3306            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3307            col("c1")
3308                .not_like(lit("%foo%"))
3309                .and(col("c1").not_eq(lit("bar")))
3310                .and(col("c1").not_like(lit("%baz%"))),
3311        );
3312        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3313        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3314    }
3315
3316    #[track_caller]
3317    fn assert_no_change(expr: Expr) {
3318        let optimized = simplify(expr.clone());
3319        assert_eq!(expr, optimized);
3320    }
3321
3322    #[track_caller]
3323    fn assert_change(expr: Expr, expected: Expr) {
3324        let optimized = simplify(expr);
3325        assert_eq!(optimized, expected);
3326    }
3327
3328    fn regex_match(left: Expr, right: Expr) -> Expr {
3329        Expr::BinaryExpr(BinaryExpr {
3330            left: Box::new(left),
3331            op: Operator::RegexMatch,
3332            right: Box::new(right),
3333        })
3334    }
3335
3336    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3337        Expr::BinaryExpr(BinaryExpr {
3338            left: Box::new(left),
3339            op: Operator::RegexNotMatch,
3340            right: Box::new(right),
3341        })
3342    }
3343
3344    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3345        Expr::BinaryExpr(BinaryExpr {
3346            left: Box::new(left),
3347            op: Operator::RegexIMatch,
3348            right: Box::new(right),
3349        })
3350    }
3351
3352    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3353        Expr::BinaryExpr(BinaryExpr {
3354            left: Box::new(left),
3355            op: Operator::RegexNotIMatch,
3356            right: Box::new(right),
3357        })
3358    }
3359
3360    // ------------------------------
3361    // ----- Simplifier tests -------
3362    // ------------------------------
3363
3364    fn try_simplify(expr: Expr) -> Result<Expr> {
3365        let schema = expr_test_schema();
3366        let execution_props = ExecutionProps::new();
3367        let simplifier = ExprSimplifier::new(
3368            SimplifyContext::new(&execution_props).with_schema(schema),
3369        );
3370        simplifier.simplify(expr)
3371    }
3372
3373    fn coerce(expr: Expr) -> Expr {
3374        let schema = expr_test_schema();
3375        let execution_props = ExecutionProps::new();
3376        let simplifier = ExprSimplifier::new(
3377            SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
3378        );
3379        simplifier.coerce(expr, schema.as_ref()).unwrap()
3380    }
3381
3382    fn simplify(expr: Expr) -> Expr {
3383        try_simplify(expr).unwrap()
3384    }
3385
3386    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3387        let schema = expr_test_schema();
3388        let execution_props = ExecutionProps::new();
3389        let simplifier = ExprSimplifier::new(
3390            SimplifyContext::new(&execution_props).with_schema(schema),
3391        );
3392        let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
3393        Ok((expr.data, count))
3394    }
3395
3396    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3397        try_simplify_with_cycle_count(expr).unwrap()
3398    }
3399
3400    fn simplify_with_guarantee(
3401        expr: Expr,
3402        guarantees: Vec<(Expr, NullableInterval)>,
3403    ) -> Expr {
3404        let schema = expr_test_schema();
3405        let execution_props = ExecutionProps::new();
3406        let simplifier = ExprSimplifier::new(
3407            SimplifyContext::new(&execution_props).with_schema(schema),
3408        )
3409        .with_guarantees(guarantees);
3410        simplifier.simplify(expr).unwrap()
3411    }
3412
3413    fn expr_test_schema() -> DFSchemaRef {
3414        Arc::new(
3415            DFSchema::from_unqualified_fields(
3416                vec![
3417                    Field::new("c1", DataType::Utf8, true),
3418                    Field::new("c2", DataType::Boolean, true),
3419                    Field::new("c3", DataType::Int64, true),
3420                    Field::new("c4", DataType::UInt32, true),
3421                    Field::new("c1_non_null", DataType::Utf8, false),
3422                    Field::new("c2_non_null", DataType::Boolean, false),
3423                    Field::new("c3_non_null", DataType::Int64, false),
3424                    Field::new("c4_non_null", DataType::UInt32, false),
3425                    Field::new("c5", DataType::FixedSizeBinary(3), true),
3426                ]
3427                .into(),
3428                HashMap::new(),
3429            )
3430            .unwrap(),
3431        )
3432    }
3433
3434    #[test]
3435    fn simplify_expr_null_comparison() {
3436        // x = null is always null
3437        assert_eq!(
3438            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3439            lit(ScalarValue::Boolean(None)),
3440        );
3441
3442        // null != null is always null
3443        assert_eq!(
3444            simplify(
3445                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3446            ),
3447            lit(ScalarValue::Boolean(None)),
3448        );
3449
3450        // x != null is always null
3451        assert_eq!(
3452            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3453            lit(ScalarValue::Boolean(None)),
3454        );
3455
3456        // null = x is always null
3457        assert_eq!(
3458            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3459            lit(ScalarValue::Boolean(None)),
3460        );
3461    }
3462
3463    #[test]
3464    fn simplify_expr_is_not_null() {
3465        assert_eq!(
3466            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3467            Expr::IsNotNull(Box::new(col("c1")))
3468        );
3469
3470        // 'c1_non_null IS NOT NULL' is always true
3471        assert_eq!(
3472            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3473            lit(true)
3474        );
3475    }
3476
3477    #[test]
3478    fn simplify_expr_is_null() {
3479        assert_eq!(
3480            simplify(Expr::IsNull(Box::new(col("c1")))),
3481            Expr::IsNull(Box::new(col("c1")))
3482        );
3483
3484        // 'c1_non_null IS NULL' is always false
3485        assert_eq!(
3486            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3487            lit(false)
3488        );
3489    }
3490
3491    #[test]
3492    fn simplify_expr_is_unknown() {
3493        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3494
3495        // 'c2_non_null is unknown' is always false
3496        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3497    }
3498
3499    #[test]
3500    fn simplify_expr_is_not_known() {
3501        assert_eq!(
3502            simplify(col("c2").is_not_unknown()),
3503            col("c2").is_not_unknown()
3504        );
3505
3506        // 'c2_non_null is not unknown' is always true
3507        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3508    }
3509
3510    #[test]
3511    fn simplify_expr_eq() {
3512        let schema = expr_test_schema();
3513        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3514
3515        // true = true -> true
3516        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3517
3518        // true = false -> false
3519        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3520
3521        // c2 = true -> c2
3522        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3523
3524        // c2 = false => !c2
3525        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3526    }
3527
3528    #[test]
3529    fn simplify_expr_eq_skip_nonboolean_type() {
3530        let schema = expr_test_schema();
3531
3532        // When one of the operand is not of boolean type, folding the
3533        // other boolean constant will change return type of
3534        // expression to non-boolean.
3535        //
3536        // Make sure c1 column to be used in tests is not boolean type
3537        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3538
3539        // don't fold c1 = foo
3540        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3541    }
3542
3543    #[test]
3544    fn simplify_expr_not_eq() {
3545        let schema = expr_test_schema();
3546
3547        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3548
3549        // c2 != true -> !c2
3550        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3551
3552        // c2 != false -> c2
3553        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3554
3555        // test constant
3556        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3557
3558        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3559    }
3560
3561    #[test]
3562    fn simplify_expr_not_eq_skip_nonboolean_type() {
3563        let schema = expr_test_schema();
3564
3565        // when one of the operand is not of boolean type, folding the
3566        // other boolean constant will change return type of
3567        // expression to non-boolean.
3568        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3569
3570        assert_eq!(
3571            simplify(col("c1").not_eq(lit("foo"))),
3572            col("c1").not_eq(lit("foo")),
3573        );
3574    }
3575
3576    #[test]
3577    fn simplify_expr_case_when_then_else() {
3578        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3579        // -->
3580        // CASE WHEN c2 THEN false ELSE c2
3581        // -->
3582        // false
3583        assert_eq!(
3584            simplify(Expr::Case(Case::new(
3585                None,
3586                vec![(
3587                    Box::new(col("c2_non_null").not_eq(lit(false))),
3588                    Box::new(lit("ok").eq(lit("not_ok"))),
3589                )],
3590                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3591            ))),
3592            lit(false) // #1716
3593        );
3594
3595        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3596        // -->
3597        // CASE WHEN c2 THEN true ELSE c2
3598        // -->
3599        // c2
3600        //
3601        // Need to call simplify 2x due to
3602        // https://github.com/apache/datafusion/issues/1160
3603        assert_eq!(
3604            simplify(simplify(Expr::Case(Case::new(
3605                None,
3606                vec![(
3607                    Box::new(col("c2_non_null").not_eq(lit(false))),
3608                    Box::new(lit("ok").eq(lit("ok"))),
3609                )],
3610                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3611            )))),
3612            col("c2_non_null")
3613        );
3614
3615        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3616        // -->
3617        // ISNULL(c2) OR c2
3618        //
3619        // Need to call simplify 2x due to
3620        // https://github.com/apache/datafusion/issues/1160
3621        assert_eq!(
3622            simplify(simplify(Expr::Case(Case::new(
3623                None,
3624                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3625                Some(Box::new(col("c2"))),
3626            )))),
3627            col("c2")
3628                .is_null()
3629                .or(col("c2").is_not_null().and(col("c2")))
3630        );
3631
3632        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3633        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3634        // --> c1 OR (NOT(c1) AND NOT(c2))
3635        // --> c1 OR NOT(c2)
3636        //
3637        // Need to call simplify 2x due to
3638        // https://github.com/apache/datafusion/issues/1160
3639        assert_eq!(
3640            simplify(simplify(Expr::Case(Case::new(
3641                None,
3642                vec![
3643                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3644                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3645                ],
3646                Some(Box::new(lit(true))),
3647            )))),
3648            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3649        );
3650
3651        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3652        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3653        // --> c1 OR (NOT(c1) AND c2)
3654        // --> c1 OR c2
3655        //
3656        // Need to call simplify 2x due to
3657        // https://github.com/apache/datafusion/issues/1160
3658        assert_eq!(
3659            simplify(simplify(Expr::Case(Case::new(
3660                None,
3661                vec![
3662                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3663                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3664                ],
3665                Some(Box::new(lit(true))),
3666            )))),
3667            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3668        );
3669
3670        // CASE WHEN c > 0 THEN true END AS c1
3671        assert_eq!(
3672            simplify(simplify(Expr::Case(Case::new(
3673                None,
3674                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3675                None,
3676            )))),
3677            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3678                col("c3").gt(lit(0_i64)),
3679                lit(true)
3680            )
3681            .and(lit_bool_null()))
3682        );
3683
3684        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3685        assert_eq!(
3686            simplify(simplify(Expr::Case(Case::new(
3687                None,
3688                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3689                Some(Box::new(lit(false))),
3690            )))),
3691            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3692        );
3693    }
3694
3695    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3696        Expr::BinaryExpr(BinaryExpr {
3697            left: Box::new(left.into()),
3698            op: Operator::IsDistinctFrom,
3699            right: Box::new(right.into()),
3700        })
3701    }
3702
3703    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3704        Expr::BinaryExpr(BinaryExpr {
3705            left: Box::new(left.into()),
3706            op: Operator::IsNotDistinctFrom,
3707            right: Box::new(right.into()),
3708        })
3709    }
3710
3711    #[test]
3712    fn simplify_expr_bool_or() {
3713        // col || true is always true
3714        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
3715
3716        // col || false is always col
3717        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
3718
3719        // true || null is always true
3720        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
3721
3722        // null || true is always true
3723        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
3724
3725        // false || null is always null
3726        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
3727
3728        // null || false is always null
3729        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
3730
3731        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
3732        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3733        // and should not be rewritten
3734        let expr = col("c1").between(lit(0), lit(10));
3735        let expr = expr.or(lit_bool_null());
3736        let result = simplify(expr);
3737
3738        let expected_expr = or(
3739            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3740            lit_bool_null(),
3741        );
3742        assert_eq!(expected_expr, result);
3743    }
3744
3745    #[test]
3746    fn simplify_inlist() {
3747        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
3748        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
3749
3750        // null in (...)  --> null
3751        assert_eq!(
3752            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
3753            lit_bool_null()
3754        );
3755
3756        // null not in (...)  --> null
3757        assert_eq!(
3758            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
3759            lit_bool_null()
3760        );
3761
3762        assert_eq!(
3763            simplify(in_list(col("c1"), vec![lit(1)], false)),
3764            col("c1").eq(lit(1))
3765        );
3766        assert_eq!(
3767            simplify(in_list(col("c1"), vec![lit(1)], true)),
3768            col("c1").not_eq(lit(1))
3769        );
3770
3771        // more complex expressions can be simplified if list contains
3772        // one element only
3773        assert_eq!(
3774            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
3775            (col("c1") * lit(10)).eq(lit(2))
3776        );
3777
3778        assert_eq!(
3779            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
3780            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
3781        );
3782        assert_eq!(
3783            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
3784            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
3785        );
3786
3787        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
3788        assert_eq!(
3789            simplify(in_list(
3790                col("c1"),
3791                vec![scalar_subquery(Arc::clone(&subquery))],
3792                false
3793            )),
3794            in_subquery(col("c1"), Arc::clone(&subquery))
3795        );
3796        assert_eq!(
3797            simplify(in_list(
3798                col("c1"),
3799                vec![scalar_subquery(Arc::clone(&subquery))],
3800                true
3801            )),
3802            not_in_subquery(col("c1"), subquery)
3803        );
3804
3805        let subquery1 =
3806            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
3807        let subquery2 =
3808            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
3809
3810        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
3811        assert_eq!(
3812            simplify(in_list(
3813                col("c1"),
3814                vec![subquery1.clone(), subquery2.clone()],
3815                true
3816            )),
3817            col("c1")
3818                .not_eq(subquery1.clone())
3819                .and(col("c1").not_eq(subquery2.clone()))
3820        );
3821
3822        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
3823        assert_eq!(
3824            simplify(in_list(
3825                col("c1"),
3826                vec![subquery1.clone(), subquery2.clone()],
3827                false
3828            )),
3829            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
3830        );
3831
3832        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
3833        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3834            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
3835        );
3836        assert_eq!(simplify(expr), lit(false));
3837
3838        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
3839        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3840            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
3841        );
3842        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
3843
3844        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
3845        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3846            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3847        );
3848        assert_eq!(simplify(expr), lit(true));
3849
3850        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
3851        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3852            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3853        );
3854        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
3855
3856        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
3857        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3858            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3859        );
3860        assert_eq!(
3861            simplify(expr),
3862            in_list(
3863                col("c1"),
3864                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
3865                true
3866            )
3867        );
3868
3869        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
3870        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
3871            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
3872        );
3873        assert_eq!(
3874            simplify(expr),
3875            in_list(
3876                col("c1"),
3877                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3878                false
3879            )
3880        );
3881
3882        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
3883        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
3884            col("c1"),
3885            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3886            true,
3887        ));
3888        assert_eq!(simplify(expr), lit(false));
3889
3890        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
3891        let expr =
3892            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
3893                col("c1"),
3894                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3895                false,
3896            ));
3897        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
3898
3899        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
3900        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3901            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3902        );
3903        assert_eq!(
3904            simplify(expr),
3905            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
3906        );
3907
3908        // inlist with more than two expressions
3909        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
3910        let expr = in_list(
3911            col("c1"),
3912            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
3913            false,
3914        )
3915        .and(in_list(
3916            col("c1"),
3917            vec![lit(1), lit(3), lit(5), lit(6)],
3918            false,
3919        ))
3920        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
3921        assert_eq!(
3922            simplify(expr),
3923            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
3924        );
3925
3926        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
3927        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3928            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
3929                .and(in_list(
3930                    col("c1"),
3931                    vec![lit(3), lit(4), lit(5), lit(6)],
3932                    true,
3933                ))
3934                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
3935        );
3936        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
3937
3938        // Contains non-InList expression
3939        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
3940        let expr =
3941            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
3942                .not_eq(lit(5))
3943                .or(in_list(
3944                    col("c1"),
3945                    vec![lit(6), lit(7), lit(8), lit(9)],
3946                    true,
3947                )));
3948        // TODO: Further simplify this expression
3949        // https://github.com/apache/datafusion/issues/8970
3950        // assert_eq!(simplify(expr.clone()), lit(true));
3951        assert_eq!(simplify(expr.clone()), expr);
3952    }
3953
3954    #[test]
3955    fn simplify_large_or() {
3956        let expr = (0..5)
3957            .map(|i| col("c1").eq(lit(i)))
3958            .fold(lit(false), |acc, e| acc.or(e));
3959        assert_eq!(
3960            simplify(expr),
3961            in_list(col("c1"), (0..5).map(lit).collect(), false),
3962        );
3963    }
3964
3965    #[test]
3966    fn simplify_expr_bool_and() {
3967        // col & true is always col
3968        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
3969        // col & false is always false
3970        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
3971
3972        // true && null is always null
3973        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
3974
3975        // null && true is always null
3976        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
3977
3978        // false && null is always false
3979        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
3980
3981        // null && false is always false
3982        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
3983
3984        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
3985        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3986        // and the Boolean(NULL) should remain
3987        let expr = col("c1").between(lit(0), lit(10));
3988        let expr = expr.and(lit_bool_null());
3989        let result = simplify(expr);
3990
3991        let expected_expr = and(
3992            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3993            lit_bool_null(),
3994        );
3995        assert_eq!(expected_expr, result);
3996    }
3997
3998    #[test]
3999    fn simplify_expr_between() {
4000        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
4001        let expr = col("c2").between(lit(3), lit(4));
4002        assert_eq!(
4003            simplify(expr),
4004            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
4005        );
4006
4007        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
4008        let expr = col("c2").not_between(lit(3), lit(4));
4009        assert_eq!(
4010            simplify(expr),
4011            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
4012        );
4013    }
4014
4015    #[test]
4016    fn test_like_and_ilike() {
4017        let null = lit(ScalarValue::Utf8(None));
4018
4019        // expr [NOT] [I]LIKE NULL
4020        let expr = col("c1").like(null.clone());
4021        assert_eq!(simplify(expr), lit_bool_null());
4022
4023        let expr = col("c1").not_like(null.clone());
4024        assert_eq!(simplify(expr), lit_bool_null());
4025
4026        let expr = col("c1").ilike(null.clone());
4027        assert_eq!(simplify(expr), lit_bool_null());
4028
4029        let expr = col("c1").not_ilike(null.clone());
4030        assert_eq!(simplify(expr), lit_bool_null());
4031
4032        // expr [NOT] [I]LIKE '%'
4033        let expr = col("c1").like(lit("%"));
4034        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4035
4036        let expr = col("c1").not_like(lit("%"));
4037        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4038
4039        let expr = col("c1").ilike(lit("%"));
4040        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4041
4042        let expr = col("c1").not_ilike(lit("%"));
4043        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4044
4045        // expr [NOT] [I]LIKE '%%'
4046        let expr = col("c1").like(lit("%%"));
4047        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4048
4049        let expr = col("c1").not_like(lit("%%"));
4050        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4051
4052        let expr = col("c1").ilike(lit("%%"));
4053        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
4054
4055        let expr = col("c1").not_ilike(lit("%%"));
4056        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
4057
4058        // not_null_expr [NOT] [I]LIKE '%'
4059        let expr = col("c1_non_null").like(lit("%"));
4060        assert_eq!(simplify(expr), lit(true));
4061
4062        let expr = col("c1_non_null").not_like(lit("%"));
4063        assert_eq!(simplify(expr), lit(false));
4064
4065        let expr = col("c1_non_null").ilike(lit("%"));
4066        assert_eq!(simplify(expr), lit(true));
4067
4068        let expr = col("c1_non_null").not_ilike(lit("%"));
4069        assert_eq!(simplify(expr), lit(false));
4070
4071        // not_null_expr [NOT] [I]LIKE '%%'
4072        let expr = col("c1_non_null").like(lit("%%"));
4073        assert_eq!(simplify(expr), lit(true));
4074
4075        let expr = col("c1_non_null").not_like(lit("%%"));
4076        assert_eq!(simplify(expr), lit(false));
4077
4078        let expr = col("c1_non_null").ilike(lit("%%"));
4079        assert_eq!(simplify(expr), lit(true));
4080
4081        let expr = col("c1_non_null").not_ilike(lit("%%"));
4082        assert_eq!(simplify(expr), lit(false));
4083
4084        // null_constant [NOT] [I]LIKE '%'
4085        let expr = null.clone().like(lit("%"));
4086        assert_eq!(simplify(expr), lit_bool_null());
4087
4088        let expr = null.clone().not_like(lit("%"));
4089        assert_eq!(simplify(expr), lit_bool_null());
4090
4091        let expr = null.clone().ilike(lit("%"));
4092        assert_eq!(simplify(expr), lit_bool_null());
4093
4094        let expr = null.clone().not_ilike(lit("%"));
4095        assert_eq!(simplify(expr), lit_bool_null());
4096
4097        // null_constant [NOT] [I]LIKE '%%'
4098        let expr = null.clone().like(lit("%%"));
4099        assert_eq!(simplify(expr), lit_bool_null());
4100
4101        let expr = null.clone().not_like(lit("%%"));
4102        assert_eq!(simplify(expr), lit_bool_null());
4103
4104        let expr = null.clone().ilike(lit("%%"));
4105        assert_eq!(simplify(expr), lit_bool_null());
4106
4107        let expr = null.clone().not_ilike(lit("%%"));
4108        assert_eq!(simplify(expr), lit_bool_null());
4109
4110        // null_constant [NOT] [I]LIKE 'a%'
4111        let expr = null.clone().like(lit("a%"));
4112        assert_eq!(simplify(expr), lit_bool_null());
4113
4114        let expr = null.clone().not_like(lit("a%"));
4115        assert_eq!(simplify(expr), lit_bool_null());
4116
4117        let expr = null.clone().ilike(lit("a%"));
4118        assert_eq!(simplify(expr), lit_bool_null());
4119
4120        let expr = null.clone().not_ilike(lit("a%"));
4121        assert_eq!(simplify(expr), lit_bool_null());
4122
4123        // expr [NOT] [I]LIKE with pattern without wildcards
4124        let expr = col("c1").like(lit("a"));
4125        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
4126        let expr = col("c1").not_like(lit("a"));
4127        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
4128        let expr = col("c1").like(lit("a_"));
4129        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
4130        let expr = col("c1").not_like(lit("a_"));
4131        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
4132
4133        let expr = col("c1").ilike(lit("a"));
4134        assert_eq!(simplify(expr), col("c1").ilike(lit("a")));
4135        let expr = col("c1").not_ilike(lit("a"));
4136        assert_eq!(simplify(expr), col("c1").not_ilike(lit("a")));
4137    }
4138
4139    #[test]
4140    fn test_simplify_with_guarantee() {
4141        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
4142        let expr_x = col("c3").gt(lit(3_i64));
4143        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
4144        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
4145        let expr = expr_x.clone().and(expr_y.or(expr_z));
4146
4147        // All guaranteed null
4148        let guarantees = vec![
4149            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
4150            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
4151            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
4152        ];
4153
4154        let output = simplify_with_guarantee(expr.clone(), guarantees);
4155        assert_eq!(output, lit_bool_null());
4156
4157        // All guaranteed false
4158        let guarantees = vec![
4159            (
4160                col("c3"),
4161                NullableInterval::NotNull {
4162                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4163                },
4164            ),
4165            (
4166                col("c4"),
4167                NullableInterval::from(ScalarValue::UInt32(Some(9))),
4168            ),
4169            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
4170        ];
4171        let output = simplify_with_guarantee(expr.clone(), guarantees);
4172        assert_eq!(output, lit(false));
4173
4174        // Guaranteed false or null -> no change.
4175        let guarantees = vec![
4176            (
4177                col("c3"),
4178                NullableInterval::MaybeNull {
4179                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
4180                },
4181            ),
4182            (
4183                col("c4"),
4184                NullableInterval::MaybeNull {
4185                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
4186                },
4187            ),
4188            (
4189                col("c1"),
4190                NullableInterval::NotNull {
4191                    values: Interval::try_new(
4192                        ScalarValue::from("d"),
4193                        ScalarValue::from("f"),
4194                    )
4195                    .unwrap(),
4196                },
4197            ),
4198        ];
4199        let output = simplify_with_guarantee(expr.clone(), guarantees);
4200        assert_eq!(&output, &expr_x);
4201
4202        // Sufficient true guarantees
4203        let guarantees = vec![
4204            (
4205                col("c3"),
4206                NullableInterval::from(ScalarValue::Int64(Some(9))),
4207            ),
4208            (
4209                col("c4"),
4210                NullableInterval::from(ScalarValue::UInt32(Some(3))),
4211            ),
4212        ];
4213        let output = simplify_with_guarantee(expr.clone(), guarantees);
4214        assert_eq!(output, lit(true));
4215
4216        // Only partially simplify
4217        let guarantees = vec![(
4218            col("c4"),
4219            NullableInterval::from(ScalarValue::UInt32(Some(3))),
4220        )];
4221        let output = simplify_with_guarantee(expr, guarantees);
4222        assert_eq!(&output, &expr_x);
4223    }
4224
4225    #[test]
4226    fn test_expression_partial_simplify_1() {
4227        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
4228        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
4229        let expected = (lit(3)) + (lit(4) / lit(0));
4230
4231        assert_eq!(simplify(expr), expected);
4232    }
4233
4234    #[test]
4235    fn test_expression_partial_simplify_2() {
4236        // (1 > 2) and (4 / 0) -> false
4237        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
4238        let expected = lit(false);
4239
4240        assert_eq!(simplify(expr), expected);
4241    }
4242
4243    #[test]
4244    fn test_simplify_cycles() {
4245        // TRUE
4246        let expr = lit(true);
4247        let expected = lit(true);
4248        let (expr, num_iter) = simplify_with_cycle_count(expr);
4249        assert_eq!(expr, expected);
4250        assert_eq!(num_iter, 1);
4251
4252        // (true != NULL) OR (5 > 10)
4253        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
4254        let expected = lit_bool_null();
4255        let (expr, num_iter) = simplify_with_cycle_count(expr);
4256        assert_eq!(expr, expected);
4257        assert_eq!(num_iter, 2);
4258
4259        // NOTE: this currently does not simplify
4260        // (((c4 - 10) + 10) *100) / 100
4261        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
4262        let expected = expr.clone();
4263        let (expr, num_iter) = simplify_with_cycle_count(expr);
4264        assert_eq!(expr, expected);
4265        assert_eq!(num_iter, 1);
4266
4267        // ((c4<1 or c3<2) and c3_non_null<3) and false
4268        let expr = col("c4")
4269            .lt(lit(1))
4270            .or(col("c3").lt(lit(2)))
4271            .and(col("c3_non_null").lt(lit(3)))
4272            .and(lit(false));
4273        let expected = lit(false);
4274        let (expr, num_iter) = simplify_with_cycle_count(expr);
4275        assert_eq!(expr, expected);
4276        assert_eq!(num_iter, 2);
4277    }
4278
4279    fn boolean_test_schema() -> DFSchemaRef {
4280        Schema::new(vec![
4281            Field::new("A", DataType::Boolean, false),
4282            Field::new("B", DataType::Boolean, false),
4283            Field::new("C", DataType::Boolean, false),
4284            Field::new("D", DataType::Boolean, false),
4285        ])
4286        .to_dfschema_ref()
4287        .unwrap()
4288    }
4289
4290    #[test]
4291    fn simplify_common_factor_conjunction_in_disjunction() {
4292        let props = ExecutionProps::new();
4293        let schema = boolean_test_schema();
4294        let simplifier =
4295            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
4296
4297        let a = || col("A");
4298        let b = || col("B");
4299        let c = || col("C");
4300        let d = || col("D");
4301
4302        // (A AND B) OR (A AND C) -> A AND (B OR C)
4303        let expr = a().and(b()).or(a().and(c()));
4304        let expected = a().and(b().or(c()));
4305
4306        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4307
4308        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
4309        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4310        let expected = a().and(b().or(c()).or(d()));
4311        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4312
4313        // A OR (B AND C AND A) -> A
4314        let expr = a().or(b().and(c().and(a())));
4315        let expected = a();
4316        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4317    }
4318
4319    #[test]
4320    fn test_simplify_udaf() {
4321        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4322        let aggregate_function_expr =
4323            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4324                udaf.into(),
4325                vec![],
4326                false,
4327                None,
4328                None,
4329                None,
4330            ));
4331
4332        let expected = col("result_column");
4333        assert_eq!(simplify(aggregate_function_expr), expected);
4334
4335        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4336        let aggregate_function_expr =
4337            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4338                udaf.into(),
4339                vec![],
4340                false,
4341                None,
4342                None,
4343                None,
4344            ));
4345
4346        let expected = aggregate_function_expr.clone();
4347        assert_eq!(simplify(aggregate_function_expr), expected);
4348    }
4349
4350    /// A Mock UDAF which defines `simplify` to be used in tests
4351    /// related to UDAF simplification
4352    #[derive(Debug, Clone)]
4353    struct SimplifyMockUdaf {
4354        simplify: bool,
4355    }
4356
4357    impl SimplifyMockUdaf {
4358        /// make simplify method return new expression
4359        fn new_with_simplify() -> Self {
4360            Self { simplify: true }
4361        }
4362        /// make simplify method return no change
4363        fn new_without_simplify() -> Self {
4364            Self { simplify: false }
4365        }
4366    }
4367
4368    impl AggregateUDFImpl for SimplifyMockUdaf {
4369        fn as_any(&self) -> &dyn std::any::Any {
4370            self
4371        }
4372
4373        fn name(&self) -> &str {
4374            "mock_simplify"
4375        }
4376
4377        fn signature(&self) -> &Signature {
4378            unimplemented!()
4379        }
4380
4381        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4382            unimplemented!("not needed for tests")
4383        }
4384
4385        fn accumulator(
4386            &self,
4387            _acc_args: AccumulatorArgs,
4388        ) -> Result<Box<dyn Accumulator>> {
4389            unimplemented!("not needed for tests")
4390        }
4391
4392        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4393            unimplemented!("not needed for testing")
4394        }
4395
4396        fn create_groups_accumulator(
4397            &self,
4398            _args: AccumulatorArgs,
4399        ) -> Result<Box<dyn GroupsAccumulator>> {
4400            unimplemented!("not needed for testing")
4401        }
4402
4403        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4404            if self.simplify {
4405                Some(Box::new(|_, _| Ok(col("result_column"))))
4406            } else {
4407                None
4408            }
4409        }
4410    }
4411
4412    #[test]
4413    fn test_simplify_udwf() {
4414        let udwf = WindowFunctionDefinition::WindowUDF(
4415            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4416        );
4417        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4418
4419        let expected = col("result_column");
4420        assert_eq!(simplify(window_function_expr), expected);
4421
4422        let udwf = WindowFunctionDefinition::WindowUDF(
4423            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4424        );
4425        let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![]));
4426
4427        let expected = window_function_expr.clone();
4428        assert_eq!(simplify(window_function_expr), expected);
4429    }
4430
4431    /// A Mock UDWF which defines `simplify` to be used in tests
4432    /// related to UDWF simplification
4433    #[derive(Debug, Clone)]
4434    struct SimplifyMockUdwf {
4435        simplify: bool,
4436    }
4437
4438    impl SimplifyMockUdwf {
4439        /// make simplify method return new expression
4440        fn new_with_simplify() -> Self {
4441            Self { simplify: true }
4442        }
4443        /// make simplify method return no change
4444        fn new_without_simplify() -> Self {
4445            Self { simplify: false }
4446        }
4447    }
4448
4449    impl WindowUDFImpl for SimplifyMockUdwf {
4450        fn as_any(&self) -> &dyn std::any::Any {
4451            self
4452        }
4453
4454        fn name(&self) -> &str {
4455            "mock_simplify"
4456        }
4457
4458        fn signature(&self) -> &Signature {
4459            unimplemented!()
4460        }
4461
4462        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4463            if self.simplify {
4464                Some(Box::new(|_, _| Ok(col("result_column"))))
4465            } else {
4466                None
4467            }
4468        }
4469
4470        fn partition_evaluator(
4471            &self,
4472            _partition_evaluator_args: PartitionEvaluatorArgs,
4473        ) -> Result<Box<dyn PartitionEvaluator>> {
4474            unimplemented!("not needed for tests")
4475        }
4476
4477        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
4478            unimplemented!("not needed for tests")
4479        }
4480    }
4481    #[derive(Debug)]
4482    struct VolatileUdf {
4483        signature: Signature,
4484    }
4485
4486    impl VolatileUdf {
4487        pub fn new() -> Self {
4488            Self {
4489                signature: Signature::exact(vec![], Volatility::Volatile),
4490            }
4491        }
4492    }
4493    impl ScalarUDFImpl for VolatileUdf {
4494        fn as_any(&self) -> &dyn std::any::Any {
4495            self
4496        }
4497
4498        fn name(&self) -> &str {
4499            "VolatileUdf"
4500        }
4501
4502        fn signature(&self) -> &Signature {
4503            &self.signature
4504        }
4505
4506        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4507            Ok(DataType::Int16)
4508        }
4509
4510        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4511            panic!("dummy - not implemented")
4512        }
4513    }
4514
4515    #[test]
4516    fn test_optimize_volatile_conditions() {
4517        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4518        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4519        {
4520            let expr = rand
4521                .clone()
4522                .eq(lit(0))
4523                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4524
4525            assert_eq!(simplify(expr.clone()), expr);
4526        }
4527
4528        {
4529            let expr = col("column1")
4530                .eq(lit(2))
4531                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4532
4533            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4534        }
4535
4536        {
4537            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4538                "column1",
4539            )
4540            .eq(lit(2))
4541            .and(rand.clone().eq(lit(0))));
4542
4543            assert_eq!(
4544                simplify(expr),
4545                col("column1")
4546                    .eq(lit(2))
4547                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4548            );
4549        }
4550    }
4551
4552    #[test]
4553    fn simplify_fixed_size_binary_eq_lit() {
4554        let bytes = [1u8, 2, 3].as_slice();
4555
4556        // The expression starts simple.
4557        let expr = col("c5").eq(lit(bytes));
4558
4559        // The type coercer introduces a cast.
4560        let coerced = coerce(expr.clone());
4561        let schema = expr_test_schema();
4562        assert_eq!(
4563            coerced,
4564            col("c5")
4565                .cast_to(&DataType::Binary, schema.as_ref())
4566                .unwrap()
4567                .eq(lit(bytes))
4568        );
4569
4570        // The simplifier removes the cast.
4571        assert_eq!(
4572            simplify(coerced),
4573            col("c5").eq(Expr::Literal(
4574                ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),),
4575                None
4576            ))
4577        );
4578    }
4579
4580    fn if_not_null(expr: Expr, then: bool) -> Expr {
4581        Expr::Case(Case {
4582            expr: Some(expr.is_not_null().into()),
4583            when_then_expr: vec![(lit(true).into(), lit(then).into())],
4584            else_expr: None,
4585        })
4586    }
4587}