datafusion_optimizer/simplify_expressions/
simplify_exprs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplify expressions optimizer rule and implementation
19
20use std::sync::Arc;
21
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
24use datafusion_expr::execution_props::ExecutionProps;
25use datafusion_expr::logical_plan::LogicalPlan;
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::utils::merge_schema;
28use datafusion_expr::Expr;
29
30use crate::optimizer::ApplyOrder;
31use crate::utils::NamePreserver;
32use crate::{OptimizerConfig, OptimizerRule};
33
34use super::ExprSimplifier;
35
36/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting
37/// [`Expr`]`s evaluating constants and applying algebraic
38/// simplifications
39///
40/// # Introduction
41/// It uses boolean algebra laws to simplify or reduce the number of terms in expressions.
42///
43/// # Example:
44/// `Filter: b > 2 AND b > 2`
45/// is optimized to
46/// `Filter: b > 2`
47///
48/// [`Expr`]: datafusion_expr::Expr
49#[derive(Default, Debug)]
50pub struct SimplifyExpressions {}
51
52impl OptimizerRule for SimplifyExpressions {
53    fn name(&self) -> &str {
54        "simplify_expressions"
55    }
56
57    fn apply_order(&self) -> Option<ApplyOrder> {
58        Some(ApplyOrder::BottomUp)
59    }
60
61    fn supports_rewrite(&self) -> bool {
62        true
63    }
64
65    fn rewrite(
66        &self,
67        plan: LogicalPlan,
68        config: &dyn OptimizerConfig,
69    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
70        let mut execution_props = ExecutionProps::new();
71        execution_props.query_execution_start_time = config.query_execution_start_time();
72        Self::optimize_internal(plan, &execution_props)
73    }
74}
75
76impl SimplifyExpressions {
77    fn optimize_internal(
78        plan: LogicalPlan,
79        execution_props: &ExecutionProps,
80    ) -> Result<Transformed<LogicalPlan>> {
81        let schema = if !plan.inputs().is_empty() {
82            DFSchemaRef::new(merge_schema(&plan.inputs()))
83        } else if let LogicalPlan::TableScan(scan) = &plan {
84            // When predicates are pushed into a table scan, there is no input
85            // schema to resolve predicates against, so it must be handled specially
86            //
87            // Note that this is not `plan.schema()` which is the *output*
88            // schema, and reflects any pushed down projection. The output schema
89            // will not contain columns that *only* appear in pushed down predicates
90            // (and no where else) in the plan.
91            //
92            // Thus, use the full schema of the inner provider without any
93            // projection applied for simplification
94            Arc::new(DFSchema::try_from_qualified_schema(
95                scan.table_name.clone(),
96                &scan.source.schema(),
97            )?)
98        } else {
99            Arc::new(DFSchema::empty())
100        };
101
102        let info = SimplifyContext::new(execution_props).with_schema(schema);
103
104        // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
105        // Just need to rewrite our own expressions
106
107        let simplifier = ExprSimplifier::new(info);
108
109        // The left and right expressions in a Join on clause are not
110        // commutative, for reasons that are not entirely clear. Thus, do not
111        // reorder expressions in Join while simplifying.
112        //
113        // This is likely related to the fact that order of the columns must
114        // match the order of the children. see
115        // https://github.com/apache/datafusion/pull/8780 for more details
116        let simplifier = if let LogicalPlan::Join(_) = plan {
117            simplifier.with_canonicalize(false)
118        } else {
119            simplifier
120        };
121
122        // Preserve expression names to avoid changing the schema of the plan.
123        let name_preserver = NamePreserver::new(&plan);
124        let mut rewrite_expr = |expr: Expr| {
125            let name = name_preserver.save(&expr);
126            let expr = simplifier.simplify(expr)?;
127            // TODO it would be nice to have a way to know if the expression was simplified
128            // or not. For now conservatively return Transformed::yes
129            Ok(Transformed::yes(name.restore(expr)))
130        };
131
132        plan.map_expressions(|expr| {
133            // Preserve the aliasing of grouping sets.
134            if let Expr::GroupingSet(_) = &expr {
135                expr.map_children(&mut rewrite_expr)
136            } else {
137                rewrite_expr(expr)
138            }
139        })
140    }
141}
142
143impl SimplifyExpressions {
144    #[allow(missing_docs)]
145    pub fn new() -> Self {
146        Self {}
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::ops::Not;
153
154    use arrow::datatypes::{DataType, Field, Schema};
155    use chrono::{DateTime, Utc};
156
157    use crate::optimizer::Optimizer;
158    use datafusion_expr::logical_plan::builder::table_scan_with_filters;
159    use datafusion_expr::logical_plan::table_scan;
160    use datafusion_expr::*;
161    use datafusion_functions_aggregate::expr_fn::{max, min};
162
163    use crate::test::{assert_fields_eq, test_table_scan_with_name};
164    use crate::OptimizerContext;
165
166    use super::*;
167
168    fn test_table_scan() -> LogicalPlan {
169        let schema = Schema::new(vec![
170            Field::new("a", DataType::Boolean, false),
171            Field::new("b", DataType::Boolean, false),
172            Field::new("c", DataType::Boolean, false),
173            Field::new("d", DataType::UInt32, false),
174            Field::new("e", DataType::UInt32, true),
175        ]);
176        table_scan(Some("test"), &schema, None)
177            .expect("creating scan")
178            .build()
179            .expect("building plan")
180    }
181
182    fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
183        // Use Optimizer to do plan traversal
184        fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
185        let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
186        let optimized_plan =
187            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
188        let formatted_plan = format!("{optimized_plan}");
189        assert_eq!(formatted_plan, expected);
190        Ok(())
191    }
192
193    #[test]
194    fn test_simplify_table_full_filter_in_scan() -> Result<()> {
195        let fields = vec![
196            Field::new("a", DataType::UInt32, false),
197            Field::new("b", DataType::UInt32, false),
198            Field::new("c", DataType::UInt32, false),
199        ];
200
201        let schema = Schema::new(fields);
202
203        let table_scan = table_scan_with_filters(
204            Some("test"),
205            &schema,
206            Some(vec![0]),
207            vec![col("b").is_not_null()],
208        )?
209        .build()?;
210        assert_eq!(1, table_scan.schema().fields().len());
211        assert_fields_eq(&table_scan, vec!["a"]);
212
213        let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]";
214
215        assert_optimized_plan_eq(table_scan, expected)
216    }
217
218    #[test]
219    fn test_simplify_filter_pushdown() -> Result<()> {
220        let table_scan = test_table_scan();
221        let plan = LogicalPlanBuilder::from(table_scan)
222            .project(vec![col("a")])?
223            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
224            .build()?;
225
226        assert_optimized_plan_eq(
227            plan,
228            "\
229	        Filter: test.b > Int32(1)\
230            \n  Projection: test.a\
231            \n    TableScan: test",
232        )
233    }
234
235    #[test]
236    fn test_simplify_optimized_plan() -> Result<()> {
237        let table_scan = test_table_scan();
238        let plan = LogicalPlanBuilder::from(table_scan)
239            .project(vec![col("a")])?
240            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
241            .build()?;
242
243        assert_optimized_plan_eq(
244            plan,
245            "\
246	        Filter: test.b > Int32(1)\
247            \n  Projection: test.a\
248            \n    TableScan: test",
249        )
250    }
251
252    #[test]
253    fn test_simplify_optimized_plan_with_or() -> Result<()> {
254        let table_scan = test_table_scan();
255        let plan = LogicalPlanBuilder::from(table_scan)
256            .project(vec![col("a")])?
257            .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))?
258            .build()?;
259
260        assert_optimized_plan_eq(
261            plan,
262            "\
263            Filter: test.b > Int32(1)\
264            \n  Projection: test.a\
265            \n    TableScan: test",
266        )
267    }
268
269    #[test]
270    fn test_simplify_optimized_plan_with_composed_and() -> Result<()> {
271        let table_scan = test_table_scan();
272        // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
273        let plan = LogicalPlanBuilder::from(table_scan)
274            .project(vec![col("a"), col("b")])?
275            .filter(and(
276                and(col("a").gt(lit(5)), col("b").lt(lit(6))),
277                col("a").gt(lit(5)),
278            ))?
279            .build()?;
280
281        assert_optimized_plan_eq(
282            plan,
283            "\
284            Filter: test.a > Int32(5) AND test.b < Int32(6)\
285            \n  Projection: test.a, test.b\
286	        \n    TableScan: test",
287        )
288    }
289
290    #[test]
291    fn test_simplify_optimized_plan_eq_expr() -> Result<()> {
292        let table_scan = test_table_scan();
293        let plan = LogicalPlanBuilder::from(table_scan)
294            .filter(col("b").eq(lit(true)))?
295            .filter(col("c").eq(lit(false)))?
296            .project(vec![col("a")])?
297            .build()?;
298
299        let expected = "\
300        Projection: test.a\
301        \n  Filter: NOT test.c\
302        \n    Filter: test.b\
303        \n      TableScan: test";
304
305        assert_optimized_plan_eq(plan, expected)
306    }
307
308    #[test]
309    fn test_simplify_optimized_plan_not_eq_expr() -> Result<()> {
310        let table_scan = test_table_scan();
311        let plan = LogicalPlanBuilder::from(table_scan)
312            .filter(col("b").not_eq(lit(true)))?
313            .filter(col("c").not_eq(lit(false)))?
314            .limit(0, Some(1))?
315            .project(vec![col("a")])?
316            .build()?;
317
318        let expected = "\
319        Projection: test.a\
320        \n  Limit: skip=0, fetch=1\
321        \n    Filter: test.c\
322        \n      Filter: NOT test.b\
323        \n        TableScan: test";
324
325        assert_optimized_plan_eq(plan, expected)
326    }
327
328    #[test]
329    fn test_simplify_optimized_plan_and_expr() -> Result<()> {
330        let table_scan = test_table_scan();
331        let plan = LogicalPlanBuilder::from(table_scan)
332            .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))?
333            .project(vec![col("a")])?
334            .build()?;
335
336        let expected = "\
337        Projection: test.a\
338        \n  Filter: NOT test.b AND test.c\
339        \n    TableScan: test";
340
341        assert_optimized_plan_eq(plan, expected)
342    }
343
344    #[test]
345    fn test_simplify_optimized_plan_or_expr() -> Result<()> {
346        let table_scan = test_table_scan();
347        let plan = LogicalPlanBuilder::from(table_scan)
348            .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))?
349            .project(vec![col("a")])?
350            .build()?;
351
352        let expected = "\
353        Projection: test.a\
354        \n  Filter: NOT test.b OR NOT test.c\
355        \n    TableScan: test";
356
357        assert_optimized_plan_eq(plan, expected)
358    }
359
360    #[test]
361    fn test_simplify_optimized_plan_not_expr() -> Result<()> {
362        let table_scan = test_table_scan();
363        let plan = LogicalPlanBuilder::from(table_scan)
364            .filter(col("b").eq(lit(false)).not())?
365            .project(vec![col("a")])?
366            .build()?;
367
368        let expected = "\
369        Projection: test.a\
370        \n  Filter: test.b\
371        \n    TableScan: test";
372
373        assert_optimized_plan_eq(plan, expected)
374    }
375
376    #[test]
377    fn test_simplify_optimized_plan_support_projection() -> Result<()> {
378        let table_scan = test_table_scan();
379        let plan = LogicalPlanBuilder::from(table_scan)
380            .project(vec![col("a"), col("d"), col("b").eq(lit(false))])?
381            .build()?;
382
383        let expected = "\
384        Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\
385        \n  TableScan: test";
386
387        assert_optimized_plan_eq(plan, expected)
388    }
389
390    #[test]
391    fn test_simplify_optimized_plan_support_aggregate() -> Result<()> {
392        let table_scan = test_table_scan();
393        let plan = LogicalPlanBuilder::from(table_scan)
394            .project(vec![col("a"), col("c"), col("b")])?
395            .aggregate(
396                vec![col("a"), col("c")],
397                vec![max(col("b").eq(lit(true))), min(col("b"))],
398            )?
399            .build()?;
400
401        let expected = "\
402        Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\
403        \n  Projection: test.a, test.c, test.b\
404        \n    TableScan: test";
405
406        assert_optimized_plan_eq(plan, expected)
407    }
408
409    #[test]
410    fn test_simplify_optimized_plan_support_values() -> Result<()> {
411        let expr1 = Expr::BinaryExpr(BinaryExpr::new(
412            Box::new(lit(1)),
413            Operator::Plus,
414            Box::new(lit(2)),
415        ));
416        let expr2 = Expr::BinaryExpr(BinaryExpr::new(
417            Box::new(lit(2)),
418            Operator::Minus,
419            Box::new(lit(1)),
420        ));
421        let values = vec![vec![expr1, expr2]];
422        let plan = LogicalPlanBuilder::values(values)?.build()?;
423
424        let expected = "\
425        Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))";
426
427        assert_optimized_plan_eq(plan, expected)
428    }
429
430    fn get_optimized_plan_formatted(
431        plan: LogicalPlan,
432        date_time: &DateTime<Utc>,
433    ) -> String {
434        let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
435        let rule = SimplifyExpressions::new();
436
437        let optimized_plan = rule.rewrite(plan, &config).unwrap().data;
438        format!("{optimized_plan}")
439    }
440
441    #[test]
442    fn cast_expr() -> Result<()> {
443        let table_scan = test_table_scan();
444        let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))];
445        let plan = LogicalPlanBuilder::from(table_scan)
446            .project(proj)?
447            .build()?;
448
449        let expected = "Projection: Int32(0) AS Utf8(\"0\")\
450            \n  TableScan: test";
451        let actual = get_optimized_plan_formatted(plan, &Utc::now());
452        assert_eq!(expected, actual);
453        Ok(())
454    }
455
456    #[test]
457    fn simplify_and_eval() -> Result<()> {
458        // demonstrate a case where the evaluation needs to run prior
459        // to the simplifier for it to work
460        let table_scan = test_table_scan();
461        let time = Utc::now();
462        // (true or false) != col --> !col
463        let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
464        let plan = LogicalPlanBuilder::from(table_scan)
465            .project(proj)?
466            .build()?;
467
468        let actual = get_optimized_plan_formatted(plan, &time);
469        let expected =
470            "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
471                        \n  TableScan: test";
472
473        assert_eq!(expected, actual);
474        Ok(())
475    }
476
477    #[test]
478    fn simplify_not_binary() -> Result<()> {
479        let table_scan = test_table_scan();
480
481        let plan = LogicalPlanBuilder::from(table_scan)
482            .filter(col("d").gt(lit(10)).not())?
483            .build()?;
484        let expected = "Filter: test.d <= Int32(10)\
485            \n  TableScan: test";
486
487        assert_optimized_plan_eq(plan, expected)
488    }
489
490    #[test]
491    fn simplify_not_bool_and() -> Result<()> {
492        let table_scan = test_table_scan();
493
494        let plan = LogicalPlanBuilder::from(table_scan)
495            .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())?
496            .build()?;
497        let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\
498        \n  TableScan: test";
499
500        assert_optimized_plan_eq(plan, expected)
501    }
502
503    #[test]
504    fn simplify_not_bool_or() -> Result<()> {
505        let table_scan = test_table_scan();
506
507        let plan = LogicalPlanBuilder::from(table_scan)
508            .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())?
509            .build()?;
510        let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\
511        \n  TableScan: test";
512
513        assert_optimized_plan_eq(plan, expected)
514    }
515
516    #[test]
517    fn simplify_not_not() -> Result<()> {
518        let table_scan = test_table_scan();
519
520        let plan = LogicalPlanBuilder::from(table_scan)
521            .filter(col("d").gt(lit(10)).not().not())?
522            .build()?;
523        let expected = "Filter: test.d > Int32(10)\
524        \n  TableScan: test";
525
526        assert_optimized_plan_eq(plan, expected)
527    }
528
529    #[test]
530    fn simplify_not_null() -> Result<()> {
531        let table_scan = test_table_scan();
532
533        let plan = LogicalPlanBuilder::from(table_scan)
534            .filter(col("e").is_null().not())?
535            .build()?;
536        let expected = "Filter: test.e IS NOT NULL\
537        \n  TableScan: test";
538
539        assert_optimized_plan_eq(plan, expected)
540    }
541
542    #[test]
543    fn simplify_not_not_null() -> Result<()> {
544        let table_scan = test_table_scan();
545
546        let plan = LogicalPlanBuilder::from(table_scan)
547            .filter(col("e").is_not_null().not())?
548            .build()?;
549        let expected = "Filter: test.e IS NULL\
550        \n  TableScan: test";
551
552        assert_optimized_plan_eq(plan, expected)
553    }
554
555    #[test]
556    fn simplify_not_in() -> Result<()> {
557        let table_scan = test_table_scan();
558
559        let plan = LogicalPlanBuilder::from(table_scan)
560            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
561            .build()?;
562        let expected =
563            "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
564        \n  TableScan: test";
565
566        assert_optimized_plan_eq(plan, expected)
567    }
568
569    #[test]
570    fn simplify_not_not_in() -> Result<()> {
571        let table_scan = test_table_scan();
572
573        let plan = LogicalPlanBuilder::from(table_scan)
574            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
575            .build()?;
576        let expected =
577            "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
578        \n  TableScan: test";
579
580        assert_optimized_plan_eq(plan, expected)
581    }
582
583    #[test]
584    fn simplify_not_between() -> Result<()> {
585        let table_scan = test_table_scan();
586        let qual = col("d").between(lit(1), lit(10));
587
588        let plan = LogicalPlanBuilder::from(table_scan)
589            .filter(qual.not())?
590            .build()?;
591        let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\
592        \n  TableScan: test";
593
594        assert_optimized_plan_eq(plan, expected)
595    }
596
597    #[test]
598    fn simplify_not_not_between() -> Result<()> {
599        let table_scan = test_table_scan();
600        let qual = col("d").not_between(lit(1), lit(10));
601
602        let plan = LogicalPlanBuilder::from(table_scan)
603            .filter(qual.not())?
604            .build()?;
605        let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\
606        \n  TableScan: test";
607
608        assert_optimized_plan_eq(plan, expected)
609    }
610
611    #[test]
612    fn simplify_not_like() -> Result<()> {
613        let schema = Schema::new(vec![
614            Field::new("a", DataType::Utf8, false),
615            Field::new("b", DataType::Utf8, false),
616        ]);
617        let table_scan = table_scan(Some("test"), &schema, None)
618            .expect("creating scan")
619            .build()
620            .expect("building plan");
621
622        let plan = LogicalPlanBuilder::from(table_scan)
623            .filter(col("a").like(col("b")).not())?
624            .build()?;
625        let expected = "Filter: test.a NOT LIKE test.b\
626        \n  TableScan: test";
627
628        assert_optimized_plan_eq(plan, expected)
629    }
630
631    #[test]
632    fn simplify_not_not_like() -> Result<()> {
633        let schema = Schema::new(vec![
634            Field::new("a", DataType::Utf8, false),
635            Field::new("b", DataType::Utf8, false),
636        ]);
637        let table_scan = table_scan(Some("test"), &schema, None)
638            .expect("creating scan")
639            .build()
640            .expect("building plan");
641
642        let plan = LogicalPlanBuilder::from(table_scan)
643            .filter(col("a").not_like(col("b")).not())?
644            .build()?;
645        let expected = "Filter: test.a LIKE test.b\
646        \n  TableScan: test";
647
648        assert_optimized_plan_eq(plan, expected)
649    }
650
651    #[test]
652    fn simplify_not_ilike() -> Result<()> {
653        let schema = Schema::new(vec![
654            Field::new("a", DataType::Utf8, false),
655            Field::new("b", DataType::Utf8, false),
656        ]);
657        let table_scan = table_scan(Some("test"), &schema, None)
658            .expect("creating scan")
659            .build()
660            .expect("building plan");
661
662        let plan = LogicalPlanBuilder::from(table_scan)
663            .filter(col("a").ilike(col("b")).not())?
664            .build()?;
665        let expected = "Filter: test.a NOT ILIKE test.b\
666        \n  TableScan: test";
667
668        assert_optimized_plan_eq(plan, expected)
669    }
670
671    #[test]
672    fn simplify_not_distinct_from() -> Result<()> {
673        let table_scan = test_table_scan();
674
675        let plan = LogicalPlanBuilder::from(table_scan)
676            .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())?
677            .build()?;
678        let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\
679        \n  TableScan: test";
680
681        assert_optimized_plan_eq(plan, expected)
682    }
683
684    #[test]
685    fn simplify_not_not_distinct_from() -> Result<()> {
686        let table_scan = test_table_scan();
687
688        let plan = LogicalPlanBuilder::from(table_scan)
689            .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())?
690            .build()?;
691        let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\
692        \n  TableScan: test";
693
694        assert_optimized_plan_eq(plan, expected)
695    }
696
697    #[test]
698    fn simplify_equijoin_predicate() -> Result<()> {
699        let t1 = test_table_scan_with_name("t1")?;
700        let t2 = test_table_scan_with_name("t2")?;
701
702        let left_key = col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, t1.schema())?;
703        let right_key =
704            col("t2.a") + lit(2i64).cast_to(&DataType::UInt32, t2.schema())?;
705        let plan = LogicalPlanBuilder::from(t1)
706            .join_with_expr_keys(
707                t2,
708                JoinType::Inner,
709                (vec![left_key], vec![right_key]),
710                None,
711            )?
712            .build()?;
713
714        // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32)
715        // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2)
716        let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\
717            \n  TableScan: t1\
718            \n  TableScan: t2";
719
720        assert_optimized_plan_eq(plan, expected)
721    }
722
723    #[test]
724    fn simplify_is_not_null() -> Result<()> {
725        let table_scan = test_table_scan();
726
727        let plan = LogicalPlanBuilder::from(table_scan)
728            .filter(col("d").is_not_null())?
729            .build()?;
730        let expected = "Filter: Boolean(true)\
731        \n  TableScan: test";
732
733        assert_optimized_plan_eq(plan, expected)
734    }
735
736    #[test]
737    fn simplify_is_null() -> Result<()> {
738        let table_scan = test_table_scan();
739
740        let plan = LogicalPlanBuilder::from(table_scan)
741            .filter(col("d").is_null())?
742            .build()?;
743        let expected = "Filter: Boolean(false)\
744        \n  TableScan: test";
745
746        assert_optimized_plan_eq(plan, expected)
747    }
748
749    #[test]
750    fn simplify_grouping_sets() -> Result<()> {
751        let table_scan = test_table_scan();
752        let plan = LogicalPlanBuilder::from(table_scan)
753            .aggregate(
754                [grouping_set(vec![
755                    vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
756                    vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
757                    vec![col("d").alias("e"), (lit(1) + lit(2))],
758                ])],
759                [] as [Expr; 0],
760            )?
761            .build()?;
762
763        let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
764        \n  TableScan: test";
765
766        assert_optimized_plan_eq(plan, expected)
767    }
768
769    #[test]
770    fn test_simplify_regex_special_cases() -> Result<()> {
771        let schema = Schema::new(vec![
772            Field::new("a", DataType::Utf8, true),
773            Field::new("b", DataType::Utf8, false),
774        ]);
775        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
776
777        // Test `= ".*"` transforms to true (except for empty strings)
778        let plan = LogicalPlanBuilder::from(table_scan.clone())
779            .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))?
780            .build()?;
781        let expected = "Filter: test.a IS NOT NULL\
782        \n  TableScan: test";
783
784        assert_optimized_plan_eq(plan, expected)?;
785
786        // Test `!= ".*"` transforms to checking if the column is empty
787        let plan = LogicalPlanBuilder::from(table_scan.clone())
788            .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))?
789            .build()?;
790        let expected = "Filter: test.a = Utf8(\"\")\
791        \n  TableScan: test";
792
793        assert_optimized_plan_eq(plan, expected)?;
794
795        // Test case-insensitive versions
796
797        // Test `=~ ".*"` (case-insensitive) transforms to true (except for empty strings)
798        let plan = LogicalPlanBuilder::from(table_scan.clone())
799            .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))?
800            .build()?;
801        let expected = "Filter: Boolean(true)\
802        \n  TableScan: test";
803
804        assert_optimized_plan_eq(plan, expected)?;
805
806        // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty
807        let plan = LogicalPlanBuilder::from(table_scan.clone())
808            .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))?
809            .build()?;
810        let expected = "Filter: test.a = Utf8(\"\")\
811        \n  TableScan: test";
812
813        assert_optimized_plan_eq(plan, expected)
814    }
815}