1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 fn try_optimize_proj(
76 &self,
77 projection: Projection,
78 config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 let Projection {
81 expr,
82 input,
83 schema,
84 ..
85 } = projection;
86 let input = Arc::unwrap_or_clone(input);
87 self.try_unary_plan(expr, input, config)?
88 .map_data(|(new_expr, new_input)| {
89 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90 .map(LogicalPlan::Projection)
91 })
92 }
93
94 fn try_optimize_sort(
95 &self,
96 sort: Sort,
97 config: &dyn OptimizerConfig,
98 ) -> Result<Transformed<LogicalPlan>> {
99 let Sort { expr, input, fetch } = sort;
100 let input = Arc::unwrap_or_clone(input);
101 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102 .into_iter()
103 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104 .unzip();
105 let new_sort = self
106 .try_unary_plan(sort_expressions, input, config)?
107 .update_data(|(new_expr, new_input)| {
108 LogicalPlan::Sort(Sort {
109 expr: new_expr
110 .into_iter()
111 .zip(sort_params)
112 .map(|(expr, (asc, nulls_first))| SortExpr {
113 expr,
114 asc,
115 nulls_first,
116 })
117 .collect(),
118 input: Arc::new(new_input),
119 fetch,
120 })
121 });
122 Ok(new_sort)
123 }
124
125 fn try_optimize_filter(
126 &self,
127 filter: Filter,
128 config: &dyn OptimizerConfig,
129 ) -> Result<Transformed<LogicalPlan>> {
130 let Filter {
131 predicate, input, ..
132 } = filter;
133 let input = Arc::unwrap_or_clone(input);
134 let expr = vec![predicate];
135 self.try_unary_plan(expr, input, config)?
136 .map_data(|(mut new_expr, new_input)| {
137 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
139 Filter::try_new(new_predicate, Arc::new(new_input))
140 .map(LogicalPlan::Filter)
141 })
142 }
143
144 fn try_optimize_window(
145 &self,
146 window: Window,
147 config: &dyn OptimizerConfig,
148 ) -> Result<Transformed<LogicalPlan>> {
149 let (window_expr_list, window_schemas, input) =
152 get_consecutive_window_exprs(window);
153
154 match CSE::new(ExprCSEController::new(
157 config.alias_generator().as_ref(),
158 ExprMask::Normal,
159 ))
160 .extract_common_nodes(window_expr_list)?
161 {
162 FoundCommonNodes::Yes {
166 common_nodes: common_exprs,
167 new_nodes_list: new_exprs_list,
168 original_nodes_list: original_exprs_list,
169 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171 }),
172 FoundCommonNodes::No {
173 original_nodes_list: original_exprs_list,
174 } => Ok(Transformed::no((original_exprs_list, input, None))),
175 }?
176 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179 self.rewrite(new_input, config)?.map_data(|new_input| {
180 Ok((new_window_expr_list, new_input, window_expr_list))
181 })
182 })?
183 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185 if let Some(window_expr_list) = window_expr_list {
194 let name_preserver = NamePreserver::new_for_projection();
195 let saved_names = window_expr_list
196 .iter()
197 .map(|exprs| {
198 exprs
199 .iter()
200 .map(|expr| name_preserver.save(expr))
201 .collect::<Vec<_>>()
202 })
203 .collect::<Vec<_>>();
204 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205 new_input,
206 |plan, (new_window_expr, saved_names)| {
207 let new_window_expr = new_window_expr
208 .into_iter()
209 .zip(saved_names)
210 .map(|(new_window_expr, saved_name)| {
211 saved_name.restore(new_window_expr)
212 })
213 .collect::<Vec<_>>();
214 Window::try_new(new_window_expr, Arc::new(plan))
215 .map(LogicalPlan::Window)
216 },
217 )
218 } else {
219 new_window_expr_list
220 .into_iter()
221 .zip(window_schemas)
222 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223 Window::try_new_with_schema(
224 new_window_expr,
225 Arc::new(plan),
226 schema,
227 )
228 .map(LogicalPlan::Window)
229 })
230 }
231 })
232 }
233
234 fn try_optimize_aggregate(
235 &self,
236 aggregate: Aggregate,
237 config: &dyn OptimizerConfig,
238 ) -> Result<Transformed<LogicalPlan>> {
239 let Aggregate {
240 group_expr,
241 aggr_expr,
242 input,
243 schema,
244 ..
245 } = aggregate;
246 let input = Arc::unwrap_or_clone(input);
247 match CSE::new(ExprCSEController::new(
249 config.alias_generator().as_ref(),
250 ExprMask::Normal,
251 ))
252 .extract_common_nodes(vec![group_expr, aggr_expr])?
253 {
254 FoundCommonNodes::Yes {
258 common_nodes: common_exprs,
259 new_nodes_list: mut new_exprs_list,
260 original_nodes_list: mut original_exprs_list,
261 } => {
262 let new_aggr_expr = new_exprs_list.pop().unwrap();
263 let new_group_expr = new_exprs_list.pop().unwrap();
264
265 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266 let aggr_expr = original_exprs_list.pop().unwrap();
267 Transformed::yes((
268 new_aggr_expr,
269 new_group_expr,
270 new_input,
271 Some(aggr_expr),
272 ))
273 })
274 }
275
276 FoundCommonNodes::No {
277 original_nodes_list: mut original_exprs_list,
278 } => {
279 let new_aggr_expr = original_exprs_list.pop().unwrap();
280 let new_group_expr = original_exprs_list.pop().unwrap();
281
282 Ok(Transformed::no((
283 new_aggr_expr,
284 new_group_expr,
285 input,
286 None,
287 )))
288 }
289 }?
290 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293 self.rewrite(new_input, config)?.map_data(|new_input| {
294 Ok((
295 new_aggr_expr,
296 new_group_expr,
297 aggr_expr,
298 Arc::new(new_input),
299 ))
300 })
301 })?
302 .transform_data(
304 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305 match CSE::new(ExprCSEController::new(
307 config.alias_generator().as_ref(),
308 ExprMask::NormalAndAggregates,
309 ))
310 .extract_common_nodes(vec![new_aggr_expr])?
311 {
312 FoundCommonNodes::Yes {
313 common_nodes: common_exprs,
314 new_nodes_list: mut new_exprs_list,
315 original_nodes_list: mut original_exprs_list,
316 } => {
317 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318 let new_aggr_expr = original_exprs_list.pop().unwrap();
319
320 let mut agg_exprs = common_exprs
321 .into_iter()
322 .map(|(expr, expr_alias)| expr.alias(expr_alias))
323 .collect::<Vec<_>>();
324
325 let mut proj_exprs = vec![];
326 for expr in &new_group_expr {
327 extract_expressions(expr, &mut proj_exprs)
328 }
329 for (expr_rewritten, expr_orig) in
330 rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
331 {
332 if expr_rewritten == expr_orig {
333 if let Expr::Alias(Alias { expr, name, .. }) =
334 expr_rewritten
335 {
336 agg_exprs.push(expr.alias(&name));
337 proj_exprs
338 .push(Expr::Column(Column::from_name(name)));
339 } else {
340 let expr_alias =
341 config.alias_generator().next(CSE_PREFIX);
342 let (qualifier, field_name) =
343 expr_rewritten.qualified_name();
344 let out_name =
345 qualified_name(qualifier.as_ref(), &field_name);
346
347 agg_exprs.push(expr_rewritten.alias(&expr_alias));
348 proj_exprs.push(
349 Expr::Column(Column::from_name(expr_alias))
350 .alias(out_name),
351 );
352 }
353 } else {
354 proj_exprs.push(expr_rewritten);
355 }
356 }
357
358 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
359 new_input,
360 new_group_expr,
361 agg_exprs,
362 )?);
363 Projection::try_new(proj_exprs, Arc::new(agg))
364 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
365 }
366
367 FoundCommonNodes::No {
370 original_nodes_list: mut original_exprs_list,
371 } => {
372 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
373
374 if let Some(aggr_expr) = aggr_expr {
385 let name_preserver = NamePreserver::new_for_projection();
386 let saved_names = aggr_expr
387 .iter()
388 .map(|expr| name_preserver.save(expr))
389 .collect::<Vec<_>>();
390 let new_aggr_expr = rewritten_aggr_expr
391 .into_iter()
392 .zip(saved_names)
393 .map(|(new_expr, saved_name)| {
394 saved_name.restore(new_expr)
395 })
396 .collect::<Vec<Expr>>();
397
398 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
401 .map(LogicalPlan::Aggregate)
402 .map(Transformed::no)
403 } else {
404 Aggregate::try_new_with_schema(
405 new_input,
406 new_group_expr,
407 rewritten_aggr_expr,
408 schema,
409 )
410 .map(LogicalPlan::Aggregate)
411 .map(Transformed::no)
412 }
413 }
414 }
415 },
416 )
417 }
418
419 fn try_unary_plan(
434 &self,
435 exprs: Vec<Expr>,
436 input: LogicalPlan,
437 config: &dyn OptimizerConfig,
438 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
439 match CSE::new(ExprCSEController::new(
441 config.alias_generator().as_ref(),
442 ExprMask::Normal,
443 ))
444 .extract_common_nodes(vec![exprs])?
445 {
446 FoundCommonNodes::Yes {
447 common_nodes: common_exprs,
448 new_nodes_list: mut new_exprs_list,
449 original_nodes_list: _,
450 } => {
451 let new_exprs = new_exprs_list.pop().unwrap();
452 build_common_expr_project_plan(input, common_exprs)
453 .map(|new_input| Transformed::yes((new_exprs, new_input)))
454 }
455 FoundCommonNodes::No {
456 original_nodes_list: mut original_exprs_list,
457 } => {
458 let new_exprs = original_exprs_list.pop().unwrap();
459 Ok(Transformed::no((new_exprs, input)))
460 }
461 }?
462 .transform_data(|(new_exprs, new_input)| {
465 self.rewrite(new_input, config)?
466 .map_data(|new_input| Ok((new_exprs, new_input)))
467 })
468 }
469}
470
471fn get_consecutive_window_exprs(
503 window: Window,
504) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
505 let mut window_expr_list = vec![];
506 let mut window_schemas = vec![];
507 let mut plan = LogicalPlan::Window(window);
508 while let LogicalPlan::Window(Window {
509 input,
510 window_expr,
511 schema,
512 }) = plan
513 {
514 window_expr_list.push(window_expr);
515 window_schemas.push(schema);
516
517 plan = Arc::unwrap_or_clone(input);
518 }
519 (window_expr_list, window_schemas, plan)
520}
521
522impl OptimizerRule for CommonSubexprEliminate {
523 fn supports_rewrite(&self) -> bool {
524 true
525 }
526
527 fn apply_order(&self) -> Option<ApplyOrder> {
528 None
532 }
533
534 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
535 fn rewrite(
536 &self,
537 plan: LogicalPlan,
538 config: &dyn OptimizerConfig,
539 ) -> Result<Transformed<LogicalPlan>> {
540 let original_schema = Arc::clone(plan.schema());
541
542 let optimized_plan = match plan {
543 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
544 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
545 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
546 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
547 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
548 LogicalPlan::Join(_)
549 | LogicalPlan::Repartition(_)
550 | LogicalPlan::Union(_)
551 | LogicalPlan::TableScan(_)
552 | LogicalPlan::Values(_)
553 | LogicalPlan::EmptyRelation(_)
554 | LogicalPlan::Subquery(_)
555 | LogicalPlan::SubqueryAlias(_)
556 | LogicalPlan::Limit(_)
557 | LogicalPlan::Ddl(_)
558 | LogicalPlan::Explain(_)
559 | LogicalPlan::Analyze(_)
560 | LogicalPlan::Statement(_)
561 | LogicalPlan::DescribeTable(_)
562 | LogicalPlan::Distinct(_)
563 | LogicalPlan::Extension(_)
564 | LogicalPlan::Dml(_)
565 | LogicalPlan::Copy(_)
566 | LogicalPlan::Unnest(_)
567 | LogicalPlan::RecursiveQuery(_) => {
568 plan.map_children(|c| self.rewrite(c, config))?
571 }
572 };
573
574 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
576 {
577 optimized_plan.map_data(|optimized_plan| {
578 build_recover_project_plan(&original_schema, optimized_plan)
579 })
580 } else {
581 Ok(optimized_plan)
582 }
583 }
584
585 fn name(&self) -> &str {
586 "common_sub_expression_eliminate"
587 }
588}
589
590#[derive(Debug, Clone, Copy)]
592enum ExprMask {
593 Normal,
602
603 NormalAndAggregates,
605}
606
607struct ExprCSEController<'a> {
608 alias_generator: &'a AliasGenerator,
609 mask: ExprMask,
610
611 alias_counter: usize,
613}
614
615impl<'a> ExprCSEController<'a> {
616 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
617 Self {
618 alias_generator,
619 mask,
620 alias_counter: 0,
621 }
622 }
623}
624
625impl CSEController for ExprCSEController<'_> {
626 type Node = Expr;
627
628 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
629 match node {
630 Expr::ScalarFunction(ScalarFunction { func, args })
634 if func.short_circuits() =>
635 {
636 Some((vec![], args.iter().collect()))
637 }
638
639 Expr::BinaryExpr(BinaryExpr {
642 left,
643 op: Operator::And | Operator::Or,
644 right,
645 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
646
647 Expr::Case(Case {
651 expr,
652 when_then_expr,
653 else_expr,
654 }) => Some((
655 expr.iter()
656 .map(|e| e.as_ref())
657 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
658 .collect(),
659 when_then_expr
660 .iter()
661 .take(1)
662 .map(|(_, then)| then.as_ref())
663 .chain(
664 when_then_expr
665 .iter()
666 .skip(1)
667 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
668 )
669 .chain(else_expr.iter().map(|e| e.as_ref()))
670 .collect(),
671 )),
672 _ => None,
673 }
674 }
675
676 fn is_valid(node: &Expr) -> bool {
677 !node.is_volatile_node()
678 }
679
680 fn is_ignored(&self, node: &Expr) -> bool {
681 #[expect(deprecated)]
683 let is_normal_minus_aggregates = matches!(
684 node,
685 Expr::Literal(..)
686 | Expr::Column(..)
687 | Expr::ScalarVariable(..)
688 | Expr::Alias(..)
689 | Expr::Wildcard { .. }
690 );
691
692 let is_aggr = matches!(node, Expr::AggregateFunction(..));
693
694 match self.mask {
695 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
696 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
697 }
698 }
699
700 fn generate_alias(&self) -> String {
701 self.alias_generator.next(CSE_PREFIX)
702 }
703
704 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
705 if self.alias_counter > 0 {
707 col(alias)
708 } else {
709 self.alias_counter += 1;
710 col(alias).alias(node.schema_name().to_string())
711 }
712 }
713
714 fn rewrite_f_down(&mut self, node: &Expr) {
715 if matches!(node, Expr::Alias(_)) {
716 self.alias_counter += 1;
717 }
718 }
719 fn rewrite_f_up(&mut self, node: &Expr) {
720 if matches!(node, Expr::Alias(_)) {
721 self.alias_counter -= 1
722 }
723 }
724}
725
726impl Default for CommonSubexprEliminate {
727 fn default() -> Self {
728 Self::new()
729 }
730}
731
732fn build_common_expr_project_plan(
743 input: LogicalPlan,
744 common_exprs: Vec<(Expr, String)>,
745) -> Result<LogicalPlan> {
746 let mut fields_set = BTreeSet::new();
747 let mut project_exprs = common_exprs
748 .into_iter()
749 .map(|(expr, expr_alias)| {
750 fields_set.insert(expr_alias.clone());
751 Ok(expr.alias(expr_alias))
752 })
753 .collect::<Result<Vec<_>>>()?;
754
755 for (qualifier, field) in input.schema().iter() {
756 if fields_set.insert(qualified_name(qualifier, field.name())) {
757 project_exprs.push(Expr::from((qualifier, field)));
758 }
759 }
760
761 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
762}
763
764fn build_recover_project_plan(
770 schema: &DFSchema,
771 input: LogicalPlan,
772) -> Result<LogicalPlan> {
773 let col_exprs = schema.iter().map(Expr::from).collect();
774 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
775}
776
777fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
778 if let Expr::GroupingSet(groupings) = expr {
779 for e in groupings.distinct_expr() {
780 let (qualifier, field_name) = e.qualified_name();
781 let col = Column::new(qualifier, field_name);
782 result.push(Expr::Column(col))
783 }
784 } else {
785 let (qualifier, field_name) = expr.qualified_name();
786 let col = Column::new(qualifier, field_name);
787 result.push(Expr::Column(col));
788 }
789}
790
791#[cfg(test)]
792mod test {
793 use std::any::Any;
794 use std::iter;
795
796 use arrow::datatypes::{DataType, Field, Schema};
797 use datafusion_expr::logical_plan::{table_scan, JoinType};
798 use datafusion_expr::{
799 grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
800 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
801 SimpleAggregateUDF, Volatility,
802 };
803 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
804
805 use super::*;
806 use crate::optimizer::OptimizerContext;
807 use crate::test::*;
808 use crate::Optimizer;
809 use datafusion_expr::test::function_stub::{avg, sum};
810
811 fn assert_optimized_plan_eq(
812 expected: &str,
813 plan: LogicalPlan,
814 config: Option<&dyn OptimizerConfig>,
815 ) {
816 let optimizer =
817 Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]);
818 let default_config = OptimizerContext::new();
819 let config = config.unwrap_or(&default_config);
820 let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap();
821 let formatted_plan = format!("{optimized_plan}");
822 assert_eq!(expected, formatted_plan);
823 }
824
825 #[test]
826 fn tpch_q1_simplified() -> Result<()> {
827 let table_scan = test_table_scan()?;
836
837 let plan = LogicalPlanBuilder::from(table_scan)
838 .aggregate(
839 iter::empty::<Expr>(),
840 vec![
841 sum(col("a") * (lit(1) - col("b"))),
842 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
843 ],
844 )?
845 .build()?;
846
847 let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
848 \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\
849 \n TableScan: test";
850
851 assert_optimized_plan_eq(expected, plan, None);
852
853 Ok(())
854 }
855
856 #[test]
857 fn nested_aliases() -> Result<()> {
858 let table_scan = test_table_scan()?;
859
860 let plan = LogicalPlanBuilder::from(table_scan)
861 .project(vec![
862 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
863 col("a") + col("b"),
864 ])?
865 .build()?;
866
867 let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\
868 \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
869 \n TableScan: test";
870
871 assert_optimized_plan_eq(expected, plan, None);
872
873 Ok(())
874 }
875
876 #[test]
877 fn aggregate() -> Result<()> {
878 let table_scan = test_table_scan()?;
879
880 let return_type = DataType::UInt32;
881 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
882 let udf_agg = |inner: Expr| {
883 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
884 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
885 "my_agg",
886 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
887 return_type.clone(),
888 Arc::clone(&accumulator),
889 vec![Field::new("value", DataType::UInt32, true)],
890 ))),
891 vec![inner],
892 false,
893 None,
894 None,
895 None,
896 ))
897 };
898
899 let plan = LogicalPlanBuilder::from(table_scan.clone())
901 .aggregate(
902 iter::empty::<Expr>(),
903 vec![
904 avg(col("a")).alias("col1"),
906 avg(col("a")).alias("col2"),
907 avg(col("b")).alias("col3"),
909 avg(col("c")),
910 udf_agg(col("a")).alias("col4"),
912 udf_agg(col("a")).alias("col5"),
913 udf_agg(col("b")).alias("col6"),
915 udf_agg(col("c")),
916 ],
917 )?
918 .build()?;
919
920 let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\
921 \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\
922 \n TableScan: test";
923
924 assert_optimized_plan_eq(expected, plan, None);
925
926 let plan = LogicalPlanBuilder::from(table_scan.clone())
928 .aggregate(
929 iter::empty::<Expr>(),
930 vec![
931 lit(1) + avg(col("a")),
932 lit(1) - avg(col("a")),
933 lit(1) + udf_agg(col("a")),
934 lit(1) - udf_agg(col("a")),
935 ],
936 )?
937 .build()?;
938
939 let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\
940 \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\
941 \n TableScan: test";
942
943 assert_optimized_plan_eq(expected, plan, None);
944
945 let plan = LogicalPlanBuilder::from(table_scan.clone())
947 .aggregate(
948 iter::empty::<Expr>(),
949 vec![
950 avg(lit(1u32) + col("a")).alias("col1"),
951 udf_agg(lit(1u32) + col("a")).alias("col2"),
952 ],
953 )?
954 .build()?;
955
956 let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
957 \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
958 \n TableScan: test";
959
960 assert_optimized_plan_eq(expected, plan, None);
961
962 let plan = LogicalPlanBuilder::from(table_scan.clone())
964 .aggregate(
965 vec![lit(1u32) + col("a")],
966 vec![
967 avg(lit(1u32) + col("a")).alias("col1"),
968 udf_agg(lit(1u32) + col("a")).alias("col2"),
969 ],
970 )?
971 .build()?;
972
973 let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
974 \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
975 \n TableScan: test";
976
977 assert_optimized_plan_eq(expected, plan, None);
978
979 let plan = LogicalPlanBuilder::from(table_scan)
981 .aggregate(
982 vec![lit(1u32) + col("a")],
983 vec![
984 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
985 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
986 avg(lit(1u32) + col("a")),
987 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
988 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
989 udf_agg(lit(1u32) + col("a")),
990 ],
991 )?
992 .build()?;
993
994 let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\
995 \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\
996 \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
997 \n TableScan: test";
998
999 assert_optimized_plan_eq(expected, plan, None);
1000
1001 Ok(())
1002 }
1003
1004 #[test]
1005 fn aggregate_with_relations_and_dots() -> Result<()> {
1006 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1007 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1008
1009 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1010
1011 let plan = LogicalPlanBuilder::from(table_scan)
1012 .aggregate(
1013 vec![col_a.clone()],
1014 vec![
1015 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1016 avg(lit(1u32) + col_a),
1017 ],
1018 )?
1019 .build()?;
1020
1021 let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\
1022 \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\
1023 \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\
1024 \n TableScan: table.test";
1025
1026 assert_optimized_plan_eq(expected, plan, None);
1027
1028 Ok(())
1029 }
1030
1031 #[test]
1032 fn subexpr_in_same_order() -> Result<()> {
1033 let table_scan = test_table_scan()?;
1034
1035 let plan = LogicalPlanBuilder::from(table_scan)
1036 .project(vec![
1037 (lit(1) + col("a")).alias("first"),
1038 (lit(1) + col("a")).alias("second"),
1039 ])?
1040 .build()?;
1041
1042 let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\
1043 \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1044 \n TableScan: test";
1045
1046 assert_optimized_plan_eq(expected, plan, None);
1047
1048 Ok(())
1049 }
1050
1051 #[test]
1052 fn subexpr_in_different_order() -> Result<()> {
1053 let table_scan = test_table_scan()?;
1054
1055 let plan = LogicalPlanBuilder::from(table_scan)
1056 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1057 .build()?;
1058
1059 let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\
1060 \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1061 \n TableScan: test";
1062
1063 assert_optimized_plan_eq(expected, plan, None);
1064
1065 Ok(())
1066 }
1067
1068 #[test]
1069 fn cross_plans_subexpr() -> Result<()> {
1070 let table_scan = test_table_scan()?;
1071
1072 let plan = LogicalPlanBuilder::from(table_scan)
1073 .project(vec![lit(1) + col("a"), col("a")])?
1074 .project(vec![lit(1) + col("a")])?
1075 .build()?;
1076
1077 let expected = "Projection: Int32(1) + test.a\
1078 \n Projection: Int32(1) + test.a, test.a\
1079 \n TableScan: test";
1080
1081 assert_optimized_plan_eq(expected, plan, None);
1082 Ok(())
1083 }
1084
1085 #[test]
1086 fn redundant_project_fields() {
1087 let table_scan = test_table_scan().unwrap();
1088 let c_plus_a = col("c") + col("a");
1089 let b_plus_a = col("b") + col("a");
1090 let common_exprs_1 = vec![
1091 (c_plus_a, format!("{CSE_PREFIX}_1")),
1092 (b_plus_a, format!("{CSE_PREFIX}_2")),
1093 ];
1094 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1095 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1096 let common_exprs_2 = vec![
1097 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1098 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1099 ];
1100 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1101 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1102
1103 let mut field_set = BTreeSet::new();
1104 for name in project_2.schema().field_names() {
1105 assert!(field_set.insert(name));
1106 }
1107 }
1108
1109 #[test]
1110 fn redundant_project_fields_join_input() {
1111 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1112 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1113 let join = LogicalPlanBuilder::from(table_scan_1)
1114 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1115 .unwrap()
1116 .build()
1117 .unwrap();
1118 let c_plus_a = col("test1.c") + col("test1.a");
1119 let b_plus_a = col("test1.b") + col("test1.a");
1120 let common_exprs_1 = vec![
1121 (c_plus_a, format!("{CSE_PREFIX}_1")),
1122 (b_plus_a, format!("{CSE_PREFIX}_2")),
1123 ];
1124 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1125 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1126 let common_exprs_2 = vec![
1127 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1128 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1129 ];
1130 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1131 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1132
1133 let mut field_set = BTreeSet::new();
1134 for name in project_2.schema().field_names() {
1135 assert!(field_set.insert(name));
1136 }
1137 }
1138
1139 #[test]
1140 fn eliminated_subexpr_datatype() {
1141 use datafusion_expr::cast;
1142
1143 let schema = Schema::new(vec![
1144 Field::new("a", DataType::UInt64, false),
1145 Field::new("b", DataType::UInt64, false),
1146 Field::new("c", DataType::UInt64, false),
1147 ]);
1148
1149 let plan = table_scan(Some("table"), &schema, None)
1150 .unwrap()
1151 .filter(
1152 cast(col("a"), DataType::Int64)
1153 .lt(lit(1_i64))
1154 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1155 )
1156 .unwrap()
1157 .build()
1158 .unwrap();
1159 let rule = CommonSubexprEliminate::new();
1160 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1161 assert!(optimized_plan.transformed);
1162 let optimized_plan = optimized_plan.data;
1163
1164 let schema = optimized_plan.schema();
1165 let fields_with_datatypes: Vec<_> = schema
1166 .fields()
1167 .iter()
1168 .map(|field| (field.name(), field.data_type()))
1169 .collect();
1170 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1171 let expected = r#"[
1172 (
1173 "a",
1174 UInt64,
1175 ),
1176 (
1177 "b",
1178 UInt64,
1179 ),
1180 (
1181 "c",
1182 UInt64,
1183 ),
1184]"#;
1185 assert_eq!(expected, formatted_fields_with_datatype);
1186 }
1187
1188 #[test]
1189 fn filter_schema_changed() -> Result<()> {
1190 let table_scan = test_table_scan()?;
1191
1192 let plan = LogicalPlanBuilder::from(table_scan)
1193 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1194 .build()?;
1195
1196 let expected = "Projection: test.a, test.b, test.c\
1197 \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\
1198 \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1199 \n TableScan: test";
1200
1201 assert_optimized_plan_eq(expected, plan, None);
1202
1203 Ok(())
1204 }
1205
1206 #[test]
1207 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1208 let mut result = Vec::with_capacity(3);
1209 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1210 extract_expressions(&grouping, &mut result);
1211
1212 assert!(result.len() == 3);
1213 Ok(())
1214 }
1215
1216 #[test]
1217 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1218 let mut result = Vec::with_capacity(2);
1219 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1220 extract_expressions(&grouping, &mut result);
1221 assert!(result.len() == 2);
1222 Ok(())
1223 }
1224
1225 #[test]
1226 fn test_alias_collision() -> Result<()> {
1227 let table_scan = test_table_scan()?;
1228
1229 let config = &OptimizerContext::new();
1230 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1231 let plan = LogicalPlanBuilder::from(table_scan.clone())
1232 .project(vec![
1233 (col("a") + col("b")).alias(common_expr_1.clone()),
1234 col("c"),
1235 ])?
1236 .project(vec![
1237 col(common_expr_1.clone()).alias("c1"),
1238 col(common_expr_1).alias("c2"),
1239 (col("c") + lit(2)).alias("c3"),
1240 (col("c") + lit(2)).alias("c4"),
1241 ])?
1242 .build()?;
1243
1244 let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\
1245 \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\
1246 \n Projection: test.a + test.b AS __common_expr_1, test.c\
1247 \n TableScan: test";
1248
1249 assert_optimized_plan_eq(expected, plan, Some(config));
1250
1251 let config = &OptimizerContext::new();
1252 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1253 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1254 let plan = LogicalPlanBuilder::from(table_scan)
1255 .project(vec![
1256 (col("a") + col("b")).alias(common_expr_2.clone()),
1257 col("c"),
1258 ])?
1259 .project(vec![
1260 col(common_expr_2.clone()).alias("c1"),
1261 col(common_expr_2).alias("c2"),
1262 (col("c") + lit(2)).alias("c3"),
1263 (col("c") + lit(2)).alias("c4"),
1264 ])?
1265 .build()?;
1266
1267 let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\
1268 \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\
1269 \n Projection: test.a + test.b AS __common_expr_2, test.c\
1270 \n TableScan: test";
1271
1272 assert_optimized_plan_eq(expected, plan, Some(config));
1273
1274 Ok(())
1275 }
1276
1277 #[test]
1278 fn test_extract_expressions_from_col() -> Result<()> {
1279 let mut result = Vec::with_capacity(1);
1280 extract_expressions(&col("a"), &mut result);
1281 assert!(result.len() == 1);
1282 Ok(())
1283 }
1284
1285 #[test]
1286 fn test_short_circuits() -> Result<()> {
1287 let table_scan = test_table_scan()?;
1288
1289 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1290 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1291 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1292 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1293 let plan = LogicalPlanBuilder::from(table_scan)
1294 .project(vec![
1295 extracted_short_circuit.clone().alias("c1"),
1296 extracted_short_circuit.alias("c2"),
1297 extracted_short_circuit_leg_1
1298 .clone()
1299 .or(not_extracted_short_circuit_leg_2.clone())
1300 .alias("c3"),
1301 extracted_short_circuit_leg_1
1302 .and(not_extracted_short_circuit_leg_2)
1303 .alias("c4"),
1304 extracted_short_circuit_leg_3
1305 .clone()
1306 .or(extracted_short_circuit_leg_3)
1307 .alias("c5"),
1308 ])?
1309 .build()?;
1310
1311 let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\
1312 \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\
1313 \n TableScan: test";
1314
1315 assert_optimized_plan_eq(expected, plan, None);
1316
1317 Ok(())
1318 }
1319
1320 #[test]
1321 fn test_volatile() -> Result<()> {
1322 let table_scan = test_table_scan()?;
1323
1324 let extracted_child = col("a") + col("b");
1325 let rand = rand_func().call(vec![]);
1326 let not_extracted_volatile = extracted_child + rand;
1327 let plan = LogicalPlanBuilder::from(table_scan)
1328 .project(vec![
1329 not_extracted_volatile.clone().alias("c1"),
1330 not_extracted_volatile.alias("c2"),
1331 ])?
1332 .build()?;
1333
1334 let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\
1335 \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1336 \n TableScan: test";
1337
1338 assert_optimized_plan_eq(expected, plan, None);
1339
1340 Ok(())
1341 }
1342
1343 #[test]
1344 fn test_volatile_short_circuits() -> Result<()> {
1345 let table_scan = test_table_scan()?;
1346
1347 let rand = rand_func().call(vec![]);
1348 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1349 let not_extracted_volatile_short_circuit_1 =
1350 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1351 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1352 let not_extracted_volatile_short_circuit_2 =
1353 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1354 let plan = LogicalPlanBuilder::from(table_scan)
1355 .project(vec![
1356 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1357 not_extracted_volatile_short_circuit_1.alias("c2"),
1358 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1359 not_extracted_volatile_short_circuit_2.alias("c4"),
1360 ])?
1361 .build()?;
1362
1363 let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\
1364 \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
1365 \n TableScan: test";
1366
1367 assert_optimized_plan_eq(expected, plan, None);
1368
1369 Ok(())
1370 }
1371
1372 #[test]
1373 fn test_non_top_level_common_expression() -> Result<()> {
1374 let table_scan = test_table_scan()?;
1375
1376 let common_expr = col("a") + col("b");
1377 let plan = LogicalPlanBuilder::from(table_scan)
1378 .project(vec![
1379 common_expr.clone().alias("c1"),
1380 common_expr.alias("c2"),
1381 ])?
1382 .project(vec![col("c1"), col("c2")])?
1383 .build()?;
1384
1385 let expected = "Projection: c1, c2\
1386 \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\
1387 \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1388 \n TableScan: test";
1389
1390 assert_optimized_plan_eq(expected, plan, None);
1391
1392 Ok(())
1393 }
1394
1395 #[test]
1396 fn test_nested_common_expression() -> Result<()> {
1397 let table_scan = test_table_scan()?;
1398
1399 let nested_common_expr = col("a") + col("b");
1400 let common_expr = nested_common_expr.clone() * nested_common_expr;
1401 let plan = LogicalPlanBuilder::from(table_scan)
1402 .project(vec![
1403 common_expr.clone().alias("c1"),
1404 common_expr.alias("c2"),
1405 ])?
1406 .build()?;
1407
1408 let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\
1409 \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\
1410 \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\
1411 \n TableScan: test";
1412
1413 assert_optimized_plan_eq(expected, plan, None);
1414
1415 Ok(())
1416 }
1417
1418 #[test]
1419 fn test_normalize_add_expression() -> Result<()> {
1420 let table_scan = test_table_scan()?;
1422 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1423 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1424
1425 let expected = "Projection: test.a, test.b, test.c\
1426 \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\
1427 \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1428 \n TableScan: test";
1429 assert_optimized_plan_eq(expected, plan, None);
1430
1431 Ok(())
1432 }
1433
1434 #[test]
1435 fn test_normalize_multi_expression() -> Result<()> {
1436 let table_scan = test_table_scan()?;
1438 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1439 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1440
1441 let expected = "Projection: test.a, test.b, test.c\
1442 \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1443 \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\
1444 \n TableScan: test";
1445 assert_optimized_plan_eq(expected, plan, None);
1446
1447 Ok(())
1448 }
1449
1450 #[test]
1451 fn test_normalize_bitset_and_expression() -> Result<()> {
1452 let table_scan = test_table_scan()?;
1454 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1455 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1456
1457 let expected = "Projection: test.a, test.b, test.c\
1458 \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1459 \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\
1460 \n TableScan: test";
1461 assert_optimized_plan_eq(expected, plan, None);
1462
1463 Ok(())
1464 }
1465
1466 #[test]
1467 fn test_normalize_bitset_or_expression() -> Result<()> {
1468 let table_scan = test_table_scan()?;
1470 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1471 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1472
1473 let expected = "Projection: test.a, test.b, test.c\
1474 \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1475 \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\
1476 \n TableScan: test";
1477 assert_optimized_plan_eq(expected, plan, None);
1478
1479 Ok(())
1480 }
1481
1482 #[test]
1483 fn test_normalize_bitset_xor_expression() -> Result<()> {
1484 let table_scan = test_table_scan()?;
1486 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1487 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1488
1489 let expected = "Projection: test.a, test.b, test.c\
1490 \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1491 \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\
1492 \n TableScan: test";
1493 assert_optimized_plan_eq(expected, plan, None);
1494
1495 Ok(())
1496 }
1497
1498 #[test]
1499 fn test_normalize_eq_expression() -> Result<()> {
1500 let table_scan = test_table_scan()?;
1502 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1503 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1504
1505 let expected = "Projection: test.a, test.b, test.c\
1506 \n Filter: __common_expr_1 AND __common_expr_1\
1507 \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\
1508 \n TableScan: test";
1509 assert_optimized_plan_eq(expected, plan, None);
1510
1511 Ok(())
1512 }
1513
1514 #[test]
1515 fn test_normalize_ne_expression() -> Result<()> {
1516 let table_scan = test_table_scan()?;
1518 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1519 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1520
1521 let expected = "Projection: test.a, test.b, test.c\
1522 \n Filter: __common_expr_1 AND __common_expr_1\
1523 \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\
1524 \n TableScan: test";
1525 assert_optimized_plan_eq(expected, plan, None);
1526
1527 Ok(())
1528 }
1529
1530 #[test]
1531 fn test_normalize_complex_expression() -> Result<()> {
1532 let table_scan = test_table_scan()?;
1534 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1535 .eq(lit(30));
1536 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1537
1538 let expected = "Projection: test.a, test.b, test.c\
1539 \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\
1540 \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\
1541 \n TableScan: test";
1542 assert_optimized_plan_eq(expected, plan, None);
1543
1544 let table_scan = test_table_scan()?;
1546 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1547 / (col("c") * (col("b") / col("c") + col("a")))
1548 + col("a"))
1549 .eq(lit(30));
1550 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1551 let expected = "Projection: test.a, test.b, test.c\
1552 \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\
1553 \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\
1554 \n TableScan: test";
1555 assert_optimized_plan_eq(expected, plan, None);
1556
1557 let table_scan = test_table_scan()?;
1559 let expr = ((col("b") / (col("a") + col("c")))
1560 * (col("b") / (col("c") + col("a"))))
1561 .eq(lit(30));
1562 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1563 let expected = "Projection: test.a, test.b, test.c\
1564 \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\
1565 \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\
1566 \n TableScan: test";
1567 assert_optimized_plan_eq(expected, plan, None);
1568
1569 Ok(())
1570 }
1571
1572 #[derive(Debug)]
1573 pub struct TestUdf {
1574 signature: Signature,
1575 }
1576
1577 impl TestUdf {
1578 pub fn new() -> Self {
1579 Self {
1580 signature: Signature::numeric(1, Volatility::Immutable),
1581 }
1582 }
1583 }
1584
1585 impl ScalarUDFImpl for TestUdf {
1586 fn as_any(&self) -> &dyn Any {
1587 self
1588 }
1589 fn name(&self) -> &str {
1590 "my_udf"
1591 }
1592
1593 fn signature(&self) -> &Signature {
1594 &self.signature
1595 }
1596
1597 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1598 Ok(DataType::Int32)
1599 }
1600
1601 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1602 panic!("not implemented")
1603 }
1604 }
1605
1606 #[test]
1607 fn test_normalize_inner_binary_expression() -> Result<()> {
1608 let table_scan = test_table_scan()?;
1610 let expr1 = not(col("a").eq(col("b")));
1611 let expr2 = not(col("b").eq(col("a")));
1612 let plan = LogicalPlanBuilder::from(table_scan)
1613 .project(vec![expr1, expr2])?
1614 .build()?;
1615 let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\
1616 \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\
1617 \n TableScan: test";
1618 assert_optimized_plan_eq(expected, plan, None);
1619
1620 let table_scan = test_table_scan()?;
1622 let expr1 = is_null(col("a").eq(col("b")));
1623 let expr2 = is_null(col("b").eq(col("a")));
1624 let plan = LogicalPlanBuilder::from(table_scan)
1625 .project(vec![expr1, expr2])?
1626 .build()?;
1627 let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\
1628 \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\
1629 \n TableScan: test";
1630 assert_optimized_plan_eq(expected, plan, None);
1631
1632 let table_scan = test_table_scan()?;
1634 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1635 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1636 let plan = LogicalPlanBuilder::from(table_scan)
1637 .project(vec![expr1, expr2])?
1638 .build()?;
1639 let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\
1640 \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\
1641 \n TableScan: test";
1642 assert_optimized_plan_eq(expected, plan, None);
1643
1644 let table_scan = test_table_scan()?;
1646 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1647 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1648 let plan = LogicalPlanBuilder::from(table_scan)
1649 .project(vec![expr1, expr2])?
1650 .build()?;
1651 let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\
1652 \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\
1653 \n TableScan: test";
1654 assert_optimized_plan_eq(expected, plan, None);
1655
1656 let udf = ScalarUDF::from(TestUdf::new());
1658 let table_scan = test_table_scan()?;
1659 let expr1 = udf.call(vec![col("a") + col("b")]);
1660 let expr2 = udf.call(vec![col("b") + col("a")]);
1661 let plan = LogicalPlanBuilder::from(table_scan)
1662 .project(vec![expr1, expr2])?
1663 .build()?;
1664 let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\
1665 \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\
1666 \n TableScan: test";
1667 assert_optimized_plan_eq(expected, plan, None);
1668 Ok(())
1669 }
1670
1671 fn rand_func() -> ScalarUDF {
1677 ScalarUDF::new_from_impl(RandomStub::new())
1678 }
1679
1680 #[derive(Debug)]
1681 struct RandomStub {
1682 signature: Signature,
1683 }
1684
1685 impl RandomStub {
1686 fn new() -> Self {
1687 Self {
1688 signature: Signature::exact(vec![], Volatility::Volatile),
1689 }
1690 }
1691 }
1692 impl ScalarUDFImpl for RandomStub {
1693 fn as_any(&self) -> &dyn Any {
1694 self
1695 }
1696
1697 fn name(&self) -> &str {
1698 "random"
1699 }
1700
1701 fn signature(&self) -> &Signature {
1702 &self.signature
1703 }
1704
1705 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1706 Ok(DataType::Float64)
1707 }
1708
1709 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1710 panic!("dummy - not implemented")
1711 }
1712 }
1713}