1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{
38 BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col,
39};
40
41const CSE_PREFIX: &str = "__common_expr";
42
43#[derive(Debug)]
70pub struct CommonSubexprEliminate {}
71
72impl CommonSubexprEliminate {
73 pub fn new() -> Self {
74 Self {}
75 }
76
77 fn try_optimize_proj(
78 &self,
79 projection: Projection,
80 config: &dyn OptimizerConfig,
81 ) -> Result<Transformed<LogicalPlan>> {
82 let Projection {
83 expr,
84 input,
85 schema,
86 ..
87 } = projection;
88 let input = Arc::unwrap_or_clone(input);
89 self.try_unary_plan(expr, input, config)?
90 .map_data(|(new_expr, new_input)| {
91 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
92 .map(LogicalPlan::Projection)
93 })
94 }
95
96 fn try_optimize_sort(
97 &self,
98 sort: Sort,
99 config: &dyn OptimizerConfig,
100 ) -> Result<Transformed<LogicalPlan>> {
101 let Sort { expr, input, fetch } = sort;
102 let input = Arc::unwrap_or_clone(input);
103 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
104 .into_iter()
105 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
106 .unzip();
107 let new_sort = self
108 .try_unary_plan(sort_expressions, input, config)?
109 .update_data(|(new_expr, new_input)| {
110 LogicalPlan::Sort(Sort {
111 expr: new_expr
112 .into_iter()
113 .zip(sort_params)
114 .map(|(expr, (asc, nulls_first))| SortExpr {
115 expr,
116 asc,
117 nulls_first,
118 })
119 .collect(),
120 input: Arc::new(new_input),
121 fetch,
122 })
123 });
124 Ok(new_sort)
125 }
126
127 fn try_optimize_filter(
128 &self,
129 filter: Filter,
130 config: &dyn OptimizerConfig,
131 ) -> Result<Transformed<LogicalPlan>> {
132 let Filter {
133 predicate, input, ..
134 } = filter;
135 let input = Arc::unwrap_or_clone(input);
136 let expr = vec![predicate];
137 self.try_unary_plan(expr, input, config)?
138 .map_data(|(mut new_expr, new_input)| {
139 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
141 Filter::try_new(new_predicate, Arc::new(new_input))
142 .map(LogicalPlan::Filter)
143 })
144 }
145
146 fn try_optimize_window(
147 &self,
148 window: Window,
149 config: &dyn OptimizerConfig,
150 ) -> Result<Transformed<LogicalPlan>> {
151 let (window_expr_list, window_schemas, input) =
154 get_consecutive_window_exprs(window);
155
156 match CSE::new(ExprCSEController::new(
159 config.alias_generator().as_ref(),
160 ExprMask::Normal,
161 ))
162 .extract_common_nodes(window_expr_list)?
163 {
164 FoundCommonNodes::Yes {
168 common_nodes: common_exprs,
169 new_nodes_list: new_exprs_list,
170 original_nodes_list: original_exprs_list,
171 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
172 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
173 }),
174 FoundCommonNodes::No {
175 original_nodes_list: original_exprs_list,
176 } => Ok(Transformed::no((original_exprs_list, input, None))),
177 }?
178 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
181 self.rewrite(new_input, config)?.map_data(|new_input| {
182 Ok((new_window_expr_list, new_input, window_expr_list))
183 })
184 })?
185 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
187 if let Some(window_expr_list) = window_expr_list {
196 let name_preserver = NamePreserver::new_for_projection();
197 let saved_names = window_expr_list
198 .iter()
199 .map(|exprs| {
200 exprs
201 .iter()
202 .map(|expr| name_preserver.save(expr))
203 .collect::<Vec<_>>()
204 })
205 .collect::<Vec<_>>();
206 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
207 new_input,
208 |plan, (new_window_expr, saved_names)| {
209 let new_window_expr = new_window_expr
210 .into_iter()
211 .zip(saved_names)
212 .map(|(new_window_expr, saved_name)| {
213 saved_name.restore(new_window_expr)
214 })
215 .collect::<Vec<_>>();
216 Window::try_new(new_window_expr, Arc::new(plan))
217 .map(LogicalPlan::Window)
218 },
219 )
220 } else {
221 new_window_expr_list
222 .into_iter()
223 .zip(window_schemas)
224 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
225 Window::try_new_with_schema(
226 new_window_expr,
227 Arc::new(plan),
228 schema,
229 )
230 .map(LogicalPlan::Window)
231 })
232 }
233 })
234 }
235
236 fn try_optimize_aggregate(
237 &self,
238 aggregate: Aggregate,
239 config: &dyn OptimizerConfig,
240 ) -> Result<Transformed<LogicalPlan>> {
241 let Aggregate {
242 group_expr,
243 aggr_expr,
244 input,
245 schema,
246 ..
247 } = aggregate;
248 let input = Arc::unwrap_or_clone(input);
249 match CSE::new(ExprCSEController::new(
251 config.alias_generator().as_ref(),
252 ExprMask::Normal,
253 ))
254 .extract_common_nodes(vec![group_expr, aggr_expr])?
255 {
256 FoundCommonNodes::Yes {
260 common_nodes: common_exprs,
261 new_nodes_list: mut new_exprs_list,
262 original_nodes_list: mut original_exprs_list,
263 } => {
264 let new_aggr_expr = new_exprs_list.pop().unwrap();
265 let new_group_expr = new_exprs_list.pop().unwrap();
266
267 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
268 let aggr_expr = original_exprs_list.pop().unwrap();
269 Transformed::yes((
270 new_aggr_expr,
271 new_group_expr,
272 new_input,
273 Some(aggr_expr),
274 ))
275 })
276 }
277
278 FoundCommonNodes::No {
279 original_nodes_list: mut original_exprs_list,
280 } => {
281 let new_aggr_expr = original_exprs_list.pop().unwrap();
282 let new_group_expr = original_exprs_list.pop().unwrap();
283
284 Ok(Transformed::no((
285 new_aggr_expr,
286 new_group_expr,
287 input,
288 None,
289 )))
290 }
291 }?
292 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
295 self.rewrite(new_input, config)?.map_data(|new_input| {
296 Ok((
297 new_aggr_expr,
298 new_group_expr,
299 aggr_expr,
300 Arc::new(new_input),
301 ))
302 })
303 })?
304 .transform_data(
306 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
307 match CSE::new(ExprCSEController::new(
309 config.alias_generator().as_ref(),
310 ExprMask::NormalAndAggregates,
311 ))
312 .extract_common_nodes(vec![new_aggr_expr])?
313 {
314 FoundCommonNodes::Yes {
315 common_nodes: common_exprs,
316 new_nodes_list: mut new_exprs_list,
317 original_nodes_list: mut original_exprs_list,
318 } => {
319 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
320 let new_aggr_expr = original_exprs_list.pop().unwrap();
321 let saved_names = if let Some(aggr_expr) = aggr_expr {
322 let name_preserver = NamePreserver::new_for_projection();
323 aggr_expr
324 .iter()
325 .map(|expr| Some(name_preserver.save(expr)))
326 .collect::<Vec<_>>()
327 } else {
328 new_aggr_expr
329 .clone()
330 .into_iter()
331 .map(|_| None)
332 .collect::<Vec<_>>()
333 };
334
335 let mut agg_exprs = common_exprs
336 .into_iter()
337 .map(|(expr, expr_alias)| expr.alias(expr_alias))
338 .collect::<Vec<_>>();
339
340 let mut proj_exprs = vec![];
341 for expr in &new_group_expr {
342 extract_expressions(expr, &mut proj_exprs)
343 }
344 for ((expr_rewritten, expr_orig), saved_name) in
345 rewritten_aggr_expr
346 .into_iter()
347 .zip(new_aggr_expr)
348 .zip(saved_names)
349 {
350 if expr_rewritten == expr_orig {
351 let expr_rewritten = if let Some(saved_name) = saved_name
352 {
353 saved_name.restore(expr_rewritten)
354 } else {
355 expr_rewritten
356 };
357 if let Expr::Alias(Alias { expr, name, .. }) =
358 expr_rewritten
359 {
360 agg_exprs.push(expr.alias(&name));
361 proj_exprs
362 .push(Expr::Column(Column::from_name(name)));
363 } else {
364 let expr_alias =
365 config.alias_generator().next(CSE_PREFIX);
366 let (qualifier, field_name) =
367 expr_rewritten.qualified_name();
368 let out_name =
369 qualified_name(qualifier.as_ref(), &field_name);
370
371 agg_exprs.push(expr_rewritten.alias(&expr_alias));
372 proj_exprs.push(
373 Expr::Column(Column::from_name(expr_alias))
374 .alias(out_name),
375 );
376 }
377 } else {
378 proj_exprs.push(expr_rewritten);
379 }
380 }
381
382 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
383 new_input,
384 new_group_expr,
385 agg_exprs,
386 )?);
387 Projection::try_new(proj_exprs, Arc::new(agg))
388 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
389 }
390
391 FoundCommonNodes::No {
394 original_nodes_list: mut original_exprs_list,
395 } => {
396 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
397
398 if let Some(aggr_expr) = aggr_expr {
409 let name_preserver = NamePreserver::new_for_projection();
410 let saved_names = aggr_expr
411 .iter()
412 .map(|expr| name_preserver.save(expr))
413 .collect::<Vec<_>>();
414 let new_aggr_expr = rewritten_aggr_expr
415 .into_iter()
416 .zip(saved_names)
417 .map(|(new_expr, saved_name)| {
418 saved_name.restore(new_expr)
419 })
420 .collect::<Vec<Expr>>();
421
422 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
425 .map(LogicalPlan::Aggregate)
426 .map(Transformed::no)
427 } else {
428 Aggregate::try_new_with_schema(
429 new_input,
430 new_group_expr,
431 rewritten_aggr_expr,
432 schema,
433 )
434 .map(LogicalPlan::Aggregate)
435 .map(Transformed::no)
436 }
437 }
438 }
439 },
440 )
441 }
442
443 fn try_unary_plan(
458 &self,
459 exprs: Vec<Expr>,
460 input: LogicalPlan,
461 config: &dyn OptimizerConfig,
462 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
463 match CSE::new(ExprCSEController::new(
465 config.alias_generator().as_ref(),
466 ExprMask::Normal,
467 ))
468 .extract_common_nodes(vec![exprs])?
469 {
470 FoundCommonNodes::Yes {
471 common_nodes: common_exprs,
472 new_nodes_list: mut new_exprs_list,
473 original_nodes_list: _,
474 } => {
475 let new_exprs = new_exprs_list.pop().unwrap();
476 build_common_expr_project_plan(input, common_exprs)
477 .map(|new_input| Transformed::yes((new_exprs, new_input)))
478 }
479 FoundCommonNodes::No {
480 original_nodes_list: mut original_exprs_list,
481 } => {
482 let new_exprs = original_exprs_list.pop().unwrap();
483 Ok(Transformed::no((new_exprs, input)))
484 }
485 }?
486 .transform_data(|(new_exprs, new_input)| {
489 self.rewrite(new_input, config)?
490 .map_data(|new_input| Ok((new_exprs, new_input)))
491 })
492 }
493}
494
495fn get_consecutive_window_exprs(
527 window: Window,
528) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
529 let mut window_expr_list = vec![];
530 let mut window_schemas = vec![];
531 let mut plan = LogicalPlan::Window(window);
532 while let LogicalPlan::Window(Window {
533 input,
534 window_expr,
535 schema,
536 }) = plan
537 {
538 window_expr_list.push(window_expr);
539 window_schemas.push(schema);
540
541 plan = Arc::unwrap_or_clone(input);
542 }
543 (window_expr_list, window_schemas, plan)
544}
545
546impl OptimizerRule for CommonSubexprEliminate {
547 fn supports_rewrite(&self) -> bool {
548 true
549 }
550
551 fn apply_order(&self) -> Option<ApplyOrder> {
552 None
556 }
557
558 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
559 fn rewrite(
560 &self,
561 plan: LogicalPlan,
562 config: &dyn OptimizerConfig,
563 ) -> Result<Transformed<LogicalPlan>> {
564 let original_schema = Arc::clone(plan.schema());
565
566 let optimized_plan = match plan {
567 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
568 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
569 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
570 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
571 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
572 LogicalPlan::Join(_)
573 | LogicalPlan::Repartition(_)
574 | LogicalPlan::Union(_)
575 | LogicalPlan::TableScan(_)
576 | LogicalPlan::Values(_)
577 | LogicalPlan::EmptyRelation(_)
578 | LogicalPlan::Subquery(_)
579 | LogicalPlan::SubqueryAlias(_)
580 | LogicalPlan::Limit(_)
581 | LogicalPlan::Ddl(_)
582 | LogicalPlan::Explain(_)
583 | LogicalPlan::Analyze(_)
584 | LogicalPlan::Statement(_)
585 | LogicalPlan::DescribeTable(_)
586 | LogicalPlan::Distinct(_)
587 | LogicalPlan::Extension(_)
588 | LogicalPlan::Dml(_)
589 | LogicalPlan::Copy(_)
590 | LogicalPlan::Unnest(_)
591 | LogicalPlan::RecursiveQuery(_) => {
592 plan.map_children(|c| self.rewrite(c, config))?
595 }
596 };
597
598 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
600 {
601 optimized_plan.map_data(|optimized_plan| {
602 build_recover_project_plan(&original_schema, optimized_plan)
603 })
604 } else {
605 Ok(optimized_plan)
606 }
607 }
608
609 fn name(&self) -> &str {
610 "common_sub_expression_eliminate"
611 }
612}
613
614#[derive(Debug, Clone, Copy)]
616enum ExprMask {
617 Normal,
626
627 NormalAndAggregates,
629}
630
631struct ExprCSEController<'a> {
632 alias_generator: &'a AliasGenerator,
633 mask: ExprMask,
634
635 alias_counter: usize,
637}
638
639impl<'a> ExprCSEController<'a> {
640 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
641 Self {
642 alias_generator,
643 mask,
644 alias_counter: 0,
645 }
646 }
647}
648
649impl CSEController for ExprCSEController<'_> {
650 type Node = Expr;
651
652 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653 match node {
654 Expr::ScalarFunction(ScalarFunction { func, args }) => {
658 func.conditional_arguments(args)
659 }
660
661 Expr::BinaryExpr(BinaryExpr {
664 left,
665 op: Operator::And | Operator::Or,
666 right,
667 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
668
669 Expr::Case(Case {
673 expr,
674 when_then_expr,
675 else_expr,
676 }) => Some((
677 expr.iter()
678 .map(|e| e.as_ref())
679 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
680 .collect(),
681 when_then_expr
682 .iter()
683 .take(1)
684 .map(|(_, then)| then.as_ref())
685 .chain(
686 when_then_expr
687 .iter()
688 .skip(1)
689 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
690 )
691 .chain(else_expr.iter().map(|e| e.as_ref()))
692 .collect(),
693 )),
694 _ => None,
695 }
696 }
697
698 fn is_valid(node: &Expr) -> bool {
699 !node.is_volatile_node()
700 }
701
702 fn is_ignored(&self, node: &Expr) -> bool {
703 if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes {
712 return true;
713 }
714
715 #[expect(deprecated)]
717 let is_normal_minus_aggregates = matches!(
718 node,
719 Expr::Literal(..)
725 | Expr::Column(..)
726 | Expr::ScalarVariable(..)
727 | Expr::Alias(..)
728 | Expr::Wildcard { .. }
729 );
730
731 let is_aggr = matches!(node, Expr::AggregateFunction(..));
732
733 match self.mask {
734 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
735 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
736 }
737 }
738
739 fn generate_alias(&self) -> String {
740 self.alias_generator.next(CSE_PREFIX)
741 }
742
743 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
744 if self.alias_counter > 0 {
746 col(alias)
747 } else {
748 self.alias_counter += 1;
749 col(alias).alias(node.schema_name().to_string())
750 }
751 }
752
753 fn rewrite_f_down(&mut self, node: &Expr) {
754 if matches!(node, Expr::Alias(_)) {
755 self.alias_counter += 1;
756 }
757 }
758 fn rewrite_f_up(&mut self, node: &Expr) {
759 if matches!(node, Expr::Alias(_)) {
760 self.alias_counter -= 1
761 }
762 }
763}
764
765impl Default for CommonSubexprEliminate {
766 fn default() -> Self {
767 Self::new()
768 }
769}
770
771fn build_common_expr_project_plan(
782 input: LogicalPlan,
783 common_exprs: Vec<(Expr, String)>,
784) -> Result<LogicalPlan> {
785 let mut fields_set = BTreeSet::new();
786 let mut project_exprs = common_exprs
787 .into_iter()
788 .map(|(expr, expr_alias)| {
789 fields_set.insert(expr_alias.clone());
790 Ok(expr.alias(expr_alias))
791 })
792 .collect::<Result<Vec<_>>>()?;
793
794 for (qualifier, field) in input.schema().iter() {
795 if fields_set.insert(qualified_name(qualifier, field.name())) {
796 project_exprs.push(Expr::from((qualifier, field)));
797 }
798 }
799
800 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
801}
802
803fn build_recover_project_plan(
809 schema: &DFSchema,
810 input: LogicalPlan,
811) -> Result<LogicalPlan> {
812 let col_exprs = schema.iter().map(Expr::from).collect();
813 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
814}
815
816fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
817 if let Expr::GroupingSet(groupings) = expr {
818 for e in groupings.distinct_expr() {
819 let (qualifier, field_name) = e.qualified_name();
820 let col = Column::new(qualifier, field_name);
821 result.push(Expr::Column(col))
822 }
823 } else {
824 let (qualifier, field_name) = expr.qualified_name();
825 let col = Column::new(qualifier, field_name);
826 result.push(Expr::Column(col));
827 }
828}
829
830#[cfg(test)]
831mod test {
832 use std::any::Any;
833 use std::iter;
834
835 use arrow::datatypes::{DataType, Field, Schema};
836 use datafusion_expr::logical_plan::{JoinType, table_scan};
837 use datafusion_expr::{
838 AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs,
839 ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
840 grouping_set, is_null, not,
841 };
842 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
843
844 use super::*;
845 use crate::assert_optimized_plan_eq_snapshot;
846 use crate::optimizer::OptimizerContext;
847 use crate::test::udfs::leaf_udf_expr;
848 use crate::test::*;
849 use datafusion_expr::test::function_stub::{avg, sum};
850
851 macro_rules! assert_optimized_plan_equal {
852 (
853 $config:expr,
854 $plan:expr,
855 @ $expected:literal $(,)?
856 ) => {{
857 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
858 assert_optimized_plan_eq_snapshot!(
859 $config,
860 rules,
861 $plan,
862 @ $expected,
863 )
864 }};
865
866 (
867 $plan:expr,
868 @ $expected:literal $(,)?
869 ) => {{
870 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
871 let optimizer_ctx = OptimizerContext::new();
872 assert_optimized_plan_eq_snapshot!(
873 optimizer_ctx,
874 rules,
875 $plan,
876 @ $expected,
877 )
878 }};
879 }
880
881 #[test]
882 fn tpch_q1_simplified() -> Result<()> {
883 let table_scan = test_table_scan()?;
892
893 let plan = LogicalPlanBuilder::from(table_scan)
894 .aggregate(
895 iter::empty::<Expr>(),
896 vec![
897 sum(col("a") * (lit(1) - col("b"))),
898 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
899 ],
900 )?
901 .build()?;
902
903 assert_optimized_plan_equal!(
904 plan,
905 @ r"
906 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
907 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
908 TableScan: test
909 "
910 )
911 }
912
913 #[test]
914 fn nested_aliases() -> Result<()> {
915 let table_scan = test_table_scan()?;
916
917 let plan = LogicalPlanBuilder::from(table_scan)
918 .project(vec![
919 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
920 col("a") + col("b"),
921 ])?
922 .build()?;
923
924 assert_optimized_plan_equal!(
925 plan,
926 @ r"
927 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
928 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
929 TableScan: test
930 "
931 )
932 }
933
934 #[test]
935 fn aggregate() -> Result<()> {
936 let table_scan = test_table_scan()?;
937
938 let return_type = DataType::UInt32;
939 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
940 let udf_agg = |inner: Expr| {
941 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
942 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
943 "my_agg",
944 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
945 return_type.clone(),
946 Arc::clone(&accumulator),
947 vec![Field::new("value", DataType::UInt32, true).into()],
948 ))),
949 vec![inner],
950 false,
951 None,
952 vec![],
953 None,
954 ))
955 };
956
957 let plan = LogicalPlanBuilder::from(table_scan.clone())
959 .aggregate(
960 iter::empty::<Expr>(),
961 vec![
962 avg(col("a")).alias("col1"),
964 avg(col("a")).alias("col2"),
965 avg(col("b")).alias("col3"),
967 avg(col("c")),
968 udf_agg(col("a")).alias("col4"),
970 udf_agg(col("a")).alias("col5"),
971 udf_agg(col("b")).alias("col6"),
973 udf_agg(col("c")),
974 ],
975 )?
976 .build()?;
977
978 assert_optimized_plan_equal!(
979 plan,
980 @ r"
981 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
982 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
983 TableScan: test
984 "
985 )?;
986
987 let plan = LogicalPlanBuilder::from(table_scan.clone())
989 .aggregate(
990 iter::empty::<Expr>(),
991 vec![
992 lit(1) + avg(col("a")),
993 lit(1) - avg(col("a")),
994 lit(1) + udf_agg(col("a")),
995 lit(1) - udf_agg(col("a")),
996 ],
997 )?
998 .build()?;
999
1000 assert_optimized_plan_equal!(
1001 plan,
1002 @ r"
1003 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
1004 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
1005 TableScan: test
1006 "
1007 )?;
1008
1009 let plan = LogicalPlanBuilder::from(table_scan.clone())
1011 .aggregate(
1012 iter::empty::<Expr>(),
1013 vec![
1014 avg(lit(1u32) + col("a")).alias("col1"),
1015 udf_agg(lit(1u32) + col("a")).alias("col2"),
1016 ],
1017 )?
1018 .build()?;
1019
1020 assert_optimized_plan_equal!(
1021 plan,
1022 @ r"
1023 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1024 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1025 TableScan: test
1026 "
1027 )?;
1028
1029 let plan = LogicalPlanBuilder::from(table_scan.clone())
1031 .aggregate(
1032 vec![lit(1u32) + col("a")],
1033 vec![
1034 avg(lit(1u32) + col("a")).alias("col1"),
1035 udf_agg(lit(1u32) + col("a")).alias("col2"),
1036 ],
1037 )?
1038 .build()?;
1039
1040 assert_optimized_plan_equal!(
1041 plan,
1042 @ r"
1043 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1044 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1045 TableScan: test
1046 "
1047 )?;
1048
1049 let plan = LogicalPlanBuilder::from(table_scan)
1051 .aggregate(
1052 vec![lit(1u32) + col("a")],
1053 vec![
1054 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1055 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1056 avg(lit(1u32) + col("a")),
1057 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1058 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1059 udf_agg(lit(1u32) + col("a")),
1060 ],
1061 )?
1062 .build()?;
1063
1064 assert_optimized_plan_equal!(
1065 plan,
1066 @ r"
1067 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1068 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1069 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1070 TableScan: test
1071 "
1072 )
1073 }
1074
1075 #[test]
1076 fn aggregate_with_relations_and_dots() -> Result<()> {
1077 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1078 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1079
1080 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1081
1082 let plan = LogicalPlanBuilder::from(table_scan)
1083 .aggregate(
1084 vec![col_a.clone()],
1085 vec![
1086 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1087 avg(lit(1u32) + col_a),
1088 ],
1089 )?
1090 .build()?;
1091
1092 assert_optimized_plan_equal!(
1093 plan,
1094 @ r"
1095 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1096 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1097 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1098 TableScan: table.test
1099 "
1100 )
1101 }
1102
1103 #[test]
1104 fn subexpr_in_same_order() -> Result<()> {
1105 let table_scan = test_table_scan()?;
1106
1107 let plan = LogicalPlanBuilder::from(table_scan)
1108 .project(vec![
1109 (lit(1) + col("a")).alias("first"),
1110 (lit(1) + col("a")).alias("second"),
1111 ])?
1112 .build()?;
1113
1114 assert_optimized_plan_equal!(
1115 plan,
1116 @ r"
1117 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1118 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1119 TableScan: test
1120 "
1121 )
1122 }
1123
1124 #[test]
1125 fn subexpr_in_different_order() -> Result<()> {
1126 let table_scan = test_table_scan()?;
1127
1128 let plan = LogicalPlanBuilder::from(table_scan)
1129 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1130 .build()?;
1131
1132 assert_optimized_plan_equal!(
1133 plan,
1134 @ r"
1135 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1136 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1137 TableScan: test
1138 "
1139 )
1140 }
1141
1142 #[test]
1143 fn cross_plans_subexpr() -> Result<()> {
1144 let table_scan = test_table_scan()?;
1145
1146 let plan = LogicalPlanBuilder::from(table_scan)
1147 .project(vec![lit(1) + col("a"), col("a")])?
1148 .project(vec![lit(1) + col("a")])?
1149 .build()?;
1150
1151 assert_optimized_plan_equal!(
1152 plan,
1153 @ r"
1154 Projection: Int32(1) + test.a
1155 Projection: Int32(1) + test.a, test.a
1156 TableScan: test
1157 "
1158 )
1159 }
1160
1161 #[test]
1162 fn redundant_project_fields() {
1163 let table_scan = test_table_scan().unwrap();
1164 let c_plus_a = col("c") + col("a");
1165 let b_plus_a = col("b") + col("a");
1166 let common_exprs_1 = vec![
1167 (c_plus_a, format!("{CSE_PREFIX}_1")),
1168 (b_plus_a, format!("{CSE_PREFIX}_2")),
1169 ];
1170 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1171 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1172 let common_exprs_2 = vec![
1173 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1174 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1175 ];
1176 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1177 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1178
1179 let mut field_set = BTreeSet::new();
1180 for name in project_2.schema().field_names() {
1181 assert!(field_set.insert(name));
1182 }
1183 }
1184
1185 #[test]
1186 fn redundant_project_fields_join_input() {
1187 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1188 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1189 let join = LogicalPlanBuilder::from(table_scan_1)
1190 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1191 .unwrap()
1192 .build()
1193 .unwrap();
1194 let c_plus_a = col("test1.c") + col("test1.a");
1195 let b_plus_a = col("test1.b") + col("test1.a");
1196 let common_exprs_1 = vec![
1197 (c_plus_a, format!("{CSE_PREFIX}_1")),
1198 (b_plus_a, format!("{CSE_PREFIX}_2")),
1199 ];
1200 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1201 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1202 let common_exprs_2 = vec![
1203 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1204 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1205 ];
1206 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1207 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1208
1209 let mut field_set = BTreeSet::new();
1210 for name in project_2.schema().field_names() {
1211 assert!(field_set.insert(name));
1212 }
1213 }
1214
1215 #[test]
1216 fn eliminated_subexpr_datatype() {
1217 use datafusion_expr::cast;
1218
1219 let schema = Schema::new(vec![
1220 Field::new("a", DataType::UInt64, false),
1221 Field::new("b", DataType::UInt64, false),
1222 Field::new("c", DataType::UInt64, false),
1223 ]);
1224
1225 let plan = table_scan(Some("table"), &schema, None)
1226 .unwrap()
1227 .filter(
1228 cast(col("a"), DataType::Int64)
1229 .lt(lit(1_i64))
1230 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1231 )
1232 .unwrap()
1233 .build()
1234 .unwrap();
1235 let rule = CommonSubexprEliminate::new();
1236 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1237 assert!(optimized_plan.transformed);
1238 let optimized_plan = optimized_plan.data;
1239
1240 let schema = optimized_plan.schema();
1241 let fields_with_datatypes: Vec<_> = schema
1242 .fields()
1243 .iter()
1244 .map(|field| (field.name(), field.data_type()))
1245 .collect();
1246 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1247 let expected = r#"[
1248 (
1249 "a",
1250 UInt64,
1251 ),
1252 (
1253 "b",
1254 UInt64,
1255 ),
1256 (
1257 "c",
1258 UInt64,
1259 ),
1260]"#;
1261 assert_eq!(expected, formatted_fields_with_datatype);
1262 }
1263
1264 #[test]
1265 fn filter_schema_changed() -> Result<()> {
1266 let table_scan = test_table_scan()?;
1267
1268 let plan = LogicalPlanBuilder::from(table_scan)
1269 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1270 .build()?;
1271
1272 assert_optimized_plan_equal!(
1273 plan,
1274 @ r"
1275 Projection: test.a, test.b, test.c
1276 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1277 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1278 TableScan: test
1279 "
1280 )
1281 }
1282
1283 #[test]
1284 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1285 let mut result = Vec::with_capacity(3);
1286 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1287 extract_expressions(&grouping, &mut result);
1288
1289 assert!(result.len() == 3);
1290 Ok(())
1291 }
1292
1293 #[test]
1294 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1295 let mut result = Vec::with_capacity(2);
1296 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1297 extract_expressions(&grouping, &mut result);
1298 assert!(result.len() == 2);
1299 Ok(())
1300 }
1301
1302 #[test]
1303 fn test_alias_collision() -> Result<()> {
1304 let table_scan = test_table_scan()?;
1305
1306 let config = OptimizerContext::new();
1307 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1308 let plan = LogicalPlanBuilder::from(table_scan.clone())
1309 .project(vec![
1310 (col("a") + col("b")).alias(common_expr_1.clone()),
1311 col("c"),
1312 ])?
1313 .project(vec![
1314 col(common_expr_1.clone()).alias("c1"),
1315 col(common_expr_1).alias("c2"),
1316 (col("c") + lit(2)).alias("c3"),
1317 (col("c") + lit(2)).alias("c4"),
1318 ])?
1319 .build()?;
1320
1321 assert_optimized_plan_equal!(
1322 config,
1323 plan,
1324 @ r"
1325 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1326 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1327 Projection: test.a + test.b AS __common_expr_1, test.c
1328 TableScan: test
1329 "
1330 )?;
1331
1332 let config = OptimizerContext::new();
1333 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1334 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1335 let plan = LogicalPlanBuilder::from(table_scan)
1336 .project(vec![
1337 (col("a") + col("b")).alias(common_expr_2.clone()),
1338 col("c"),
1339 ])?
1340 .project(vec![
1341 col(common_expr_2.clone()).alias("c1"),
1342 col(common_expr_2).alias("c2"),
1343 (col("c") + lit(2)).alias("c3"),
1344 (col("c") + lit(2)).alias("c4"),
1345 ])?
1346 .build()?;
1347
1348 assert_optimized_plan_equal!(
1349 config,
1350 plan,
1351 @ r"
1352 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1353 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1354 Projection: test.a + test.b AS __common_expr_2, test.c
1355 TableScan: test
1356 "
1357 )?;
1358
1359 Ok(())
1360 }
1361
1362 #[test]
1363 fn test_extract_expressions_from_col() -> Result<()> {
1364 let mut result = Vec::with_capacity(1);
1365 extract_expressions(&col("a"), &mut result);
1366 assert!(result.len() == 1);
1367 Ok(())
1368 }
1369
1370 #[test]
1371 fn test_short_circuits() -> Result<()> {
1372 let table_scan = test_table_scan()?;
1373
1374 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1375 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1376 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1377 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1378 let plan = LogicalPlanBuilder::from(table_scan)
1379 .project(vec![
1380 extracted_short_circuit.clone().alias("c1"),
1381 extracted_short_circuit.alias("c2"),
1382 extracted_short_circuit_leg_1
1383 .clone()
1384 .or(not_extracted_short_circuit_leg_2.clone())
1385 .alias("c3"),
1386 extracted_short_circuit_leg_1
1387 .and(not_extracted_short_circuit_leg_2)
1388 .alias("c4"),
1389 extracted_short_circuit_leg_3
1390 .clone()
1391 .or(extracted_short_circuit_leg_3)
1392 .alias("c5"),
1393 ])?
1394 .build()?;
1395
1396 assert_optimized_plan_equal!(
1397 plan,
1398 @ r"
1399 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1400 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1401 TableScan: test
1402 "
1403 )
1404 }
1405
1406 #[test]
1407 fn test_volatile() -> Result<()> {
1408 let table_scan = test_table_scan()?;
1409
1410 let extracted_child = col("a") + col("b");
1411 let rand = rand_func().call(vec![]);
1412 let not_extracted_volatile = extracted_child + rand;
1413 let plan = LogicalPlanBuilder::from(table_scan)
1414 .project(vec![
1415 not_extracted_volatile.clone().alias("c1"),
1416 not_extracted_volatile.alias("c2"),
1417 ])?
1418 .build()?;
1419
1420 assert_optimized_plan_equal!(
1421 plan,
1422 @ r"
1423 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1424 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1425 TableScan: test
1426 "
1427 )
1428 }
1429
1430 #[test]
1431 fn test_volatile_short_circuits() -> Result<()> {
1432 let table_scan = test_table_scan()?;
1433
1434 let rand = rand_func().call(vec![]);
1435 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1436 let not_extracted_volatile_short_circuit_1 =
1437 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1438 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1439 let not_extracted_volatile_short_circuit_2 =
1440 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1441 let plan = LogicalPlanBuilder::from(table_scan)
1442 .project(vec![
1443 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1444 not_extracted_volatile_short_circuit_1.alias("c2"),
1445 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1446 not_extracted_volatile_short_circuit_2.alias("c4"),
1447 ])?
1448 .build()?;
1449
1450 assert_optimized_plan_equal!(
1451 plan,
1452 @ r"
1453 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1454 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1455 TableScan: test
1456 "
1457 )
1458 }
1459
1460 #[test]
1461 fn test_non_top_level_common_expression() -> Result<()> {
1462 let table_scan = test_table_scan()?;
1463
1464 let common_expr = col("a") + col("b");
1465 let plan = LogicalPlanBuilder::from(table_scan)
1466 .project(vec![
1467 common_expr.clone().alias("c1"),
1468 common_expr.alias("c2"),
1469 ])?
1470 .project(vec![col("c1"), col("c2")])?
1471 .build()?;
1472
1473 assert_optimized_plan_equal!(
1474 plan,
1475 @ r"
1476 Projection: c1, c2
1477 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1478 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1479 TableScan: test
1480 "
1481 )
1482 }
1483
1484 #[test]
1485 fn test_nested_common_expression() -> Result<()> {
1486 let table_scan = test_table_scan()?;
1487
1488 let nested_common_expr = col("a") + col("b");
1489 let common_expr = nested_common_expr.clone() * nested_common_expr;
1490 let plan = LogicalPlanBuilder::from(table_scan)
1491 .project(vec![
1492 common_expr.clone().alias("c1"),
1493 common_expr.alias("c2"),
1494 ])?
1495 .build()?;
1496
1497 assert_optimized_plan_equal!(
1498 plan,
1499 @ r"
1500 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1501 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1502 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1503 TableScan: test
1504 "
1505 )
1506 }
1507
1508 #[test]
1509 fn test_normalize_add_expression() -> Result<()> {
1510 let table_scan = test_table_scan()?;
1512 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1513 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1514
1515 assert_optimized_plan_equal!(
1516 plan,
1517 @ r"
1518 Projection: test.a, test.b, test.c
1519 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1520 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1521 TableScan: test
1522 "
1523 )
1524 }
1525
1526 #[test]
1527 fn test_normalize_multi_expression() -> Result<()> {
1528 let table_scan = test_table_scan()?;
1530 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1531 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1532
1533 assert_optimized_plan_equal!(
1534 plan,
1535 @ r"
1536 Projection: test.a, test.b, test.c
1537 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1538 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1539 TableScan: test
1540 "
1541 )
1542 }
1543
1544 #[test]
1545 fn test_normalize_bitset_and_expression() -> Result<()> {
1546 let table_scan = test_table_scan()?;
1548 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1549 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1550
1551 assert_optimized_plan_equal!(
1552 plan,
1553 @ r"
1554 Projection: test.a, test.b, test.c
1555 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1556 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1557 TableScan: test
1558 "
1559 )
1560 }
1561
1562 #[test]
1563 fn test_normalize_bitset_or_expression() -> Result<()> {
1564 let table_scan = test_table_scan()?;
1566 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1567 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1568
1569 assert_optimized_plan_equal!(
1570 plan,
1571 @ r"
1572 Projection: test.a, test.b, test.c
1573 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1574 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1575 TableScan: test
1576 "
1577 )
1578 }
1579
1580 #[test]
1581 fn test_normalize_bitset_xor_expression() -> Result<()> {
1582 let table_scan = test_table_scan()?;
1584 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1585 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1586
1587 assert_optimized_plan_equal!(
1588 plan,
1589 @ r"
1590 Projection: test.a, test.b, test.c
1591 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1592 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1593 TableScan: test
1594 "
1595 )
1596 }
1597
1598 #[test]
1599 fn test_normalize_eq_expression() -> Result<()> {
1600 let table_scan = test_table_scan()?;
1602 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1603 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1604
1605 assert_optimized_plan_equal!(
1606 plan,
1607 @ r"
1608 Projection: test.a, test.b, test.c
1609 Filter: __common_expr_1 AND __common_expr_1
1610 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1611 TableScan: test
1612 "
1613 )
1614 }
1615
1616 #[test]
1617 fn test_normalize_ne_expression() -> Result<()> {
1618 let table_scan = test_table_scan()?;
1620 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1621 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1622
1623 assert_optimized_plan_equal!(
1624 plan,
1625 @ r"
1626 Projection: test.a, test.b, test.c
1627 Filter: __common_expr_1 AND __common_expr_1
1628 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1629 TableScan: test
1630 "
1631 )
1632 }
1633
1634 #[test]
1635 fn test_normalize_complex_expression() -> Result<()> {
1636 let table_scan = test_table_scan()?;
1638 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1639 .eq(lit(30));
1640 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1641
1642 assert_optimized_plan_equal!(
1643 plan,
1644 @ r"
1645 Projection: test.a, test.b, test.c
1646 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1647 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1648 TableScan: test
1649 "
1650 )?;
1651
1652 let table_scan = test_table_scan()?;
1654 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1655 / (col("c") * (col("b") / col("c") + col("a")))
1656 + col("a"))
1657 .eq(lit(30));
1658 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1659
1660 assert_optimized_plan_equal!(
1661 plan,
1662 @ r"
1663 Projection: test.a, test.b, test.c
1664 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1665 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1666 TableScan: test
1667 "
1668 )?;
1669
1670 let table_scan = test_table_scan()?;
1672 let expr = ((col("b") / (col("a") + col("c")))
1673 * (col("b") / (col("c") + col("a"))))
1674 .eq(lit(30));
1675 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1676 assert_optimized_plan_equal!(
1677 plan,
1678 @ r"
1679 Projection: test.a, test.b, test.c
1680 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1681 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1682 TableScan: test
1683 "
1684 )?;
1685
1686 Ok(())
1687 }
1688
1689 #[derive(Debug, PartialEq, Eq, Hash)]
1690 pub struct TestUdf {
1691 signature: Signature,
1692 }
1693
1694 impl TestUdf {
1695 pub fn new() -> Self {
1696 Self {
1697 signature: Signature::numeric(1, Volatility::Immutable),
1698 }
1699 }
1700 }
1701
1702 impl ScalarUDFImpl for TestUdf {
1703 fn as_any(&self) -> &dyn Any {
1704 self
1705 }
1706 fn name(&self) -> &str {
1707 "my_udf"
1708 }
1709
1710 fn signature(&self) -> &Signature {
1711 &self.signature
1712 }
1713
1714 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1715 Ok(DataType::Int32)
1716 }
1717
1718 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1719 panic!("not implemented")
1720 }
1721 }
1722
1723 #[test]
1724 fn test_normalize_inner_binary_expression() -> Result<()> {
1725 let table_scan = test_table_scan()?;
1727 let expr1 = not(col("a").eq(col("b")));
1728 let expr2 = not(col("b").eq(col("a")));
1729 let plan = LogicalPlanBuilder::from(table_scan)
1730 .project(vec![expr1, expr2])?
1731 .build()?;
1732 assert_optimized_plan_equal!(
1733 plan,
1734 @ r"
1735 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1736 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1737 TableScan: test
1738 "
1739 )?;
1740
1741 let table_scan = test_table_scan()?;
1743 let expr1 = is_null(col("a").eq(col("b")));
1744 let expr2 = is_null(col("b").eq(col("a")));
1745 let plan = LogicalPlanBuilder::from(table_scan)
1746 .project(vec![expr1, expr2])?
1747 .build()?;
1748 assert_optimized_plan_equal!(
1749 plan,
1750 @ r"
1751 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1752 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1753 TableScan: test
1754 "
1755 )?;
1756
1757 let table_scan = test_table_scan()?;
1759 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1760 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1761 let plan = LogicalPlanBuilder::from(table_scan)
1762 .project(vec![expr1, expr2])?
1763 .build()?;
1764 assert_optimized_plan_equal!(
1765 plan,
1766 @ r"
1767 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1768 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1769 TableScan: test
1770 "
1771 )?;
1772
1773 let table_scan = test_table_scan()?;
1775 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1776 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1777 let plan = LogicalPlanBuilder::from(table_scan)
1778 .project(vec![expr1, expr2])?
1779 .build()?;
1780 assert_optimized_plan_equal!(
1781 plan,
1782 @ r"
1783 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1784 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1785 TableScan: test
1786 "
1787 )?;
1788
1789 let udf = ScalarUDF::from(TestUdf::new());
1791 let table_scan = test_table_scan()?;
1792 let expr1 = udf.call(vec![col("a") + col("b")]);
1793 let expr2 = udf.call(vec![col("b") + col("a")]);
1794 let plan = LogicalPlanBuilder::from(table_scan)
1795 .project(vec![expr1, expr2])?
1796 .build()?;
1797 assert_optimized_plan_equal!(
1798 plan,
1799 @ r"
1800 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1801 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1802 TableScan: test
1803 "
1804 )
1805 }
1806
1807 fn rand_func() -> ScalarUDF {
1813 ScalarUDF::new_from_impl(RandomStub::new())
1814 }
1815
1816 #[derive(Debug, PartialEq, Eq, Hash)]
1817 struct RandomStub {
1818 signature: Signature,
1819 }
1820
1821 impl RandomStub {
1822 fn new() -> Self {
1823 Self {
1824 signature: Signature::exact(vec![], Volatility::Volatile),
1825 }
1826 }
1827 }
1828 impl ScalarUDFImpl for RandomStub {
1829 fn as_any(&self) -> &dyn Any {
1830 self
1831 }
1832
1833 fn name(&self) -> &str {
1834 "random"
1835 }
1836
1837 fn signature(&self) -> &Signature {
1838 &self.signature
1839 }
1840
1841 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1842 Ok(DataType::Float64)
1843 }
1844
1845 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1846 panic!("dummy - not implemented")
1847 }
1848 }
1849
1850 #[test]
1857 fn test_leaf_expression_not_extracted() -> Result<()> {
1858 let table_scan = test_table_scan()?;
1859
1860 let leaf = leaf_udf_expr(col("a"));
1861 let plan = LogicalPlanBuilder::from(table_scan)
1862 .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])?
1863 .build()?;
1864
1865 assert_optimized_plan_equal!(
1867 plan,
1868 @r"
1869 Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2
1870 TableScan: test
1871 "
1872 )
1873 }
1874
1875 #[test]
1879 fn test_leaf_subexpression_not_extracted() -> Result<()> {
1880 let table_scan = test_table_scan()?;
1881
1882 let common = leaf_udf_expr(col("a")) + col("b");
1886 let plan = LogicalPlanBuilder::from(table_scan)
1887 .project(vec![common.clone().alias("c1"), common.alias("c2")])?
1888 .build()?;
1889
1890 assert_optimized_plan_equal!(
1893 plan,
1894 @r"
1895 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1896 Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c
1897 TableScan: test
1898 "
1899 )
1900 }
1901}