1use std::sync::Arc;
21
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
24use datafusion_expr::execution_props::ExecutionProps;
25use datafusion_expr::logical_plan::LogicalPlan;
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::utils::merge_schema;
28use datafusion_expr::Expr;
29
30use crate::optimizer::ApplyOrder;
31use crate::utils::NamePreserver;
32use crate::{OptimizerConfig, OptimizerRule};
33
34use super::ExprSimplifier;
35
36#[derive(Default, Debug)]
50pub struct SimplifyExpressions {}
51
52impl OptimizerRule for SimplifyExpressions {
53 fn name(&self) -> &str {
54 "simplify_expressions"
55 }
56
57 fn apply_order(&self) -> Option<ApplyOrder> {
58 Some(ApplyOrder::BottomUp)
59 }
60
61 fn supports_rewrite(&self) -> bool {
62 true
63 }
64
65 fn rewrite(
68 &self,
69 plan: LogicalPlan,
70 config: &dyn OptimizerConfig,
71 ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
72 let mut execution_props = ExecutionProps::new();
73 execution_props.query_execution_start_time = config.query_execution_start_time();
74 Self::optimize_internal(plan, &execution_props)
75 }
76}
77
78impl SimplifyExpressions {
79 fn optimize_internal(
80 plan: LogicalPlan,
81 execution_props: &ExecutionProps,
82 ) -> Result<Transformed<LogicalPlan>> {
83 let schema = if !plan.inputs().is_empty() {
84 DFSchemaRef::new(merge_schema(&plan.inputs()))
85 } else if let LogicalPlan::TableScan(scan) = &plan {
86 Arc::new(DFSchema::try_from_qualified_schema(
97 scan.table_name.clone(),
98 &scan.source.schema(),
99 )?)
100 } else {
101 Arc::new(DFSchema::empty())
102 };
103
104 let info = SimplifyContext::new(execution_props).with_schema(schema);
105
106 let simplifier = ExprSimplifier::new(info);
110
111 let simplifier = if let LogicalPlan::Join(_) = plan {
119 simplifier.with_canonicalize(false)
120 } else {
121 simplifier
122 };
123
124 let name_preserver = NamePreserver::new(&plan);
126 let mut rewrite_expr = |expr: Expr| {
127 let name = name_preserver.save(&expr);
128 let expr = simplifier.simplify(expr)?;
129 Ok(Transformed::yes(name.restore(expr)))
132 };
133
134 plan.map_expressions(|expr| {
135 if let Expr::GroupingSet(_) = &expr {
137 expr.map_children(&mut rewrite_expr)
138 } else {
139 rewrite_expr(expr)
140 }
141 })
142 }
143}
144
145impl SimplifyExpressions {
146 #[allow(missing_docs)]
147 pub fn new() -> Self {
148 Self {}
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use std::ops::Not;
155
156 use arrow::datatypes::{DataType, Field, Schema};
157 use chrono::{DateTime, Utc};
158
159 use crate::optimizer::Optimizer;
160 use datafusion_expr::logical_plan::builder::table_scan_with_filters;
161 use datafusion_expr::logical_plan::table_scan;
162 use datafusion_expr::*;
163 use datafusion_functions_aggregate::expr_fn::{max, min};
164
165 use crate::test::{assert_fields_eq, test_table_scan_with_name};
166 use crate::OptimizerContext;
167
168 use super::*;
169
170 fn test_table_scan() -> LogicalPlan {
171 let schema = Schema::new(vec![
172 Field::new("a", DataType::Boolean, false),
173 Field::new("b", DataType::Boolean, false),
174 Field::new("c", DataType::Boolean, false),
175 Field::new("d", DataType::UInt32, false),
176 Field::new("e", DataType::UInt32, true),
177 ]);
178 table_scan(Some("test"), &schema, None)
179 .expect("creating scan")
180 .build()
181 .expect("building plan")
182 }
183
184 fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
185 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
187 let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]);
188 let optimized_plan =
189 optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
190 let formatted_plan = format!("{optimized_plan}");
191 assert_eq!(formatted_plan, expected);
192 Ok(())
193 }
194
195 #[test]
196 fn test_simplify_table_full_filter_in_scan() -> Result<()> {
197 let fields = vec![
198 Field::new("a", DataType::UInt32, false),
199 Field::new("b", DataType::UInt32, false),
200 Field::new("c", DataType::UInt32, false),
201 ];
202
203 let schema = Schema::new(fields);
204
205 let table_scan = table_scan_with_filters(
206 Some("test"),
207 &schema,
208 Some(vec![0]),
209 vec![col("b").is_not_null()],
210 )?
211 .build()?;
212 assert_eq!(1, table_scan.schema().fields().len());
213 assert_fields_eq(&table_scan, vec!["a"]);
214
215 let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]";
216
217 assert_optimized_plan_eq(table_scan, expected)
218 }
219
220 #[test]
221 fn test_simplify_filter_pushdown() -> Result<()> {
222 let table_scan = test_table_scan();
223 let plan = LogicalPlanBuilder::from(table_scan)
224 .project(vec![col("a")])?
225 .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
226 .build()?;
227
228 assert_optimized_plan_eq(
229 plan,
230 "\
231 Filter: test.b > Int32(1)\
232 \n Projection: test.a\
233 \n TableScan: test",
234 )
235 }
236
237 #[test]
238 fn test_simplify_optimized_plan() -> Result<()> {
239 let table_scan = test_table_scan();
240 let plan = LogicalPlanBuilder::from(table_scan)
241 .project(vec![col("a")])?
242 .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))?
243 .build()?;
244
245 assert_optimized_plan_eq(
246 plan,
247 "\
248 Filter: test.b > Int32(1)\
249 \n Projection: test.a\
250 \n TableScan: test",
251 )
252 }
253
254 #[test]
255 fn test_simplify_optimized_plan_with_or() -> Result<()> {
256 let table_scan = test_table_scan();
257 let plan = LogicalPlanBuilder::from(table_scan)
258 .project(vec![col("a")])?
259 .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))?
260 .build()?;
261
262 assert_optimized_plan_eq(
263 plan,
264 "\
265 Filter: test.b > Int32(1)\
266 \n Projection: test.a\
267 \n TableScan: test",
268 )
269 }
270
271 #[test]
272 fn test_simplify_optimized_plan_with_composed_and() -> Result<()> {
273 let table_scan = test_table_scan();
274 let plan = LogicalPlanBuilder::from(table_scan)
276 .project(vec![col("a"), col("b")])?
277 .filter(and(
278 and(col("a").gt(lit(5)), col("b").lt(lit(6))),
279 col("a").gt(lit(5)),
280 ))?
281 .build()?;
282
283 assert_optimized_plan_eq(
284 plan,
285 "\
286 Filter: test.a > Int32(5) AND test.b < Int32(6)\
287 \n Projection: test.a, test.b\
288 \n TableScan: test",
289 )
290 }
291
292 #[test]
293 fn test_simplify_optimized_plan_eq_expr() -> Result<()> {
294 let table_scan = test_table_scan();
295 let plan = LogicalPlanBuilder::from(table_scan)
296 .filter(col("b").eq(lit(true)))?
297 .filter(col("c").eq(lit(false)))?
298 .project(vec![col("a")])?
299 .build()?;
300
301 let expected = "\
302 Projection: test.a\
303 \n Filter: NOT test.c\
304 \n Filter: test.b\
305 \n TableScan: test";
306
307 assert_optimized_plan_eq(plan, expected)
308 }
309
310 #[test]
311 fn test_simplify_optimized_plan_not_eq_expr() -> Result<()> {
312 let table_scan = test_table_scan();
313 let plan = LogicalPlanBuilder::from(table_scan)
314 .filter(col("b").not_eq(lit(true)))?
315 .filter(col("c").not_eq(lit(false)))?
316 .limit(0, Some(1))?
317 .project(vec![col("a")])?
318 .build()?;
319
320 let expected = "\
321 Projection: test.a\
322 \n Limit: skip=0, fetch=1\
323 \n Filter: test.c\
324 \n Filter: NOT test.b\
325 \n TableScan: test";
326
327 assert_optimized_plan_eq(plan, expected)
328 }
329
330 #[test]
331 fn test_simplify_optimized_plan_and_expr() -> Result<()> {
332 let table_scan = test_table_scan();
333 let plan = LogicalPlanBuilder::from(table_scan)
334 .filter(col("b").not_eq(lit(true)).and(col("c").eq(lit(true))))?
335 .project(vec![col("a")])?
336 .build()?;
337
338 let expected = "\
339 Projection: test.a\
340 \n Filter: NOT test.b AND test.c\
341 \n TableScan: test";
342
343 assert_optimized_plan_eq(plan, expected)
344 }
345
346 #[test]
347 fn test_simplify_optimized_plan_or_expr() -> Result<()> {
348 let table_scan = test_table_scan();
349 let plan = LogicalPlanBuilder::from(table_scan)
350 .filter(col("b").not_eq(lit(true)).or(col("c").eq(lit(false))))?
351 .project(vec![col("a")])?
352 .build()?;
353
354 let expected = "\
355 Projection: test.a\
356 \n Filter: NOT test.b OR NOT test.c\
357 \n TableScan: test";
358
359 assert_optimized_plan_eq(plan, expected)
360 }
361
362 #[test]
363 fn test_simplify_optimized_plan_not_expr() -> Result<()> {
364 let table_scan = test_table_scan();
365 let plan = LogicalPlanBuilder::from(table_scan)
366 .filter(col("b").eq(lit(false)).not())?
367 .project(vec![col("a")])?
368 .build()?;
369
370 let expected = "\
371 Projection: test.a\
372 \n Filter: test.b\
373 \n TableScan: test";
374
375 assert_optimized_plan_eq(plan, expected)
376 }
377
378 #[test]
379 fn test_simplify_optimized_plan_support_projection() -> Result<()> {
380 let table_scan = test_table_scan();
381 let plan = LogicalPlanBuilder::from(table_scan)
382 .project(vec![col("a"), col("d"), col("b").eq(lit(false))])?
383 .build()?;
384
385 let expected = "\
386 Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\
387 \n TableScan: test";
388
389 assert_optimized_plan_eq(plan, expected)
390 }
391
392 #[test]
393 fn test_simplify_optimized_plan_support_aggregate() -> Result<()> {
394 let table_scan = test_table_scan();
395 let plan = LogicalPlanBuilder::from(table_scan)
396 .project(vec![col("a"), col("c"), col("b")])?
397 .aggregate(
398 vec![col("a"), col("c")],
399 vec![max(col("b").eq(lit(true))), min(col("b"))],
400 )?
401 .build()?;
402
403 let expected = "\
404 Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\
405 \n Projection: test.a, test.c, test.b\
406 \n TableScan: test";
407
408 assert_optimized_plan_eq(plan, expected)
409 }
410
411 #[test]
412 fn test_simplify_optimized_plan_support_values() -> Result<()> {
413 let expr1 = Expr::BinaryExpr(BinaryExpr::new(
414 Box::new(lit(1)),
415 Operator::Plus,
416 Box::new(lit(2)),
417 ));
418 let expr2 = Expr::BinaryExpr(BinaryExpr::new(
419 Box::new(lit(2)),
420 Operator::Minus,
421 Box::new(lit(1)),
422 ));
423 let values = vec![vec![expr1, expr2]];
424 let plan = LogicalPlanBuilder::values(values)?.build()?;
425
426 let expected = "\
427 Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))";
428
429 assert_optimized_plan_eq(plan, expected)
430 }
431
432 fn get_optimized_plan_formatted(
433 plan: LogicalPlan,
434 date_time: &DateTime<Utc>,
435 ) -> String {
436 let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
437 let rule = SimplifyExpressions::new();
438
439 let optimized_plan = rule.rewrite(plan, &config).unwrap().data;
440 format!("{optimized_plan}")
441 }
442
443 #[test]
444 fn cast_expr() -> Result<()> {
445 let table_scan = test_table_scan();
446 let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))];
447 let plan = LogicalPlanBuilder::from(table_scan)
448 .project(proj)?
449 .build()?;
450
451 let expected = "Projection: Int32(0) AS Utf8(\"0\")\
452 \n TableScan: test";
453 let actual = get_optimized_plan_formatted(plan, &Utc::now());
454 assert_eq!(expected, actual);
455 Ok(())
456 }
457
458 #[test]
459 fn simplify_and_eval() -> Result<()> {
460 let table_scan = test_table_scan();
463 let time = Utc::now();
464 let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
466 let plan = LogicalPlanBuilder::from(table_scan)
467 .project(proj)?
468 .build()?;
469
470 let actual = get_optimized_plan_formatted(plan, &time);
471 let expected =
472 "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\
473 \n TableScan: test";
474
475 assert_eq!(expected, actual);
476 Ok(())
477 }
478
479 #[test]
480 fn simplify_not_binary() -> Result<()> {
481 let table_scan = test_table_scan();
482
483 let plan = LogicalPlanBuilder::from(table_scan)
484 .filter(col("d").gt(lit(10)).not())?
485 .build()?;
486 let expected = "Filter: test.d <= Int32(10)\
487 \n TableScan: test";
488
489 assert_optimized_plan_eq(plan, expected)
490 }
491
492 #[test]
493 fn simplify_not_bool_and() -> Result<()> {
494 let table_scan = test_table_scan();
495
496 let plan = LogicalPlanBuilder::from(table_scan)
497 .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())?
498 .build()?;
499 let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\
500 \n TableScan: test";
501
502 assert_optimized_plan_eq(plan, expected)
503 }
504
505 #[test]
506 fn simplify_not_bool_or() -> Result<()> {
507 let table_scan = test_table_scan();
508
509 let plan = LogicalPlanBuilder::from(table_scan)
510 .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())?
511 .build()?;
512 let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\
513 \n TableScan: test";
514
515 assert_optimized_plan_eq(plan, expected)
516 }
517
518 #[test]
519 fn simplify_not_not() -> Result<()> {
520 let table_scan = test_table_scan();
521
522 let plan = LogicalPlanBuilder::from(table_scan)
523 .filter(col("d").gt(lit(10)).not().not())?
524 .build()?;
525 let expected = "Filter: test.d > Int32(10)\
526 \n TableScan: test";
527
528 assert_optimized_plan_eq(plan, expected)
529 }
530
531 #[test]
532 fn simplify_not_null() -> Result<()> {
533 let table_scan = test_table_scan();
534
535 let plan = LogicalPlanBuilder::from(table_scan)
536 .filter(col("e").is_null().not())?
537 .build()?;
538 let expected = "Filter: test.e IS NOT NULL\
539 \n TableScan: test";
540
541 assert_optimized_plan_eq(plan, expected)
542 }
543
544 #[test]
545 fn simplify_not_not_null() -> Result<()> {
546 let table_scan = test_table_scan();
547
548 let plan = LogicalPlanBuilder::from(table_scan)
549 .filter(col("e").is_not_null().not())?
550 .build()?;
551 let expected = "Filter: test.e IS NULL\
552 \n TableScan: test";
553
554 assert_optimized_plan_eq(plan, expected)
555 }
556
557 #[test]
558 fn simplify_not_in() -> Result<()> {
559 let table_scan = test_table_scan();
560
561 let plan = LogicalPlanBuilder::from(table_scan)
562 .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())?
563 .build()?;
564 let expected =
565 "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\
566 \n TableScan: test";
567
568 assert_optimized_plan_eq(plan, expected)
569 }
570
571 #[test]
572 fn simplify_not_not_in() -> Result<()> {
573 let table_scan = test_table_scan();
574
575 let plan = LogicalPlanBuilder::from(table_scan)
576 .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())?
577 .build()?;
578 let expected =
579 "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\
580 \n TableScan: test";
581
582 assert_optimized_plan_eq(plan, expected)
583 }
584
585 #[test]
586 fn simplify_not_between() -> Result<()> {
587 let table_scan = test_table_scan();
588 let qual = col("d").between(lit(1), lit(10));
589
590 let plan = LogicalPlanBuilder::from(table_scan)
591 .filter(qual.not())?
592 .build()?;
593 let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\
594 \n TableScan: test";
595
596 assert_optimized_plan_eq(plan, expected)
597 }
598
599 #[test]
600 fn simplify_not_not_between() -> Result<()> {
601 let table_scan = test_table_scan();
602 let qual = col("d").not_between(lit(1), lit(10));
603
604 let plan = LogicalPlanBuilder::from(table_scan)
605 .filter(qual.not())?
606 .build()?;
607 let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\
608 \n TableScan: test";
609
610 assert_optimized_plan_eq(plan, expected)
611 }
612
613 #[test]
614 fn simplify_not_like() -> Result<()> {
615 let schema = Schema::new(vec![
616 Field::new("a", DataType::Utf8, false),
617 Field::new("b", DataType::Utf8, false),
618 ]);
619 let table_scan = table_scan(Some("test"), &schema, None)
620 .expect("creating scan")
621 .build()
622 .expect("building plan");
623
624 let plan = LogicalPlanBuilder::from(table_scan)
625 .filter(col("a").like(col("b")).not())?
626 .build()?;
627 let expected = "Filter: test.a NOT LIKE test.b\
628 \n TableScan: test";
629
630 assert_optimized_plan_eq(plan, expected)
631 }
632
633 #[test]
634 fn simplify_not_not_like() -> Result<()> {
635 let schema = Schema::new(vec![
636 Field::new("a", DataType::Utf8, false),
637 Field::new("b", DataType::Utf8, false),
638 ]);
639 let table_scan = table_scan(Some("test"), &schema, None)
640 .expect("creating scan")
641 .build()
642 .expect("building plan");
643
644 let plan = LogicalPlanBuilder::from(table_scan)
645 .filter(col("a").not_like(col("b")).not())?
646 .build()?;
647 let expected = "Filter: test.a LIKE test.b\
648 \n TableScan: test";
649
650 assert_optimized_plan_eq(plan, expected)
651 }
652
653 #[test]
654 fn simplify_not_ilike() -> Result<()> {
655 let schema = Schema::new(vec![
656 Field::new("a", DataType::Utf8, false),
657 Field::new("b", DataType::Utf8, false),
658 ]);
659 let table_scan = table_scan(Some("test"), &schema, None)
660 .expect("creating scan")
661 .build()
662 .expect("building plan");
663
664 let plan = LogicalPlanBuilder::from(table_scan)
665 .filter(col("a").ilike(col("b")).not())?
666 .build()?;
667 let expected = "Filter: test.a NOT ILIKE test.b\
668 \n TableScan: test";
669
670 assert_optimized_plan_eq(plan, expected)
671 }
672
673 #[test]
674 fn simplify_not_distinct_from() -> Result<()> {
675 let table_scan = test_table_scan();
676
677 let plan = LogicalPlanBuilder::from(table_scan)
678 .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())?
679 .build()?;
680 let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\
681 \n TableScan: test";
682
683 assert_optimized_plan_eq(plan, expected)
684 }
685
686 #[test]
687 fn simplify_not_not_distinct_from() -> Result<()> {
688 let table_scan = test_table_scan();
689
690 let plan = LogicalPlanBuilder::from(table_scan)
691 .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())?
692 .build()?;
693 let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\
694 \n TableScan: test";
695
696 assert_optimized_plan_eq(plan, expected)
697 }
698
699 #[test]
700 fn simplify_equijoin_predicate() -> Result<()> {
701 let t1 = test_table_scan_with_name("t1")?;
702 let t2 = test_table_scan_with_name("t2")?;
703
704 let left_key = col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, t1.schema())?;
705 let right_key =
706 col("t2.a") + lit(2i64).cast_to(&DataType::UInt32, t2.schema())?;
707 let plan = LogicalPlanBuilder::from(t1)
708 .join_with_expr_keys(
709 t2,
710 JoinType::Inner,
711 (vec![left_key], vec![right_key]),
712 None,
713 )?
714 .build()?;
715
716 let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\
719 \n TableScan: t1\
720 \n TableScan: t2";
721
722 assert_optimized_plan_eq(plan, expected)
723 }
724
725 #[test]
726 fn simplify_is_not_null() -> Result<()> {
727 let table_scan = test_table_scan();
728
729 let plan = LogicalPlanBuilder::from(table_scan)
730 .filter(col("d").is_not_null())?
731 .build()?;
732 let expected = "Filter: Boolean(true)\
733 \n TableScan: test";
734
735 assert_optimized_plan_eq(plan, expected)
736 }
737
738 #[test]
739 fn simplify_is_null() -> Result<()> {
740 let table_scan = test_table_scan();
741
742 let plan = LogicalPlanBuilder::from(table_scan)
743 .filter(col("d").is_null())?
744 .build()?;
745 let expected = "Filter: Boolean(false)\
746 \n TableScan: test";
747
748 assert_optimized_plan_eq(plan, expected)
749 }
750
751 #[test]
752 fn simplify_grouping_sets() -> Result<()> {
753 let table_scan = test_table_scan();
754 let plan = LogicalPlanBuilder::from(table_scan)
755 .aggregate(
756 [grouping_set(vec![
757 vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
758 vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
759 vec![col("d").alias("e"), (lit(1) + lit(2))],
760 ])],
761 [] as [Expr; 0],
762 )?
763 .build()?;
764
765 let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
766 \n TableScan: test";
767
768 assert_optimized_plan_eq(plan, expected)
769 }
770}