1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33use datafusion_expr::expr::{Alias, HigherOrderFunction, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{
38 BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col,
39};
40
41const CSE_PREFIX: &str = "__common_expr";
42
43#[derive(Debug)]
70pub struct CommonSubexprEliminate {}
71
72impl CommonSubexprEliminate {
73 pub fn new() -> Self {
74 Self {}
75 }
76
77 fn try_optimize_proj(
78 &self,
79 projection: Projection,
80 config: &dyn OptimizerConfig,
81 ) -> Result<Transformed<LogicalPlan>> {
82 let Projection {
83 expr,
84 input,
85 schema,
86 ..
87 } = projection;
88 let input = Arc::unwrap_or_clone(input);
89 self.try_unary_plan(expr, input, config)?
90 .map_data(|(new_expr, new_input)| {
91 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
92 .map(LogicalPlan::Projection)
93 })
94 }
95
96 fn try_optimize_sort(
97 &self,
98 sort: Sort,
99 config: &dyn OptimizerConfig,
100 ) -> Result<Transformed<LogicalPlan>> {
101 let Sort { expr, input, fetch } = sort;
102 let input = Arc::unwrap_or_clone(input);
103 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
104 .into_iter()
105 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
106 .unzip();
107 let new_sort = self
108 .try_unary_plan(sort_expressions, input, config)?
109 .update_data(|(new_expr, new_input)| {
110 LogicalPlan::Sort(Sort {
111 expr: new_expr
112 .into_iter()
113 .zip(sort_params)
114 .map(|(expr, (asc, nulls_first))| SortExpr {
115 expr,
116 asc,
117 nulls_first,
118 })
119 .collect(),
120 input: Arc::new(new_input),
121 fetch,
122 })
123 });
124 Ok(new_sort)
125 }
126
127 fn try_optimize_filter(
128 &self,
129 filter: Filter,
130 config: &dyn OptimizerConfig,
131 ) -> Result<Transformed<LogicalPlan>> {
132 let Filter {
133 predicate, input, ..
134 } = filter;
135 let input = Arc::unwrap_or_clone(input);
136 let expr = vec![predicate];
137 self.try_unary_plan(expr, input, config)?
138 .map_data(|(mut new_expr, new_input)| {
139 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
141 Filter::try_new(new_predicate, Arc::new(new_input))
142 .map(LogicalPlan::Filter)
143 })
144 }
145
146 fn try_optimize_window(
147 &self,
148 window: Window,
149 config: &dyn OptimizerConfig,
150 ) -> Result<Transformed<LogicalPlan>> {
151 let (window_expr_list, window_schemas, input) =
154 get_consecutive_window_exprs(window);
155
156 match CSE::new(ExprCSEController::new(
159 config.alias_generator().as_ref(),
160 ExprMask::Normal,
161 ))
162 .extract_common_nodes(window_expr_list)?
163 {
164 FoundCommonNodes::Yes {
168 common_nodes: common_exprs,
169 new_nodes_list: new_exprs_list,
170 original_nodes_list: original_exprs_list,
171 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
172 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
173 }),
174 FoundCommonNodes::No {
175 original_nodes_list: original_exprs_list,
176 } => Ok(Transformed::no((original_exprs_list, input, None))),
177 }?
178 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
181 self.rewrite(new_input, config)?.map_data(|new_input| {
182 Ok((new_window_expr_list, new_input, window_expr_list))
183 })
184 })?
185 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
187 if let Some(window_expr_list) = window_expr_list {
196 let name_preserver = NamePreserver::new_for_projection();
197 let saved_names = window_expr_list
198 .iter()
199 .map(|exprs| {
200 exprs
201 .iter()
202 .map(|expr| name_preserver.save(expr))
203 .collect::<Vec<_>>()
204 })
205 .collect::<Vec<_>>();
206 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
207 new_input,
208 |plan, (new_window_expr, saved_names)| {
209 let new_window_expr = new_window_expr
210 .into_iter()
211 .zip(saved_names)
212 .map(|(new_window_expr, saved_name)| {
213 saved_name.restore(new_window_expr)
214 })
215 .collect::<Vec<_>>();
216 Window::try_new(new_window_expr, Arc::new(plan))
217 .map(LogicalPlan::Window)
218 },
219 )
220 } else {
221 new_window_expr_list
222 .into_iter()
223 .zip(window_schemas)
224 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
225 Window::try_new_with_schema(
226 new_window_expr,
227 Arc::new(plan),
228 schema,
229 )
230 .map(LogicalPlan::Window)
231 })
232 }
233 })
234 }
235
236 fn try_optimize_aggregate(
237 &self,
238 aggregate: Aggregate,
239 config: &dyn OptimizerConfig,
240 ) -> Result<Transformed<LogicalPlan>> {
241 let Aggregate {
242 group_expr,
243 aggr_expr,
244 input,
245 schema,
246 ..
247 } = aggregate;
248 let input = Arc::unwrap_or_clone(input);
249 match CSE::new(ExprCSEController::new(
251 config.alias_generator().as_ref(),
252 ExprMask::Normal,
253 ))
254 .extract_common_nodes(vec![group_expr, aggr_expr])?
255 {
256 FoundCommonNodes::Yes {
260 common_nodes: common_exprs,
261 new_nodes_list: mut new_exprs_list,
262 original_nodes_list: mut original_exprs_list,
263 } => {
264 let new_aggr_expr = new_exprs_list.pop().unwrap();
265 let new_group_expr = new_exprs_list.pop().unwrap();
266
267 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
268 let aggr_expr = original_exprs_list.pop().unwrap();
269 Transformed::yes((
270 new_aggr_expr,
271 new_group_expr,
272 new_input,
273 Some(aggr_expr),
274 ))
275 })
276 }
277
278 FoundCommonNodes::No {
279 original_nodes_list: mut original_exprs_list,
280 } => {
281 let new_aggr_expr = original_exprs_list.pop().unwrap();
282 let new_group_expr = original_exprs_list.pop().unwrap();
283
284 Ok(Transformed::no((
285 new_aggr_expr,
286 new_group_expr,
287 input,
288 None,
289 )))
290 }
291 }?
292 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
295 self.rewrite(new_input, config)?.map_data(|new_input| {
296 Ok((
297 new_aggr_expr,
298 new_group_expr,
299 aggr_expr,
300 Arc::new(new_input),
301 ))
302 })
303 })?
304 .transform_data(
306 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
307 match CSE::new(ExprCSEController::new(
309 config.alias_generator().as_ref(),
310 ExprMask::NormalAndAggregates,
311 ))
312 .extract_common_nodes(vec![new_aggr_expr])?
313 {
314 FoundCommonNodes::Yes {
315 common_nodes: common_exprs,
316 new_nodes_list: mut new_exprs_list,
317 original_nodes_list: mut original_exprs_list,
318 } => {
319 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
320 let new_aggr_expr = original_exprs_list.pop().unwrap();
321 let saved_names = if let Some(aggr_expr) = aggr_expr {
322 let name_preserver = NamePreserver::new_for_projection();
323 aggr_expr
324 .iter()
325 .map(|expr| Some(name_preserver.save(expr)))
326 .collect::<Vec<_>>()
327 } else {
328 (0..new_aggr_expr.len()).map(|_| None).collect()
329 };
330
331 let mut agg_exprs = common_exprs
332 .into_iter()
333 .map(|(expr, expr_alias)| expr.alias(expr_alias))
334 .collect::<Vec<_>>();
335
336 let mut proj_exprs = vec![];
337 for expr in &new_group_expr {
338 extract_expressions(expr, &mut proj_exprs)
339 }
340 for ((expr_rewritten, expr_orig), saved_name) in
341 rewritten_aggr_expr
342 .into_iter()
343 .zip(new_aggr_expr)
344 .zip(saved_names)
345 {
346 if expr_rewritten == expr_orig {
347 let expr_rewritten = if let Some(saved_name) = saved_name
348 {
349 saved_name.restore(expr_rewritten)
350 } else {
351 expr_rewritten
352 };
353 if let Expr::Alias(Alias { expr, name, .. }) =
354 expr_rewritten
355 {
356 agg_exprs.push(expr.alias(&name));
357 proj_exprs
358 .push(Expr::Column(Column::from_name(name)));
359 } else {
360 let expr_alias =
361 config.alias_generator().next(CSE_PREFIX);
362 let (qualifier, field_name) =
363 expr_rewritten.qualified_name();
364 let out_name =
365 qualified_name(qualifier.as_ref(), &field_name);
366
367 agg_exprs.push(expr_rewritten.alias(&expr_alias));
368 proj_exprs.push(
369 Expr::Column(Column::from_name(expr_alias))
370 .alias(out_name),
371 );
372 }
373 } else {
374 proj_exprs.push(expr_rewritten);
375 }
376 }
377
378 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
379 new_input,
380 new_group_expr,
381 agg_exprs,
382 )?);
383 Projection::try_new(proj_exprs, Arc::new(agg))
384 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
385 }
386
387 FoundCommonNodes::No {
390 original_nodes_list: mut original_exprs_list,
391 } => {
392 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
393
394 if let Some(aggr_expr) = aggr_expr {
405 let name_preserver = NamePreserver::new_for_projection();
406 let saved_names = aggr_expr
407 .iter()
408 .map(|expr| name_preserver.save(expr))
409 .collect::<Vec<_>>();
410 let new_aggr_expr = rewritten_aggr_expr
411 .into_iter()
412 .zip(saved_names)
413 .map(|(new_expr, saved_name)| {
414 saved_name.restore(new_expr)
415 })
416 .collect::<Vec<Expr>>();
417
418 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
421 .map(LogicalPlan::Aggregate)
422 .map(Transformed::no)
423 } else {
424 Aggregate::try_new_with_schema(
425 new_input,
426 new_group_expr,
427 rewritten_aggr_expr,
428 schema,
429 )
430 .map(LogicalPlan::Aggregate)
431 .map(Transformed::no)
432 }
433 }
434 }
435 },
436 )
437 }
438
439 fn try_unary_plan(
454 &self,
455 exprs: Vec<Expr>,
456 input: LogicalPlan,
457 config: &dyn OptimizerConfig,
458 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
459 match CSE::new(ExprCSEController::new(
461 config.alias_generator().as_ref(),
462 ExprMask::Normal,
463 ))
464 .extract_common_nodes(vec![exprs])?
465 {
466 FoundCommonNodes::Yes {
467 common_nodes: common_exprs,
468 new_nodes_list: mut new_exprs_list,
469 original_nodes_list: _,
470 } => {
471 let new_exprs = new_exprs_list.pop().unwrap();
472 build_common_expr_project_plan(input, common_exprs)
473 .map(|new_input| Transformed::yes((new_exprs, new_input)))
474 }
475 FoundCommonNodes::No {
476 original_nodes_list: mut original_exprs_list,
477 } => {
478 let new_exprs = original_exprs_list.pop().unwrap();
479 Ok(Transformed::no((new_exprs, input)))
480 }
481 }?
482 .transform_data(|(new_exprs, new_input)| {
485 self.rewrite(new_input, config)?
486 .map_data(|new_input| Ok((new_exprs, new_input)))
487 })
488 }
489}
490
491fn get_consecutive_window_exprs(
523 window: Window,
524) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
525 let mut window_expr_list = vec![];
526 let mut window_schemas = vec![];
527 let mut plan = LogicalPlan::Window(window);
528 while let LogicalPlan::Window(Window {
529 input,
530 window_expr,
531 schema,
532 }) = plan
533 {
534 window_expr_list.push(window_expr);
535 window_schemas.push(schema);
536
537 plan = Arc::unwrap_or_clone(input);
538 }
539 (window_expr_list, window_schemas, plan)
540}
541
542impl OptimizerRule for CommonSubexprEliminate {
543 fn supports_rewrite(&self) -> bool {
544 true
545 }
546
547 fn apply_order(&self) -> Option<ApplyOrder> {
548 None
552 }
553
554 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
555 fn rewrite(
556 &self,
557 plan: LogicalPlan,
558 config: &dyn OptimizerConfig,
559 ) -> Result<Transformed<LogicalPlan>> {
560 let original_schema = Arc::clone(plan.schema());
561
562 let optimized_plan = match plan {
563 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
564 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
565 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
566 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
567 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
568 LogicalPlan::Join(_)
569 | LogicalPlan::Repartition(_)
570 | LogicalPlan::Union(_)
571 | LogicalPlan::TableScan(_)
572 | LogicalPlan::Values(_)
573 | LogicalPlan::EmptyRelation(_)
574 | LogicalPlan::Subquery(_)
575 | LogicalPlan::SubqueryAlias(_)
576 | LogicalPlan::Limit(_)
577 | LogicalPlan::Ddl(_)
578 | LogicalPlan::Explain(_)
579 | LogicalPlan::Analyze(_)
580 | LogicalPlan::Statement(_)
581 | LogicalPlan::DescribeTable(_)
582 | LogicalPlan::Distinct(_)
583 | LogicalPlan::Extension(_)
584 | LogicalPlan::Dml(_)
585 | LogicalPlan::Copy(_)
586 | LogicalPlan::Unnest(_)
587 | LogicalPlan::RecursiveQuery(_) => {
588 plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))?
592 .transform_sibling(|plan| {
593 plan.map_children(|c| self.rewrite(c, config))
594 })?
595 }
596 };
597
598 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
600 {
601 optimized_plan.map_data(|optimized_plan| {
602 build_recover_project_plan(&original_schema, optimized_plan)
603 })
604 } else {
605 Ok(optimized_plan)
606 }
607 }
608
609 fn name(&self) -> &str {
610 "common_sub_expression_eliminate"
611 }
612}
613
614#[derive(Debug, Clone, Copy)]
616enum ExprMask {
617 Normal,
626
627 NormalAndAggregates,
629}
630
631struct ExprCSEController<'a> {
632 alias_generator: &'a AliasGenerator,
633 mask: ExprMask,
634
635 alias_counter: usize,
637}
638
639impl<'a> ExprCSEController<'a> {
640 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
641 Self {
642 alias_generator,
643 mask,
644 alias_counter: 0,
645 }
646 }
647}
648
649impl CSEController for ExprCSEController<'_> {
650 type Node = Expr;
651
652 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653 match node {
654 Expr::ScalarFunction(ScalarFunction { func, args }) => {
658 func.conditional_arguments(args)
659 }
660 Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
661 func.conditional_arguments(args)
662 }
663
664 Expr::BinaryExpr(BinaryExpr {
667 left,
668 op: Operator::And | Operator::Or,
669 right,
670 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
671
672 Expr::Case(Case {
676 expr,
677 when_then_expr,
678 else_expr,
679 }) => Some((
680 expr.iter()
681 .map(|e| e.as_ref())
682 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
683 .collect(),
684 when_then_expr
685 .iter()
686 .take(1)
687 .map(|(_, then)| then.as_ref())
688 .chain(
689 when_then_expr
690 .iter()
691 .skip(1)
692 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
693 )
694 .chain(else_expr.iter().map(|e| e.as_ref()))
695 .collect(),
696 )),
697 _ => None,
698 }
699 }
700
701 fn is_valid(node: &Expr) -> bool {
702 !node.is_volatile_node()
703 && !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_))
704 }
705
706 fn is_ignored(&self, node: &Expr) -> bool {
707 if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes {
716 return true;
717 }
718
719 #[expect(deprecated)]
721 let is_normal_minus_aggregates = matches!(
722 node,
723 Expr::Literal(..)
729 | Expr::Column(..)
730 | Expr::ScalarVariable(..)
731 | Expr::Alias(..)
732 | Expr::Wildcard { .. }
733 | Expr::Lambda(_)
734 | Expr::LambdaVariable(_)
735 );
736
737 let is_aggr = matches!(node, Expr::AggregateFunction(..));
738
739 match self.mask {
740 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
741 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
742 }
743 }
744
745 fn generate_alias(&self) -> String {
746 self.alias_generator.next(CSE_PREFIX)
747 }
748
749 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
750 if self.alias_counter > 0 {
752 col(alias)
753 } else {
754 self.alias_counter += 1;
755 col(alias).alias(node.schema_name().to_string())
756 }
757 }
758
759 fn rewrite_f_down(&mut self, node: &Expr) {
760 if matches!(node, Expr::Alias(_)) {
761 self.alias_counter += 1;
762 }
763 }
764 fn rewrite_f_up(&mut self, node: &Expr) {
765 if matches!(node, Expr::Alias(_)) {
766 self.alias_counter -= 1
767 }
768 }
769}
770
771impl Default for CommonSubexprEliminate {
772 fn default() -> Self {
773 Self::new()
774 }
775}
776
777fn build_common_expr_project_plan(
788 input: LogicalPlan,
789 common_exprs: Vec<(Expr, String)>,
790) -> Result<LogicalPlan> {
791 let mut fields_set = BTreeSet::new();
792 let mut project_exprs = common_exprs
793 .into_iter()
794 .map(|(expr, expr_alias)| {
795 fields_set.insert(expr_alias.clone());
796 Ok(expr.alias(expr_alias))
797 })
798 .collect::<Result<Vec<_>>>()?;
799
800 for (qualifier, field) in input.schema().iter() {
801 if fields_set.insert(qualified_name(qualifier, field.name())) {
802 project_exprs.push(Expr::from((qualifier, field)));
803 }
804 }
805
806 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
807}
808
809fn build_recover_project_plan(
815 schema: &DFSchema,
816 input: LogicalPlan,
817) -> Result<LogicalPlan> {
818 let col_exprs = schema.iter().map(Expr::from).collect();
819 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
820}
821
822fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
823 if let Expr::GroupingSet(groupings) = expr {
824 for e in groupings.distinct_expr() {
825 let (qualifier, field_name) = e.qualified_name();
826 let col = Column::new(qualifier, field_name);
827 result.push(Expr::Column(col))
828 }
829 } else {
830 let (qualifier, field_name) = expr.qualified_name();
831 let col = Column::new(qualifier, field_name);
832 result.push(Expr::Column(col));
833 }
834}
835
836#[cfg(test)]
837mod test {
838
839 use std::iter;
840
841 use arrow::datatypes::{DataType, Field, Schema};
842 use datafusion_expr::logical_plan::{JoinType, table_scan};
843 use datafusion_expr::{
844 AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs,
845 ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
846 grouping_set, is_null, not,
847 };
848 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
849
850 use super::*;
851 use crate::assert_optimized_plan_eq_snapshot;
852 use crate::optimizer::OptimizerContext;
853 use crate::test::udfs::leaf_udf_expr;
854 use crate::test::*;
855 use datafusion_expr::test::function_stub::{avg, sum};
856
857 macro_rules! assert_optimized_plan_equal {
858 (
859 $config:expr,
860 $plan:expr,
861 @ $expected:literal $(,)?
862 ) => {{
863 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
864 assert_optimized_plan_eq_snapshot!(
865 $config,
866 rules,
867 $plan,
868 @ $expected,
869 )
870 }};
871
872 (
873 $plan:expr,
874 @ $expected:literal $(,)?
875 ) => {{
876 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
877 let optimizer_ctx = OptimizerContext::new();
878 assert_optimized_plan_eq_snapshot!(
879 optimizer_ctx,
880 rules,
881 $plan,
882 @ $expected,
883 )
884 }};
885 }
886
887 #[test]
888 fn tpch_q1_simplified() -> Result<()> {
889 let table_scan = test_table_scan()?;
898
899 let plan = LogicalPlanBuilder::from(table_scan)
900 .aggregate(
901 iter::empty::<Expr>(),
902 vec![
903 sum(col("a") * (lit(1) - col("b"))),
904 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
905 ],
906 )?
907 .build()?;
908
909 assert_optimized_plan_equal!(
910 plan,
911 @ r"
912 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
913 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
914 TableScan: test
915 "
916 )
917 }
918
919 #[test]
920 fn nested_aliases() -> Result<()> {
921 let table_scan = test_table_scan()?;
922
923 let plan = LogicalPlanBuilder::from(table_scan)
924 .project(vec![
925 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
926 col("a") + col("b"),
927 ])?
928 .build()?;
929
930 assert_optimized_plan_equal!(
931 plan,
932 @ r"
933 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
934 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
935 TableScan: test
936 "
937 )
938 }
939
940 #[test]
941 fn aggregate() -> Result<()> {
942 let table_scan = test_table_scan()?;
943
944 let return_type = DataType::UInt32;
945 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
946 let udf_agg = |inner: Expr| {
947 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
948 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
949 "my_agg",
950 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
951 return_type.clone(),
952 Arc::clone(&accumulator),
953 vec![Field::new("value", DataType::UInt32, true).into()],
954 ))),
955 vec![inner],
956 false,
957 None,
958 vec![],
959 None,
960 ))
961 };
962
963 let plan = LogicalPlanBuilder::from(table_scan.clone())
965 .aggregate(
966 iter::empty::<Expr>(),
967 vec![
968 avg(col("a")).alias("col1"),
970 avg(col("a")).alias("col2"),
971 avg(col("b")).alias("col3"),
973 avg(col("c")),
974 udf_agg(col("a")).alias("col4"),
976 udf_agg(col("a")).alias("col5"),
977 udf_agg(col("b")).alias("col6"),
979 udf_agg(col("c")),
980 ],
981 )?
982 .build()?;
983
984 assert_optimized_plan_equal!(
985 plan,
986 @ r"
987 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
988 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
989 TableScan: test
990 "
991 )?;
992
993 let plan = LogicalPlanBuilder::from(table_scan.clone())
995 .aggregate(
996 iter::empty::<Expr>(),
997 vec![
998 lit(1) + avg(col("a")),
999 lit(1) - avg(col("a")),
1000 lit(1) + udf_agg(col("a")),
1001 lit(1) - udf_agg(col("a")),
1002 ],
1003 )?
1004 .build()?;
1005
1006 assert_optimized_plan_equal!(
1007 plan,
1008 @ r"
1009 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
1010 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
1011 TableScan: test
1012 "
1013 )?;
1014
1015 let plan = LogicalPlanBuilder::from(table_scan.clone())
1017 .aggregate(
1018 iter::empty::<Expr>(),
1019 vec![
1020 avg(lit(1u32) + col("a")).alias("col1"),
1021 udf_agg(lit(1u32) + col("a")).alias("col2"),
1022 ],
1023 )?
1024 .build()?;
1025
1026 assert_optimized_plan_equal!(
1027 plan,
1028 @ r"
1029 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1030 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1031 TableScan: test
1032 "
1033 )?;
1034
1035 let plan = LogicalPlanBuilder::from(table_scan.clone())
1037 .aggregate(
1038 vec![lit(1u32) + col("a")],
1039 vec![
1040 avg(lit(1u32) + col("a")).alias("col1"),
1041 udf_agg(lit(1u32) + col("a")).alias("col2"),
1042 ],
1043 )?
1044 .build()?;
1045
1046 assert_optimized_plan_equal!(
1047 plan,
1048 @ r"
1049 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1050 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1051 TableScan: test
1052 "
1053 )?;
1054
1055 let plan = LogicalPlanBuilder::from(table_scan)
1057 .aggregate(
1058 vec![lit(1u32) + col("a")],
1059 vec![
1060 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1061 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1062 avg(lit(1u32) + col("a")),
1063 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1064 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1065 udf_agg(lit(1u32) + col("a")),
1066 ],
1067 )?
1068 .build()?;
1069
1070 assert_optimized_plan_equal!(
1071 plan,
1072 @ r"
1073 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1074 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1075 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1076 TableScan: test
1077 "
1078 )
1079 }
1080
1081 #[test]
1082 fn aggregate_with_relations_and_dots() -> Result<()> {
1083 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1084 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1085
1086 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1087
1088 let plan = LogicalPlanBuilder::from(table_scan)
1089 .aggregate(
1090 vec![col_a.clone()],
1091 vec![
1092 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1093 avg(lit(1u32) + col_a),
1094 ],
1095 )?
1096 .build()?;
1097
1098 assert_optimized_plan_equal!(
1099 plan,
1100 @ r"
1101 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1102 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1103 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1104 TableScan: table.test
1105 "
1106 )
1107 }
1108
1109 #[test]
1110 fn subexpr_in_same_order() -> Result<()> {
1111 let table_scan = test_table_scan()?;
1112
1113 let plan = LogicalPlanBuilder::from(table_scan)
1114 .project(vec![
1115 (lit(1) + col("a")).alias("first"),
1116 (lit(1) + col("a")).alias("second"),
1117 ])?
1118 .build()?;
1119
1120 assert_optimized_plan_equal!(
1121 plan,
1122 @ r"
1123 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1124 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1125 TableScan: test
1126 "
1127 )
1128 }
1129
1130 #[test]
1131 fn subexpr_in_different_order() -> Result<()> {
1132 let table_scan = test_table_scan()?;
1133
1134 let plan = LogicalPlanBuilder::from(table_scan)
1135 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1136 .build()?;
1137
1138 assert_optimized_plan_equal!(
1139 plan,
1140 @ r"
1141 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1142 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1143 TableScan: test
1144 "
1145 )
1146 }
1147
1148 #[test]
1149 fn cross_plans_subexpr() -> Result<()> {
1150 let table_scan = test_table_scan()?;
1151
1152 let plan = LogicalPlanBuilder::from(table_scan)
1153 .project(vec![lit(1) + col("a"), col("a")])?
1154 .project(vec![lit(1) + col("a")])?
1155 .build()?;
1156
1157 assert_optimized_plan_equal!(
1158 plan,
1159 @ r"
1160 Projection: Int32(1) + test.a
1161 Projection: Int32(1) + test.a, test.a
1162 TableScan: test
1163 "
1164 )
1165 }
1166
1167 #[test]
1168 fn redundant_project_fields() {
1169 let table_scan = test_table_scan().unwrap();
1170 let c_plus_a = col("c") + col("a");
1171 let b_plus_a = col("b") + col("a");
1172 let common_exprs_1 = vec![
1173 (c_plus_a, format!("{CSE_PREFIX}_1")),
1174 (b_plus_a, format!("{CSE_PREFIX}_2")),
1175 ];
1176 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1177 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1178 let common_exprs_2 = vec![
1179 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1180 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1181 ];
1182 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1183 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1184
1185 let mut field_set = BTreeSet::new();
1186 for name in project_2.schema().field_names() {
1187 assert!(field_set.insert(name));
1188 }
1189 }
1190
1191 #[test]
1192 fn redundant_project_fields_join_input() {
1193 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1194 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1195 let join = LogicalPlanBuilder::from(table_scan_1)
1196 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1197 .unwrap()
1198 .build()
1199 .unwrap();
1200 let c_plus_a = col("test1.c") + col("test1.a");
1201 let b_plus_a = col("test1.b") + col("test1.a");
1202 let common_exprs_1 = vec![
1203 (c_plus_a, format!("{CSE_PREFIX}_1")),
1204 (b_plus_a, format!("{CSE_PREFIX}_2")),
1205 ];
1206 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1207 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1208 let common_exprs_2 = vec![
1209 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1210 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1211 ];
1212 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1213 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1214
1215 let mut field_set = BTreeSet::new();
1216 for name in project_2.schema().field_names() {
1217 assert!(field_set.insert(name));
1218 }
1219 }
1220
1221 #[test]
1222 fn eliminated_subexpr_datatype() {
1223 use datafusion_expr::cast;
1224
1225 let schema = Schema::new(vec![
1226 Field::new("a", DataType::UInt64, false),
1227 Field::new("b", DataType::UInt64, false),
1228 Field::new("c", DataType::UInt64, false),
1229 ]);
1230
1231 let plan = table_scan(Some("table"), &schema, None)
1232 .unwrap()
1233 .filter(
1234 cast(col("a"), DataType::Int64)
1235 .lt(lit(1_i64))
1236 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1237 )
1238 .unwrap()
1239 .build()
1240 .unwrap();
1241 let rule = CommonSubexprEliminate::new();
1242 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1243 assert!(optimized_plan.transformed);
1244 let optimized_plan = optimized_plan.data;
1245
1246 let schema = optimized_plan.schema();
1247 let fields_with_datatypes: Vec<_> = schema
1248 .fields()
1249 .iter()
1250 .map(|field| (field.name(), field.data_type()))
1251 .collect();
1252 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1253 let expected = r#"[
1254 (
1255 "a",
1256 UInt64,
1257 ),
1258 (
1259 "b",
1260 UInt64,
1261 ),
1262 (
1263 "c",
1264 UInt64,
1265 ),
1266]"#;
1267 assert_eq!(expected, formatted_fields_with_datatype);
1268 }
1269
1270 #[test]
1271 fn filter_schema_changed() -> Result<()> {
1272 let table_scan = test_table_scan()?;
1273
1274 let plan = LogicalPlanBuilder::from(table_scan)
1275 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1276 .build()?;
1277
1278 assert_optimized_plan_equal!(
1279 plan,
1280 @ r"
1281 Projection: test.a, test.b, test.c
1282 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1283 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1284 TableScan: test
1285 "
1286 )
1287 }
1288
1289 #[test]
1290 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1291 let mut result = Vec::with_capacity(3);
1292 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1293 extract_expressions(&grouping, &mut result);
1294
1295 assert!(result.len() == 3);
1296 Ok(())
1297 }
1298
1299 #[test]
1300 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1301 let mut result = Vec::with_capacity(2);
1302 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1303 extract_expressions(&grouping, &mut result);
1304 assert!(result.len() == 2);
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn test_alias_collision() -> Result<()> {
1310 let table_scan = test_table_scan()?;
1311
1312 let config = OptimizerContext::new();
1313 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1314 let plan = LogicalPlanBuilder::from(table_scan.clone())
1315 .project(vec![
1316 (col("a") + col("b")).alias(common_expr_1.clone()),
1317 col("c"),
1318 ])?
1319 .project(vec![
1320 col(common_expr_1.clone()).alias("c1"),
1321 col(common_expr_1).alias("c2"),
1322 (col("c") + lit(2)).alias("c3"),
1323 (col("c") + lit(2)).alias("c4"),
1324 ])?
1325 .build()?;
1326
1327 assert_optimized_plan_equal!(
1328 config,
1329 plan,
1330 @ r"
1331 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1332 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1333 Projection: test.a + test.b AS __common_expr_1, test.c
1334 TableScan: test
1335 "
1336 )?;
1337
1338 let config = OptimizerContext::new();
1339 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1340 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1341 let plan = LogicalPlanBuilder::from(table_scan)
1342 .project(vec![
1343 (col("a") + col("b")).alias(common_expr_2.clone()),
1344 col("c"),
1345 ])?
1346 .project(vec![
1347 col(common_expr_2.clone()).alias("c1"),
1348 col(common_expr_2).alias("c2"),
1349 (col("c") + lit(2)).alias("c3"),
1350 (col("c") + lit(2)).alias("c4"),
1351 ])?
1352 .build()?;
1353
1354 assert_optimized_plan_equal!(
1355 config,
1356 plan,
1357 @ r"
1358 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1359 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1360 Projection: test.a + test.b AS __common_expr_2, test.c
1361 TableScan: test
1362 "
1363 )?;
1364
1365 Ok(())
1366 }
1367
1368 #[test]
1369 fn test_extract_expressions_from_col() -> Result<()> {
1370 let mut result = Vec::with_capacity(1);
1371 extract_expressions(&col("a"), &mut result);
1372 assert!(result.len() == 1);
1373 Ok(())
1374 }
1375
1376 #[test]
1377 fn test_short_circuits() -> Result<()> {
1378 let table_scan = test_table_scan()?;
1379
1380 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1381 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1382 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1383 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1384 let plan = LogicalPlanBuilder::from(table_scan)
1385 .project(vec![
1386 extracted_short_circuit.clone().alias("c1"),
1387 extracted_short_circuit.alias("c2"),
1388 extracted_short_circuit_leg_1
1389 .clone()
1390 .or(not_extracted_short_circuit_leg_2.clone())
1391 .alias("c3"),
1392 extracted_short_circuit_leg_1
1393 .and(not_extracted_short_circuit_leg_2)
1394 .alias("c4"),
1395 extracted_short_circuit_leg_3
1396 .clone()
1397 .or(extracted_short_circuit_leg_3)
1398 .alias("c5"),
1399 ])?
1400 .build()?;
1401
1402 assert_optimized_plan_equal!(
1403 plan,
1404 @ r"
1405 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1406 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1407 TableScan: test
1408 "
1409 )
1410 }
1411
1412 #[test]
1413 fn test_volatile() -> Result<()> {
1414 let table_scan = test_table_scan()?;
1415
1416 let extracted_child = col("a") + col("b");
1417 let rand = rand_func().call(vec![]);
1418 let not_extracted_volatile = extracted_child + rand;
1419 let plan = LogicalPlanBuilder::from(table_scan)
1420 .project(vec![
1421 not_extracted_volatile.clone().alias("c1"),
1422 not_extracted_volatile.alias("c2"),
1423 ])?
1424 .build()?;
1425
1426 assert_optimized_plan_equal!(
1427 plan,
1428 @ r"
1429 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1430 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1431 TableScan: test
1432 "
1433 )
1434 }
1435
1436 #[test]
1437 fn test_volatile_short_circuits() -> Result<()> {
1438 let table_scan = test_table_scan()?;
1439
1440 let rand = rand_func().call(vec![]);
1441 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1442 let not_extracted_volatile_short_circuit_1 =
1443 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1444 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1445 let not_extracted_volatile_short_circuit_2 =
1446 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1447 let plan = LogicalPlanBuilder::from(table_scan)
1448 .project(vec![
1449 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1450 not_extracted_volatile_short_circuit_1.alias("c2"),
1451 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1452 not_extracted_volatile_short_circuit_2.alias("c4"),
1453 ])?
1454 .build()?;
1455
1456 assert_optimized_plan_equal!(
1457 plan,
1458 @ r"
1459 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1460 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1461 TableScan: test
1462 "
1463 )
1464 }
1465
1466 #[test]
1467 fn test_non_top_level_common_expression() -> Result<()> {
1468 let table_scan = test_table_scan()?;
1469
1470 let common_expr = col("a") + col("b");
1471 let plan = LogicalPlanBuilder::from(table_scan)
1472 .project(vec![
1473 common_expr.clone().alias("c1"),
1474 common_expr.alias("c2"),
1475 ])?
1476 .project(vec![col("c1"), col("c2")])?
1477 .build()?;
1478
1479 assert_optimized_plan_equal!(
1480 plan,
1481 @ r"
1482 Projection: c1, c2
1483 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1484 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1485 TableScan: test
1486 "
1487 )
1488 }
1489
1490 #[test]
1491 fn test_nested_common_expression() -> Result<()> {
1492 let table_scan = test_table_scan()?;
1493
1494 let nested_common_expr = col("a") + col("b");
1495 let common_expr = nested_common_expr.clone() * nested_common_expr;
1496 let plan = LogicalPlanBuilder::from(table_scan)
1497 .project(vec![
1498 common_expr.clone().alias("c1"),
1499 common_expr.alias("c2"),
1500 ])?
1501 .build()?;
1502
1503 assert_optimized_plan_equal!(
1504 plan,
1505 @ r"
1506 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1507 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1508 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1509 TableScan: test
1510 "
1511 )
1512 }
1513
1514 #[test]
1515 fn test_normalize_add_expression() -> Result<()> {
1516 let table_scan = test_table_scan()?;
1518 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1519 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1520
1521 assert_optimized_plan_equal!(
1522 plan,
1523 @ r"
1524 Projection: test.a, test.b, test.c
1525 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1526 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1527 TableScan: test
1528 "
1529 )
1530 }
1531
1532 #[test]
1533 fn test_normalize_multi_expression() -> Result<()> {
1534 let table_scan = test_table_scan()?;
1536 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1537 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1538
1539 assert_optimized_plan_equal!(
1540 plan,
1541 @ r"
1542 Projection: test.a, test.b, test.c
1543 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1544 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1545 TableScan: test
1546 "
1547 )
1548 }
1549
1550 #[test]
1551 fn test_normalize_bitset_and_expression() -> Result<()> {
1552 let table_scan = test_table_scan()?;
1554 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1555 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1556
1557 assert_optimized_plan_equal!(
1558 plan,
1559 @ r"
1560 Projection: test.a, test.b, test.c
1561 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1562 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1563 TableScan: test
1564 "
1565 )
1566 }
1567
1568 #[test]
1569 fn test_normalize_bitset_or_expression() -> Result<()> {
1570 let table_scan = test_table_scan()?;
1572 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1573 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1574
1575 assert_optimized_plan_equal!(
1576 plan,
1577 @ r"
1578 Projection: test.a, test.b, test.c
1579 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1580 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1581 TableScan: test
1582 "
1583 )
1584 }
1585
1586 #[test]
1587 fn test_normalize_bitset_xor_expression() -> Result<()> {
1588 let table_scan = test_table_scan()?;
1590 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1591 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1592
1593 assert_optimized_plan_equal!(
1594 plan,
1595 @ r"
1596 Projection: test.a, test.b, test.c
1597 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1598 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1599 TableScan: test
1600 "
1601 )
1602 }
1603
1604 #[test]
1605 fn test_normalize_eq_expression() -> Result<()> {
1606 let table_scan = test_table_scan()?;
1608 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1609 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1610
1611 assert_optimized_plan_equal!(
1612 plan,
1613 @ r"
1614 Projection: test.a, test.b, test.c
1615 Filter: __common_expr_1 AND __common_expr_1
1616 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1617 TableScan: test
1618 "
1619 )
1620 }
1621
1622 #[test]
1623 fn test_normalize_ne_expression() -> Result<()> {
1624 let table_scan = test_table_scan()?;
1626 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1627 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1628
1629 assert_optimized_plan_equal!(
1630 plan,
1631 @ r"
1632 Projection: test.a, test.b, test.c
1633 Filter: __common_expr_1 AND __common_expr_1
1634 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1635 TableScan: test
1636 "
1637 )
1638 }
1639
1640 #[test]
1641 fn test_normalize_complex_expression() -> Result<()> {
1642 let table_scan = test_table_scan()?;
1644 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1645 .eq(lit(30));
1646 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1647
1648 assert_optimized_plan_equal!(
1649 plan,
1650 @ r"
1651 Projection: test.a, test.b, test.c
1652 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1653 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1654 TableScan: test
1655 "
1656 )?;
1657
1658 let table_scan = test_table_scan()?;
1660 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1661 / (col("c") * (col("b") / col("c") + col("a")))
1662 + col("a"))
1663 .eq(lit(30));
1664 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1665
1666 assert_optimized_plan_equal!(
1667 plan,
1668 @ r"
1669 Projection: test.a, test.b, test.c
1670 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1671 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1672 TableScan: test
1673 "
1674 )?;
1675
1676 let table_scan = test_table_scan()?;
1678 let expr = ((col("b") / (col("a") + col("c")))
1679 * (col("b") / (col("c") + col("a"))))
1680 .eq(lit(30));
1681 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1682 assert_optimized_plan_equal!(
1683 plan,
1684 @ r"
1685 Projection: test.a, test.b, test.c
1686 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1687 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1688 TableScan: test
1689 "
1690 )?;
1691
1692 Ok(())
1693 }
1694
1695 #[derive(Debug, PartialEq, Eq, Hash)]
1696 pub struct TestUdf {
1697 signature: Signature,
1698 }
1699
1700 impl TestUdf {
1701 pub fn new() -> Self {
1702 Self {
1703 signature: Signature::numeric(1, Volatility::Immutable),
1704 }
1705 }
1706 }
1707
1708 impl ScalarUDFImpl for TestUdf {
1709 fn name(&self) -> &str {
1710 "my_udf"
1711 }
1712
1713 fn signature(&self) -> &Signature {
1714 &self.signature
1715 }
1716
1717 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1718 Ok(DataType::Int32)
1719 }
1720
1721 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1722 panic!("not implemented")
1723 }
1724 }
1725
1726 #[test]
1727 fn test_normalize_inner_binary_expression() -> Result<()> {
1728 let table_scan = test_table_scan()?;
1730 let expr1 = not(col("a").eq(col("b")));
1731 let expr2 = not(col("b").eq(col("a")));
1732 let plan = LogicalPlanBuilder::from(table_scan)
1733 .project(vec![expr1, expr2])?
1734 .build()?;
1735 assert_optimized_plan_equal!(
1736 plan,
1737 @ r"
1738 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1739 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1740 TableScan: test
1741 "
1742 )?;
1743
1744 let table_scan = test_table_scan()?;
1746 let expr1 = is_null(col("a").eq(col("b")));
1747 let expr2 = is_null(col("b").eq(col("a")));
1748 let plan = LogicalPlanBuilder::from(table_scan)
1749 .project(vec![expr1, expr2])?
1750 .build()?;
1751 assert_optimized_plan_equal!(
1752 plan,
1753 @ r"
1754 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1755 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1756 TableScan: test
1757 "
1758 )?;
1759
1760 let table_scan = test_table_scan()?;
1762 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1763 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1764 let plan = LogicalPlanBuilder::from(table_scan)
1765 .project(vec![expr1, expr2])?
1766 .build()?;
1767 assert_optimized_plan_equal!(
1768 plan,
1769 @ r"
1770 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1771 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1772 TableScan: test
1773 "
1774 )?;
1775
1776 let table_scan = test_table_scan()?;
1778 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1779 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1780 let plan = LogicalPlanBuilder::from(table_scan)
1781 .project(vec![expr1, expr2])?
1782 .build()?;
1783 assert_optimized_plan_equal!(
1784 plan,
1785 @ r"
1786 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1787 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1788 TableScan: test
1789 "
1790 )?;
1791
1792 let udf = ScalarUDF::from(TestUdf::new());
1794 let table_scan = test_table_scan()?;
1795 let expr1 = udf.call(vec![col("a") + col("b")]);
1796 let expr2 = udf.call(vec![col("b") + col("a")]);
1797 let plan = LogicalPlanBuilder::from(table_scan)
1798 .project(vec![expr1, expr2])?
1799 .build()?;
1800 assert_optimized_plan_equal!(
1801 plan,
1802 @ r"
1803 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1804 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1805 TableScan: test
1806 "
1807 )
1808 }
1809
1810 fn rand_func() -> ScalarUDF {
1816 ScalarUDF::new_from_impl(RandomStub::new())
1817 }
1818
1819 #[derive(Debug, PartialEq, Eq, Hash)]
1820 struct RandomStub {
1821 signature: Signature,
1822 }
1823
1824 impl RandomStub {
1825 fn new() -> Self {
1826 Self {
1827 signature: Signature::exact(vec![], Volatility::Volatile),
1828 }
1829 }
1830 }
1831 impl ScalarUDFImpl for RandomStub {
1832 fn name(&self) -> &str {
1833 "random"
1834 }
1835
1836 fn signature(&self) -> &Signature {
1837 &self.signature
1838 }
1839
1840 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1841 Ok(DataType::Float64)
1842 }
1843
1844 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1845 panic!("dummy - not implemented")
1846 }
1847 }
1848
1849 #[test]
1856 fn test_leaf_expression_not_extracted() -> Result<()> {
1857 let table_scan = test_table_scan()?;
1858
1859 let leaf = leaf_udf_expr(col("a"));
1860 let plan = LogicalPlanBuilder::from(table_scan)
1861 .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])?
1862 .build()?;
1863
1864 assert_optimized_plan_equal!(
1866 plan,
1867 @r"
1868 Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2
1869 TableScan: test
1870 "
1871 )
1872 }
1873
1874 #[test]
1878 fn test_leaf_subexpression_not_extracted() -> Result<()> {
1879 let table_scan = test_table_scan()?;
1880
1881 let common = leaf_udf_expr(col("a")) + col("b");
1885 let plan = LogicalPlanBuilder::from(table_scan)
1886 .project(vec![common.clone().alias("c1"), common.alias("c2")])?
1887 .build()?;
1888
1889 assert_optimized_plan_equal!(
1892 plan,
1893 @r"
1894 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1895 Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c
1896 TableScan: test
1897 "
1898 )
1899 }
1900}