datafusion_optimizer/simplify_expressions/
expr_simplifier.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression simplification API
19
20use std::borrow::Cow;
21use std::collections::HashSet;
22use std::ops::Not;
23
24use arrow::{
25    array::{new_null_array, AsArray},
26    datatypes::{DataType, Field, Schema},
27    record_batch::RecordBatch,
28};
29
30use datafusion_common::{
31    cast::{as_large_list_array, as_list_array},
32    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
33};
34use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35use datafusion_expr::simplify::ExprSimplifyResult;
36use datafusion_expr::{
37    and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
38    WindowFunctionDefinition,
39};
40use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
41use datafusion_expr::{
42    expr::{InList, InSubquery, WindowFunction},
43    utils::{iter_conjunction, iter_conjunction_owned},
44};
45use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
46
47use super::inlist_simplifier::ShortenInListSimplifier;
48use super::utils::*;
49use crate::analyzer::type_coercion::TypeCoercionRewriter;
50use crate::simplify_expressions::guarantees::GuaranteeRewriter;
51use crate::simplify_expressions::regex::simplify_regex_expr;
52use crate::simplify_expressions::SimplifyInfo;
53use indexmap::IndexSet;
54use regex::Regex;
55
56/// This structure handles API for expression simplification
57///
58/// Provides simplification information based on DFSchema and
59/// [`ExecutionProps`]. This is the default implementation used by DataFusion
60///
61/// For example:
62/// ```
63/// use arrow::datatypes::{Schema, Field, DataType};
64/// use datafusion_expr::{col, lit};
65/// use datafusion_common::{DataFusionError, ToDFSchema};
66/// use datafusion_expr::execution_props::ExecutionProps;
67/// use datafusion_expr::simplify::SimplifyContext;
68/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
69///
70/// // Create the schema
71/// let schema = Schema::new(vec![
72///     Field::new("i", DataType::Int64, false),
73///   ])
74///   .to_dfschema_ref().unwrap();
75///
76/// // Create the simplifier
77/// let props = ExecutionProps::new();
78/// let context = SimplifyContext::new(&props)
79///    .with_schema(schema);
80/// let simplifier = ExprSimplifier::new(context);
81///
82/// // Use the simplifier
83///
84/// // b < 2 or (1 > 3)
85/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3)));
86///
87/// // b < 2
88/// let simplified = simplifier.simplify(expr).unwrap();
89/// assert_eq!(simplified, col("b").lt(lit(2)));
90/// ```
91pub struct ExprSimplifier<S> {
92    info: S,
93    /// Guarantees about the values of columns. This is provided by the user
94    /// in [ExprSimplifier::with_guarantees()].
95    guarantees: Vec<(Expr, NullableInterval)>,
96    /// Should expressions be canonicalized before simplification? Defaults to
97    /// true
98    canonicalize: bool,
99    /// Maximum number of simplifier cycles
100    max_simplifier_cycles: u32,
101}
102
103pub const THRESHOLD_INLINE_INLIST: usize = 3;
104pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
105
106impl<S: SimplifyInfo> ExprSimplifier<S> {
107    /// Create a new `ExprSimplifier` with the given `info` such as an
108    /// instance of [`SimplifyContext`]. See
109    /// [`simplify`](Self::simplify) for an example.
110    ///
111    /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext
112    pub fn new(info: S) -> Self {
113        Self {
114            info,
115            guarantees: vec![],
116            canonicalize: true,
117            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
118        }
119    }
120
121    /// Simplifies this [`Expr`] as much as possible, evaluating
122    /// constants and applying algebraic simplifications.
123    ///
124    /// The types of the expression must match what operators expect,
125    /// or else an error may occur trying to evaluate. See
126    /// [`coerce`](Self::coerce) for a function to help.
127    ///
128    /// # Example:
129    ///
130    /// `b > 2 AND b > 2`
131    ///
132    /// can be written to
133    ///
134    /// `b > 2`
135    ///
136    /// ```
137    /// use arrow::datatypes::DataType;
138    /// use datafusion_expr::{col, lit, Expr};
139    /// use datafusion_common::Result;
140    /// use datafusion_expr::execution_props::ExecutionProps;
141    /// use datafusion_expr::simplify::SimplifyContext;
142    /// use datafusion_expr::simplify::SimplifyInfo;
143    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
144    /// use datafusion_common::DFSchema;
145    /// use std::sync::Arc;
146    ///
147    /// /// Simple implementation that provides `Simplifier` the information it needs
148    /// /// See SimplifyContext for a structure that does this.
149    /// #[derive(Default)]
150    /// struct Info {
151    ///   execution_props: ExecutionProps,
152    /// };
153    ///
154    /// impl SimplifyInfo for Info {
155    ///   fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
156    ///     Ok(false)
157    ///   }
158    ///   fn nullable(&self, expr: &Expr) -> Result<bool> {
159    ///     Ok(true)
160    ///   }
161    ///   fn execution_props(&self) -> &ExecutionProps {
162    ///     &self.execution_props
163    ///   }
164    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
165    ///     Ok(DataType::Int32)
166    ///   }
167    /// }
168    ///
169    /// // Create the simplifier
170    /// let simplifier = ExprSimplifier::new(Info::default());
171    ///
172    /// // b < 2
173    /// let b_lt_2 = col("b").gt(lit(2));
174    ///
175    /// // (b < 2) OR (b < 2)
176    /// let expr = b_lt_2.clone().or(b_lt_2.clone());
177    ///
178    /// // (b < 2) OR (b < 2) --> (b < 2)
179    /// let expr = simplifier.simplify(expr).unwrap();
180    /// assert_eq!(expr, b_lt_2);
181    /// ```
182    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
183        Ok(self.simplify_with_cycle_count(expr)?.0)
184    }
185
186    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
187    /// constants and applying algebraic simplifications. Additionally returns a `u32`
188    /// representing the number of simplification cycles performed, which can be useful for testing
189    /// optimizations.
190    ///
191    /// See [Self::simplify] for details and usage examples.
192    ///
193    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
194        let mut simplifier = Simplifier::new(&self.info);
195        let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
196        let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
197        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
198
199        if self.canonicalize {
200            expr = expr.rewrite(&mut Canonicalizer::new()).data()?
201        }
202
203        // Evaluating constants can enable new simplifications and
204        // simplifications can enable new constant evaluation
205        // see `Self::with_max_cycles`
206        let mut num_cycles = 0;
207        loop {
208            let Transformed {
209                data, transformed, ..
210            } = expr
211                .rewrite(&mut const_evaluator)?
212                .transform_data(|expr| expr.rewrite(&mut simplifier))?
213                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
214            expr = data;
215            num_cycles += 1;
216            if !transformed || num_cycles >= self.max_simplifier_cycles {
217                break;
218            }
219        }
220        // shorten inlist should be started after other inlist rules are applied
221        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
222        Ok((expr, num_cycles))
223    }
224
225    /// Apply type coercion to an [`Expr`] so that it can be
226    /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
227    ///
228    /// See the [type coercion module](datafusion_expr::type_coercion)
229    /// documentation for more details on type coercion
230    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
231        let mut expr_rewrite = TypeCoercionRewriter { schema };
232        expr.rewrite(&mut expr_rewrite).data()
233    }
234
235    /// Input guarantees about the values of columns.
236    ///
237    /// The guarantees can simplify expressions. For example, if a column `x` is
238    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the
239    /// literal `true`.
240    ///
241    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
242    /// where the [Expr] is a column reference and the [NullableInterval]
243    /// is an interval representing the known possible values of that column.
244    ///
245    /// ```rust
246    /// use arrow::datatypes::{DataType, Field, Schema};
247    /// use datafusion_expr::{col, lit, Expr};
248    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
249    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
250    /// use datafusion_expr::execution_props::ExecutionProps;
251    /// use datafusion_expr::simplify::SimplifyContext;
252    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
253    ///
254    /// let schema = Schema::new(vec![
255    ///   Field::new("x", DataType::Int64, false),
256    ///   Field::new("y", DataType::UInt32, false),
257    ///   Field::new("z", DataType::Int64, false),
258    ///   ])
259    ///   .to_dfschema_ref().unwrap();
260    ///
261    /// // Create the simplifier
262    /// let props = ExecutionProps::new();
263    /// let context = SimplifyContext::new(&props)
264    ///    .with_schema(schema);
265    ///
266    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
267    /// let expr_x = col("x").gt_eq(lit(3_i64));
268    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
269    /// let expr_z = col("z").gt(lit(5_i64));
270    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
271    ///
272    /// let guarantees = vec![
273    ///    // x ∈ [3, 5]
274    ///    (
275    ///        col("x"),
276    ///        NullableInterval::NotNull {
277    ///            values: Interval::make(Some(3_i64), Some(5_i64)).unwrap()
278    ///        }
279    ///    ),
280    ///    // y = 3
281    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
282    /// ];
283    /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees);
284    /// let output = simplifier.simplify(expr).unwrap();
285    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
286    /// // z > 5.
287    /// assert_eq!(output, expr_z);
288    /// ```
289    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self {
290        self.guarantees = guarantees;
291        self
292    }
293
294    /// Should `Canonicalizer` be applied before simplification?
295    ///
296    /// If true (the default), the expression will be rewritten to canonical
297    /// form before simplification. This is useful to ensure that the simplifier
298    /// can apply all possible simplifications.
299    ///
300    /// Some expressions, such as those in some Joins, can not be canonicalized
301    /// without changing their meaning. In these cases, canonicalization should
302    /// be disabled.
303    ///
304    /// ```rust
305    /// use arrow::datatypes::{DataType, Field, Schema};
306    /// use datafusion_expr::{col, lit, Expr};
307    /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
308    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
309    /// use datafusion_expr::execution_props::ExecutionProps;
310    /// use datafusion_expr::simplify::SimplifyContext;
311    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
312    ///
313    /// let schema = Schema::new(vec![
314    ///   Field::new("a", DataType::Int64, false),
315    ///   Field::new("b", DataType::Int64, false),
316    ///   Field::new("c", DataType::Int64, false),
317    ///   ])
318    ///   .to_dfschema_ref().unwrap();
319    ///
320    /// // Create the simplifier
321    /// let props = ExecutionProps::new();
322    /// let context = SimplifyContext::new(&props)
323    ///    .with_schema(schema);
324    /// let simplifier = ExprSimplifier::new(context);
325    ///
326    /// // Expression: a = c AND 1 = b
327    /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b")));
328    ///
329    /// // With canonicalization, the expression is rewritten to canonical form
330    /// // (though it is no simpler in this case):
331    /// let canonical = simplifier.simplify(expr.clone()).unwrap();
332    /// // Expression has been rewritten to: (c = a AND b = 1)
333    /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1))));
334    ///
335    /// // If canonicalization is disabled, the expression is not changed
336    /// let non_canonicalized = simplifier
337    ///   .with_canonicalize(false)
338    ///   .simplify(expr.clone())
339    ///   .unwrap();
340    ///
341    /// assert_eq!(non_canonicalized, expr);
342    /// ```
343    pub fn with_canonicalize(mut self, canonicalize: bool) -> Self {
344        self.canonicalize = canonicalize;
345        self
346    }
347
348    /// Specifies the maximum number of simplification cycles to run.
349    ///
350    /// The simplifier can perform multiple passes of simplification. This is
351    /// because the output of one simplification step can allow more optimizations
352    /// in another simplification step. For example, constant evaluation can allow more
353    /// expression simplifications, and expression simplifications can allow more constant
354    /// evaluations.
355    ///
356    /// This method specifies the maximum number of allowed iteration cycles before the simplifier
357    /// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
358    /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
359    /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
360    ///
361    /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
362    /// instead.
363    ///
364    /// ```rust
365    /// use arrow::datatypes::{DataType, Field, Schema};
366    /// use datafusion_expr::{col, lit, Expr};
367    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
368    /// use datafusion_expr::execution_props::ExecutionProps;
369    /// use datafusion_expr::simplify::SimplifyContext;
370    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
371    ///
372    /// let schema = Schema::new(vec![
373    ///   Field::new("a", DataType::Int64, false),
374    ///   ])
375    ///   .to_dfschema_ref().unwrap();
376    ///
377    /// // Create the simplifier
378    /// let props = ExecutionProps::new();
379    /// let context = SimplifyContext::new(&props)
380    ///    .with_schema(schema);
381    /// let simplifier = ExprSimplifier::new(context);
382    ///
383    /// // Expression: a IS NOT NULL
384    /// let expr = col("a").is_not_null();
385    ///
386    /// // When using default maximum cycles, 2 cycles will be performed.
387    /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
388    /// assert_eq!(simplified_expr, lit(true));
389    /// // 2 cycles were executed, but only 1 was needed
390    /// assert_eq!(count, 2);
391    ///
392    /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
393    /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
394    /// // Expression has been rewritten to: (c = a AND b = 1)
395    /// assert_eq!(simplified_expr, lit(true));
396    /// // Only 1 cycle was executed
397    /// assert_eq!(count, 1);
398    ///
399    /// ```
400    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
401        self.max_simplifier_cycles = max_simplifier_cycles;
402        self
403    }
404}
405
406/// Canonicalize any BinaryExprs that are not in canonical form
407///
408/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
409///
410/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
411/// than `col2` (`a > b` would be canonicalized to `b < a`)
412struct Canonicalizer {}
413
414impl Canonicalizer {
415    fn new() -> Self {
416        Self {}
417    }
418}
419
420impl TreeNodeRewriter for Canonicalizer {
421    type Node = Expr;
422
423    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
424        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
425            return Ok(Transformed::no(expr));
426        };
427        match (left.as_ref(), right.as_ref(), op.swap()) {
428            // <col1> <op> <col2>
429            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
430                if right_col > left_col =>
431            {
432                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
433                    left: right,
434                    op: swapped_op,
435                    right: left,
436                })))
437            }
438            // <literal> <op> <col>
439            (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => {
440                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
441                    left: right,
442                    op: swapped_op,
443                    right: left,
444                })))
445            }
446            _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
447                left,
448                op,
449                right,
450            }))),
451        }
452    }
453}
454
455#[allow(rustdoc::private_intra_doc_links)]
456/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
457///
458/// Note it does not handle algebraic rewrites such as `(a or false)`
459/// --> `a`, which is handled by [`Simplifier`]
460struct ConstEvaluator<'a> {
461    /// `can_evaluate` is used during the depth-first-search of the
462    /// `Expr` tree to track if any siblings (or their descendants) were
463    /// non evaluatable (e.g. had a column reference or volatile
464    /// function)
465    ///
466    /// Specifically, `can_evaluate[N]` represents the state of
467    /// traversal when we are N levels deep in the tree, one entry for
468    /// this Expr and each of its parents.
469    ///
470    /// After visiting all siblings if `can_evaluate.top()` is true, that
471    /// means there were no non evaluatable siblings (or their
472    /// descendants) so this `Expr` can be evaluated
473    can_evaluate: Vec<bool>,
474
475    execution_props: &'a ExecutionProps,
476    input_schema: DFSchema,
477    input_batch: RecordBatch,
478}
479
480#[allow(dead_code)]
481/// The simplify result of ConstEvaluator
482enum ConstSimplifyResult {
483    // Expr was simplified and contains the new expression
484    Simplified(ScalarValue),
485    // Expr was not simplified and original value is returned
486    NotSimplified(ScalarValue),
487    // Evaluation encountered an error, contains the original expression
488    SimplifyRuntimeError(DataFusionError, Expr),
489}
490
491impl TreeNodeRewriter for ConstEvaluator<'_> {
492    type Node = Expr;
493
494    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
495        // Default to being able to evaluate this node
496        self.can_evaluate.push(true);
497
498        // if this expr is not ok to evaluate, mark entire parent
499        // stack as not ok (as all parents have at least one child or
500        // descendant that can not be evaluated
501
502        if !Self::can_evaluate(&expr) {
503            // walk back up stack, marking first parent that is not mutable
504            let parent_iter = self.can_evaluate.iter_mut().rev();
505            for p in parent_iter {
506                if !*p {
507                    // optimization: if we find an element on the
508                    // stack already marked, know all elements above are also marked
509                    break;
510                }
511                *p = false;
512            }
513        }
514
515        // NB: do not short circuit recursion even if we find a non
516        // evaluatable node (so we can fold other children, args to
517        // functions, etc.)
518        Ok(Transformed::no(expr))
519    }
520
521    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
522        match self.can_evaluate.pop() {
523            // Certain expressions such as `CASE` and `COALESCE` are short-circuiting
524            // and may not evaluate all their sub expressions. Thus, if
525            // any error is countered during simplification, return the original
526            // so that normal evaluation can occur
527            Some(true) => match self.evaluate_to_scalar(expr) {
528                ConstSimplifyResult::Simplified(s) => {
529                    Ok(Transformed::yes(Expr::Literal(s)))
530                }
531                ConstSimplifyResult::NotSimplified(s) => {
532                    Ok(Transformed::no(Expr::Literal(s)))
533                }
534                ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
535                    Ok(Transformed::yes(expr))
536                }
537            },
538            Some(false) => Ok(Transformed::no(expr)),
539            _ => internal_err!("Failed to pop can_evaluate"),
540        }
541    }
542}
543
544impl<'a> ConstEvaluator<'a> {
545    /// Create a new `ConstantEvaluator`. Session constants (such as
546    /// the time for `now()` are taken from the passed
547    /// `execution_props`.
548    pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
549        // The dummy column name is unused and doesn't matter as only
550        // expressions without column references can be evaluated
551        static DUMMY_COL_NAME: &str = ".";
552        let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
553        let input_schema = DFSchema::try_from(schema.clone())?;
554        // Need a single "input" row to produce a single output row
555        let col = new_null_array(&DataType::Null, 1);
556        let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?;
557
558        Ok(Self {
559            can_evaluate: vec![],
560            execution_props,
561            input_schema,
562            input_batch,
563        })
564    }
565
566    /// Can a function of the specified volatility be evaluated?
567    fn volatility_ok(volatility: Volatility) -> bool {
568        match volatility {
569            Volatility::Immutable => true,
570            // Values for functions such as now() are taken from ExecutionProps
571            Volatility::Stable => true,
572            Volatility::Volatile => false,
573        }
574    }
575
576    /// Can the expression be evaluated at plan time, (assuming all of
577    /// its children can also be evaluated)?
578    fn can_evaluate(expr: &Expr) -> bool {
579        // check for reasons we can't evaluate this node
580        //
581        // NOTE all expr types are listed here so when new ones are
582        // added they can be checked for their ability to be evaluated
583        // at plan time
584        match expr {
585            // TODO: remove the next line after `Expr::Wildcard` is removed
586            #[expect(deprecated)]
587            Expr::AggregateFunction { .. }
588            | Expr::ScalarVariable(_, _)
589            | Expr::Column(_)
590            | Expr::OuterReferenceColumn(_, _)
591            | Expr::Exists { .. }
592            | Expr::InSubquery(_)
593            | Expr::ScalarSubquery(_)
594            | Expr::WindowFunction { .. }
595            | Expr::GroupingSet(_)
596            | Expr::Wildcard { .. }
597            | Expr::Placeholder(_) => false,
598            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
599                Self::volatility_ok(func.signature().volatility)
600            }
601            Expr::Literal(_)
602            | Expr::Alias(..)
603            | Expr::Unnest(_)
604            | Expr::BinaryExpr { .. }
605            | Expr::Not(_)
606            | Expr::IsNotNull(_)
607            | Expr::IsNull(_)
608            | Expr::IsTrue(_)
609            | Expr::IsFalse(_)
610            | Expr::IsUnknown(_)
611            | Expr::IsNotTrue(_)
612            | Expr::IsNotFalse(_)
613            | Expr::IsNotUnknown(_)
614            | Expr::Negative(_)
615            | Expr::Between { .. }
616            | Expr::Like { .. }
617            | Expr::SimilarTo { .. }
618            | Expr::Case(_)
619            | Expr::Cast { .. }
620            | Expr::TryCast { .. }
621            | Expr::InList { .. } => true,
622        }
623    }
624
625    /// Internal helper to evaluates an Expr
626    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
627        if let Expr::Literal(s) = expr {
628            return ConstSimplifyResult::NotSimplified(s);
629        }
630
631        let phys_expr =
632            match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
633                Ok(e) => e,
634                Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
635            };
636        let col_val = match phys_expr.evaluate(&self.input_batch) {
637            Ok(v) => v,
638            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
639        };
640        match col_val {
641            ColumnarValue::Array(a) => {
642                if a.len() != 1 {
643                    ConstSimplifyResult::SimplifyRuntimeError(
644                        DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
645                        expr,
646                    )
647                } else if as_list_array(&a).is_ok() {
648                    ConstSimplifyResult::Simplified(ScalarValue::List(
649                        a.as_list::<i32>().to_owned().into(),
650                    ))
651                } else if as_large_list_array(&a).is_ok() {
652                    ConstSimplifyResult::Simplified(ScalarValue::LargeList(
653                        a.as_list::<i64>().to_owned().into(),
654                    ))
655                } else {
656                    // Non-ListArray
657                    match ScalarValue::try_from_array(&a, 0) {
658                        Ok(s) => {
659                            // TODO: support the optimization for `Map` type after support impl hash for it
660                            if matches!(&s, ScalarValue::Map(_)) {
661                                ConstSimplifyResult::SimplifyRuntimeError(
662                                    DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()),
663                                    expr,
664                                )
665                            } else {
666                                ConstSimplifyResult::Simplified(s)
667                            }
668                        }
669                        Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
670                    }
671                }
672            }
673            ColumnarValue::Scalar(s) => {
674                // TODO: support the optimization for `Map` type after support impl hash for it
675                if matches!(&s, ScalarValue::Map(_)) {
676                    ConstSimplifyResult::SimplifyRuntimeError(
677                        DataFusionError::NotImplemented(
678                            "Const evaluate for Map type is still not supported"
679                                .to_string(),
680                        ),
681                        expr,
682                    )
683                } else {
684                    ConstSimplifyResult::Simplified(s)
685                }
686            }
687        }
688    }
689}
690
691/// Simplifies [`Expr`]s by applying algebraic transformation rules
692///
693/// Example transformations that are applied:
694/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type
695/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type
696/// * `true = true` and `false = false` to `true`
697/// * `false = true` and `true = false` to `false`
698/// * `!!expr` to `expr`
699/// * `expr = null` and `expr != null` to `null`
700struct Simplifier<'a, S> {
701    info: &'a S,
702}
703
704impl<'a, S> Simplifier<'a, S> {
705    pub fn new(info: &'a S) -> Self {
706        Self { info }
707    }
708}
709
710impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
711    type Node = Expr;
712
713    /// rewrite the expression simplifying any constant expressions
714    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
715        use datafusion_expr::Operator::{
716            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
717            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
718            RegexNotIMatch, RegexNotMatch,
719        };
720
721        let info = self.info;
722        Ok(match expr {
723            //
724            // Rules for Eq
725            //
726
727            // true = A  --> A
728            // false = A --> !A
729            // null = A --> null
730            Expr::BinaryExpr(BinaryExpr {
731                left,
732                op: Eq,
733                right,
734            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
735                Transformed::yes(match as_bool_lit(&left)? {
736                    Some(true) => *right,
737                    Some(false) => Expr::Not(right),
738                    None => lit_bool_null(),
739                })
740            }
741            // A = true  --> A
742            // A = false --> !A
743            // A = null --> null
744            Expr::BinaryExpr(BinaryExpr {
745                left,
746                op: Eq,
747                right,
748            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
749                Transformed::yes(match as_bool_lit(&right)? {
750                    Some(true) => *left,
751                    Some(false) => Expr::Not(left),
752                    None => lit_bool_null(),
753                })
754            }
755            // Rules for NotEq
756            //
757
758            // true != A  --> !A
759            // false != A --> A
760            // null != A --> null
761            Expr::BinaryExpr(BinaryExpr {
762                left,
763                op: NotEq,
764                right,
765            }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
766                Transformed::yes(match as_bool_lit(&left)? {
767                    Some(true) => Expr::Not(right),
768                    Some(false) => *right,
769                    None => lit_bool_null(),
770                })
771            }
772            // A != true  --> !A
773            // A != false --> A
774            // A != null --> null,
775            Expr::BinaryExpr(BinaryExpr {
776                left,
777                op: NotEq,
778                right,
779            }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
780                Transformed::yes(match as_bool_lit(&right)? {
781                    Some(true) => Expr::Not(left),
782                    Some(false) => *left,
783                    None => lit_bool_null(),
784                })
785            }
786
787            //
788            // Rules for OR
789            //
790
791            // true OR A --> true (even if A is null)
792            Expr::BinaryExpr(BinaryExpr {
793                left,
794                op: Or,
795                right: _,
796            }) if is_true(&left) => Transformed::yes(*left),
797            // false OR A --> A
798            Expr::BinaryExpr(BinaryExpr {
799                left,
800                op: Or,
801                right,
802            }) if is_false(&left) => Transformed::yes(*right),
803            // A OR true --> true (even if A is null)
804            Expr::BinaryExpr(BinaryExpr {
805                left: _,
806                op: Or,
807                right,
808            }) if is_true(&right) => Transformed::yes(*right),
809            // A OR false --> A
810            Expr::BinaryExpr(BinaryExpr {
811                left,
812                op: Or,
813                right,
814            }) if is_false(&right) => Transformed::yes(*left),
815            // A OR !A ---> true (if A not nullable)
816            Expr::BinaryExpr(BinaryExpr {
817                left,
818                op: Or,
819                right,
820            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
821                Transformed::yes(lit(true))
822            }
823            // !A OR A ---> true (if A not nullable)
824            Expr::BinaryExpr(BinaryExpr {
825                left,
826                op: Or,
827                right,
828            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
829                Transformed::yes(lit(true))
830            }
831            // (..A..) OR A --> (..A..)
832            Expr::BinaryExpr(BinaryExpr {
833                left,
834                op: Or,
835                right,
836            }) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
837            // A OR (..A..) --> (..A..)
838            Expr::BinaryExpr(BinaryExpr {
839                left,
840                op: Or,
841                right,
842            }) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
843            // A OR (A AND B) --> A
844            Expr::BinaryExpr(BinaryExpr {
845                left,
846                op: Or,
847                right,
848            }) if is_op_with(And, &right, &left) => Transformed::yes(*left),
849            // (A AND B) OR A --> A
850            Expr::BinaryExpr(BinaryExpr {
851                left,
852                op: Or,
853                right,
854            }) if is_op_with(And, &left, &right) => Transformed::yes(*right),
855            // Eliminate common factors in conjunctions e.g
856            // (A AND B) OR (A AND C) -> A AND (B OR C)
857            Expr::BinaryExpr(BinaryExpr {
858                left,
859                op: Or,
860                right,
861            }) if has_common_conjunction(&left, &right) => {
862                let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
863                let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
864                    .partition(|e| lhs.contains(e) && !e.is_volatile());
865
866                let new_rhs = rhs.into_iter().reduce(and);
867                let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
868                let common_conjunction = common.into_iter().reduce(and).unwrap();
869
870                let new_expr = match (new_lhs, new_rhs) {
871                    (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
872                    (_, _) => common_conjunction,
873                };
874                Transformed::yes(new_expr)
875            }
876
877            //
878            // Rules for AND
879            //
880
881            // true AND A --> A
882            Expr::BinaryExpr(BinaryExpr {
883                left,
884                op: And,
885                right,
886            }) if is_true(&left) => Transformed::yes(*right),
887            // false AND A --> false (even if A is null)
888            Expr::BinaryExpr(BinaryExpr {
889                left,
890                op: And,
891                right: _,
892            }) if is_false(&left) => Transformed::yes(*left),
893            // A AND true --> A
894            Expr::BinaryExpr(BinaryExpr {
895                left,
896                op: And,
897                right,
898            }) if is_true(&right) => Transformed::yes(*left),
899            // A AND false --> false (even if A is null)
900            Expr::BinaryExpr(BinaryExpr {
901                left: _,
902                op: And,
903                right,
904            }) if is_false(&right) => Transformed::yes(*right),
905            // A AND !A ---> false (if A not nullable)
906            Expr::BinaryExpr(BinaryExpr {
907                left,
908                op: And,
909                right,
910            }) if is_not_of(&right, &left) && !info.nullable(&left)? => {
911                Transformed::yes(lit(false))
912            }
913            // !A AND A ---> false (if A not nullable)
914            Expr::BinaryExpr(BinaryExpr {
915                left,
916                op: And,
917                right,
918            }) if is_not_of(&left, &right) && !info.nullable(&right)? => {
919                Transformed::yes(lit(false))
920            }
921            // (..A..) AND A --> (..A..)
922            Expr::BinaryExpr(BinaryExpr {
923                left,
924                op: And,
925                right,
926            }) if expr_contains(&left, &right, And) => Transformed::yes(*left),
927            // A AND (..A..) --> (..A..)
928            Expr::BinaryExpr(BinaryExpr {
929                left,
930                op: And,
931                right,
932            }) if expr_contains(&right, &left, And) => Transformed::yes(*right),
933            // A AND (A OR B) --> A
934            Expr::BinaryExpr(BinaryExpr {
935                left,
936                op: And,
937                right,
938            }) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
939            // (A OR B) AND A --> A
940            Expr::BinaryExpr(BinaryExpr {
941                left,
942                op: And,
943                right,
944            }) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
945
946            //
947            // Rules for Multiply
948            //
949
950            // A * 1 --> A
951            Expr::BinaryExpr(BinaryExpr {
952                left,
953                op: Multiply,
954                right,
955            }) if is_one(&right) => Transformed::yes(*left),
956            // 1 * A --> A
957            Expr::BinaryExpr(BinaryExpr {
958                left,
959                op: Multiply,
960                right,
961            }) if is_one(&left) => Transformed::yes(*right),
962            // A * null --> null
963            Expr::BinaryExpr(BinaryExpr {
964                left: _,
965                op: Multiply,
966                right,
967            }) if is_null(&right) => Transformed::yes(*right),
968            // null * A --> null
969            Expr::BinaryExpr(BinaryExpr {
970                left,
971                op: Multiply,
972                right: _,
973            }) if is_null(&left) => Transformed::yes(*left),
974
975            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
976            Expr::BinaryExpr(BinaryExpr {
977                left,
978                op: Multiply,
979                right,
980            }) if !info.nullable(&left)?
981                && !info.get_data_type(&left)?.is_floating()
982                && is_zero(&right) =>
983            {
984                Transformed::yes(*right)
985            }
986            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
987            Expr::BinaryExpr(BinaryExpr {
988                left,
989                op: Multiply,
990                right,
991            }) if !info.nullable(&right)?
992                && !info.get_data_type(&right)?.is_floating()
993                && is_zero(&left) =>
994            {
995                Transformed::yes(*left)
996            }
997
998            //
999            // Rules for Divide
1000            //
1001
1002            // A / 1 --> A
1003            Expr::BinaryExpr(BinaryExpr {
1004                left,
1005                op: Divide,
1006                right,
1007            }) if is_one(&right) => Transformed::yes(*left),
1008            // null / A --> null
1009            Expr::BinaryExpr(BinaryExpr {
1010                left,
1011                op: Divide,
1012                right: _,
1013            }) if is_null(&left) => Transformed::yes(*left),
1014            // A / null --> null
1015            Expr::BinaryExpr(BinaryExpr {
1016                left: _,
1017                op: Divide,
1018                right,
1019            }) if is_null(&right) => Transformed::yes(*right),
1020
1021            //
1022            // Rules for Modulo
1023            //
1024
1025            // A % null --> null
1026            Expr::BinaryExpr(BinaryExpr {
1027                left: _,
1028                op: Modulo,
1029                right,
1030            }) if is_null(&right) => Transformed::yes(*right),
1031            // null % A --> null
1032            Expr::BinaryExpr(BinaryExpr {
1033                left,
1034                op: Modulo,
1035                right: _,
1036            }) if is_null(&left) => Transformed::yes(*left),
1037            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
1038            Expr::BinaryExpr(BinaryExpr {
1039                left,
1040                op: Modulo,
1041                right,
1042            }) if !info.nullable(&left)?
1043                && !info.get_data_type(&left)?.is_floating()
1044                && is_one(&right) =>
1045            {
1046                Transformed::yes(Expr::Literal(ScalarValue::new_zero(
1047                    &info.get_data_type(&left)?,
1048                )?))
1049            }
1050
1051            //
1052            // Rules for BitwiseAnd
1053            //
1054
1055            // A & null -> null
1056            Expr::BinaryExpr(BinaryExpr {
1057                left: _,
1058                op: BitwiseAnd,
1059                right,
1060            }) if is_null(&right) => Transformed::yes(*right),
1061
1062            // null & A -> null
1063            Expr::BinaryExpr(BinaryExpr {
1064                left,
1065                op: BitwiseAnd,
1066                right: _,
1067            }) if is_null(&left) => Transformed::yes(*left),
1068
1069            // A & 0 -> 0 (if A not nullable)
1070            Expr::BinaryExpr(BinaryExpr {
1071                left,
1072                op: BitwiseAnd,
1073                right,
1074            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
1075
1076            // 0 & A -> 0 (if A not nullable)
1077            Expr::BinaryExpr(BinaryExpr {
1078                left,
1079                op: BitwiseAnd,
1080                right,
1081            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
1082
1083            // !A & A -> 0 (if A not nullable)
1084            Expr::BinaryExpr(BinaryExpr {
1085                left,
1086                op: BitwiseAnd,
1087                right,
1088            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1089                Transformed::yes(Expr::Literal(ScalarValue::new_zero(
1090                    &info.get_data_type(&left)?,
1091                )?))
1092            }
1093
1094            // A & !A -> 0 (if A not nullable)
1095            Expr::BinaryExpr(BinaryExpr {
1096                left,
1097                op: BitwiseAnd,
1098                right,
1099            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1100                Transformed::yes(Expr::Literal(ScalarValue::new_zero(
1101                    &info.get_data_type(&left)?,
1102                )?))
1103            }
1104
1105            // (..A..) & A --> (..A..)
1106            Expr::BinaryExpr(BinaryExpr {
1107                left,
1108                op: BitwiseAnd,
1109                right,
1110            }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
1111
1112            // A & (..A..) --> (..A..)
1113            Expr::BinaryExpr(BinaryExpr {
1114                left,
1115                op: BitwiseAnd,
1116                right,
1117            }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
1118
1119            // A & (A | B) --> A (if B not null)
1120            Expr::BinaryExpr(BinaryExpr {
1121                left,
1122                op: BitwiseAnd,
1123                right,
1124            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
1125                Transformed::yes(*left)
1126            }
1127
1128            // (A | B) & A --> A (if B not null)
1129            Expr::BinaryExpr(BinaryExpr {
1130                left,
1131                op: BitwiseAnd,
1132                right,
1133            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
1134                Transformed::yes(*right)
1135            }
1136
1137            //
1138            // Rules for BitwiseOr
1139            //
1140
1141            // A | null -> null
1142            Expr::BinaryExpr(BinaryExpr {
1143                left: _,
1144                op: BitwiseOr,
1145                right,
1146            }) if is_null(&right) => Transformed::yes(*right),
1147
1148            // null | A -> null
1149            Expr::BinaryExpr(BinaryExpr {
1150                left,
1151                op: BitwiseOr,
1152                right: _,
1153            }) if is_null(&left) => Transformed::yes(*left),
1154
1155            // A | 0 -> A (even if A is null)
1156            Expr::BinaryExpr(BinaryExpr {
1157                left,
1158                op: BitwiseOr,
1159                right,
1160            }) if is_zero(&right) => Transformed::yes(*left),
1161
1162            // 0 | A -> A (even if A is null)
1163            Expr::BinaryExpr(BinaryExpr {
1164                left,
1165                op: BitwiseOr,
1166                right,
1167            }) if is_zero(&left) => Transformed::yes(*right),
1168
1169            // !A | A -> -1 (if A not nullable)
1170            Expr::BinaryExpr(BinaryExpr {
1171                left,
1172                op: BitwiseOr,
1173                right,
1174            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1175                Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
1176                    &info.get_data_type(&left)?,
1177                )?))
1178            }
1179
1180            // A | !A -> -1 (if A not nullable)
1181            Expr::BinaryExpr(BinaryExpr {
1182                left,
1183                op: BitwiseOr,
1184                right,
1185            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1186                Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
1187                    &info.get_data_type(&left)?,
1188                )?))
1189            }
1190
1191            // (..A..) | A --> (..A..)
1192            Expr::BinaryExpr(BinaryExpr {
1193                left,
1194                op: BitwiseOr,
1195                right,
1196            }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
1197
1198            // A | (..A..) --> (..A..)
1199            Expr::BinaryExpr(BinaryExpr {
1200                left,
1201                op: BitwiseOr,
1202                right,
1203            }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
1204
1205            // A | (A & B) --> A (if B not null)
1206            Expr::BinaryExpr(BinaryExpr {
1207                left,
1208                op: BitwiseOr,
1209                right,
1210            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
1211                Transformed::yes(*left)
1212            }
1213
1214            // (A & B) | A --> A (if B not null)
1215            Expr::BinaryExpr(BinaryExpr {
1216                left,
1217                op: BitwiseOr,
1218                right,
1219            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
1220                Transformed::yes(*right)
1221            }
1222
1223            //
1224            // Rules for BitwiseXor
1225            //
1226
1227            // A ^ null -> null
1228            Expr::BinaryExpr(BinaryExpr {
1229                left: _,
1230                op: BitwiseXor,
1231                right,
1232            }) if is_null(&right) => Transformed::yes(*right),
1233
1234            // null ^ A -> null
1235            Expr::BinaryExpr(BinaryExpr {
1236                left,
1237                op: BitwiseXor,
1238                right: _,
1239            }) if is_null(&left) => Transformed::yes(*left),
1240
1241            // A ^ 0 -> A (if A not nullable)
1242            Expr::BinaryExpr(BinaryExpr {
1243                left,
1244                op: BitwiseXor,
1245                right,
1246            }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
1247
1248            // 0 ^ A -> A (if A not nullable)
1249            Expr::BinaryExpr(BinaryExpr {
1250                left,
1251                op: BitwiseXor,
1252                right,
1253            }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
1254
1255            // !A ^ A -> -1 (if A not nullable)
1256            Expr::BinaryExpr(BinaryExpr {
1257                left,
1258                op: BitwiseXor,
1259                right,
1260            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
1261                Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
1262                    &info.get_data_type(&left)?,
1263                )?))
1264            }
1265
1266            // A ^ !A -> -1 (if A not nullable)
1267            Expr::BinaryExpr(BinaryExpr {
1268                left,
1269                op: BitwiseXor,
1270                right,
1271            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
1272                Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
1273                    &info.get_data_type(&left)?,
1274                )?))
1275            }
1276
1277            // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
1278            Expr::BinaryExpr(BinaryExpr {
1279                left,
1280                op: BitwiseXor,
1281                right,
1282            }) if expr_contains(&left, &right, BitwiseXor) => {
1283                let expr = delete_xor_in_complex_expr(&left, &right, false);
1284                Transformed::yes(if expr == *right {
1285                    Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?)
1286                } else {
1287                    expr
1288                })
1289            }
1290
1291            // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
1292            Expr::BinaryExpr(BinaryExpr {
1293                left,
1294                op: BitwiseXor,
1295                right,
1296            }) if expr_contains(&right, &left, BitwiseXor) => {
1297                let expr = delete_xor_in_complex_expr(&right, &left, true);
1298                Transformed::yes(if expr == *left {
1299                    Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
1300                } else {
1301                    expr
1302                })
1303            }
1304
1305            //
1306            // Rules for BitwiseShiftRight
1307            //
1308
1309            // A >> null -> null
1310            Expr::BinaryExpr(BinaryExpr {
1311                left: _,
1312                op: BitwiseShiftRight,
1313                right,
1314            }) if is_null(&right) => Transformed::yes(*right),
1315
1316            // null >> A -> null
1317            Expr::BinaryExpr(BinaryExpr {
1318                left,
1319                op: BitwiseShiftRight,
1320                right: _,
1321            }) if is_null(&left) => Transformed::yes(*left),
1322
1323            // A >> 0 -> A (even if A is null)
1324            Expr::BinaryExpr(BinaryExpr {
1325                left,
1326                op: BitwiseShiftRight,
1327                right,
1328            }) if is_zero(&right) => Transformed::yes(*left),
1329
1330            //
1331            // Rules for BitwiseShiftRight
1332            //
1333
1334            // A << null -> null
1335            Expr::BinaryExpr(BinaryExpr {
1336                left: _,
1337                op: BitwiseShiftLeft,
1338                right,
1339            }) if is_null(&right) => Transformed::yes(*right),
1340
1341            // null << A -> null
1342            Expr::BinaryExpr(BinaryExpr {
1343                left,
1344                op: BitwiseShiftLeft,
1345                right: _,
1346            }) if is_null(&left) => Transformed::yes(*left),
1347
1348            // A << 0 -> A (even if A is null)
1349            Expr::BinaryExpr(BinaryExpr {
1350                left,
1351                op: BitwiseShiftLeft,
1352                right,
1353            }) if is_zero(&right) => Transformed::yes(*left),
1354
1355            //
1356            // Rules for Not
1357            //
1358            Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
1359
1360            //
1361            // Rules for Negative
1362            //
1363            Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
1364
1365            //
1366            // Rules for Case
1367            //
1368
1369            // CASE
1370            //   WHEN X THEN A
1371            //   WHEN Y THEN B
1372            //   ...
1373            //   ELSE Q
1374            // END
1375            //
1376            // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
1377            //
1378            // Note: the rationale for this rewrite is that the expr can then be further
1379            // simplified using the existing rules for AND/OR
1380            Expr::Case(Case {
1381                expr: None,
1382                when_then_expr,
1383                else_expr,
1384            }) if !when_then_expr.is_empty()
1385                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1386                && info.is_boolean_type(&when_then_expr[0].1)? =>
1387            {
1388                // String disjunction of all the when predicates encountered so far. Not nullable.
1389                let mut filter_expr = lit(false);
1390                // The disjunction of all the cases
1391                let mut out_expr = lit(false);
1392
1393                for (when, then) in when_then_expr {
1394                    let when = is_exactly_true(*when, info)?;
1395                    let case_expr =
1396                        when.clone().and(filter_expr.clone().not()).and(*then);
1397
1398                    out_expr = out_expr.or(case_expr);
1399                    filter_expr = filter_expr.or(when);
1400                }
1401
1402                let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1403                let case_expr = filter_expr.not().and(else_expr);
1404                out_expr = out_expr.or(case_expr);
1405
1406                // Do a first pass at simplification
1407                out_expr.rewrite(self)?
1408            }
1409            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
1410                match udf.simplify(args, info)? {
1411                    ExprSimplifyResult::Original(args) => {
1412                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
1413                            func: udf,
1414                            args,
1415                        }))
1416                    }
1417                    ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
1418                }
1419            }
1420
1421            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1422                ref func,
1423                ..
1424            }) => match (func.simplify(), expr) {
1425                (Some(simplify_function), Expr::AggregateFunction(af)) => {
1426                    Transformed::yes(simplify_function(af, info)?)
1427                }
1428                (_, expr) => Transformed::no(expr),
1429            },
1430
1431            Expr::WindowFunction(WindowFunction {
1432                fun: WindowFunctionDefinition::WindowUDF(ref udwf),
1433                ..
1434            }) => match (udwf.simplify(), expr) {
1435                (Some(simplify_function), Expr::WindowFunction(wf)) => {
1436                    Transformed::yes(simplify_function(wf, info)?)
1437                }
1438                (_, expr) => Transformed::no(expr),
1439            },
1440
1441            //
1442            // Rules for Between
1443            //
1444
1445            // a between 3 and 5  -->  a >= 3 AND a <=5
1446            // a not between 3 and 5  -->  a < 3 OR a > 5
1447            Expr::Between(between) => Transformed::yes(if between.negated {
1448                let l = *between.expr.clone();
1449                let r = *between.expr;
1450                or(l.lt(*between.low), r.gt(*between.high))
1451            } else {
1452                and(
1453                    between.expr.clone().gt_eq(*between.low),
1454                    between.expr.lt_eq(*between.high),
1455                )
1456            }),
1457
1458            //
1459            // Rules for regexes
1460            //
1461            Expr::BinaryExpr(BinaryExpr {
1462                left,
1463                op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
1464                right,
1465            }) => Transformed::yes(simplify_regex_expr(left, op, right)?),
1466
1467            // Rules for Like
1468            Expr::Like(like) => {
1469                // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291
1470                let escape_char = like.escape_char.unwrap_or('\\');
1471                match as_string_scalar(&like.pattern) {
1472                    Some((data_type, pattern_str)) => {
1473                        match pattern_str {
1474                            None => return Ok(Transformed::yes(lit_bool_null())),
1475                            Some(pattern_str) if pattern_str == "%" => {
1476                                // exp LIKE '%' is
1477                                //   - when exp is not NULL, it's true
1478                                //   - when exp is NULL, it's NULL
1479                                // exp NOT LIKE '%' is
1480                                //   - when exp is not NULL, it's false
1481                                //   - when exp is NULL, it's NULL
1482                                let result_for_non_null = lit(!like.negated);
1483                                Transformed::yes(if !info.nullable(&like.expr)? {
1484                                    result_for_non_null
1485                                } else {
1486                                    Expr::Case(Case {
1487                                        expr: Some(Box::new(Expr::IsNotNull(like.expr))),
1488                                        when_then_expr: vec![(
1489                                            Box::new(lit(true)),
1490                                            Box::new(result_for_non_null),
1491                                        )],
1492                                        else_expr: None,
1493                                    })
1494                                })
1495                            }
1496                            Some(pattern_str)
1497                                if pattern_str.contains("%%")
1498                                    && !pattern_str.contains(escape_char) =>
1499                            {
1500                                // Repeated occurrences of wildcard are redundant so remove them
1501                                // exp LIKE '%%'  --> exp LIKE '%'
1502                                let simplified_pattern = Regex::new("%%+")
1503                                    .unwrap()
1504                                    .replace_all(pattern_str, "%")
1505                                    .to_string();
1506                                Transformed::yes(Expr::Like(Like {
1507                                    pattern: Box::new(to_string_scalar(
1508                                        data_type,
1509                                        Some(simplified_pattern),
1510                                    )),
1511                                    ..like
1512                                }))
1513                            }
1514                            Some(pattern_str)
1515                                if !pattern_str
1516                                    .contains(['%', '_', escape_char].as_ref()) =>
1517                            {
1518                                // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression
1519                                // TODO: handle escape characters
1520                                Transformed::yes(Expr::BinaryExpr(BinaryExpr {
1521                                    left: like.expr.clone(),
1522                                    op: if like.negated { NotEq } else { Eq },
1523                                    right: like.pattern.clone(),
1524                                }))
1525                            }
1526
1527                            Some(_pattern_str) => Transformed::no(Expr::Like(like)),
1528                        }
1529                    }
1530                    None => Transformed::no(Expr::Like(like)),
1531                }
1532            }
1533
1534            // a is not null/unknown --> true (if a is not nullable)
1535            Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
1536                if !info.nullable(&expr)? =>
1537            {
1538                Transformed::yes(lit(true))
1539            }
1540
1541            // a is null/unknown --> false (if a is not nullable)
1542            Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
1543                Transformed::yes(lit(false))
1544            }
1545
1546            // expr IN () --> false
1547            // expr NOT IN () --> true
1548            Expr::InList(InList {
1549                expr,
1550                list,
1551                negated,
1552            }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => {
1553                Transformed::yes(lit(negated))
1554            }
1555
1556            // null in (x, y, z) --> null
1557            // null not in (x, y, z) --> null
1558            Expr::InList(InList {
1559                expr,
1560                list: _,
1561                negated: _,
1562            }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()),
1563
1564            // expr IN ((subquery)) -> expr IN (subquery), see ##5529
1565            Expr::InList(InList {
1566                expr,
1567                mut list,
1568                negated,
1569            }) if list.len() == 1
1570                && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
1571            {
1572                let Expr::ScalarSubquery(subquery) = list.remove(0) else {
1573                    unreachable!()
1574                };
1575
1576                Transformed::yes(Expr::InSubquery(InSubquery::new(
1577                    expr, subquery, negated,
1578                )))
1579            }
1580
1581            // Combine multiple OR expressions into a single IN list expression if possible
1582            //
1583            // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
1584            Expr::BinaryExpr(BinaryExpr {
1585                left,
1586                op: Or,
1587                right,
1588            }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
1589                let lhs = to_inlist(*left).unwrap();
1590                let rhs = to_inlist(*right).unwrap();
1591                let mut seen: HashSet<Expr> = HashSet::new();
1592                let list = lhs
1593                    .list
1594                    .into_iter()
1595                    .chain(rhs.list)
1596                    .filter(|e| seen.insert(e.to_owned()))
1597                    .collect::<Vec<_>>();
1598
1599                let merged_inlist = InList {
1600                    expr: lhs.expr,
1601                    list,
1602                    negated: false,
1603                };
1604
1605                Transformed::yes(Expr::InList(merged_inlist))
1606            }
1607
1608            // Simplify expressions that is guaranteed to be true or false to a literal boolean expression
1609            //
1610            // Rules:
1611            // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists
1612            //   Intersection:
1613            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
1614            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
1615            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
1616            //   Union:
1617            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
1618            //     # This rule is handled by `or_in_list_simplifier.rs`
1619            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
1620            // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression
1621            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
1622            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
1623            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
1624            Expr::BinaryExpr(BinaryExpr {
1625                left,
1626                op: And,
1627                right,
1628            }) if are_inlist_and_eq_and_match_neg(
1629                left.as_ref(),
1630                right.as_ref(),
1631                false,
1632                false,
1633            ) =>
1634            {
1635                match (*left, *right) {
1636                    (Expr::InList(l1), Expr::InList(l2)) => {
1637                        return inlist_intersection(l1, &l2, false).map(Transformed::yes);
1638                    }
1639                    // Matched previously once
1640                    _ => unreachable!(),
1641                }
1642            }
1643
1644            Expr::BinaryExpr(BinaryExpr {
1645                left,
1646                op: And,
1647                right,
1648            }) if are_inlist_and_eq_and_match_neg(
1649                left.as_ref(),
1650                right.as_ref(),
1651                true,
1652                true,
1653            ) =>
1654            {
1655                match (*left, *right) {
1656                    (Expr::InList(l1), Expr::InList(l2)) => {
1657                        return inlist_union(l1, l2, true).map(Transformed::yes);
1658                    }
1659                    // Matched previously once
1660                    _ => unreachable!(),
1661                }
1662            }
1663
1664            Expr::BinaryExpr(BinaryExpr {
1665                left,
1666                op: And,
1667                right,
1668            }) if are_inlist_and_eq_and_match_neg(
1669                left.as_ref(),
1670                right.as_ref(),
1671                false,
1672                true,
1673            ) =>
1674            {
1675                match (*left, *right) {
1676                    (Expr::InList(l1), Expr::InList(l2)) => {
1677                        return inlist_except(l1, &l2).map(Transformed::yes);
1678                    }
1679                    // Matched previously once
1680                    _ => unreachable!(),
1681                }
1682            }
1683
1684            Expr::BinaryExpr(BinaryExpr {
1685                left,
1686                op: And,
1687                right,
1688            }) if are_inlist_and_eq_and_match_neg(
1689                left.as_ref(),
1690                right.as_ref(),
1691                true,
1692                false,
1693            ) =>
1694            {
1695                match (*left, *right) {
1696                    (Expr::InList(l1), Expr::InList(l2)) => {
1697                        return inlist_except(l2, &l1).map(Transformed::yes);
1698                    }
1699                    // Matched previously once
1700                    _ => unreachable!(),
1701                }
1702            }
1703
1704            Expr::BinaryExpr(BinaryExpr {
1705                left,
1706                op: Or,
1707                right,
1708            }) if are_inlist_and_eq_and_match_neg(
1709                left.as_ref(),
1710                right.as_ref(),
1711                true,
1712                true,
1713            ) =>
1714            {
1715                match (*left, *right) {
1716                    (Expr::InList(l1), Expr::InList(l2)) => {
1717                        return inlist_intersection(l1, &l2, true).map(Transformed::yes);
1718                    }
1719                    // Matched previously once
1720                    _ => unreachable!(),
1721                }
1722            }
1723
1724            // no additional rewrites possible
1725            expr => Transformed::no(expr),
1726        })
1727    }
1728}
1729
1730fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option<String>)> {
1731    match expr {
1732        Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)),
1733        Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)),
1734        Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)),
1735        _ => None,
1736    }
1737}
1738
1739fn to_string_scalar(data_type: DataType, value: Option<String>) -> Expr {
1740    match data_type {
1741        DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)),
1742        DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)),
1743        DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)),
1744        _ => unreachable!(),
1745    }
1746}
1747
1748fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1749    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1750    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
1751}
1752
1753// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1754fn are_inlist_and_eq_and_match_neg(
1755    left: &Expr,
1756    right: &Expr,
1757    is_left_neg: bool,
1758    is_right_neg: bool,
1759) -> bool {
1760    match (left, right) {
1761        (Expr::InList(l), Expr::InList(r)) => {
1762            l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg
1763        }
1764        _ => false,
1765    }
1766}
1767
1768// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
1769fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
1770    let left = as_inlist(left);
1771    let right = as_inlist(right);
1772    if let (Some(lhs), Some(rhs)) = (left, right) {
1773        matches!(lhs.expr.as_ref(), Expr::Column(_))
1774            && matches!(rhs.expr.as_ref(), Expr::Column(_))
1775            && lhs.expr == rhs.expr
1776            && !lhs.negated
1777            && !rhs.negated
1778    } else {
1779        false
1780    }
1781}
1782
1783/// Try to convert an expression to an in-list expression
1784fn as_inlist(expr: &Expr) -> Option<Cow<InList>> {
1785    match expr {
1786        Expr::InList(inlist) => Some(Cow::Borrowed(inlist)),
1787        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
1788            match (left.as_ref(), right.as_ref()) {
1789                (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList {
1790                    expr: left.clone(),
1791                    list: vec![*right.clone()],
1792                    negated: false,
1793                })),
1794                (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList {
1795                    expr: right.clone(),
1796                    list: vec![*left.clone()],
1797                    negated: false,
1798                })),
1799                _ => None,
1800            }
1801        }
1802        _ => None,
1803    }
1804}
1805
1806fn to_inlist(expr: Expr) -> Option<InList> {
1807    match expr {
1808        Expr::InList(inlist) => Some(inlist),
1809        Expr::BinaryExpr(BinaryExpr {
1810            left,
1811            op: Operator::Eq,
1812            right,
1813        }) => match (left.as_ref(), right.as_ref()) {
1814            (Expr::Column(_), Expr::Literal(_)) => Some(InList {
1815                expr: left,
1816                list: vec![*right],
1817                negated: false,
1818            }),
1819            (Expr::Literal(_), Expr::Column(_)) => Some(InList {
1820                expr: right,
1821                list: vec![*left],
1822                negated: false,
1823            }),
1824            _ => None,
1825        },
1826        _ => None,
1827    }
1828}
1829
1830/// Return the union of two inlist expressions
1831/// maintaining the order of the elements in the two lists
1832fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
1833    // extend the list in l1 with the elements in l2 that are not already in l1
1834    let l1_items: HashSet<_> = l1.list.iter().collect();
1835
1836    // keep all l2 items that do not also appear in l1
1837    let keep_l2: Vec<_> = l2
1838        .list
1839        .into_iter()
1840        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
1841        .collect();
1842
1843    l1.list.extend(keep_l2);
1844    l1.negated = negated;
1845    Ok(Expr::InList(l1))
1846}
1847
1848/// Return the intersection of two inlist expressions
1849/// maintaining the order of the elements in the two lists
1850fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result<Expr> {
1851    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1852
1853    // remove all items from l1 that are not in l2
1854    l1.list.retain(|e| l2_items.contains(e));
1855
1856    // e in () is always false
1857    // e not in () is always true
1858    if l1.list.is_empty() {
1859        return Ok(lit(negated));
1860    }
1861    Ok(Expr::InList(l1))
1862}
1863
1864/// Return the all items in l1 that are not in l2
1865/// maintaining the order of the elements in the two lists
1866fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
1867    let l2_items = l2.list.iter().collect::<HashSet<_>>();
1868
1869    // keep only items from l1 that are not in l2
1870    l1.list.retain(|e| !l2_items.contains(e));
1871
1872    if l1.list.is_empty() {
1873        return Ok(lit(false));
1874    }
1875    Ok(Expr::InList(l1))
1876}
1877
1878/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
1879fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1880    if !info.nullable(&expr)? {
1881        Ok(expr)
1882    } else {
1883        Ok(Expr::BinaryExpr(BinaryExpr {
1884            left: Box::new(expr),
1885            op: Operator::IsNotDistinctFrom,
1886            right: Box::new(lit(true)),
1887        }))
1888    }
1889}
1890
1891#[cfg(test)]
1892mod tests {
1893    use crate::simplify_expressions::SimplifyContext;
1894    use crate::test::test_table_scan_with_name;
1895    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1896    use datafusion_expr::{
1897        function::{
1898            AccumulatorArgs, AggregateFunctionSimplification,
1899            WindowFunctionSimplification,
1900        },
1901        interval_arithmetic::Interval,
1902        *,
1903    };
1904    use datafusion_functions_window_common::field::WindowUDFFieldArgs;
1905    use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
1906    use std::{
1907        collections::HashMap,
1908        ops::{BitAnd, BitOr, BitXor},
1909        sync::Arc,
1910    };
1911
1912    use super::*;
1913
1914    // ------------------------------
1915    // --- ExprSimplifier tests -----
1916    // ------------------------------
1917    #[test]
1918    fn api_basic() {
1919        let props = ExecutionProps::new();
1920        let simplifier =
1921            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
1922
1923        let expr = lit(1) + lit(2);
1924        let expected = lit(3);
1925        assert_eq!(expected, simplifier.simplify(expr).unwrap());
1926    }
1927
1928    #[test]
1929    fn basic_coercion() {
1930        let schema = test_schema();
1931        let props = ExecutionProps::new();
1932        let simplifier = ExprSimplifier::new(
1933            SimplifyContext::new(&props).with_schema(Arc::clone(&schema)),
1934        );
1935
1936        // Note expr type is int32 (not int64)
1937        // (1i64 + 2i32) < i
1938        let expr = (lit(1i64) + lit(2i32)).lt(col("i"));
1939        // should fully simplify to 3 < i (though i has been coerced to i64)
1940        let expected = lit(3i64).lt(col("i"));
1941
1942        let expr = simplifier.coerce(expr, &schema).unwrap();
1943
1944        assert_eq!(expected, simplifier.simplify(expr).unwrap());
1945    }
1946
1947    fn test_schema() -> DFSchemaRef {
1948        Schema::new(vec![
1949            Field::new("i", DataType::Int64, false),
1950            Field::new("b", DataType::Boolean, true),
1951        ])
1952        .to_dfschema_ref()
1953        .unwrap()
1954    }
1955
1956    #[test]
1957    fn simplify_and_constant_prop() {
1958        let props = ExecutionProps::new();
1959        let simplifier =
1960            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
1961
1962        // should be able to simplify to false
1963        // (i * (1 - 2)) > 0
1964        let expr = (col("i") * (lit(1) - lit(1))).gt(lit(0));
1965        let expected = lit(false);
1966        assert_eq!(expected, simplifier.simplify(expr).unwrap());
1967    }
1968
1969    #[test]
1970    fn simplify_and_constant_prop_with_case() {
1971        let props = ExecutionProps::new();
1972        let simplifier =
1973            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema()));
1974
1975        //   CASE
1976        //     WHEN i>5 AND false THEN i > 5
1977        //     WHEN i<5 AND true THEN i < 5
1978        //     ELSE false
1979        //   END
1980        //
1981        // Can be simplified to `i < 5`
1982        let expr = when(col("i").gt(lit(5)).and(lit(false)), col("i").gt(lit(5)))
1983            .when(col("i").lt(lit(5)).and(lit(true)), col("i").lt(lit(5)))
1984            .otherwise(lit(false))
1985            .unwrap();
1986        let expected = col("i").lt(lit(5));
1987        assert_eq!(expected, simplifier.simplify(expr).unwrap());
1988    }
1989
1990    // ------------------------------
1991    // --- Simplifier tests -----
1992    // ------------------------------
1993
1994    #[test]
1995    fn test_simplify_canonicalize() {
1996        {
1997            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
1998            let expected = col("c2").gt(lit(1));
1999            assert_eq!(simplify(expr), expected);
2000        }
2001        {
2002            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
2003            let expected = col("c2").gt(col("c1"));
2004            assert_eq!(simplify(expr), expected);
2005        }
2006        {
2007            let expr = col("c1")
2008                .eq(lit(1))
2009                .and(lit(1).eq(col("c1")))
2010                .and(col("c1").eq(lit(3)));
2011            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
2012            assert_eq!(simplify(expr), expected);
2013        }
2014        {
2015            let expr = col("c1")
2016                .eq(col("c2"))
2017                .and(col("c1").gt(lit(5)))
2018                .and(col("c2").eq(col("c1")));
2019            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
2020            assert_eq!(simplify(expr), expected);
2021        }
2022        {
2023            let expr = col("c1")
2024                .eq(lit(1))
2025                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
2026            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
2027            assert_eq!(simplify(expr), expected);
2028        }
2029        {
2030            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2031            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2032            assert_eq!(simplify(expr), expected);
2033        }
2034        {
2035            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2036            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
2037            assert_eq!(simplify(expr), expected);
2038        }
2039        {
2040            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
2041            let expected = col("c2").lt(col("c1"));
2042            assert_eq!(simplify(expr), expected);
2043        }
2044    }
2045
2046    #[test]
2047    fn test_simplify_or_true() {
2048        let expr_a = col("c2").or(lit(true));
2049        let expr_b = lit(true).or(col("c2"));
2050        let expected = lit(true);
2051
2052        assert_eq!(simplify(expr_a), expected);
2053        assert_eq!(simplify(expr_b), expected);
2054    }
2055
2056    #[test]
2057    fn test_simplify_or_false() {
2058        let expr_a = lit(false).or(col("c2"));
2059        let expr_b = col("c2").or(lit(false));
2060        let expected = col("c2");
2061
2062        assert_eq!(simplify(expr_a), expected);
2063        assert_eq!(simplify(expr_b), expected);
2064    }
2065
2066    #[test]
2067    fn test_simplify_or_same() {
2068        let expr = col("c2").or(col("c2"));
2069        let expected = col("c2");
2070
2071        assert_eq!(simplify(expr), expected);
2072    }
2073
2074    #[test]
2075    fn test_simplify_or_not_self() {
2076        // A OR !A if A is not nullable --> true
2077        // !A OR A if A is not nullable --> true
2078        let expr_a = col("c2_non_null").or(col("c2_non_null").not());
2079        let expr_b = col("c2_non_null").not().or(col("c2_non_null"));
2080        let expected = lit(true);
2081
2082        assert_eq!(simplify(expr_a), expected);
2083        assert_eq!(simplify(expr_b), expected);
2084    }
2085
2086    #[test]
2087    fn test_simplify_and_false() {
2088        let expr_a = lit(false).and(col("c2"));
2089        let expr_b = col("c2").and(lit(false));
2090        let expected = lit(false);
2091
2092        assert_eq!(simplify(expr_a), expected);
2093        assert_eq!(simplify(expr_b), expected);
2094    }
2095
2096    #[test]
2097    fn test_simplify_and_same() {
2098        let expr = col("c2").and(col("c2"));
2099        let expected = col("c2");
2100
2101        assert_eq!(simplify(expr), expected);
2102    }
2103
2104    #[test]
2105    fn test_simplify_and_true() {
2106        let expr_a = lit(true).and(col("c2"));
2107        let expr_b = col("c2").and(lit(true));
2108        let expected = col("c2");
2109
2110        assert_eq!(simplify(expr_a), expected);
2111        assert_eq!(simplify(expr_b), expected);
2112    }
2113
2114    #[test]
2115    fn test_simplify_and_not_self() {
2116        // A AND !A if A is not nullable --> false
2117        // !A AND A if A is not nullable --> false
2118        let expr_a = col("c2_non_null").and(col("c2_non_null").not());
2119        let expr_b = col("c2_non_null").not().and(col("c2_non_null"));
2120        let expected = lit(false);
2121
2122        assert_eq!(simplify(expr_a), expected);
2123        assert_eq!(simplify(expr_b), expected);
2124    }
2125
2126    #[test]
2127    fn test_simplify_multiply_by_one() {
2128        let expr_a = col("c2") * lit(1);
2129        let expr_b = lit(1) * col("c2");
2130        let expected = col("c2");
2131
2132        assert_eq!(simplify(expr_a), expected);
2133        assert_eq!(simplify(expr_b), expected);
2134
2135        let expr = col("c2") * lit(ScalarValue::Decimal128(Some(10000000000), 38, 10));
2136        assert_eq!(simplify(expr), expected);
2137
2138        let expr = lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)) * col("c2");
2139        assert_eq!(simplify(expr), expected);
2140    }
2141
2142    #[test]
2143    fn test_simplify_multiply_by_null() {
2144        let null = Expr::Literal(ScalarValue::Null);
2145        // A * null --> null
2146        {
2147            let expr = col("c2") * null.clone();
2148            assert_eq!(simplify(expr), null);
2149        }
2150        // null * A --> null
2151        {
2152            let expr = null.clone() * col("c2");
2153            assert_eq!(simplify(expr), null);
2154        }
2155    }
2156
2157    #[test]
2158    fn test_simplify_multiply_by_zero() {
2159        // cannot optimize A * null (null * A) if A is nullable
2160        {
2161            let expr_a = col("c2") * lit(0);
2162            let expr_b = lit(0) * col("c2");
2163
2164            assert_eq!(simplify(expr_a.clone()), expr_a);
2165            assert_eq!(simplify(expr_b.clone()), expr_b);
2166        }
2167        // 0 * A --> 0 if A is not nullable
2168        {
2169            let expr = lit(0) * col("c2_non_null");
2170            assert_eq!(simplify(expr), lit(0));
2171        }
2172        // A * 0 --> 0 if A is not nullable
2173        {
2174            let expr = col("c2_non_null") * lit(0);
2175            assert_eq!(simplify(expr), lit(0));
2176        }
2177        // A * Decimal128(0) --> 0 if A is not nullable
2178        {
2179            let expr = col("c2_non_null") * lit(ScalarValue::Decimal128(Some(0), 31, 10));
2180            assert_eq!(
2181                simplify(expr),
2182                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2183            );
2184            let expr = binary_expr(
2185                lit(ScalarValue::Decimal128(Some(0), 31, 10)),
2186                Operator::Multiply,
2187                col("c2_non_null"),
2188            );
2189            assert_eq!(
2190                simplify(expr),
2191                lit(ScalarValue::Decimal128(Some(0), 31, 10))
2192            );
2193        }
2194    }
2195
2196    #[test]
2197    fn test_simplify_divide_by_one() {
2198        let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
2199        let expected = col("c2");
2200        assert_eq!(simplify(expr), expected);
2201        let expr = col("c2") / lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2202        assert_eq!(simplify(expr), expected);
2203    }
2204
2205    #[test]
2206    fn test_simplify_divide_null() {
2207        // A / null --> null
2208        let null = lit(ScalarValue::Null);
2209        {
2210            let expr = col("c") / null.clone();
2211            assert_eq!(simplify(expr), null);
2212        }
2213        // null / A --> null
2214        {
2215            let expr = null.clone() / col("c");
2216            assert_eq!(simplify(expr), null);
2217        }
2218    }
2219
2220    #[test]
2221    fn test_simplify_divide_by_same() {
2222        let expr = col("c2") / col("c2");
2223        // if c2 is null, c2 / c2 = null, so can't simplify
2224        let expected = expr.clone();
2225
2226        assert_eq!(simplify(expr), expected);
2227    }
2228
2229    #[test]
2230    fn test_simplify_modulo_by_null() {
2231        let null = lit(ScalarValue::Null);
2232        // A % null --> null
2233        {
2234            let expr = col("c2") % null.clone();
2235            assert_eq!(simplify(expr), null);
2236        }
2237        // null % A --> null
2238        {
2239            let expr = null.clone() % col("c2");
2240            assert_eq!(simplify(expr), null);
2241        }
2242    }
2243
2244    #[test]
2245    fn test_simplify_modulo_by_one() {
2246        let expr = col("c2") % lit(1);
2247        // if c2 is null, c2 % 1 = null, so can't simplify
2248        let expected = expr.clone();
2249
2250        assert_eq!(simplify(expr), expected);
2251    }
2252
2253    #[test]
2254    fn test_simplify_divide_zero_by_zero() {
2255        // because divide by 0 maybe occur in short-circuit expression
2256        // so we should not simplify this, and throw error in runtime
2257        let expr = lit(0) / lit(0);
2258        let expected = expr.clone();
2259
2260        assert_eq!(simplify(expr), expected);
2261    }
2262
2263    #[test]
2264    fn test_simplify_divide_by_zero() {
2265        // because divide by 0 maybe occur in short-circuit expression
2266        // so we should not simplify this, and throw error in runtime
2267        let expr = col("c2_non_null") / lit(0);
2268        let expected = expr.clone();
2269
2270        assert_eq!(simplify(expr), expected);
2271    }
2272
2273    #[test]
2274    fn test_simplify_modulo_by_one_non_null() {
2275        let expr = col("c3_non_null") % lit(1);
2276        let expected = lit(0_i64);
2277        assert_eq!(simplify(expr), expected);
2278        let expr =
2279            col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10));
2280        assert_eq!(simplify(expr), expected);
2281    }
2282
2283    #[test]
2284    fn test_simplify_bitwise_xor_by_null() {
2285        let null = lit(ScalarValue::Null);
2286        // A ^ null --> null
2287        {
2288            let expr = col("c2") ^ null.clone();
2289            assert_eq!(simplify(expr), null);
2290        }
2291        // null ^ A --> null
2292        {
2293            let expr = null.clone() ^ col("c2");
2294            assert_eq!(simplify(expr), null);
2295        }
2296    }
2297
2298    #[test]
2299    fn test_simplify_bitwise_shift_right_by_null() {
2300        let null = lit(ScalarValue::Null);
2301        // A >> null --> null
2302        {
2303            let expr = col("c2") >> null.clone();
2304            assert_eq!(simplify(expr), null);
2305        }
2306        // null >> A --> null
2307        {
2308            let expr = null.clone() >> col("c2");
2309            assert_eq!(simplify(expr), null);
2310        }
2311    }
2312
2313    #[test]
2314    fn test_simplify_bitwise_shift_left_by_null() {
2315        let null = lit(ScalarValue::Null);
2316        // A << null --> null
2317        {
2318            let expr = col("c2") << null.clone();
2319            assert_eq!(simplify(expr), null);
2320        }
2321        // null << A --> null
2322        {
2323            let expr = null.clone() << col("c2");
2324            assert_eq!(simplify(expr), null);
2325        }
2326    }
2327
2328    #[test]
2329    fn test_simplify_bitwise_and_by_zero() {
2330        // A & 0 --> 0
2331        {
2332            let expr = col("c2_non_null") & lit(0);
2333            assert_eq!(simplify(expr), lit(0));
2334        }
2335        // 0 & A --> 0
2336        {
2337            let expr = lit(0) & col("c2_non_null");
2338            assert_eq!(simplify(expr), lit(0));
2339        }
2340    }
2341
2342    #[test]
2343    fn test_simplify_bitwise_or_by_zero() {
2344        // A | 0 --> A
2345        {
2346            let expr = col("c2_non_null") | lit(0);
2347            assert_eq!(simplify(expr), col("c2_non_null"));
2348        }
2349        // 0 | A --> A
2350        {
2351            let expr = lit(0) | col("c2_non_null");
2352            assert_eq!(simplify(expr), col("c2_non_null"));
2353        }
2354    }
2355
2356    #[test]
2357    fn test_simplify_bitwise_xor_by_zero() {
2358        // A ^ 0 --> A
2359        {
2360            let expr = col("c2_non_null") ^ lit(0);
2361            assert_eq!(simplify(expr), col("c2_non_null"));
2362        }
2363        // 0 ^ A --> A
2364        {
2365            let expr = lit(0) ^ col("c2_non_null");
2366            assert_eq!(simplify(expr), col("c2_non_null"));
2367        }
2368    }
2369
2370    #[test]
2371    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
2372        // A >> 0 --> A
2373        {
2374            let expr = col("c2_non_null") >> lit(0);
2375            assert_eq!(simplify(expr), col("c2_non_null"));
2376        }
2377    }
2378
2379    #[test]
2380    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
2381        // A << 0 --> A
2382        {
2383            let expr = col("c2_non_null") << lit(0);
2384            assert_eq!(simplify(expr), col("c2_non_null"));
2385        }
2386    }
2387
2388    #[test]
2389    fn test_simplify_bitwise_and_by_null() {
2390        let null = lit(ScalarValue::Null);
2391        // A & null --> null
2392        {
2393            let expr = col("c2") & null.clone();
2394            assert_eq!(simplify(expr), null);
2395        }
2396        // null & A --> null
2397        {
2398            let expr = null.clone() & col("c2");
2399            assert_eq!(simplify(expr), null);
2400        }
2401    }
2402
2403    #[test]
2404    fn test_simplify_composed_bitwise_and() {
2405        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
2406
2407        let expr = bitwise_and(
2408            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2409            col("c2").gt(lit(5)),
2410        );
2411        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2412
2413        assert_eq!(simplify(expr), expected);
2414
2415        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
2416
2417        let expr = bitwise_and(
2418            col("c2").gt(lit(5)),
2419            bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2420        );
2421        let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2422        assert_eq!(simplify(expr), expected);
2423    }
2424
2425    #[test]
2426    fn test_simplify_composed_bitwise_or() {
2427        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
2428
2429        let expr = bitwise_or(
2430            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2431            col("c2").gt(lit(5)),
2432        );
2433        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2434
2435        assert_eq!(simplify(expr), expected);
2436
2437        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
2438
2439        let expr = bitwise_or(
2440            col("c2").gt(lit(5)),
2441            bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2442        );
2443        let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2444
2445        assert_eq!(simplify(expr), expected);
2446    }
2447
2448    #[test]
2449    fn test_simplify_composed_bitwise_xor() {
2450        // with an even number of the column "c2"
2451        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
2452
2453        let expr = bitwise_xor(
2454            col("c2"),
2455            bitwise_xor(
2456                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2457                bitwise_and(col("c1"), col("c2")),
2458            ),
2459        );
2460
2461        let expected = bitwise_xor(
2462            bitwise_or(col("c2"), col("c1")),
2463            bitwise_and(col("c1"), col("c2")),
2464        );
2465
2466        assert_eq!(simplify(expr), expected);
2467
2468        // with an odd number of the column "c2"
2469        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2))
2470
2471        let expr = bitwise_xor(
2472            col("c2"),
2473            bitwise_xor(
2474                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2475                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2476            ),
2477        );
2478
2479        let expected = bitwise_xor(
2480            col("c2"),
2481            bitwise_xor(
2482                bitwise_or(col("c2"), col("c1")),
2483                bitwise_and(col("c1"), col("c2")),
2484            ),
2485        );
2486
2487        assert_eq!(simplify(expr), expected);
2488
2489        // with an even number of the column "c2"
2490        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
2491
2492        let expr = bitwise_xor(
2493            bitwise_xor(
2494                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2495                bitwise_and(col("c1"), col("c2")),
2496            ),
2497            col("c2"),
2498        );
2499
2500        let expected = bitwise_xor(
2501            bitwise_or(col("c2"), col("c1")),
2502            bitwise_and(col("c1"), col("c2")),
2503        );
2504
2505        assert_eq!(simplify(expr), expected);
2506
2507        // with an odd number of the column "c2"
2508        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2
2509
2510        let expr = bitwise_xor(
2511            bitwise_xor(
2512                bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))),
2513                bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")),
2514            ),
2515            col("c2"),
2516        );
2517
2518        let expected = bitwise_xor(
2519            bitwise_xor(
2520                bitwise_or(col("c2"), col("c1")),
2521                bitwise_and(col("c1"), col("c2")),
2522            ),
2523            col("c2"),
2524        );
2525
2526        assert_eq!(simplify(expr), expected);
2527    }
2528
2529    #[test]
2530    fn test_simplify_negated_bitwise_and() {
2531        // !c4 & c4 --> 0
2532        let expr = (-col("c4_non_null")) & col("c4_non_null");
2533        let expected = lit(0u32);
2534
2535        assert_eq!(simplify(expr), expected);
2536        // c4 & !c4 --> 0
2537        let expr = col("c4_non_null") & (-col("c4_non_null"));
2538        let expected = lit(0u32);
2539
2540        assert_eq!(simplify(expr), expected);
2541
2542        // !c3 & c3 --> 0
2543        let expr = (-col("c3_non_null")) & col("c3_non_null");
2544        let expected = lit(0i64);
2545
2546        assert_eq!(simplify(expr), expected);
2547        // c3 & !c3 --> 0
2548        let expr = col("c3_non_null") & (-col("c3_non_null"));
2549        let expected = lit(0i64);
2550
2551        assert_eq!(simplify(expr), expected);
2552    }
2553
2554    #[test]
2555    fn test_simplify_negated_bitwise_or() {
2556        // !c4 | c4 --> -1
2557        let expr = (-col("c4_non_null")) | col("c4_non_null");
2558        let expected = lit(-1i32);
2559
2560        assert_eq!(simplify(expr), expected);
2561
2562        // c4 | !c4 --> -1
2563        let expr = col("c4_non_null") | (-col("c4_non_null"));
2564        let expected = lit(-1i32);
2565
2566        assert_eq!(simplify(expr), expected);
2567
2568        // !c3 | c3 --> -1
2569        let expr = (-col("c3_non_null")) | col("c3_non_null");
2570        let expected = lit(-1i64);
2571
2572        assert_eq!(simplify(expr), expected);
2573
2574        // c3 | !c3 --> -1
2575        let expr = col("c3_non_null") | (-col("c3_non_null"));
2576        let expected = lit(-1i64);
2577
2578        assert_eq!(simplify(expr), expected);
2579    }
2580
2581    #[test]
2582    fn test_simplify_negated_bitwise_xor() {
2583        // !c4 ^ c4 --> -1
2584        let expr = (-col("c4_non_null")) ^ col("c4_non_null");
2585        let expected = lit(-1i32);
2586
2587        assert_eq!(simplify(expr), expected);
2588
2589        // c4 ^ !c4 --> -1
2590        let expr = col("c4_non_null") ^ (-col("c4_non_null"));
2591        let expected = lit(-1i32);
2592
2593        assert_eq!(simplify(expr), expected);
2594
2595        // !c3 ^ c3 --> -1
2596        let expr = (-col("c3_non_null")) ^ col("c3_non_null");
2597        let expected = lit(-1i64);
2598
2599        assert_eq!(simplify(expr), expected);
2600
2601        // c3 ^ !c3 --> -1
2602        let expr = col("c3_non_null") ^ (-col("c3_non_null"));
2603        let expected = lit(-1i64);
2604
2605        assert_eq!(simplify(expr), expected);
2606    }
2607
2608    #[test]
2609    fn test_simplify_bitwise_and_or() {
2610        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
2611        let expr = bitwise_and(
2612            col("c2_non_null").lt(lit(3)),
2613            bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2614        );
2615        let expected = col("c2_non_null").lt(lit(3));
2616
2617        assert_eq!(simplify(expr), expected);
2618    }
2619
2620    #[test]
2621    fn test_simplify_bitwise_or_and() {
2622        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
2623        let expr = bitwise_or(
2624            col("c2_non_null").lt(lit(3)),
2625            bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")),
2626        );
2627        let expected = col("c2_non_null").lt(lit(3));
2628
2629        assert_eq!(simplify(expr), expected);
2630    }
2631
2632    #[test]
2633    fn test_simplify_simple_bitwise_and() {
2634        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
2635        let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5)));
2636        let expected = col("c2").gt(lit(5));
2637
2638        assert_eq!(simplify(expr), expected);
2639    }
2640
2641    #[test]
2642    fn test_simplify_simple_bitwise_or() {
2643        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
2644        let expr = (col("c2").gt(lit(5))).bitor(col("c2").gt(lit(5)));
2645        let expected = col("c2").gt(lit(5));
2646
2647        assert_eq!(simplify(expr), expected);
2648    }
2649
2650    #[test]
2651    fn test_simplify_simple_bitwise_xor() {
2652        // c4 ^ c4 -> 0
2653        let expr = (col("c4")).bitxor(col("c4"));
2654        let expected = lit(0u32);
2655
2656        assert_eq!(simplify(expr), expected);
2657
2658        // c3 ^ c3 -> 0
2659        let expr = col("c3").bitxor(col("c3"));
2660        let expected = lit(0i64);
2661
2662        assert_eq!(simplify(expr), expected);
2663    }
2664
2665    #[test]
2666    fn test_simplify_modulo_by_zero_non_null() {
2667        // because modulo by 0 maybe occur in short-circuit expression
2668        // so we should not simplify this, and throw error in runtime.
2669        let expr = col("c2_non_null") % lit(0);
2670        let expected = expr.clone();
2671
2672        assert_eq!(simplify(expr), expected);
2673    }
2674
2675    #[test]
2676    fn test_simplify_simple_and() {
2677        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
2678        let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
2679        let expected = col("c2").gt(lit(5));
2680
2681        assert_eq!(simplify(expr), expected);
2682    }
2683
2684    #[test]
2685    fn test_simplify_composed_and() {
2686        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
2687        let expr = and(
2688            and(col("c2").gt(lit(5)), col("c1").lt(lit(6))),
2689            col("c2").gt(lit(5)),
2690        );
2691        let expected = and(col("c2").gt(lit(5)), col("c1").lt(lit(6)));
2692
2693        assert_eq!(simplify(expr), expected);
2694    }
2695
2696    #[test]
2697    fn test_simplify_negated_and() {
2698        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
2699        let expr = and(col("c2").gt(lit(5)), Expr::not(col("c2").gt(lit(5))));
2700        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
2701
2702        assert_eq!(simplify(expr), expected);
2703    }
2704
2705    #[test]
2706    fn test_simplify_or_and() {
2707        let l = col("c2").gt(lit(5));
2708        let r = and(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2709
2710        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
2711        let expr = or(l.clone(), r.clone());
2712
2713        let expected = l.clone();
2714        assert_eq!(simplify(expr), expected);
2715
2716        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
2717        let expr = or(r, l);
2718        assert_eq!(simplify(expr), expected);
2719    }
2720
2721    #[test]
2722    fn test_simplify_or_and_non_null() {
2723        let l = col("c2_non_null").gt(lit(5));
2724        let r = and(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2725
2726        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
2727        let expr = or(l.clone(), r.clone());
2728
2729        // This is only true if `c1 < 6` is not nullable / can not be null.
2730        let expected = col("c2_non_null").gt(lit(5));
2731
2732        assert_eq!(simplify(expr), expected);
2733
2734        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
2735        let expr = or(l, r);
2736
2737        assert_eq!(simplify(expr), expected);
2738    }
2739
2740    #[test]
2741    fn test_simplify_and_or() {
2742        let l = col("c2").gt(lit(5));
2743        let r = or(col("c1").lt(lit(6)), col("c2").gt(lit(5)));
2744
2745        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2746        let expr = and(l.clone(), r.clone());
2747
2748        let expected = l.clone();
2749        assert_eq!(simplify(expr), expected);
2750
2751        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2752        let expr = and(r, l);
2753        assert_eq!(simplify(expr), expected);
2754    }
2755
2756    #[test]
2757    fn test_simplify_and_or_non_null() {
2758        let l = col("c2_non_null").gt(lit(5));
2759        let r = or(col("c1_non_null").lt(lit(6)), col("c2_non_null").gt(lit(5)));
2760
2761        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
2762        let expr = and(l.clone(), r.clone());
2763
2764        // This is only true if `c1 < 6` is not nullable / can not be null.
2765        let expected = col("c2_non_null").gt(lit(5));
2766
2767        assert_eq!(simplify(expr), expected);
2768
2769        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
2770        let expr = and(l, r);
2771
2772        assert_eq!(simplify(expr), expected);
2773    }
2774
2775    #[test]
2776    fn test_simplify_by_de_morgan_laws() {
2777        // Laws with logical operations
2778        // !(c3 AND c4) --> !c3 OR !c4
2779        let expr = and(col("c3"), col("c4")).not();
2780        let expected = or(col("c3").not(), col("c4").not());
2781        assert_eq!(simplify(expr), expected);
2782        // !(c3 OR c4) --> !c3 AND !c4
2783        let expr = or(col("c3"), col("c4")).not();
2784        let expected = and(col("c3").not(), col("c4").not());
2785        assert_eq!(simplify(expr), expected);
2786        // !(!c3) --> c3
2787        let expr = col("c3").not().not();
2788        let expected = col("c3");
2789        assert_eq!(simplify(expr), expected);
2790
2791        // Laws with bitwise operations
2792        // !(c3 & c4) --> !c3 | !c4
2793        let expr = -bitwise_and(col("c3"), col("c4"));
2794        let expected = bitwise_or(-col("c3"), -col("c4"));
2795        assert_eq!(simplify(expr), expected);
2796        // !(c3 | c4) --> !c3 & !c4
2797        let expr = -bitwise_or(col("c3"), col("c4"));
2798        let expected = bitwise_and(-col("c3"), -col("c4"));
2799        assert_eq!(simplify(expr), expected);
2800        // !(!c3) --> c3
2801        let expr = -(-col("c3"));
2802        let expected = col("c3");
2803        assert_eq!(simplify(expr), expected);
2804    }
2805
2806    #[test]
2807    fn test_simplify_null_and_false() {
2808        let expr = and(lit_bool_null(), lit(false));
2809        let expr_eq = lit(false);
2810
2811        assert_eq!(simplify(expr), expr_eq);
2812    }
2813
2814    #[test]
2815    fn test_simplify_divide_null_by_null() {
2816        let null = lit(ScalarValue::Int32(None));
2817        let expr_plus = null.clone() / null.clone();
2818        let expr_eq = null;
2819
2820        assert_eq!(simplify(expr_plus), expr_eq);
2821    }
2822
2823    #[test]
2824    fn test_simplify_simplify_arithmetic_expr() {
2825        let expr_plus = lit(1) + lit(1);
2826
2827        assert_eq!(simplify(expr_plus), lit(2));
2828    }
2829
2830    #[test]
2831    fn test_simplify_simplify_eq_expr() {
2832        let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
2833
2834        assert_eq!(simplify(expr_eq), lit(true));
2835    }
2836
2837    #[test]
2838    fn test_simplify_regex() {
2839        // malformed regex
2840        assert_contains!(
2841            try_simplify(regex_match(col("c1"), lit("foo{")))
2842                .unwrap_err()
2843                .to_string(),
2844            "regex parse error"
2845        );
2846
2847        // unsupported cases
2848        assert_no_change(regex_match(col("c1"), lit("foo.*")));
2849        assert_no_change(regex_match(col("c1"), lit("(foo)")));
2850        assert_no_change(regex_match(col("c1"), lit("%")));
2851        assert_no_change(regex_match(col("c1"), lit("_")));
2852        assert_no_change(regex_match(col("c1"), lit("f%o")));
2853        assert_no_change(regex_match(col("c1"), lit("^f%o")));
2854        assert_no_change(regex_match(col("c1"), lit("f_o")));
2855
2856        // empty cases
2857        assert_change(
2858            regex_match(col("c1"), lit("")),
2859            if_not_null(col("c1"), true),
2860        );
2861        assert_change(
2862            regex_not_match(col("c1"), lit("")),
2863            if_not_null(col("c1"), false),
2864        );
2865        assert_change(
2866            regex_imatch(col("c1"), lit("")),
2867            if_not_null(col("c1"), true),
2868        );
2869        assert_change(
2870            regex_not_imatch(col("c1"), lit("")),
2871            if_not_null(col("c1"), false),
2872        );
2873
2874        // single character
2875        assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%")));
2876
2877        // single word
2878        assert_change(
2879            regex_match(col("c1"), lit("foo")),
2880            col("c1").like(lit("%foo%")),
2881        );
2882
2883        // regular expressions that match an exact literal
2884        assert_change(regex_match(col("c1"), lit("^$")), col("c1").eq(lit("")));
2885        assert_change(
2886            regex_not_match(col("c1"), lit("^$")),
2887            col("c1").not_eq(lit("")),
2888        );
2889        assert_change(
2890            regex_match(col("c1"), lit("^foo$")),
2891            col("c1").eq(lit("foo")),
2892        );
2893        assert_change(
2894            regex_not_match(col("c1"), lit("^foo$")),
2895            col("c1").not_eq(lit("foo")),
2896        );
2897
2898        // regular expressions that match exact captured literals
2899        assert_change(
2900            regex_match(col("c1"), lit("^(foo|bar)$")),
2901            col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))),
2902        );
2903        assert_change(
2904            regex_not_match(col("c1"), lit("^(foo|bar)$")),
2905            col("c1")
2906                .not_eq(lit("foo"))
2907                .and(col("c1").not_eq(lit("bar"))),
2908        );
2909        assert_change(
2910            regex_match(col("c1"), lit("^(foo)$")),
2911            col("c1").eq(lit("foo")),
2912        );
2913        assert_change(
2914            regex_match(col("c1"), lit("^(foo|bar|baz)$")),
2915            ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar"))))
2916                .or(col("c1").eq(lit("baz"))),
2917        );
2918        assert_change(
2919            regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")),
2920            col("c1")
2921                .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false),
2922        );
2923        assert_change(
2924            regex_match(col("c1"), lit("^(fo_o)$")),
2925            col("c1").eq(lit("fo_o")),
2926        );
2927        assert_change(
2928            regex_match(col("c1"), lit("^(fo_o)$")),
2929            col("c1").eq(lit("fo_o")),
2930        );
2931        assert_change(
2932            regex_match(col("c1"), lit("^(fo_o|ba_r)$")),
2933            col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))),
2934        );
2935        assert_change(
2936            regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")),
2937            col("c1")
2938                .not_eq(lit("fo_o"))
2939                .and(col("c1").not_eq(lit("ba_r"))),
2940        );
2941        assert_change(
2942            regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")),
2943            ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r"))))
2944                .or(col("c1").eq(lit("ba_z"))),
2945        );
2946        assert_change(
2947            regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")),
2948            col("c1").in_list(
2949                vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")],
2950                false,
2951            ),
2952        );
2953
2954        // regular expressions that mismatch captured literals
2955        assert_no_change(regex_match(col("c1"), lit("(foo|bar)")));
2956        assert_no_change(regex_match(col("c1"), lit("(foo|bar)*")));
2957        assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)")));
2958        assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*")));
2959        assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*")));
2960        assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*")));
2961        assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$")));
2962        assert_no_change(regex_match(col("c1"), lit("^")));
2963        assert_no_change(regex_match(col("c1"), lit("$")));
2964        assert_no_change(regex_match(col("c1"), lit("$^")));
2965        assert_no_change(regex_match(col("c1"), lit("$foo^")));
2966
2967        // regular expressions that match a partial literal
2968        assert_change(
2969            regex_match(col("c1"), lit("^foo")),
2970            col("c1").like(lit("foo%")),
2971        );
2972        assert_change(
2973            regex_match(col("c1"), lit("foo$")),
2974            col("c1").like(lit("%foo")),
2975        );
2976        assert_change(
2977            regex_match(col("c1"), lit("^foo|bar$")),
2978            col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))),
2979        );
2980
2981        // OR-chain
2982        assert_change(
2983            regex_match(col("c1"), lit("foo|bar|baz")),
2984            col("c1")
2985                .like(lit("%foo%"))
2986                .or(col("c1").like(lit("%bar%")))
2987                .or(col("c1").like(lit("%baz%"))),
2988        );
2989        assert_change(
2990            regex_match(col("c1"), lit("foo|x|baz")),
2991            col("c1")
2992                .like(lit("%foo%"))
2993                .or(col("c1").like(lit("%x%")))
2994                .or(col("c1").like(lit("%baz%"))),
2995        );
2996        assert_change(
2997            regex_not_match(col("c1"), lit("foo|bar|baz")),
2998            col("c1")
2999                .not_like(lit("%foo%"))
3000                .and(col("c1").not_like(lit("%bar%")))
3001                .and(col("c1").not_like(lit("%baz%"))),
3002        );
3003        // both anchored expressions (translated to equality) and unanchored
3004        assert_change(
3005            regex_match(col("c1"), lit("foo|^x$|baz")),
3006            col("c1")
3007                .like(lit("%foo%"))
3008                .or(col("c1").eq(lit("x")))
3009                .or(col("c1").like(lit("%baz%"))),
3010        );
3011        assert_change(
3012            regex_not_match(col("c1"), lit("foo|^bar$|baz")),
3013            col("c1")
3014                .not_like(lit("%foo%"))
3015                .and(col("c1").not_eq(lit("bar")))
3016                .and(col("c1").not_like(lit("%baz%"))),
3017        );
3018        // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION)
3019        assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc")));
3020    }
3021
3022    #[track_caller]
3023    fn assert_no_change(expr: Expr) {
3024        let optimized = simplify(expr.clone());
3025        assert_eq!(expr, optimized);
3026    }
3027
3028    #[track_caller]
3029    fn assert_change(expr: Expr, expected: Expr) {
3030        let optimized = simplify(expr);
3031        assert_eq!(optimized, expected);
3032    }
3033
3034    fn regex_match(left: Expr, right: Expr) -> Expr {
3035        Expr::BinaryExpr(BinaryExpr {
3036            left: Box::new(left),
3037            op: Operator::RegexMatch,
3038            right: Box::new(right),
3039        })
3040    }
3041
3042    fn regex_not_match(left: Expr, right: Expr) -> Expr {
3043        Expr::BinaryExpr(BinaryExpr {
3044            left: Box::new(left),
3045            op: Operator::RegexNotMatch,
3046            right: Box::new(right),
3047        })
3048    }
3049
3050    fn regex_imatch(left: Expr, right: Expr) -> Expr {
3051        Expr::BinaryExpr(BinaryExpr {
3052            left: Box::new(left),
3053            op: Operator::RegexIMatch,
3054            right: Box::new(right),
3055        })
3056    }
3057
3058    fn regex_not_imatch(left: Expr, right: Expr) -> Expr {
3059        Expr::BinaryExpr(BinaryExpr {
3060            left: Box::new(left),
3061            op: Operator::RegexNotIMatch,
3062            right: Box::new(right),
3063        })
3064    }
3065
3066    // ------------------------------
3067    // ----- Simplifier tests -------
3068    // ------------------------------
3069
3070    fn try_simplify(expr: Expr) -> Result<Expr> {
3071        let schema = expr_test_schema();
3072        let execution_props = ExecutionProps::new();
3073        let simplifier = ExprSimplifier::new(
3074            SimplifyContext::new(&execution_props).with_schema(schema),
3075        );
3076        simplifier.simplify(expr)
3077    }
3078
3079    fn simplify(expr: Expr) -> Expr {
3080        try_simplify(expr).unwrap()
3081    }
3082
3083    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
3084        let schema = expr_test_schema();
3085        let execution_props = ExecutionProps::new();
3086        let simplifier = ExprSimplifier::new(
3087            SimplifyContext::new(&execution_props).with_schema(schema),
3088        );
3089        simplifier.simplify_with_cycle_count(expr)
3090    }
3091
3092    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
3093        try_simplify_with_cycle_count(expr).unwrap()
3094    }
3095
3096    fn simplify_with_guarantee(
3097        expr: Expr,
3098        guarantees: Vec<(Expr, NullableInterval)>,
3099    ) -> Expr {
3100        let schema = expr_test_schema();
3101        let execution_props = ExecutionProps::new();
3102        let simplifier = ExprSimplifier::new(
3103            SimplifyContext::new(&execution_props).with_schema(schema),
3104        )
3105        .with_guarantees(guarantees);
3106        simplifier.simplify(expr).unwrap()
3107    }
3108
3109    fn expr_test_schema() -> DFSchemaRef {
3110        Arc::new(
3111            DFSchema::from_unqualified_fields(
3112                vec![
3113                    Field::new("c1", DataType::Utf8, true),
3114                    Field::new("c2", DataType::Boolean, true),
3115                    Field::new("c3", DataType::Int64, true),
3116                    Field::new("c4", DataType::UInt32, true),
3117                    Field::new("c1_non_null", DataType::Utf8, false),
3118                    Field::new("c2_non_null", DataType::Boolean, false),
3119                    Field::new("c3_non_null", DataType::Int64, false),
3120                    Field::new("c4_non_null", DataType::UInt32, false),
3121                ]
3122                .into(),
3123                HashMap::new(),
3124            )
3125            .unwrap(),
3126        )
3127    }
3128
3129    #[test]
3130    fn simplify_expr_null_comparison() {
3131        // x = null is always null
3132        assert_eq!(
3133            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
3134            lit(ScalarValue::Boolean(None)),
3135        );
3136
3137        // null != null is always null
3138        assert_eq!(
3139            simplify(
3140                lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
3141            ),
3142            lit(ScalarValue::Boolean(None)),
3143        );
3144
3145        // x != null is always null
3146        assert_eq!(
3147            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
3148            lit(ScalarValue::Boolean(None)),
3149        );
3150
3151        // null = x is always null
3152        assert_eq!(
3153            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
3154            lit(ScalarValue::Boolean(None)),
3155        );
3156    }
3157
3158    #[test]
3159    fn simplify_expr_is_not_null() {
3160        assert_eq!(
3161            simplify(Expr::IsNotNull(Box::new(col("c1")))),
3162            Expr::IsNotNull(Box::new(col("c1")))
3163        );
3164
3165        // 'c1_non_null IS NOT NULL' is always true
3166        assert_eq!(
3167            simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))),
3168            lit(true)
3169        );
3170    }
3171
3172    #[test]
3173    fn simplify_expr_is_null() {
3174        assert_eq!(
3175            simplify(Expr::IsNull(Box::new(col("c1")))),
3176            Expr::IsNull(Box::new(col("c1")))
3177        );
3178
3179        // 'c1_non_null IS NULL' is always false
3180        assert_eq!(
3181            simplify(Expr::IsNull(Box::new(col("c1_non_null")))),
3182            lit(false)
3183        );
3184    }
3185
3186    #[test]
3187    fn simplify_expr_is_unknown() {
3188        assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),);
3189
3190        // 'c2_non_null is unknown' is always false
3191        assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false));
3192    }
3193
3194    #[test]
3195    fn simplify_expr_is_not_known() {
3196        assert_eq!(
3197            simplify(col("c2").is_not_unknown()),
3198            col("c2").is_not_unknown()
3199        );
3200
3201        // 'c2_non_null is not unknown' is always true
3202        assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true));
3203    }
3204
3205    #[test]
3206    fn simplify_expr_eq() {
3207        let schema = expr_test_schema();
3208        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3209
3210        // true = true -> true
3211        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
3212
3213        // true = false -> false
3214        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
3215
3216        // c2 = true -> c2
3217        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
3218
3219        // c2 = false => !c2
3220        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
3221    }
3222
3223    #[test]
3224    fn simplify_expr_eq_skip_nonboolean_type() {
3225        let schema = expr_test_schema();
3226
3227        // When one of the operand is not of boolean type, folding the
3228        // other boolean constant will change return type of
3229        // expression to non-boolean.
3230        //
3231        // Make sure c1 column to be used in tests is not boolean type
3232        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3233
3234        // don't fold c1 = foo
3235        assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
3236    }
3237
3238    #[test]
3239    fn simplify_expr_not_eq() {
3240        let schema = expr_test_schema();
3241
3242        assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
3243
3244        // c2 != true -> !c2
3245        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
3246
3247        // c2 != false -> c2
3248        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
3249
3250        // test constant
3251        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
3252
3253        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
3254    }
3255
3256    #[test]
3257    fn simplify_expr_not_eq_skip_nonboolean_type() {
3258        let schema = expr_test_schema();
3259
3260        // when one of the operand is not of boolean type, folding the
3261        // other boolean constant will change return type of
3262        // expression to non-boolean.
3263        assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
3264
3265        assert_eq!(
3266            simplify(col("c1").not_eq(lit("foo"))),
3267            col("c1").not_eq(lit("foo")),
3268        );
3269    }
3270
3271    #[test]
3272    fn simplify_expr_case_when_then_else() {
3273        // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
3274        // -->
3275        // CASE WHEN c2 THEN false ELSE c2
3276        // -->
3277        // false
3278        assert_eq!(
3279            simplify(Expr::Case(Case::new(
3280                None,
3281                vec![(
3282                    Box::new(col("c2_non_null").not_eq(lit(false))),
3283                    Box::new(lit("ok").eq(lit("not_ok"))),
3284                )],
3285                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3286            ))),
3287            lit(false) // #1716
3288        );
3289
3290        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
3291        // -->
3292        // CASE WHEN c2 THEN true ELSE c2
3293        // -->
3294        // c2
3295        //
3296        // Need to call simplify 2x due to
3297        // https://github.com/apache/datafusion/issues/1160
3298        assert_eq!(
3299            simplify(simplify(Expr::Case(Case::new(
3300                None,
3301                vec![(
3302                    Box::new(col("c2_non_null").not_eq(lit(false))),
3303                    Box::new(lit("ok").eq(lit("ok"))),
3304                )],
3305                Some(Box::new(col("c2_non_null").eq(lit(true)))),
3306            )))),
3307            col("c2_non_null")
3308        );
3309
3310        // CASE WHEN ISNULL(c2) THEN true ELSE c2
3311        // -->
3312        // ISNULL(c2) OR c2
3313        //
3314        // Need to call simplify 2x due to
3315        // https://github.com/apache/datafusion/issues/1160
3316        assert_eq!(
3317            simplify(simplify(Expr::Case(Case::new(
3318                None,
3319                vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)],
3320                Some(Box::new(col("c2"))),
3321            )))),
3322            col("c2")
3323                .is_null()
3324                .or(col("c2").is_not_null().and(col("c2")))
3325        );
3326
3327        // CASE WHEN c1 then true WHEN c2 then false ELSE true
3328        // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
3329        // --> c1 OR (NOT(c1) AND NOT(c2))
3330        // --> c1 OR NOT(c2)
3331        //
3332        // Need to call simplify 2x due to
3333        // https://github.com/apache/datafusion/issues/1160
3334        assert_eq!(
3335            simplify(simplify(Expr::Case(Case::new(
3336                None,
3337                vec![
3338                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3339                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3340                ],
3341                Some(Box::new(lit(true))),
3342            )))),
3343            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3344        );
3345
3346        // CASE WHEN c1 then true WHEN c2 then true ELSE false
3347        // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
3348        // --> c1 OR (NOT(c1) AND c2)
3349        // --> c1 OR c2
3350        //
3351        // Need to call simplify 2x due to
3352        // https://github.com/apache/datafusion/issues/1160
3353        assert_eq!(
3354            simplify(simplify(Expr::Case(Case::new(
3355                None,
3356                vec![
3357                    (Box::new(col("c1_non_null")), Box::new(lit(true)),),
3358                    (Box::new(col("c2_non_null")), Box::new(lit(false)),),
3359                ],
3360                Some(Box::new(lit(true))),
3361            )))),
3362            col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3363        );
3364
3365        // CASE WHEN c > 0 THEN true END AS c1
3366        assert_eq!(
3367            simplify(simplify(Expr::Case(Case::new(
3368                None,
3369                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3370                None,
3371            )))),
3372            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3373                col("c3").gt(lit(0_i64)),
3374                lit(true)
3375            )
3376            .and(lit_bool_null()))
3377        );
3378
3379        // CASE WHEN c > 0 THEN true ELSE false END AS c1
3380        assert_eq!(
3381            simplify(simplify(Expr::Case(Case::new(
3382                None,
3383                vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3384                Some(Box::new(lit(false))),
3385            )))),
3386            not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3387        );
3388    }
3389
3390    fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3391        Expr::BinaryExpr(BinaryExpr {
3392            left: Box::new(left.into()),
3393            op: Operator::IsDistinctFrom,
3394            right: Box::new(right.into()),
3395        })
3396    }
3397
3398    fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3399        Expr::BinaryExpr(BinaryExpr {
3400            left: Box::new(left.into()),
3401            op: Operator::IsNotDistinctFrom,
3402            right: Box::new(right.into()),
3403        })
3404    }
3405
3406    #[test]
3407    fn simplify_expr_bool_or() {
3408        // col || true is always true
3409        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
3410
3411        // col || false is always col
3412        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
3413
3414        // true || null is always true
3415        assert_eq!(simplify(lit(true).or(lit_bool_null())), lit(true),);
3416
3417        // null || true is always true
3418        assert_eq!(simplify(lit_bool_null().or(lit(true))), lit(true),);
3419
3420        // false || null is always null
3421        assert_eq!(simplify(lit(false).or(lit_bool_null())), lit_bool_null(),);
3422
3423        // null || false is always null
3424        assert_eq!(simplify(lit_bool_null().or(lit(false))), lit_bool_null(),);
3425
3426        // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
3427        // it can be either NULL or  TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3428        // and should not be rewritten
3429        let expr = col("c1").between(lit(0), lit(10));
3430        let expr = expr.or(lit_bool_null());
3431        let result = simplify(expr);
3432
3433        let expected_expr = or(
3434            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3435            lit_bool_null(),
3436        );
3437        assert_eq!(expected_expr, result);
3438    }
3439
3440    #[test]
3441    fn simplify_inlist() {
3442        assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
3443        assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
3444
3445        // null in (...)  --> null
3446        assert_eq!(
3447            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)),
3448            lit_bool_null()
3449        );
3450
3451        // null not in (...)  --> null
3452        assert_eq!(
3453            simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)),
3454            lit_bool_null()
3455        );
3456
3457        assert_eq!(
3458            simplify(in_list(col("c1"), vec![lit(1)], false)),
3459            col("c1").eq(lit(1))
3460        );
3461        assert_eq!(
3462            simplify(in_list(col("c1"), vec![lit(1)], true)),
3463            col("c1").not_eq(lit(1))
3464        );
3465
3466        // more complex expressions can be simplified if list contains
3467        // one element only
3468        assert_eq!(
3469            simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
3470            (col("c1") * lit(10)).eq(lit(2))
3471        );
3472
3473        assert_eq!(
3474            simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
3475            col("c1").eq(lit(1)).or(col("c1").eq(lit(2)))
3476        );
3477        assert_eq!(
3478            simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
3479            col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2)))
3480        );
3481
3482        let subquery = Arc::new(test_table_scan_with_name("test").unwrap());
3483        assert_eq!(
3484            simplify(in_list(
3485                col("c1"),
3486                vec![scalar_subquery(Arc::clone(&subquery))],
3487                false
3488            )),
3489            in_subquery(col("c1"), Arc::clone(&subquery))
3490        );
3491        assert_eq!(
3492            simplify(in_list(
3493                col("c1"),
3494                vec![scalar_subquery(Arc::clone(&subquery))],
3495                true
3496            )),
3497            not_in_subquery(col("c1"), subquery)
3498        );
3499
3500        let subquery1 =
3501            scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap()));
3502        let subquery2 =
3503            scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap()));
3504
3505        // c1 NOT IN (<subquery1>, <subquery2>) -> c1 != <subquery1> AND c1 != <subquery2>
3506        assert_eq!(
3507            simplify(in_list(
3508                col("c1"),
3509                vec![subquery1.clone(), subquery2.clone()],
3510                true
3511            )),
3512            col("c1")
3513                .not_eq(subquery1.clone())
3514                .and(col("c1").not_eq(subquery2.clone()))
3515        );
3516
3517        // c1 IN (<subquery1>, <subquery2>) -> c1 == <subquery1> OR c1 == <subquery2>
3518        assert_eq!(
3519            simplify(in_list(
3520                col("c1"),
3521                vec![subquery1.clone(), subquery2.clone()],
3522                false
3523            )),
3524            col("c1").eq(subquery1).or(col("c1").eq(subquery2))
3525        );
3526
3527        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
3528        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3529            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
3530        );
3531        assert_eq!(simplify(expr), lit(false));
3532
3533        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
3534        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3535            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
3536        );
3537        assert_eq!(simplify(expr), col("c1").eq(lit(4)));
3538
3539        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
3540        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3541            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3542        );
3543        assert_eq!(simplify(expr), lit(true));
3544
3545        // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps)
3546        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(
3547            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3548        );
3549        assert_eq!(simplify(expr), col("c1").not_eq(lit(4)));
3550
3551        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7)
3552        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3553            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
3554        );
3555        assert_eq!(
3556            simplify(expr),
3557            in_list(
3558                col("c1"),
3559                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
3560                true
3561            )
3562        );
3563
3564        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
3565        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or(
3566            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
3567        );
3568        assert_eq!(
3569            simplify(expr),
3570            in_list(
3571                col("c1"),
3572                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3573                false
3574            )
3575        );
3576
3577        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
3578        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list(
3579            col("c1"),
3580            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3581            true,
3582        ));
3583        assert_eq!(simplify(expr), lit(false));
3584
3585        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
3586        let expr =
3587            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list(
3588                col("c1"),
3589                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
3590                false,
3591            ));
3592        assert_eq!(simplify(expr), col("c1").eq(lit(5)));
3593
3594        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
3595        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and(
3596            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
3597        );
3598        assert_eq!(
3599            simplify(expr),
3600            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
3601        );
3602
3603        // inlist with more than two expressions
3604        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6
3605        let expr = in_list(
3606            col("c1"),
3607            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
3608            false,
3609        )
3610        .and(in_list(
3611            col("c1"),
3612            vec![lit(1), lit(3), lit(5), lit(6)],
3613            false,
3614        ))
3615        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
3616        assert_eq!(
3617            simplify(expr),
3618            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
3619        );
3620
3621        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8
3622        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(
3623            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
3624                .and(in_list(
3625                    col("c1"),
3626                    vec![lit(3), lit(4), lit(5), lit(6)],
3627                    true,
3628                ))
3629                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
3630        );
3631        assert_eq!(simplify(expr), col("c1").eq(lit(8)));
3632
3633        // Contains non-InList expression
3634        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
3635        let expr =
3636            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1")
3637                .not_eq(lit(5))
3638                .or(in_list(
3639                    col("c1"),
3640                    vec![lit(6), lit(7), lit(8), lit(9)],
3641                    true,
3642                )));
3643        // TODO: Further simplify this expression
3644        // https://github.com/apache/datafusion/issues/8970
3645        // assert_eq!(simplify(expr.clone()), lit(true));
3646        assert_eq!(simplify(expr.clone()), expr);
3647    }
3648
3649    #[test]
3650    fn simplify_large_or() {
3651        let expr = (0..5)
3652            .map(|i| col("c1").eq(lit(i)))
3653            .fold(lit(false), |acc, e| acc.or(e));
3654        assert_eq!(
3655            simplify(expr),
3656            in_list(col("c1"), (0..5).map(lit).collect(), false),
3657        );
3658    }
3659
3660    #[test]
3661    fn simplify_expr_bool_and() {
3662        // col & true is always col
3663        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
3664        // col & false is always false
3665        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
3666
3667        // true && null is always null
3668        assert_eq!(simplify(lit(true).and(lit_bool_null())), lit_bool_null(),);
3669
3670        // null && true is always null
3671        assert_eq!(simplify(lit_bool_null().and(lit(true))), lit_bool_null(),);
3672
3673        // false && null is always false
3674        assert_eq!(simplify(lit(false).and(lit_bool_null())), lit(false),);
3675
3676        // null && false is always false
3677        assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
3678
3679        // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
3680        // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
3681        // and the Boolean(NULL) should remain
3682        let expr = col("c1").between(lit(0), lit(10));
3683        let expr = expr.and(lit_bool_null());
3684        let result = simplify(expr);
3685
3686        let expected_expr = and(
3687            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
3688            lit_bool_null(),
3689        );
3690        assert_eq!(expected_expr, result);
3691    }
3692
3693    #[test]
3694    fn simplify_expr_between() {
3695        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
3696        let expr = col("c2").between(lit(3), lit(4));
3697        assert_eq!(
3698            simplify(expr),
3699            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
3700        );
3701
3702        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
3703        let expr = col("c2").not_between(lit(3), lit(4));
3704        assert_eq!(
3705            simplify(expr),
3706            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
3707        );
3708    }
3709
3710    #[test]
3711    fn test_like_and_ilike() {
3712        let null = lit(ScalarValue::Utf8(None));
3713
3714        // expr [NOT] [I]LIKE NULL
3715        let expr = col("c1").like(null.clone());
3716        assert_eq!(simplify(expr), lit_bool_null());
3717
3718        let expr = col("c1").not_like(null.clone());
3719        assert_eq!(simplify(expr), lit_bool_null());
3720
3721        let expr = col("c1").ilike(null.clone());
3722        assert_eq!(simplify(expr), lit_bool_null());
3723
3724        let expr = col("c1").not_ilike(null.clone());
3725        assert_eq!(simplify(expr), lit_bool_null());
3726
3727        // expr [NOT] [I]LIKE '%'
3728        let expr = col("c1").like(lit("%"));
3729        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3730
3731        let expr = col("c1").not_like(lit("%"));
3732        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3733
3734        let expr = col("c1").ilike(lit("%"));
3735        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3736
3737        let expr = col("c1").not_ilike(lit("%"));
3738        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3739
3740        // expr [NOT] [I]LIKE '%%'
3741        let expr = col("c1").like(lit("%%"));
3742        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3743
3744        let expr = col("c1").not_like(lit("%%"));
3745        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3746
3747        let expr = col("c1").ilike(lit("%%"));
3748        assert_eq!(simplify(expr), if_not_null(col("c1"), true));
3749
3750        let expr = col("c1").not_ilike(lit("%%"));
3751        assert_eq!(simplify(expr), if_not_null(col("c1"), false));
3752
3753        // not_null_expr [NOT] [I]LIKE '%'
3754        let expr = col("c1_non_null").like(lit("%"));
3755        assert_eq!(simplify(expr), lit(true));
3756
3757        let expr = col("c1_non_null").not_like(lit("%"));
3758        assert_eq!(simplify(expr), lit(false));
3759
3760        let expr = col("c1_non_null").ilike(lit("%"));
3761        assert_eq!(simplify(expr), lit(true));
3762
3763        let expr = col("c1_non_null").not_ilike(lit("%"));
3764        assert_eq!(simplify(expr), lit(false));
3765
3766        // not_null_expr [NOT] [I]LIKE '%%'
3767        let expr = col("c1_non_null").like(lit("%%"));
3768        assert_eq!(simplify(expr), lit(true));
3769
3770        let expr = col("c1_non_null").not_like(lit("%%"));
3771        assert_eq!(simplify(expr), lit(false));
3772
3773        let expr = col("c1_non_null").ilike(lit("%%"));
3774        assert_eq!(simplify(expr), lit(true));
3775
3776        let expr = col("c1_non_null").not_ilike(lit("%%"));
3777        assert_eq!(simplify(expr), lit(false));
3778
3779        // null_constant [NOT] [I]LIKE '%'
3780        let expr = null.clone().like(lit("%"));
3781        assert_eq!(simplify(expr), lit_bool_null());
3782
3783        let expr = null.clone().not_like(lit("%"));
3784        assert_eq!(simplify(expr), lit_bool_null());
3785
3786        let expr = null.clone().ilike(lit("%"));
3787        assert_eq!(simplify(expr), lit_bool_null());
3788
3789        let expr = null.clone().not_ilike(lit("%"));
3790        assert_eq!(simplify(expr), lit_bool_null());
3791
3792        // null_constant [NOT] [I]LIKE '%%'
3793        let expr = null.clone().like(lit("%%"));
3794        assert_eq!(simplify(expr), lit_bool_null());
3795
3796        let expr = null.clone().not_like(lit("%%"));
3797        assert_eq!(simplify(expr), lit_bool_null());
3798
3799        let expr = null.clone().ilike(lit("%%"));
3800        assert_eq!(simplify(expr), lit_bool_null());
3801
3802        let expr = null.clone().not_ilike(lit("%%"));
3803        assert_eq!(simplify(expr), lit_bool_null());
3804
3805        // null_constant [NOT] [I]LIKE 'a%'
3806        let expr = null.clone().like(lit("a%"));
3807        assert_eq!(simplify(expr), lit_bool_null());
3808
3809        let expr = null.clone().not_like(lit("a%"));
3810        assert_eq!(simplify(expr), lit_bool_null());
3811
3812        let expr = null.clone().ilike(lit("a%"));
3813        assert_eq!(simplify(expr), lit_bool_null());
3814
3815        let expr = null.clone().not_ilike(lit("a%"));
3816        assert_eq!(simplify(expr), lit_bool_null());
3817
3818        // expr [NOT] [I]LIKE with pattern without wildcards
3819        let expr = col("c1").like(lit("a"));
3820        assert_eq!(simplify(expr), col("c1").eq(lit("a")));
3821        let expr = col("c1").not_like(lit("a"));
3822        assert_eq!(simplify(expr), col("c1").not_eq(lit("a")));
3823        let expr = col("c1").like(lit("a_"));
3824        assert_eq!(simplify(expr), col("c1").like(lit("a_")));
3825        let expr = col("c1").not_like(lit("a_"));
3826        assert_eq!(simplify(expr), col("c1").not_like(lit("a_")));
3827    }
3828
3829    #[test]
3830    fn test_simplify_with_guarantee() {
3831        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
3832        let expr_x = col("c3").gt(lit(3_i64));
3833        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
3834        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
3835        let expr = expr_x.clone().and(expr_y.or(expr_z));
3836
3837        // All guaranteed null
3838        let guarantees = vec![
3839            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
3840            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
3841            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
3842        ];
3843
3844        let output = simplify_with_guarantee(expr.clone(), guarantees);
3845        assert_eq!(output, lit_bool_null());
3846
3847        // All guaranteed false
3848        let guarantees = vec![
3849            (
3850                col("c3"),
3851                NullableInterval::NotNull {
3852                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
3853                },
3854            ),
3855            (
3856                col("c4"),
3857                NullableInterval::from(ScalarValue::UInt32(Some(9))),
3858            ),
3859            (col("c1"), NullableInterval::from(ScalarValue::from("a"))),
3860        ];
3861        let output = simplify_with_guarantee(expr.clone(), guarantees);
3862        assert_eq!(output, lit(false));
3863
3864        // Guaranteed false or null -> no change.
3865        let guarantees = vec![
3866            (
3867                col("c3"),
3868                NullableInterval::MaybeNull {
3869                    values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(),
3870                },
3871            ),
3872            (
3873                col("c4"),
3874                NullableInterval::MaybeNull {
3875                    values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(),
3876                },
3877            ),
3878            (
3879                col("c1"),
3880                NullableInterval::NotNull {
3881                    values: Interval::try_new(
3882                        ScalarValue::from("d"),
3883                        ScalarValue::from("f"),
3884                    )
3885                    .unwrap(),
3886                },
3887            ),
3888        ];
3889        let output = simplify_with_guarantee(expr.clone(), guarantees);
3890        assert_eq!(&output, &expr_x);
3891
3892        // Sufficient true guarantees
3893        let guarantees = vec![
3894            (
3895                col("c3"),
3896                NullableInterval::from(ScalarValue::Int64(Some(9))),
3897            ),
3898            (
3899                col("c4"),
3900                NullableInterval::from(ScalarValue::UInt32(Some(3))),
3901            ),
3902        ];
3903        let output = simplify_with_guarantee(expr.clone(), guarantees);
3904        assert_eq!(output, lit(true));
3905
3906        // Only partially simplify
3907        let guarantees = vec![(
3908            col("c4"),
3909            NullableInterval::from(ScalarValue::UInt32(Some(3))),
3910        )];
3911        let output = simplify_with_guarantee(expr, guarantees);
3912        assert_eq!(&output, &expr_x);
3913    }
3914
3915    #[test]
3916    fn test_expression_partial_simplify_1() {
3917        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
3918        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
3919        let expected = (lit(3)) + (lit(4) / lit(0));
3920
3921        assert_eq!(simplify(expr), expected);
3922    }
3923
3924    #[test]
3925    fn test_expression_partial_simplify_2() {
3926        // (1 > 2) and (4 / 0) -> false
3927        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
3928        let expected = lit(false);
3929
3930        assert_eq!(simplify(expr), expected);
3931    }
3932
3933    #[test]
3934    fn test_simplify_cycles() {
3935        // TRUE
3936        let expr = lit(true);
3937        let expected = lit(true);
3938        let (expr, num_iter) = simplify_with_cycle_count(expr);
3939        assert_eq!(expr, expected);
3940        assert_eq!(num_iter, 1);
3941
3942        // (true != NULL) OR (5 > 10)
3943        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
3944        let expected = lit_bool_null();
3945        let (expr, num_iter) = simplify_with_cycle_count(expr);
3946        assert_eq!(expr, expected);
3947        assert_eq!(num_iter, 2);
3948
3949        // NOTE: this currently does not simplify
3950        // (((c4 - 10) + 10) *100) / 100
3951        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
3952        let expected = expr.clone();
3953        let (expr, num_iter) = simplify_with_cycle_count(expr);
3954        assert_eq!(expr, expected);
3955        assert_eq!(num_iter, 1);
3956
3957        // ((c4<1 or c3<2) and c3_non_null<3) and false
3958        let expr = col("c4")
3959            .lt(lit(1))
3960            .or(col("c3").lt(lit(2)))
3961            .and(col("c3_non_null").lt(lit(3)))
3962            .and(lit(false));
3963        let expected = lit(false);
3964        let (expr, num_iter) = simplify_with_cycle_count(expr);
3965        assert_eq!(expr, expected);
3966        assert_eq!(num_iter, 2);
3967    }
3968
3969    fn boolean_test_schema() -> DFSchemaRef {
3970        Schema::new(vec![
3971            Field::new("A", DataType::Boolean, false),
3972            Field::new("B", DataType::Boolean, false),
3973            Field::new("C", DataType::Boolean, false),
3974            Field::new("D", DataType::Boolean, false),
3975        ])
3976        .to_dfschema_ref()
3977        .unwrap()
3978    }
3979
3980    #[test]
3981    fn simplify_common_factor_conjunction_in_disjunction() {
3982        let props = ExecutionProps::new();
3983        let schema = boolean_test_schema();
3984        let simplifier =
3985            ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema));
3986
3987        let a = || col("A");
3988        let b = || col("B");
3989        let c = || col("C");
3990        let d = || col("D");
3991
3992        // (A AND B) OR (A AND C) -> A AND (B OR C)
3993        let expr = a().and(b()).or(a().and(c()));
3994        let expected = a().and(b().or(c()));
3995
3996        assert_eq!(expected, simplifier.simplify(expr).unwrap());
3997
3998        // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D)
3999        let expr = a().and(b()).or(a().and(c())).or(a().and(d()));
4000        let expected = a().and(b().or(c()).or(d()));
4001        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4002
4003        // A OR (B AND C AND A) -> A
4004        let expr = a().or(b().and(c().and(a())));
4005        let expected = a();
4006        assert_eq!(expected, simplifier.simplify(expr).unwrap());
4007    }
4008
4009    #[test]
4010    fn test_simplify_udaf() {
4011        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
4012        let aggregate_function_expr =
4013            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4014                udaf.into(),
4015                vec![],
4016                false,
4017                None,
4018                None,
4019                None,
4020            ));
4021
4022        let expected = col("result_column");
4023        assert_eq!(simplify(aggregate_function_expr), expected);
4024
4025        let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
4026        let aggregate_function_expr =
4027            Expr::AggregateFunction(expr::AggregateFunction::new_udf(
4028                udaf.into(),
4029                vec![],
4030                false,
4031                None,
4032                None,
4033                None,
4034            ));
4035
4036        let expected = aggregate_function_expr.clone();
4037        assert_eq!(simplify(aggregate_function_expr), expected);
4038    }
4039
4040    /// A Mock UDAF which defines `simplify` to be used in tests
4041    /// related to UDAF simplification
4042    #[derive(Debug, Clone)]
4043    struct SimplifyMockUdaf {
4044        simplify: bool,
4045    }
4046
4047    impl SimplifyMockUdaf {
4048        /// make simplify method return new expression
4049        fn new_with_simplify() -> Self {
4050            Self { simplify: true }
4051        }
4052        /// make simplify method return no change
4053        fn new_without_simplify() -> Self {
4054            Self { simplify: false }
4055        }
4056    }
4057
4058    impl AggregateUDFImpl for SimplifyMockUdaf {
4059        fn as_any(&self) -> &dyn std::any::Any {
4060            self
4061        }
4062
4063        fn name(&self) -> &str {
4064            "mock_simplify"
4065        }
4066
4067        fn signature(&self) -> &Signature {
4068            unimplemented!()
4069        }
4070
4071        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4072            unimplemented!("not needed for tests")
4073        }
4074
4075        fn accumulator(
4076            &self,
4077            _acc_args: AccumulatorArgs,
4078        ) -> Result<Box<dyn Accumulator>> {
4079            unimplemented!("not needed for tests")
4080        }
4081
4082        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
4083            unimplemented!("not needed for testing")
4084        }
4085
4086        fn create_groups_accumulator(
4087            &self,
4088            _args: AccumulatorArgs,
4089        ) -> Result<Box<dyn GroupsAccumulator>> {
4090            unimplemented!("not needed for testing")
4091        }
4092
4093        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
4094            if self.simplify {
4095                Some(Box::new(|_, _| Ok(col("result_column"))))
4096            } else {
4097                None
4098            }
4099        }
4100    }
4101
4102    #[test]
4103    fn test_simplify_udwf() {
4104        let udwf = WindowFunctionDefinition::WindowUDF(
4105            WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
4106        );
4107        let window_function_expr =
4108            Expr::WindowFunction(WindowFunction::new(udwf, vec![]));
4109
4110        let expected = col("result_column");
4111        assert_eq!(simplify(window_function_expr), expected);
4112
4113        let udwf = WindowFunctionDefinition::WindowUDF(
4114            WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
4115        );
4116        let window_function_expr =
4117            Expr::WindowFunction(WindowFunction::new(udwf, vec![]));
4118
4119        let expected = window_function_expr.clone();
4120        assert_eq!(simplify(window_function_expr), expected);
4121    }
4122
4123    /// A Mock UDWF which defines `simplify` to be used in tests
4124    /// related to UDWF simplification
4125    #[derive(Debug, Clone)]
4126    struct SimplifyMockUdwf {
4127        simplify: bool,
4128    }
4129
4130    impl SimplifyMockUdwf {
4131        /// make simplify method return new expression
4132        fn new_with_simplify() -> Self {
4133            Self { simplify: true }
4134        }
4135        /// make simplify method return no change
4136        fn new_without_simplify() -> Self {
4137            Self { simplify: false }
4138        }
4139    }
4140
4141    impl WindowUDFImpl for SimplifyMockUdwf {
4142        fn as_any(&self) -> &dyn std::any::Any {
4143            self
4144        }
4145
4146        fn name(&self) -> &str {
4147            "mock_simplify"
4148        }
4149
4150        fn signature(&self) -> &Signature {
4151            unimplemented!()
4152        }
4153
4154        fn simplify(&self) -> Option<WindowFunctionSimplification> {
4155            if self.simplify {
4156                Some(Box::new(|_, _| Ok(col("result_column"))))
4157            } else {
4158                None
4159            }
4160        }
4161
4162        fn partition_evaluator(
4163            &self,
4164            _partition_evaluator_args: PartitionEvaluatorArgs,
4165        ) -> Result<Box<dyn PartitionEvaluator>> {
4166            unimplemented!("not needed for tests")
4167        }
4168
4169        fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<Field> {
4170            unimplemented!("not needed for tests")
4171        }
4172    }
4173    #[derive(Debug)]
4174    struct VolatileUdf {
4175        signature: Signature,
4176    }
4177
4178    impl VolatileUdf {
4179        pub fn new() -> Self {
4180            Self {
4181                signature: Signature::exact(vec![], Volatility::Volatile),
4182            }
4183        }
4184    }
4185    impl ScalarUDFImpl for VolatileUdf {
4186        fn as_any(&self) -> &dyn std::any::Any {
4187            self
4188        }
4189
4190        fn name(&self) -> &str {
4191            "VolatileUdf"
4192        }
4193
4194        fn signature(&self) -> &Signature {
4195            &self.signature
4196        }
4197
4198        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4199            Ok(DataType::Int16)
4200        }
4201    }
4202
4203    #[test]
4204    fn test_optimize_volatile_conditions() {
4205        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4206        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4207        {
4208            let expr = rand
4209                .clone()
4210                .eq(lit(0))
4211                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4212
4213            assert_eq!(simplify(expr.clone()), expr);
4214        }
4215
4216        {
4217            let expr = col("column1")
4218                .eq(lit(2))
4219                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4220
4221            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4222        }
4223
4224        {
4225            let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4226                "column1",
4227            )
4228            .eq(lit(2))
4229            .and(rand.clone().eq(lit(0))));
4230
4231            assert_eq!(
4232                simplify(expr),
4233                col("column1")
4234                    .eq(lit(2))
4235                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4236            );
4237        }
4238    }
4239
4240    fn if_not_null(expr: Expr, then: bool) -> Expr {
4241        Expr::Case(Case {
4242            expr: Some(expr.is_not_null().into()),
4243            when_then_expr: vec![(lit(true).into(), lit(then).into())],
4244            else_expr: None,
4245        })
4246    }
4247}