datafusion_optimizer/simplify_expressions/
simplify_exprs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplify expressions optimizer rule and implementation
19
20use std::sync::Arc;
21
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
24use datafusion_expr::execution_props::ExecutionProps;
25use datafusion_expr::logical_plan::LogicalPlan;
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::utils::merge_schema;
28use datafusion_expr::Expr;
29
30use crate::optimizer::ApplyOrder;
31use crate::utils::NamePreserver;
32use crate::{OptimizerConfig, OptimizerRule};
33
34use super::ExprSimplifier;
35
36/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting
37/// [`Expr`]`s evaluating constants and applying algebraic
38/// simplifications
39///
40/// # Introduction
41/// It uses boolean algebra laws to simplify or reduce the number of terms in expressions.
42///
43/// # Example:
44/// `Filter: b > 2 AND b > 2`
45/// is optimized to
46/// `Filter: b > 2`
47///
48/// [`Expr`]: datafusion_expr::Expr
49#[derive(Default, Debug)]
50pub struct SimplifyExpressions {}
51
52impl OptimizerRule for SimplifyExpressions {
53    fn name(&self) -> &str {
54        "simplify_expressions"
55    }
56
57    fn apply_order(&self) -> Option<ApplyOrder> {
58        Some(ApplyOrder::BottomUp)
59    }
60
61    fn supports_rewrite(&self) -> bool {
62        true
63    }
64
65    /// if supports_owned returns true, the Optimizer calls
66    /// [`Self::rewrite`] instead of [`Self::try_optimize`]
67    fn rewrite(
68        &self,
69        plan: LogicalPlan,
70        config: &dyn OptimizerConfig,
71    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
72        let mut execution_props = ExecutionProps::new();
73        execution_props.query_execution_start_time = config.query_execution_start_time();
74        Self::optimize_internal(plan, &execution_props)
75    }
76}
77
78impl SimplifyExpressions {
79    fn optimize_internal(
80        plan: LogicalPlan,
81        execution_props: &ExecutionProps,
82    ) -> Result<Transformed<LogicalPlan>> {
83        let schema = if !plan.inputs().is_empty() {
84            DFSchemaRef::new(merge_schema(&plan.inputs()))
85        } else if let LogicalPlan::TableScan(scan) = &plan {
86            // When predicates are pushed into a table scan, there is no input
87            // schema to resolve predicates against, so it must be handled specially
88            //
89            // Note that this is not `plan.schema()` which is the *output*
90            // schema, and reflects any pushed down projection. The output schema
91            // will not contain columns that *only* appear in pushed down predicates
92            // (and no where else) in the plan.
93            //
94            // Thus, use the full schema of the inner provider without any
95            // projection applied for simplification
96            Arc::new(DFSchema::try_from_qualified_schema(
97                scan.table_name.clone(),
98                &scan.source.schema(),
99            )?)
100        } else {
101            Arc::new(DFSchema::empty())
102        };
103
104        let info = SimplifyContext::new(execution_props).with_schema(schema);
105
106        // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
107        // Just need to rewrite our own expressions
108
109        let simplifier = ExprSimplifier::new(info);
110
111        // The left and right expressions in a Join on clause are not
112        // commutative, for reasons that are not entirely clear. Thus, do not
113        // reorder expressions in Join while simplifying.
114        //
115        // This is likely related to the fact that order of the columns must
116        // match the order of the children. see
117        // https://github.com/apache/datafusion/pull/8780 for more details
118        let simplifier = if let LogicalPlan::Join(_) = plan {
119            simplifier.with_canonicalize(false)
120        } else {
121            simplifier
122        };
123
124        // Preserve expression names to avoid changing the schema of the plan.
125        let name_preserver = NamePreserver::new(&plan);
126        let mut rewrite_expr = |expr: Expr| {
127            let name = name_preserver.save(&expr);
128            let expr = simplifier.simplify(expr)?;
129            // TODO it would be nice to have a way to know if the expression was simplified
130            // or not. For now conservatively return Transformed::yes
131            Ok(Transformed::yes(name.restore(expr)))
132        };
133
134        plan.map_expressions(|expr| {
135            // Preserve the aliasing of grouping sets.
136            if let Expr::GroupingSet(_) = &expr {
137                expr.map_children(&mut rewrite_expr)
138            } else {
139                rewrite_expr(expr)
140            }
141        })
142    }
143}
144
145impl SimplifyExpressions {
146    #[allow(missing_docs)]
147    pub fn new() -> Self {
148        Self {}
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use std::ops::Not;
155
156    use arrow::datatypes::{DataType, Field, Schema};
157    use chrono::{DateTime, Utc};
158
159    use crate::optimizer::Optimizer;
160    use datafusion_expr::logical_plan::builder::table_scan_with_filters;
161    use datafusion_expr::logical_plan::table_scan;
162    use datafusion_expr::*;
163    use datafusion_functions_aggregate::expr_fn::{max, min};
164
165    use crate::test::{assert_fields_eq, test_table_scan_with_name};
166    use crate::OptimizerContext;
167
168    use super::*;
169
170    fn test_table_scan() -> LogicalPlan {
171        let schema = Schema::new(vec![
172            Field::new("a", DataType::Boolean, false),
173            Field::new("b", DataType::Boolean, false),
174            Field::new("c", DataType::Boolean, false),
175            Field::new("d", DataType::UInt32, false),
176            Field::new("e", DataType::UInt32, true),
177        ]);
178        table_scan(Some("test"), &schema, None)
179            .expect("creating scan")
180            .build()
181            .expect("building plan")
182    }
183
184    fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
185        // Use Optimizer to do plan traversal
186        fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
187        let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
188        let optimized_plan =
189            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
190        let formatted_plan = format!("{optimized_plan}");
191        assert_eq!(formatted_plan, expected);
192        Ok(())
193    }
194
195    #[test]
196    fn test_simplify_table_full_filter_in_scan() -> Result<()> {
197        let fields = vec![
198            Field::new("a", DataType::UInt32, false),
199            Field::new("b", DataType::UInt32, false),
200            Field::new("c", DataType::UInt32, false),
201        ];
202
203        let schema = Schema::new(fields);
204
205        let table_scan = table_scan_with_filters(
206            Some("test"),
207            &schema,
208            Some(vec![0]),
209            vec![col("b").is_not_null()],
210        )?
211        .build()?;
212        assert_eq!(1, table_scan.schema().fields().len());
213        assert_fields_eq(&table_scan, vec!["a"]);
214
215        let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]";
216
217        assert_optimized_plan_eq(table_scan, expected)
218    }
219
220    #[test]
221    fn test_simplify_filter_pushdown() -> Result<()> {
222        let table_scan = test_table_scan();
223        let plan = LogicalPlanBuilder::from(table_scan)
224            .project(vec![col("a")])?
225            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
226            .build()?;
227
228        assert_optimized_plan_eq(
229            plan,
230            "\
231	        Filter: test.b > Int32(1)\
232            \n  Projection: test.a\
233            \n    TableScan: test",
234        )
235    }
236
237    #[test]
238    fn test_simplify_optimized_plan() -> Result<()> {
239        let table_scan = test_table_scan();
240        let plan = LogicalPlanBuilder::from(table_scan)
241            .project(vec![col("a")])?
242            .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
243            .build()?;
244
245        assert_optimized_plan_eq(
246            plan,
247            "\
248	        Filter: test.b > Int32(1)\
249            \n  Projection: test.a\
250            \n    TableScan: test",
251        )
252    }
253
254    #[test]
255    fn test_simplify_optimized_plan_with_or() -> Result<()> {
256        let table_scan = test_table_scan();
257        let plan = LogicalPlanBuilder::from(table_scan)
258            .project(vec![col("a")])?
259            .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))?
260            .build()?;
261
262        assert_optimized_plan_eq(
263            plan,
264            "\
265            Filter: test.b > Int32(1)\
266            \n  Projection: test.a\
267            \n    TableScan: test",
268        )
269    }
270
271    #[test]
272    fn test_simplify_optimized_plan_with_composed_and() -> Result<()> {
273        let table_scan = test_table_scan();
274        // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
275        let plan = LogicalPlanBuilder::from(table_scan)
276            .project(vec![col("a"), col("b")])?
277            .filter(and(
278                and(col("a").gt(lit(5)), col("b").lt(lit(6))),
279                col("a").gt(lit(5)),
280            ))?
281            .build()?;
282
283        assert_optimized_plan_eq(
284            plan,
285            "\
286            Filter: test.a > Int32(5) AND test.b < Int32(6)\
287            \n  Projection: test.a, test.b\
288	        \n    TableScan: test",
289        )
290    }
291
292    #[test]
293    fn test_simplify_optimized_plan_eq_expr() -> Result<()> {
294        let table_scan = test_table_scan();
295        let plan = LogicalPlanBuilder::from(table_scan)
296            .filter(col("b").eq(lit(true)))?
297            .filter(col("c").eq(lit(false)))?
298            .project(vec![col("a")])?
299            .build()?;
300
301        let expected = "\
302        Projection: test.a\
303        \n  Filter: NOT test.c\
304        \n    Filter: test.b\
305        \n      TableScan: test";
306
307        assert_optimized_plan_eq(plan, expected)
308    }
309
310    #[test]
311    fn test_simplify_optimized_plan_not_eq_expr() -> Result<()> {
312        let table_scan = test_table_scan();
313        let plan = LogicalPlanBuilder::from(table_scan)
314            .filter(col("b").not_eq(lit(true)))?
315            .filter(col("c").not_eq(lit(false)))?
316            .limit(0, Some(1))?
317            .project(vec![col("a")])?
318            .build()?;
319
320        let expected = "\
321        Projection: test.a\
322        \n  Limit: skip=0, fetch=1\
323        \n    Filter: test.c\
324        \n      Filter: NOT test.b\
325        \n        TableScan: test";
326
327        assert_optimized_plan_eq(plan, expected)
328    }
329
330    #[test]
331    fn test_simplify_optimized_plan_and_expr() -> Result<()> {
332        let table_scan = test_table_scan();
333        let plan = LogicalPlanBuilder::from(table_scan)
334            .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))?
335            .project(vec![col("a")])?
336            .build()?;
337
338        let expected = "\
339        Projection: test.a\
340        \n  Filter: NOT test.b AND test.c\
341        \n    TableScan: test";
342
343        assert_optimized_plan_eq(plan, expected)
344    }
345
346    #[test]
347    fn test_simplify_optimized_plan_or_expr() -> Result<()> {
348        let table_scan = test_table_scan();
349        let plan = LogicalPlanBuilder::from(table_scan)
350            .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))?
351            .project(vec![col("a")])?
352            .build()?;
353
354        let expected = "\
355        Projection: test.a\
356        \n  Filter: NOT test.b OR NOT test.c\
357        \n    TableScan: test";
358
359        assert_optimized_plan_eq(plan, expected)
360    }
361
362    #[test]
363    fn test_simplify_optimized_plan_not_expr() -> Result<()> {
364        let table_scan = test_table_scan();
365        let plan = LogicalPlanBuilder::from(table_scan)
366            .filter(col("b").eq(lit(false)).not())?
367            .project(vec![col("a")])?
368            .build()?;
369
370        let expected = "\
371        Projection: test.a\
372        \n  Filter: test.b\
373        \n    TableScan: test";
374
375        assert_optimized_plan_eq(plan, expected)
376    }
377
378    #[test]
379    fn test_simplify_optimized_plan_support_projection() -> Result<()> {
380        let table_scan = test_table_scan();
381        let plan = LogicalPlanBuilder::from(table_scan)
382            .project(vec![col("a"), col("d"), col("b").eq(lit(false))])?
383            .build()?;
384
385        let expected = "\
386        Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\
387        \n  TableScan: test";
388
389        assert_optimized_plan_eq(plan, expected)
390    }
391
392    #[test]
393    fn test_simplify_optimized_plan_support_aggregate() -> Result<()> {
394        let table_scan = test_table_scan();
395        let plan = LogicalPlanBuilder::from(table_scan)
396            .project(vec![col("a"), col("c"), col("b")])?
397            .aggregate(
398                vec![col("a"), col("c")],
399                vec![max(col("b").eq(lit(true))), min(col("b"))],
400            )?
401            .build()?;
402
403        let expected = "\
404        Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\
405        \n  Projection: test.a, test.c, test.b\
406        \n    TableScan: test";
407
408        assert_optimized_plan_eq(plan, expected)
409    }
410
411    #[test]
412    fn test_simplify_optimized_plan_support_values() -> Result<()> {
413        let expr1 = Expr::BinaryExpr(BinaryExpr::new(
414            Box::new(lit(1)),
415            Operator::Plus,
416            Box::new(lit(2)),
417        ));
418        let expr2 = Expr::BinaryExpr(BinaryExpr::new(
419            Box::new(lit(2)),
420            Operator::Minus,
421            Box::new(lit(1)),
422        ));
423        let values = vec![vec![expr1, expr2]];
424        let plan = LogicalPlanBuilder::values(values)?.build()?;
425
426        let expected = "\
427        Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))";
428
429        assert_optimized_plan_eq(plan, expected)
430    }
431
432    fn get_optimized_plan_formatted(
433        plan: LogicalPlan,
434        date_time: &DateTime<Utc>,
435    ) -> String {
436        let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
437        let rule = SimplifyExpressions::new();
438
439        let optimized_plan = rule.rewrite(plan, &config).unwrap().data;
440        format!("{optimized_plan}")
441    }
442
443    #[test]
444    fn cast_expr() -> Result<()> {
445        let table_scan = test_table_scan();
446        let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))];
447        let plan = LogicalPlanBuilder::from(table_scan)
448            .project(proj)?
449            .build()?;
450
451        let expected = "Projection: Int32(0) AS Utf8(\"0\")\
452            \n  TableScan: test";
453        let actual = get_optimized_plan_formatted(plan, &Utc::now());
454        assert_eq!(expected, actual);
455        Ok(())
456    }
457
458    #[test]
459    fn simplify_and_eval() -> Result<()> {
460        // demonstrate a case where the evaluation needs to run prior
461        // to the simplifier for it to work
462        let table_scan = test_table_scan();
463        let time = Utc::now();
464        // (true or false) != col --> !col
465        let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
466        let plan = LogicalPlanBuilder::from(table_scan)
467            .project(proj)?
468            .build()?;
469
470        let actual = get_optimized_plan_formatted(plan, &time);
471        let expected =
472            "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
473                        \n  TableScan: test";
474
475        assert_eq!(expected, actual);
476        Ok(())
477    }
478
479    #[test]
480    fn simplify_not_binary() -> Result<()> {
481        let table_scan = test_table_scan();
482
483        let plan = LogicalPlanBuilder::from(table_scan)
484            .filter(col("d").gt(lit(10)).not())?
485            .build()?;
486        let expected = "Filter: test.d <= Int32(10)\
487            \n  TableScan: test";
488
489        assert_optimized_plan_eq(plan, expected)
490    }
491
492    #[test]
493    fn simplify_not_bool_and() -> Result<()> {
494        let table_scan = test_table_scan();
495
496        let plan = LogicalPlanBuilder::from(table_scan)
497            .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())?
498            .build()?;
499        let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\
500        \n  TableScan: test";
501
502        assert_optimized_plan_eq(plan, expected)
503    }
504
505    #[test]
506    fn simplify_not_bool_or() -> Result<()> {
507        let table_scan = test_table_scan();
508
509        let plan = LogicalPlanBuilder::from(table_scan)
510            .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())?
511            .build()?;
512        let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\
513        \n  TableScan: test";
514
515        assert_optimized_plan_eq(plan, expected)
516    }
517
518    #[test]
519    fn simplify_not_not() -> Result<()> {
520        let table_scan = test_table_scan();
521
522        let plan = LogicalPlanBuilder::from(table_scan)
523            .filter(col("d").gt(lit(10)).not().not())?
524            .build()?;
525        let expected = "Filter: test.d > Int32(10)\
526        \n  TableScan: test";
527
528        assert_optimized_plan_eq(plan, expected)
529    }
530
531    #[test]
532    fn simplify_not_null() -> Result<()> {
533        let table_scan = test_table_scan();
534
535        let plan = LogicalPlanBuilder::from(table_scan)
536            .filter(col("e").is_null().not())?
537            .build()?;
538        let expected = "Filter: test.e IS NOT NULL\
539        \n  TableScan: test";
540
541        assert_optimized_plan_eq(plan, expected)
542    }
543
544    #[test]
545    fn simplify_not_not_null() -> Result<()> {
546        let table_scan = test_table_scan();
547
548        let plan = LogicalPlanBuilder::from(table_scan)
549            .filter(col("e").is_not_null().not())?
550            .build()?;
551        let expected = "Filter: test.e IS NULL\
552        \n  TableScan: test";
553
554        assert_optimized_plan_eq(plan, expected)
555    }
556
557    #[test]
558    fn simplify_not_in() -> Result<()> {
559        let table_scan = test_table_scan();
560
561        let plan = LogicalPlanBuilder::from(table_scan)
562            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
563            .build()?;
564        let expected =
565            "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
566        \n  TableScan: test";
567
568        assert_optimized_plan_eq(plan, expected)
569    }
570
571    #[test]
572    fn simplify_not_not_in() -> Result<()> {
573        let table_scan = test_table_scan();
574
575        let plan = LogicalPlanBuilder::from(table_scan)
576            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
577            .build()?;
578        let expected =
579            "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
580        \n  TableScan: test";
581
582        assert_optimized_plan_eq(plan, expected)
583    }
584
585    #[test]
586    fn simplify_not_between() -> Result<()> {
587        let table_scan = test_table_scan();
588        let qual = col("d").between(lit(1), lit(10));
589
590        let plan = LogicalPlanBuilder::from(table_scan)
591            .filter(qual.not())?
592            .build()?;
593        let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\
594        \n  TableScan: test";
595
596        assert_optimized_plan_eq(plan, expected)
597    }
598
599    #[test]
600    fn simplify_not_not_between() -> Result<()> {
601        let table_scan = test_table_scan();
602        let qual = col("d").not_between(lit(1), lit(10));
603
604        let plan = LogicalPlanBuilder::from(table_scan)
605            .filter(qual.not())?
606            .build()?;
607        let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\
608        \n  TableScan: test";
609
610        assert_optimized_plan_eq(plan, expected)
611    }
612
613    #[test]
614    fn simplify_not_like() -> Result<()> {
615        let schema = Schema::new(vec![
616            Field::new("a", DataType::Utf8, false),
617            Field::new("b", DataType::Utf8, false),
618        ]);
619        let table_scan = table_scan(Some("test"), &schema, None)
620            .expect("creating scan")
621            .build()
622            .expect("building plan");
623
624        let plan = LogicalPlanBuilder::from(table_scan)
625            .filter(col("a").like(col("b")).not())?
626            .build()?;
627        let expected = "Filter: test.a NOT LIKE test.b\
628        \n  TableScan: test";
629
630        assert_optimized_plan_eq(plan, expected)
631    }
632
633    #[test]
634    fn simplify_not_not_like() -> Result<()> {
635        let schema = Schema::new(vec![
636            Field::new("a", DataType::Utf8, false),
637            Field::new("b", DataType::Utf8, false),
638        ]);
639        let table_scan = table_scan(Some("test"), &schema, None)
640            .expect("creating scan")
641            .build()
642            .expect("building plan");
643
644        let plan = LogicalPlanBuilder::from(table_scan)
645            .filter(col("a").not_like(col("b")).not())?
646            .build()?;
647        let expected = "Filter: test.a LIKE test.b\
648        \n  TableScan: test";
649
650        assert_optimized_plan_eq(plan, expected)
651    }
652
653    #[test]
654    fn simplify_not_ilike() -> Result<()> {
655        let schema = Schema::new(vec![
656            Field::new("a", DataType::Utf8, false),
657            Field::new("b", DataType::Utf8, false),
658        ]);
659        let table_scan = table_scan(Some("test"), &schema, None)
660            .expect("creating scan")
661            .build()
662            .expect("building plan");
663
664        let plan = LogicalPlanBuilder::from(table_scan)
665            .filter(col("a").ilike(col("b")).not())?
666            .build()?;
667        let expected = "Filter: test.a NOT ILIKE test.b\
668        \n  TableScan: test";
669
670        assert_optimized_plan_eq(plan, expected)
671    }
672
673    #[test]
674    fn simplify_not_distinct_from() -> Result<()> {
675        let table_scan = test_table_scan();
676
677        let plan = LogicalPlanBuilder::from(table_scan)
678            .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())?
679            .build()?;
680        let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\
681        \n  TableScan: test";
682
683        assert_optimized_plan_eq(plan, expected)
684    }
685
686    #[test]
687    fn simplify_not_not_distinct_from() -> Result<()> {
688        let table_scan = test_table_scan();
689
690        let plan = LogicalPlanBuilder::from(table_scan)
691            .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())?
692            .build()?;
693        let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\
694        \n  TableScan: test";
695
696        assert_optimized_plan_eq(plan, expected)
697    }
698
699    #[test]
700    fn simplify_equijoin_predicate() -> Result<()> {
701        let t1 = test_table_scan_with_name("t1")?;
702        let t2 = test_table_scan_with_name("t2")?;
703
704        let left_key = col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, t1.schema())?;
705        let right_key =
706            col("t2.a") + lit(2i64).cast_to(&DataType::UInt32, t2.schema())?;
707        let plan = LogicalPlanBuilder::from(t1)
708            .join_with_expr_keys(
709                t2,
710                JoinType::Inner,
711                (vec![left_key], vec![right_key]),
712                None,
713            )?
714            .build()?;
715
716        // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32)
717        // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2)
718        let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\
719            \n  TableScan: t1\
720            \n  TableScan: t2";
721
722        assert_optimized_plan_eq(plan, expected)
723    }
724
725    #[test]
726    fn simplify_is_not_null() -> Result<()> {
727        let table_scan = test_table_scan();
728
729        let plan = LogicalPlanBuilder::from(table_scan)
730            .filter(col("d").is_not_null())?
731            .build()?;
732        let expected = "Filter: Boolean(true)\
733        \n  TableScan: test";
734
735        assert_optimized_plan_eq(plan, expected)
736    }
737
738    #[test]
739    fn simplify_is_null() -> Result<()> {
740        let table_scan = test_table_scan();
741
742        let plan = LogicalPlanBuilder::from(table_scan)
743            .filter(col("d").is_null())?
744            .build()?;
745        let expected = "Filter: Boolean(false)\
746        \n  TableScan: test";
747
748        assert_optimized_plan_eq(plan, expected)
749    }
750
751    #[test]
752    fn simplify_grouping_sets() -> Result<()> {
753        let table_scan = test_table_scan();
754        let plan = LogicalPlanBuilder::from(table_scan)
755            .aggregate(
756                [grouping_set(vec![
757                    vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
758                    vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
759                    vec![col("d").alias("e"), (lit(1) + lit(2))],
760                ])],
761                [] as [Expr; 0],
762            )?
763            .build()?;
764
765        let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
766        \n  TableScan: test";
767
768        assert_optimized_plan_eq(plan, expected)
769    }
770}