Skip to main content

datafusion_optimizer/simplify_expressions/
simplify_exprs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplify expressions optimizer rule and implementation
19
20use std::sync::Arc;
21
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError, Result};
24use datafusion_expr::Expr;
25use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection};
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::utils::{
28    columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist, merge_schema,
29};
30
31use super::ExprSimplifier;
32use crate::optimizer::ApplyOrder;
33use crate::simplify_expressions::linear_aggregates::rewrite_multiple_linear_aggregates;
34use crate::utils::NamePreserver;
35use crate::{OptimizerConfig, OptimizerRule};
36
37/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting
38/// [`Expr`]`s evaluating constants and applying algebraic
39/// simplifications
40///
41/// # Introduction
42/// It uses boolean algebra laws to simplify or reduce the number of terms in expressions.
43///
44/// # Example:
45/// `Filter: b > 2 AND b > 2`
46/// is optimized to
47/// `Filter: b > 2`
48///
49/// [`Expr`]: datafusion_expr::Expr
50#[derive(Default, Debug)]
51pub struct SimplifyExpressions {}
52
53impl OptimizerRule for SimplifyExpressions {
54    fn name(&self) -> &str {
55        "simplify_expressions"
56    }
57
58    fn apply_order(&self) -> Option<ApplyOrder> {
59        Some(ApplyOrder::BottomUp)
60    }
61
62    fn supports_rewrite(&self) -> bool {
63        true
64    }
65
66    fn rewrite(
67        &self,
68        plan: LogicalPlan,
69        config: &dyn OptimizerConfig,
70    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
71        Self::optimize_internal(plan, config)
72    }
73}
74
75impl SimplifyExpressions {
76    fn optimize_internal(
77        plan: LogicalPlan,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let schema = if !plan.inputs().is_empty() {
81            DFSchemaRef::new(merge_schema(&plan.inputs()))
82        } else if let LogicalPlan::TableScan(scan) = &plan {
83            // When predicates are pushed into a table scan, there is no input
84            // schema to resolve predicates against, so it must be handled specially
85            //
86            // Note that this is not `plan.schema()` which is the *output*
87            // schema, and reflects any pushed down projection. The output schema
88            // will not contain columns that *only* appear in pushed down predicates
89            // (and no where else) in the plan.
90            //
91            // Thus, use the full schema of the inner provider without any
92            // projection applied for simplification
93            Arc::new(DFSchema::try_from_qualified_schema(
94                scan.table_name.clone(),
95                &scan.source.schema(),
96            )?)
97        } else {
98            Arc::new(DFSchema::empty())
99        };
100
101        let info = SimplifyContext::builder()
102            .with_schema(schema)
103            .with_config_options(config.options())
104            .with_query_execution_start_time(config.query_execution_start_time())
105            .build();
106
107        // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
108        // Just need to rewrite our own expressions
109
110        let simplifier = ExprSimplifier::new(info);
111
112        // The left and right expressions in a Join on clause are not
113        // commutative, for reasons that are not entirely clear. Thus, do not
114        // reorder expressions in Join while simplifying.
115        //
116        // This is likely related to the fact that order of the columns must
117        // match the order of the children. see
118        // https://github.com/apache/datafusion/pull/8780 for more details
119        let simplifier = if let LogicalPlan::Join(_) = plan {
120            simplifier.with_canonicalize(false)
121        } else {
122            simplifier
123        };
124
125        // Preserve expression names to avoid changing the schema of the plan.
126        let name_preserver = NamePreserver::new(&plan);
127        let mut rewrite_expr = |expr: Expr| {
128            let name = name_preserver.save(&expr);
129            let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0;
130            Ok(Transformed::new_transformed(
131                name.restore(expr.data),
132                expr.transformed,
133            ))
134        };
135
136        plan.map_expressions(|expr| {
137            // Preserve the aliasing of grouping sets.
138            if let Expr::GroupingSet(_) = &expr {
139                expr.map_children(&mut rewrite_expr)
140            } else {
141                rewrite_expr(expr)
142            }
143        })?
144        .transform_data(rewrite_aggregate_non_aggregate_aggr_expr)
145    }
146}
147
148impl SimplifyExpressions {
149    #[expect(missing_docs)]
150    pub fn new() -> Self {
151        Self {}
152    }
153}
154
155/// Ensures that `LogicalPlan::Aggregate` is well formed after rewrites
156/// by potentially introducing an extra `Projection`.
157///
158/// Also applies the [`rewrite_multiple_linear_aggregates`] special case
159///
160/// # Rationale:
161///
162/// [`LogicalPlan::Aggregate`] requires agg expressions to be (possibly aliased)
163/// [`Expr::AggregateFunction`]. Some UDAF simplifiers may return other [`Expr`]
164/// variants.
165///
166/// # Operation
167///
168/// Rewrites things like this (note that `exp1` is not an aggregate):
169/// * `Aggregate(group_expr, aggr_expr=[exp1 + agg(exp2)])`
170///
171/// into:
172/// * `Projection(exp1 + _X)`
173/// * `  Aggregate(group_expr, aggr_expr=[agg(exp2) AS _X])`
174fn rewrite_aggregate_non_aggregate_aggr_expr(
175    plan: LogicalPlan,
176) -> Result<Transformed<LogicalPlan>> {
177    let LogicalPlan::Aggregate(Aggregate {
178        input,
179        group_expr,
180        mut aggr_expr,
181        schema,
182        ..
183    }) = plan
184    else {
185        return Ok(Transformed::no(plan));
186    };
187
188    let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggr_expr)?;
189
190    // Ensure that all Aggregate arguments are AggregateExpr
191    if aggr_expr.iter().all(is_top_level_aggregate_expr) {
192        let new_plan = LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
193            input, group_expr, aggr_expr, schema,
194        )?);
195        return if !rewrote_aggs {
196            Ok(Transformed::no(new_plan))
197        } else {
198            Ok(Transformed::yes(new_plan))
199        };
200    }
201
202    // Otherwise we need to add a Projection above Aggregate to calculate
203    // the final output expressions.
204
205    let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter());
206    let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
207        Arc::clone(&input),
208        group_expr.clone(),
209        inner_aggr_expr,
210    )?);
211    let inner_aggregate = Arc::new(inner_aggregate);
212
213    let mut projection_exprs = aggregate_output_exprs(&group_expr)?;
214    projection_exprs.extend(aggr_expr);
215    let projection_exprs = projection_exprs
216        .into_iter()
217        .map(|expr| columnize_expr(expr, inner_aggregate.as_ref()))
218        .collect::<Result<Vec<_>>>()?;
219
220    Ok(Transformed::yes(LogicalPlan::Projection(
221        Projection::try_new(projection_exprs, inner_aggregate)?,
222    )))
223}
224
225fn is_top_level_aggregate_expr(expr: &Expr) -> bool {
226    matches!(
227        expr.clone().unalias_nested().data,
228        Expr::AggregateFunction(_)
229    )
230}
231
232fn aggregate_output_exprs(group_expr: &[Expr]) -> Result<Vec<Expr>> {
233    let mut output_exprs = grouping_set_to_exprlist(group_expr)?
234        .into_iter()
235        .cloned()
236        .collect::<Vec<_>>();
237
238    if matches!(group_expr, [Expr::GroupingSet(_)]) {
239        output_exprs.push(Expr::Column(Column::from_name(
240            Aggregate::INTERNAL_GROUPING_ID,
241        )));
242    }
243
244    Ok(output_exprs)
245}
246
247#[cfg(test)]
248mod tests {
249    use std::ops::Not;
250
251    use arrow::datatypes::{DataType, Field, Schema};
252    use chrono::{DateTime, Utc};
253
254    use datafusion_common::ScalarValue;
255    use datafusion_expr::logical_plan::builder::table_scan_with_filters;
256    use datafusion_expr::logical_plan::table_scan;
257    use datafusion_expr::*;
258    use datafusion_functions_aggregate::expr_fn::{max, min, sum};
259
260    use crate::OptimizerContext;
261    use crate::assert_optimized_plan_eq_snapshot;
262    use crate::test::{assert_fields_eq, test_table_scan_with_name};
263
264    use super::*;
265
266    fn test_table_scan() -> LogicalPlan {
267        let schema = Schema::new(vec![
268            Field::new("a", DataType::Boolean, false),
269            Field::new("b", DataType::Boolean, false),
270            Field::new("c", DataType::Boolean, false),
271            Field::new("d", DataType::UInt32, false),
272            Field::new("e", DataType::UInt32, true),
273        ]);
274        table_scan(Some("test"), &schema, None)
275            .expect("creating scan")
276            .build()
277            .expect("building plan")
278    }
279
280    macro_rules! assert_optimized_plan_equal {
281        (
282            $plan:expr,
283            @ $expected:literal $(,)?
284        ) => {{
285            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(SimplifyExpressions::new())];
286            let optimizer_ctx = OptimizerContext::new();
287            assert_optimized_plan_eq_snapshot!(
288                optimizer_ctx,
289                rules,
290                $plan,
291                @ $expected,
292            )
293        }};
294    }
295
296    #[test]
297    fn test_simplify_table_full_filter_in_scan() -> Result<()> {
298        let fields = vec![
299            Field::new("a", DataType::UInt32, false),
300            Field::new("b", DataType::UInt32, false),
301            Field::new("c", DataType::UInt32, false),
302        ];
303
304        let schema = Schema::new(fields);
305
306        let table_scan = table_scan_with_filters(
307            Some("test"),
308            &schema,
309            Some(vec![0]),
310            vec![col("b").is_not_null()],
311        )?
312        .build()?;
313        assert_eq!(1, table_scan.schema().fields().len());
314        assert_fields_eq(&table_scan, vec!["a"]);
315
316        assert_optimized_plan_equal!(
317            table_scan,
318            @ "TableScan: test projection=[a], full_filters=[Boolean(true)]"
319        )
320    }
321
322    #[test]
323    fn test_simplify_filter_pushdown() -> Result<()> {
324        let table_scan = test_table_scan();
325        let plan = LogicalPlanBuilder::from(table_scan)
326            .project(vec![col("a")])?
327            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
328            .build()?;
329
330        assert_optimized_plan_equal!(
331            plan,
332            @ r"
333        Filter: test.b > Int32(1)
334          Projection: test.a
335            TableScan: test
336        "
337        )
338    }
339
340    #[test]
341    fn test_simplify_optimized_plan() -> Result<()> {
342        let table_scan = test_table_scan();
343        let plan = LogicalPlanBuilder::from(table_scan)
344            .project(vec![col("a")])?
345            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
346            .build()?;
347
348        assert_optimized_plan_equal!(
349            plan,
350            @ r"
351        Filter: test.b > Int32(1)
352          Projection: test.a
353            TableScan: test
354        "
355        )
356    }
357
358    #[test]
359    fn test_simplify_udaf_to_non_aggregate_expr() -> Result<()> {
360        let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
361        let table_scan = table_scan(Some("test"), &schema, None)?
362            .build()
363            .expect("building scan");
364
365        let plan = LogicalPlanBuilder::from(table_scan)
366            .aggregate(Vec::<Expr>::new(), vec![sum(col("a") + lit(2i64))])?
367            .build()?;
368
369        assert_optimized_plan_equal!(
370            plan,
371            @r"
372        Aggregate: groupBy=[[]], aggr=[[sum(test.a + Int64(2))]]
373          TableScan: test
374        "
375        )?;
376        Ok(())
377    }
378
379    #[test]
380    fn test_simplify_common_sum_arg() -> Result<()> {
381        let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
382        let table_scan = table_scan(Some("test"), &schema, None)?
383            .build()
384            .expect("building scan");
385
386        let plan = LogicalPlanBuilder::from(table_scan)
387            .aggregate(
388                Vec::<Expr>::new(),
389                vec![sum(col("a") + lit(2i64)), sum(col("a") + lit(3i64))],
390            )?
391            .build()?;
392
393        assert_optimized_plan_equal!(
394            plan,
395            @r"
396        Projection: sum(test.a) + Int64(2) * CAST(count(test.a) AS Int64) AS sum(test.a + Int64(2)), sum(test.a) + Int64(3) * CAST(count(test.a) AS Int64) AS sum(test.a + Int64(3))
397          Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
398            TableScan: test
399        "
400        )?;
401        Ok(())
402    }
403
404    #[test]
405    fn test_simplify_optimized_plan_with_or() -> Result<()> {
406        let table_scan = test_table_scan();
407        let plan = LogicalPlanBuilder::from(table_scan)
408            .project(vec![col("a")])?
409            .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))?
410            .build()?;
411
412        assert_optimized_plan_equal!(
413            plan,
414            @ r"
415        Filter: test.b > Int32(1)
416          Projection: test.a
417            TableScan: test
418        "
419        )
420    }
421
422    #[test]
423    fn test_simplify_optimized_plan_with_composed_and() -> Result<()> {
424        let table_scan = test_table_scan();
425        // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
426        let plan = LogicalPlanBuilder::from(table_scan)
427            .project(vec![col("a"), col("b")])?
428            .filter(and(
429                and(col("a").gt(lit(5)), col("b").lt(lit(6))),
430                col("a").gt(lit(5)),
431            ))?
432            .build()?;
433
434        assert_optimized_plan_equal!(
435            plan,
436            @ r"
437        Filter: test.a > Int32(5) AND test.b < Int32(6)
438          Projection: test.a, test.b
439            TableScan: test
440        "
441        )
442    }
443
444    #[test]
445    fn test_simplify_optimized_plan_eq_expr() -> Result<()> {
446        let table_scan = test_table_scan();
447        let plan = LogicalPlanBuilder::from(table_scan)
448            .filter(col("b").eq(lit(true)))?
449            .filter(col("c").eq(lit(false)))?
450            .project(vec![col("a")])?
451            .build()?;
452
453        assert_optimized_plan_equal!(
454            plan,
455            @ r"
456        Projection: test.a
457          Filter: NOT test.c
458            Filter: test.b
459              TableScan: test
460        "
461        )
462    }
463
464    #[test]
465    fn test_simplify_optimized_plan_not_eq_expr() -> Result<()> {
466        let table_scan = test_table_scan();
467        let plan = LogicalPlanBuilder::from(table_scan)
468            .filter(col("b").not_eq(lit(true)))?
469            .filter(col("c").not_eq(lit(false)))?
470            .limit(0, Some(1))?
471            .project(vec![col("a")])?
472            .build()?;
473
474        assert_optimized_plan_equal!(
475            plan,
476            @ r"
477        Projection: test.a
478          Limit: skip=0, fetch=1
479            Filter: test.c
480              Filter: NOT test.b
481                TableScan: test
482        "
483        )
484    }
485
486    #[test]
487    fn test_simplify_optimized_plan_and_expr() -> Result<()> {
488        let table_scan = test_table_scan();
489        let plan = LogicalPlanBuilder::from(table_scan)
490            .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))?
491            .project(vec![col("a")])?
492            .build()?;
493
494        assert_optimized_plan_equal!(
495            plan,
496            @ r"
497        Projection: test.a
498          Filter: NOT test.b AND test.c
499            TableScan: test
500        "
501        )
502    }
503
504    #[test]
505    fn test_simplify_optimized_plan_or_expr() -> Result<()> {
506        let table_scan = test_table_scan();
507        let plan = LogicalPlanBuilder::from(table_scan)
508            .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))?
509            .project(vec![col("a")])?
510            .build()?;
511
512        assert_optimized_plan_equal!(
513            plan,
514            @ r"
515        Projection: test.a
516          Filter: NOT test.b OR NOT test.c
517            TableScan: test
518        "
519        )
520    }
521
522    #[test]
523    fn test_simplify_optimized_plan_not_expr() -> Result<()> {
524        let table_scan = test_table_scan();
525        let plan = LogicalPlanBuilder::from(table_scan)
526            .filter(col("b").eq(lit(false)).not())?
527            .project(vec![col("a")])?
528            .build()?;
529
530        assert_optimized_plan_equal!(
531            plan,
532            @ r"
533        Projection: test.a
534          Filter: test.b
535            TableScan: test
536        "
537        )
538    }
539
540    #[test]
541    fn test_simplify_optimized_plan_support_projection() -> Result<()> {
542        let table_scan = test_table_scan();
543        let plan = LogicalPlanBuilder::from(table_scan)
544            .project(vec![col("a"), col("d"), col("b").eq(lit(false))])?
545            .build()?;
546
547        assert_optimized_plan_equal!(
548            plan,
549            @ r"
550        Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)
551          TableScan: test
552        "
553        )
554    }
555
556    #[test]
557    fn test_simplify_optimized_plan_support_aggregate() -> Result<()> {
558        let table_scan = test_table_scan();
559        let plan = LogicalPlanBuilder::from(table_scan)
560            .project(vec![col("a"), col("c"), col("b")])?
561            .aggregate(
562                vec![col("a"), col("c")],
563                vec![max(col("b").eq(lit(true))), min(col("b"))],
564            )?
565            .build()?;
566
567        assert_optimized_plan_equal!(
568            plan,
569            @ r"
570        Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]
571          Projection: test.a, test.c, test.b
572            TableScan: test
573        "
574        )
575    }
576
577    #[test]
578    fn test_simplify_optimized_plan_support_values() -> Result<()> {
579        let expr1 = Expr::BinaryExpr(BinaryExpr::new(
580            Box::new(lit(1)),
581            Operator::Plus,
582            Box::new(lit(2)),
583        ));
584        let expr2 = Expr::BinaryExpr(BinaryExpr::new(
585            Box::new(lit(2)),
586            Operator::Minus,
587            Box::new(lit(1)),
588        ));
589        let values = vec![vec![expr1, expr2]];
590        let plan = LogicalPlanBuilder::values(values)?.build()?;
591
592        assert_optimized_plan_equal!(
593            plan,
594            @ "Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"
595        )
596    }
597
598    fn get_optimized_plan_formatted(
599        plan: LogicalPlan,
600        date_time: &DateTime<Utc>,
601    ) -> String {
602        let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
603        let rule = SimplifyExpressions::new();
604
605        let optimized_plan = rule.rewrite(plan, &config).unwrap().data;
606        format!("{optimized_plan}")
607    }
608
609    #[test]
610    fn cast_expr() -> Result<()> {
611        let table_scan = test_table_scan();
612        let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))];
613        let plan = LogicalPlanBuilder::from(table_scan)
614            .project(proj)?
615            .build()?;
616
617        let expected = "Projection: Int32(0) AS Utf8(\"0\")\
618            \n  TableScan: test";
619        let actual = get_optimized_plan_formatted(plan, &Utc::now());
620        assert_eq!(expected, actual);
621        Ok(())
622    }
623
624    #[test]
625    fn simplify_and_eval() -> Result<()> {
626        // demonstrate a case where the evaluation needs to run prior
627        // to the simplifier for it to work
628        let table_scan = test_table_scan();
629        let time = Utc::now();
630        // (true or false) != col --> !col
631        let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
632        let plan = LogicalPlanBuilder::from(table_scan)
633            .project(proj)?
634            .build()?;
635
636        let actual = get_optimized_plan_formatted(plan, &time);
637        let expected = "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
638                        \n  TableScan: test";
639
640        assert_eq!(expected, actual);
641        Ok(())
642    }
643
644    #[test]
645    fn simplify_not_binary() -> Result<()> {
646        let table_scan = test_table_scan();
647
648        let plan = LogicalPlanBuilder::from(table_scan)
649            .filter(col("d").gt(lit(10)).not())?
650            .build()?;
651
652        assert_optimized_plan_equal!(
653            plan,
654            @ r"
655        Filter: test.d <= Int32(10)
656          TableScan: test
657        "
658        )
659    }
660
661    #[test]
662    fn simplify_not_bool_and() -> Result<()> {
663        let table_scan = test_table_scan();
664
665        let plan = LogicalPlanBuilder::from(table_scan)
666            .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())?
667            .build()?;
668
669        assert_optimized_plan_equal!(
670            plan,
671            @ r"
672        Filter: test.d <= Int32(10) OR test.d >= Int32(100)
673          TableScan: test
674        "
675        )
676    }
677
678    #[test]
679    fn simplify_not_bool_or() -> Result<()> {
680        let table_scan = test_table_scan();
681
682        let plan = LogicalPlanBuilder::from(table_scan)
683            .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())?
684            .build()?;
685
686        assert_optimized_plan_equal!(
687            plan,
688            @ r"
689        Filter: test.d <= Int32(10) AND test.d >= Int32(100)
690          TableScan: test
691        "
692        )
693    }
694
695    #[test]
696    fn simplify_not_not() -> Result<()> {
697        let table_scan = test_table_scan();
698
699        let plan = LogicalPlanBuilder::from(table_scan)
700            .filter(col("d").gt(lit(10)).not().not())?
701            .build()?;
702
703        assert_optimized_plan_equal!(
704            plan,
705            @ r"
706        Filter: test.d > Int32(10)
707          TableScan: test
708        "
709        )
710    }
711
712    #[test]
713    fn simplify_not_null() -> Result<()> {
714        let table_scan = test_table_scan();
715
716        let plan = LogicalPlanBuilder::from(table_scan)
717            .filter(col("e").is_null().not())?
718            .build()?;
719
720        assert_optimized_plan_equal!(
721            plan,
722            @ r"
723        Filter: test.e IS NOT NULL
724          TableScan: test
725        "
726        )
727    }
728
729    #[test]
730    fn simplify_not_not_null() -> Result<()> {
731        let table_scan = test_table_scan();
732
733        let plan = LogicalPlanBuilder::from(table_scan)
734            .filter(col("e").is_not_null().not())?
735            .build()?;
736
737        assert_optimized_plan_equal!(
738            plan,
739            @ r"
740        Filter: test.e IS NULL
741          TableScan: test
742        "
743        )
744    }
745
746    #[test]
747    fn simplify_not_in() -> Result<()> {
748        let table_scan = test_table_scan();
749
750        let plan = LogicalPlanBuilder::from(table_scan)
751            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
752            .build()?;
753
754        assert_optimized_plan_equal!(
755            plan,
756            @ r"
757        Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)
758          TableScan: test
759        "
760        )
761    }
762
763    #[test]
764    fn simplify_not_not_in() -> Result<()> {
765        let table_scan = test_table_scan();
766
767        let plan = LogicalPlanBuilder::from(table_scan)
768            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
769            .build()?;
770
771        assert_optimized_plan_equal!(
772            plan,
773            @ r"
774        Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)
775          TableScan: test
776        "
777        )
778    }
779
780    #[test]
781    fn simplify_not_between() -> Result<()> {
782        let table_scan = test_table_scan();
783        let qual = col("d").between(lit(1), lit(10));
784
785        let plan = LogicalPlanBuilder::from(table_scan)
786            .filter(qual.not())?
787            .build()?;
788
789        assert_optimized_plan_equal!(
790            plan,
791            @ r"
792        Filter: test.d < Int32(1) OR test.d > Int32(10)
793          TableScan: test
794        "
795        )
796    }
797
798    #[test]
799    fn simplify_not_not_between() -> Result<()> {
800        let table_scan = test_table_scan();
801        let qual = col("d").not_between(lit(1), lit(10));
802
803        let plan = LogicalPlanBuilder::from(table_scan)
804            .filter(qual.not())?
805            .build()?;
806
807        assert_optimized_plan_equal!(
808            plan,
809            @ r"
810        Filter: test.d >= Int32(1) AND test.d <= Int32(10)
811          TableScan: test
812        "
813        )
814    }
815
816    #[test]
817    fn simplify_not_like() -> Result<()> {
818        let schema = Schema::new(vec![
819            Field::new("a", DataType::Utf8, false),
820            Field::new("b", DataType::Utf8, false),
821        ]);
822        let table_scan = table_scan(Some("test"), &schema, None)
823            .expect("creating scan")
824            .build()
825            .expect("building plan");
826
827        let plan = LogicalPlanBuilder::from(table_scan)
828            .filter(col("a").like(col("b")).not())?
829            .build()?;
830
831        assert_optimized_plan_equal!(
832            plan,
833            @ r"
834        Filter: test.a NOT LIKE test.b
835          TableScan: test
836        "
837        )
838    }
839
840    #[test]
841    fn simplify_not_not_like() -> Result<()> {
842        let schema = Schema::new(vec![
843            Field::new("a", DataType::Utf8, false),
844            Field::new("b", DataType::Utf8, false),
845        ]);
846        let table_scan = table_scan(Some("test"), &schema, None)
847            .expect("creating scan")
848            .build()
849            .expect("building plan");
850
851        let plan = LogicalPlanBuilder::from(table_scan)
852            .filter(col("a").not_like(col("b")).not())?
853            .build()?;
854
855        assert_optimized_plan_equal!(
856            plan,
857            @ r"
858        Filter: test.a LIKE test.b
859          TableScan: test
860        "
861        )
862    }
863
864    #[test]
865    fn simplify_not_ilike() -> Result<()> {
866        let schema = Schema::new(vec![
867            Field::new("a", DataType::Utf8, false),
868            Field::new("b", DataType::Utf8, false),
869        ]);
870        let table_scan = table_scan(Some("test"), &schema, None)
871            .expect("creating scan")
872            .build()
873            .expect("building plan");
874
875        let plan = LogicalPlanBuilder::from(table_scan)
876            .filter(col("a").ilike(col("b")).not())?
877            .build()?;
878
879        assert_optimized_plan_equal!(
880            plan,
881            @ r"
882        Filter: test.a NOT ILIKE test.b
883          TableScan: test
884        "
885        )
886    }
887
888    #[test]
889    fn simplify_not_distinct_from() -> Result<()> {
890        let table_scan = test_table_scan();
891
892        let plan = LogicalPlanBuilder::from(table_scan)
893            .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())?
894            .build()?;
895
896        assert_optimized_plan_equal!(
897            plan,
898            @ r"
899        Filter: test.d IS NOT DISTINCT FROM Int32(10)
900          TableScan: test
901        "
902        )
903    }
904
905    #[test]
906    fn simplify_not_not_distinct_from() -> Result<()> {
907        let table_scan = test_table_scan();
908
909        let plan = LogicalPlanBuilder::from(table_scan)
910            .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())?
911            .build()?;
912
913        assert_optimized_plan_equal!(
914            plan,
915            @ r"
916        Filter: test.d IS DISTINCT FROM Int32(10)
917          TableScan: test
918        "
919        )
920    }
921
922    #[test]
923    fn simplify_equijoin_predicate() -> Result<()> {
924        let t1 = test_table_scan_with_name("t1")?;
925        let t2 = test_table_scan_with_name("t2")?;
926
927        let left_key = col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, t1.schema())?;
928        let right_key =
929            col("t2.a") + lit(2i64).cast_to(&DataType::UInt32, t2.schema())?;
930        let plan = LogicalPlanBuilder::from(t1)
931            .join_with_expr_keys(
932                t2,
933                JoinType::Inner,
934                (vec![left_key], vec![right_key]),
935                None,
936            )?
937            .build()?;
938
939        // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32)
940        // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2)
941        assert_optimized_plan_equal!(
942            plan,
943            @ r"
944        Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)
945          TableScan: t1
946          TableScan: t2
947        "
948        )
949    }
950
951    #[test]
952    fn simplify_is_not_null() -> Result<()> {
953        let table_scan = test_table_scan();
954
955        let plan = LogicalPlanBuilder::from(table_scan)
956            .filter(col("d").is_not_null())?
957            .build()?;
958
959        assert_optimized_plan_equal!(
960            plan,
961            @ r"
962        Filter: Boolean(true)
963          TableScan: test
964        "
965        )
966    }
967
968    #[test]
969    fn simplify_is_null() -> Result<()> {
970        let table_scan = test_table_scan();
971
972        let plan = LogicalPlanBuilder::from(table_scan)
973            .filter(col("d").is_null())?
974            .build()?;
975
976        assert_optimized_plan_equal!(
977            plan,
978            @ r"
979        Filter: Boolean(false)
980          TableScan: test
981        "
982        )
983    }
984
985    #[test]
986    fn simplify_grouping_sets() -> Result<()> {
987        let table_scan = test_table_scan();
988        let plan = LogicalPlanBuilder::from(table_scan)
989            .aggregate(
990                [grouping_set(vec![
991                    vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
992                    vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
993                    vec![col("d").alias("e"), (lit(1) + lit(2))],
994                ])],
995                [] as [Expr; 0],
996            )?
997            .build()?;
998
999        assert_optimized_plan_equal!(
1000            plan,
1001            @ r"
1002        Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]
1003          TableScan: test
1004        "
1005        )
1006    }
1007
1008    #[test]
1009    fn test_simplify_regex_special_cases() -> Result<()> {
1010        let schema = Schema::new(vec![
1011            Field::new("a", DataType::Utf8, true),
1012            Field::new("b", DataType::Utf8, false),
1013        ]);
1014        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1015
1016        // Test `~ ".*"` transforms to true for any non-NULL string
1017        let plan = LogicalPlanBuilder::from(table_scan.clone())
1018            .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))?
1019            .build()?;
1020
1021        assert_optimized_plan_equal!(
1022            plan,
1023            @ r"
1024        Filter: test.a IS NOT NULL
1025          TableScan: test
1026        "
1027        )?;
1028
1029        // Test `!~ ".*"` preserves NULL semantics while remaining false for non-NULL strings
1030        let plan = LogicalPlanBuilder::from(table_scan.clone())
1031            .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))?
1032            .build()?;
1033
1034        assert_optimized_plan_equal!(
1035            plan,
1036            @ r"
1037        Filter: test.a IS NULL AND Boolean(NULL)
1038          TableScan: test
1039        "
1040        )?;
1041
1042        // Test case-insensitive versions
1043
1044        // Test `~* ".*"` transforms to true for any non-NULL string
1045        let plan = LogicalPlanBuilder::from(table_scan.clone())
1046            .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))?
1047            .build()?;
1048
1049        assert_optimized_plan_equal!(
1050            plan,
1051            @ r"
1052        Filter: Boolean(true)
1053          TableScan: test
1054        "
1055        )?;
1056
1057        // Test NULL `!~ ".*"` transforms to Boolean(NULL)
1058        let plan = LogicalPlanBuilder::from(table_scan.clone())
1059            .filter(binary_expr(
1060                lit(ScalarValue::Utf8(None)),
1061                Operator::RegexNotMatch,
1062                lit(".*"),
1063            ))?
1064            .build()?;
1065
1066        assert_optimized_plan_equal!(
1067            plan,
1068            @ r"
1069        Filter: Boolean(NULL)
1070          TableScan: test
1071        "
1072        )?;
1073
1074        // Test `!~* ".*"` preserves NULL semantics while remaining false for non-NULL strings
1075        let plan = LogicalPlanBuilder::from(table_scan.clone())
1076            .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))?
1077            .build()?;
1078
1079        assert_optimized_plan_equal!(
1080            plan,
1081            @ r"
1082        Filter: test.a IS NULL AND Boolean(NULL)
1083          TableScan: test
1084        "
1085        )?;
1086
1087        // Test NULL `!~* ".*"` transforms to Boolean(NULL)
1088        let plan = LogicalPlanBuilder::from(table_scan.clone())
1089            .filter(binary_expr(
1090                lit(ScalarValue::Utf8(None)),
1091                Operator::RegexNotIMatch,
1092                lit(".*"),
1093            ))?
1094            .build()?;
1095
1096        assert_optimized_plan_equal!(
1097            plan,
1098            @ r"
1099        Filter: Boolean(NULL)
1100          TableScan: test
1101        "
1102        )
1103    }
1104
1105    #[test]
1106    fn simplify_not_in_list() -> Result<()> {
1107        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1108        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1109
1110        let plan = LogicalPlanBuilder::from(table_scan)
1111            .filter(col("a").in_list(vec![lit("a"), lit("b")], false).not())?
1112            .build()?;
1113
1114        assert_optimized_plan_equal!(
1115            plan,
1116            @ r#"
1117        Filter: test.a != Utf8("a") AND test.a != Utf8("b")
1118          TableScan: test
1119        "#
1120        )
1121    }
1122
1123    #[test]
1124    fn simplify_not_not_in_list() -> Result<()> {
1125        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1126        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1127
1128        let plan = LogicalPlanBuilder::from(table_scan)
1129            .filter(
1130                col("a")
1131                    .in_list(vec![lit("a"), lit("b")], false)
1132                    .not()
1133                    .not(),
1134            )?
1135            .build()?;
1136
1137        assert_optimized_plan_equal!(
1138            plan,
1139            @ r#"
1140        Filter: test.a = Utf8("a") OR test.a = Utf8("b")
1141          TableScan: test
1142        "#
1143        )
1144    }
1145
1146    #[test]
1147    fn simplify_not_exists() -> Result<()> {
1148        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1149        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1150        let table_scan2 =
1151            datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?;
1152
1153        let plan = LogicalPlanBuilder::from(table_scan)
1154            .filter(
1155                exists(Arc::new(LogicalPlanBuilder::from(table_scan2).build()?)).not(),
1156            )?
1157            .build()?;
1158
1159        assert_optimized_plan_equal!(
1160            plan,
1161            @ r"
1162        Filter: NOT EXISTS (<subquery>)
1163          Subquery:
1164            TableScan: test2
1165          TableScan: test
1166        "
1167        )
1168    }
1169
1170    #[test]
1171    fn simplify_not_not_exists() -> Result<()> {
1172        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1173        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1174        let table_scan2 =
1175            datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?;
1176
1177        let plan = LogicalPlanBuilder::from(table_scan)
1178            .filter(
1179                exists(Arc::new(LogicalPlanBuilder::from(table_scan2).build()?))
1180                    .not()
1181                    .not(),
1182            )?
1183            .build()?;
1184
1185        assert_optimized_plan_equal!(
1186            plan,
1187            @ r"
1188        Filter: EXISTS (<subquery>)
1189          Subquery:
1190            TableScan: test2
1191          TableScan: test
1192        "
1193        )
1194    }
1195
1196    #[test]
1197    fn simplify_not_in_subquery() -> Result<()> {
1198        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1199        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1200        let table_scan2 =
1201            datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?;
1202
1203        let plan = LogicalPlanBuilder::from(table_scan)
1204            .filter(
1205                in_subquery(
1206                    col("a"),
1207                    Arc::new(LogicalPlanBuilder::from(table_scan2).build()?),
1208                )
1209                .not(),
1210            )?
1211            .build()?;
1212
1213        assert_optimized_plan_equal!(
1214            plan,
1215            @ r"
1216        Filter: test.a NOT IN (<subquery>)
1217          Subquery:
1218            TableScan: test2
1219          TableScan: test
1220        "
1221        )
1222    }
1223
1224    #[test]
1225    fn simplify_not_not_in_subquery() -> Result<()> {
1226        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1227        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1228        let table_scan2 =
1229            datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?;
1230
1231        let plan = LogicalPlanBuilder::from(table_scan)
1232            .filter(
1233                in_subquery(
1234                    col("a"),
1235                    Arc::new(LogicalPlanBuilder::from(table_scan2).build()?),
1236                )
1237                .not()
1238                .not(),
1239            )?
1240            .build()?;
1241
1242        assert_optimized_plan_equal!(
1243            plan,
1244            @ r"
1245        Filter: test.a IN (<subquery>)
1246          Subquery:
1247            TableScan: test2
1248          TableScan: test
1249        "
1250        )
1251    }
1252}